1
0
Fork 0

Merging upstream version 10.4.2.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:01:55 +01:00
parent de4e42d4d3
commit 0c79f8b507
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
88 changed files with 1637 additions and 436 deletions

View file

@ -20,7 +20,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install -r dev-requirements.txt make install-dev
- name: Run checks (linter, code style, tests) - name: Run checks (linter, code style, tests)
run: | run: |
./run_checks.sh make check

5
.gitignore vendored
View file

@ -130,3 +130,8 @@ dmypy.json
# PyCharm # PyCharm
.idea/ .idea/
# Visual Studio Code
.vscode
.DS_STORE

31
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,31 @@
repos:
- repo: local
hooks:
- id: autoflake
name: autoflake
entry: autoflake -i -r
language: system
types: [ python ]
require_serial: true
files: ^(sqlglot/|tests/|setup.py)
- id: isort
name: isort
entry: isort
language: system
types: [ python ]
files: ^(sqlglot/|tests/|setup.py)
require_serial: true
- id: black
name: black
entry: black --line-length 100
language: system
types: [ python ]
require_serial: true
files: ^(sqlglot/|tests/|setup.py)
- id: mypy
name: mypy
entry: mypy
language: system
types: [ python ]
files: ^(sqlglot/|tests/)
require_serial: true

View file

@ -1,3 +0,0 @@
{
"python.linting.pylintEnabled": true
}

View file

@ -1,6 +1,37 @@
Changelog Changelog
========= =========
v10.4.0
------
Changes:
- Breaking: Removed the quote_identities optimizer rule.
- New: ARRAYAGG, SUM, ARRAYANY support in the engine. SQLGlot is now able to execute all TPC-H queries.
- Improvement: Transpile DATEDIFF to postgres.
- Improvement: Right join pushdown fixes.
- Improvement: Have Snowflake generate VALUES columns without quotes.
- Improvement: Support NaN values in convert.
- Improvement: Recursive CTE scope [fixes](https://github.com/tobymao/sqlglot/commit/bec36391d85152fa478222403d06beffa8d6ddfb).
v10.3.0
------
Changes:
- Breaking: Json ops changed to binary expressions.
- New: Jinja tokenization.
- Improvement: More robust type inference.
v10.2.0 v10.2.0
------ ------

24
Makefile Normal file
View file

@ -0,0 +1,24 @@
.PHONY: install install-dev install-pre-commit test style check docs docs-serve
install:
pip install -e .
install-dev:
pip install -e ".[dev]"
install-pre-commit:
pre-commit install
test:
python -m unittest
style:
pre-commit run --all-files
check: style test
docs:
pdoc/cli.py -o pdoc/docs
docs-serve:
pdoc/cli.py

View file

@ -1,8 +1,8 @@
# SQLGlot # SQLGlot
SQLGlot is a no dependency Python SQL parser, transpiler, and optimizer. It can be used to format SQL or translate between different dialects like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects. SQLGlot is a no dependency Python SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between different dialects like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects.
It is a very comprehensive generic SQL parser with a robust [test suite](tests). It is also quite [performant](#benchmarks) while being written purely in Python. It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks) while being written purely in Python.
You can easily [customize](#custom-dialects) the parser, [analyze](#metadata) queries, traverse expression trees, and programmatically [build](#build-and-modify-sql) SQL. You can easily [customize](#custom-dialects) the parser, [analyze](#metadata) queries, traverse expression trees, and programmatically [build](#build-and-modify-sql) SQL.
@ -13,8 +13,7 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/
## Table of Contents ## Table of Contents
* [Install](#install) * [Install](#install)
* [Documentation](#documentation) * [Get in Touch](#get-in-touch)
* [Run Tests and Lint](#run-tests-and-lint)
* [Examples](#examples) * [Examples](#examples)
* [Formatting and Transpiling](#formatting-and-transpiling) * [Formatting and Transpiling](#formatting-and-transpiling)
* [Metadata](#metadata) * [Metadata](#metadata)
@ -26,6 +25,8 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/
* [AST Diff](#ast-diff) * [AST Diff](#ast-diff)
* [Custom Dialects](#custom-dialects) * [Custom Dialects](#custom-dialects)
* [SQL Execution](#sql-execution) * [SQL Execution](#sql-execution)
* [Documentation](#documentation)
* [Run Tests and Lint](#run-tests-and-lint)
* [Benchmarks](#benchmarks) * [Benchmarks](#benchmarks)
* [Optional Dependencies](#optional-dependencies) * [Optional Dependencies](#optional-dependencies)
@ -40,30 +41,17 @@ pip3 install sqlglot
Or with a local checkout: Or with a local checkout:
``` ```
pip3 install -e . make install
``` ```
Requirements for development (optional): Requirements for development (optional):
``` ```
pip3 install -r dev-requirements.txt make install-dev
```
## Documentation
SQLGlot uses [pdocs](https://pdoc.dev/) to serve its API documentation:
```
pdoc sqlglot --docformat google
```
## Run Tests and Lint
```
# set `SKIP_INTEGRATION=1` to skip integration tests
./run_checks.sh
``` ```
## Get in Touch
We'd love to hear from you. Join our community [Slack channel](https://join.slack.com/t/tobiko-data/shared_invite/zt-1ma66d79v-a4dbf4DUpLAQJ8ptQrJygg)!
## Examples ## Examples
@ -163,16 +151,16 @@ from sqlglot import parse_one, exp
# print all column references (a and b) # print all column references (a and b)
for column in parse_one("SELECT a, b + 1 AS c FROM d").find_all(exp.Column): for column in parse_one("SELECT a, b + 1 AS c FROM d").find_all(exp.Column):
print(column.alias_or_name) print(column.alias_or_name)
# find all projections in select statements (a and c) # find all projections in select statements (a and c)
for select in parse_one("SELECT a, b + 1 AS c FROM d").find_all(exp.Select): for select in parse_one("SELECT a, b + 1 AS c FROM d").find_all(exp.Select):
for projection in select.expressions: for projection in select.expressions:
print(projection.alias_or_name) print(projection.alias_or_name)
# find all tables (x, y, z) # find all tables (x, y, z)
for table in parse_one("SELECT * FROM x JOIN y JOIN z").find_all(exp.Table): for table in parse_one("SELECT * FROM x JOIN y JOIN z").find_all(exp.Table):
print(table.name) print(table.name)
``` ```
### Parser Errors ### Parser Errors
@ -274,7 +262,7 @@ transformed_tree.sql()
### SQL Optimizer ### SQL Optimizer
SQLGlot can rewrite queries into an "optimized" form. It performs a variety of [techniques](sqlglot/optimizer/optimizer.py) to create a new canonical AST. This AST can be used to standardize queries or provide the foundations for implementing an actual engine. For example: SQLGlot can rewrite queries into an "optimized" form. It performs a variety of [techniques](https://github.com/tobymao/sqlglot/blob/main/sqlglot/optimizer/optimizer.py) to create a new canonical AST. This AST can be used to standardize queries or provide the foundations for implementing an actual engine. For example:
```python ```python
import sqlglot import sqlglot
@ -292,7 +280,7 @@ print(
) )
``` ```
``` ```sql
SELECT SELECT
( (
"x"."A" OR "x"."B" OR "x"."C" "x"."A" OR "x"."B" OR "x"."C"
@ -351,9 +339,11 @@ diff(parse_one("SELECT a + b, c, d"), parse_one("SELECT c, a - b, d"))
] ]
``` ```
See also: [Semantic Diff for SQL](https://github.com/tobymao/sqlglot/blob/main/posts/sql_diff.md).
### Custom Dialects ### Custom Dialects
[Dialects](sqlglot/dialects) can be added by subclassing `Dialect`: [Dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) can be added by subclassing `Dialect`:
```python ```python
from sqlglot import exp from sqlglot import exp
@ -391,7 +381,7 @@ class Custom(Dialect):
print(Dialect["custom"]) print(Dialect["custom"])
``` ```
```python ```
<class '__main__.Custom'> <class '__main__.Custom'>
``` ```
@ -442,9 +432,23 @@ user_id price
2 3.0 2 3.0
``` ```
## Documentation
SQLGlot uses [pdocs](https://pdoc.dev/) to serve its API documentation:
```
make docs-serve
```
## Run Tests and Lint
```
make check # Set SKIP_INTEGRATION=1 to skip integration tests
```
## Benchmarks ## Benchmarks
[Benchmarks](benchmarks) run on Python 3.10.5 in seconds. [Benchmarks](https://github.com/tobymao/sqlglot/blob/main/benchmarks/bench.py) run on Python 3.10.5 in seconds.
| Query | sqlglot | sqlfluff | sqltree | sqlparse | moz_sql_parser | sqloxide | | Query | sqlglot | sqlfluff | sqltree | sqlparse | moz_sql_parser | sqloxide |
| --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- |

View file

@ -1,9 +0,0 @@
autoflake
black
duckdb
isort
mypy
pandas
pyspark
python-dateutil
pdoc

34
pdoc/cli.py Executable file
View file

@ -0,0 +1,34 @@
#!/usr/bin/env python3
from importlib import import_module
from pathlib import Path
from unittest import mock
from pdoc.__main__ import cli, parser
# Need this import or else import_module doesn't work
import sqlglot
def mocked_import(*args, **kwargs):
"""Return a MagicMock if import fails for any reason"""
try:
return import_module(*args, **kwargs)
except Exception:
mocked_module = mock.MagicMock()
mocked_module.__name__ = args[0]
return mocked_module
if __name__ == "__main__":
# Mock uninstalled dependencies so pdoc can still work
with mock.patch("importlib.import_module", side_effect=mocked_import):
opts = parser.parse_args()
opts.docformat = "google"
opts.modules = ["sqlglot"]
opts.footer_text = "Copyright (c) 2022 Toby Mao"
opts.template_directory = Path(__file__).parent.joinpath("templates").absolute()
opts.edit_url = ["sqlglot=https://github.com/tobymao/sqlglot/"]
with mock.patch("pdoc.__main__.parser", **{"parse_args.return_value": opts}):
cli()

41
pdoc/docs/expressions.md Normal file
View file

@ -0,0 +1,41 @@
# Expressions
Every AST node in SQLGlot is represented by a subclass of `Expression`. Each such expression encapsulates any necessary context, such as its child expressions, their names, or arg keys, and whether each child expression is optional or not.
Furthermore, the following attributes are common across all expressions:
#### key
A unique key for each class in the `Expression` hierarchy. This is useful for hashing and representing expressions as strings.
#### args
A dictionary used for mapping child arg keys, to the corresponding expressions. A value in this mapping is usually either a single or a list of `Expression` instances, but SQLGlot doesn't impose any constraints on the actual type of the value.
#### arg_types
A dictionary used for mapping arg keys to booleans that determine whether the corresponding expressions are optional or not. Consider the following example:
```python
class Limit(Expression):
arg_types = {"this": False, "expression": True}
```
Here, `Limit` declares that it expects to have one optional and one required child expression, which can be referenced through `this` and `expression`, respectively. The arg keys are generally arbitrary, but there are helper methods for keys like `this`, `expression` and `expressions` that abstract away dictionary lookups and related checks. For this reason, these keys are common throughout SQLGlot's codebase.
#### parent
A reference to the parent expression (may be `None`).
#### arg_key
The arg key an expression is associated with, i.e. the name its parent expression uses to refer to it.
#### comments
A list of comments that are associated with a given expression. This is used in order to preserve comments when transpiling SQL code.
#### type
The data type of an expression, as inferred by SQLGlot's optimizer.

View file

@ -0,0 +1,6 @@
{% extends "default/module.html.jinja2" %}
{% if module.docstring %}
{% macro module_name() %}
{% endmacro %}
{% endif %}

208
posts/python_sql_engine.md Normal file
View file

@ -0,0 +1,208 @@
# Writing a Python SQL engine from scratch
[Toby Mao](https://www.linkedin.com/in/toby-mao/)
## Introduction
When I first started writing SQLGlot in early 2021, my goal was just to translate SQL queries from SparkSQL to Presto and vice versa. However, over the last year and a half, I've ended up with a full-fledged SQL engine. SQLGlot can now parse and transpile between [18 SQL dialects](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) and can execute all 24 [TPC-H](https://www.tpc.org/tpch/) SQL queries. The parser and engine are all written from scratch using Python.
This post will cover [why](#why) I went through the effort of creating a Python SQL engine and [how](#how) a simple query goes from a string to actually transforming data. The following steps are briefly summarized:
* [Tokenizing](#tokenizing)
* [Parsing](#parsing)
* [Optimizing](#optimizing)
* [Planning](#planning)
* [Executing](#executing)
## Why?
I started working on SQLGlot because of my work on the [experimentation and metrics platform](https://netflixtechblog.com/reimagining-experimentation-analysis-at-netflix-71356393af21) at Netflix, where I built tools that allowed data scientists to define and compute SQL-based metrics. Netflix relied on multiple engines to query data (Spark, Presto, and Druid), so my team built the metrics platform around [PyPika](https://github.com/kayak/pypika), a Python SQL query builder. This way, definitions could be reused across multiple engines. However, it became quickly apparent that writing python code to programatically generate SQL was challenging for data scientists, especially those with academic backgrounds, since they were mostly familiar with R and SQL. At the time, the only Python SQL parser was [sqlparse]([https://github.com/andialbrecht/sqlparse), which is not actually a parser but a tokenizer, so having users write raw SQL into the platform wasn't really an option. Some time later, I randomly stumbled across [Crafting Interpreters](https://craftinginterpreters.com/) and realized that I could use it as a guide towards creating my own SQL parser/transpiler.
Why did I do this? Isn't a Python SQL engine going to be extremely slow?
The main reason why I ended up building a SQL engine was...just for **entertainment**. It's been fun learning about all the things required to actually run a SQL query, and seeing it actually work is extremely rewarding. Before SQLGlot, I had zero experience with lexers, parsers, or compilers.
In terms of practical use cases, I planned to use the Python SQL engine for unit testing SQL pipelines. Big data pipelines are tough to test because many of the engines are not open source and cannot be run locally. With SQLGlot, you can take a SQL query targeting a warehouse such as [Snowflake](https://www.snowflake.com/en/) and seamlessly run it in CI on mock Python data. It's easy to mock data and create arbitrary [UDFs](https://en.wikipedia.org/wiki/User-defined_function) because everything is just Python. Although the implementation is slow and unsuitable for large amounts of data (> 1 millon rows), there's very little overhead/startup and you can run queries on test data in a couple of milliseconds.
Finally, the components that have been built to support execution can be used as a **foundation** for a faster engine. I'm inspired by what [Apache Calcite](https://github.com/apache/calcite) has done for the JVM world. Even though Python is commonly used for data, there hasn't been a Calcite for Python. So, you could say that SQLGlot aims to be that framework. For example, it wouldn't take much work to replace the Python execution engine with numpy/pandas/arrow to become a respectably-performing query engine. The implementation would be able to leverage the parser, optimizer, and logical planner, only needing to implement physical execution. There is a lot of work in the Python ecosystem around high performance vectorized computation, which I think could benefit from a pure Python-based [AST](https://en.wikipedia.org/wiki/Abstract_syntax_tree)/[plan](https://en.wikipedia.org/wiki/Query_plan). Parsing and planning doesn't have to be fast when the bottleneck of running queries is processing terabytes of data. So, having a Python-based ecosystem around SQL is beneficial given the ease of development in Python, despite not having bare metal performance.
Parts of SQLGlot's toolkit are being used today by the following:
* [Ibis](https://github.com/ibis-project/ibis): A Python library that provides a lightweight, universal interface for data wrangling.
- Uses the Python SQL expression builder and leverages the optimizer/planner to convert SQL into dataframe operations.
* [mysql-mimic](https://github.com/kelsin/mysql-mimic): Pure-Python implementation of the MySQL server wire protocol
- Parses / transforms SQL and executes INFORMATION_SCHEMA queries.
* [Quokka](https://github.com/marsupialtail/quokka): Push-based vectorized query engine
- Parse and optimizes SQL.
* [Splink](https://github.com/moj-analytical-services/splink): Fast, accurate and scalable probabilistic data linkage using your choice of SQL backend.
- Transpiles queries.
## How?
There are many steps involved with actually running a simple query like:
```sql
SELECT
bar.a,
b + 1 AS b
FROM bar
JOIN baz
ON bar.a = baz.a
WHERE bar.a > 1
```
In this post, I'll walk through all the steps SQLGlot takes to run this query over Python objects.
## Tokenizing
The first step is to convert the sql string into a list of tokens. SQLGlot's tokenizer is quite simple and can be found [here](https://github.com/tobymao/sqlglot/blob/main/sqlglot/tokens.py). In a while loop, it checks each character and either appends the character to the current token, or makes a new token.
Running the SQLGlot tokenizer shows the output.
![Tokenizer Output](python_sql_engine_images/tokenizer.png)
Each keyword has been converted to a SQLGlot Token object. Each token has some metadata associated with it, like line/column information for error messages. Comments are also a part of the token, so that comments can be preserved.
## Parsing
Once a SQL statement is tokenized, we don't need to worry about white space and other formatting, so it's easier to work with. We can now convert the list of tokens into an AST. The SQLGlot [parser](https://github.com/tobymao/sqlglot/blob/main/sqlglot/parser.py) is a handwritten [recursive descent](https://en.wikipedia.org/wiki/Recursive_descent_parser) parser.
Similar to the tokenizer, it consumes the tokens sequentially, but it instead uses a recursive algorithm. The tokens are converted into a single AST node that presents the SQL query. The SQLGlot parser was designed to support various dialects, so it contains many options for overriding parsing functionality.
![Parser Output](python_sql_engine_images/parser.png)
The AST is a generic representation of a given SQL query. Each dialect can override or implement its own generator, which can convert an AST object into syntatically-correct SQL.
## Optimizing
Once we have our AST, we can transform it into an equivalent query that produces the same results more efficiently. When optimizing queries, most engines first convert the AST into a logical plan and then optimize the plan. However, I chose to **optimize the AST directly** for the following reasons:
1. It's easier to debug and [validate](https://github.com/tobymao/sqlglot/blob/main/tests/fixtures/optimizer) the optimizations when the input and output are both SQL.
2. Rules can be applied a la carte to transform SQL into a more desireable form.
3. I wanted a way to generate 'canonical sql'. Having a canonical representation of SQL is useful for understanding if two queries are semantically equivalent (e.g. `SELECT 1 + 1` and `SELECT 2`).
I've yet to find another engine that takes this approach, but I'm quite happy with this decision. The optimizer currently does not perform any "physical optimizations" such as join reordering. Those are left to the execution layer, as additional statistics and information could become relevant.
![Optimizer Output](python_sql_engine_images/optimizer.png)
The optimizer currently has [17 rules](https://github.com/tobymao/sqlglot/tree/main/sqlglot/optimizer). Each of these rules is applied, transforming the AST in place. The combination of these rules creates "canonical" sql that can then be more easily converted into a logical plan and executed.
Some example rules are:
### qualify\_tables and qualify_columns
- Adds all db/catalog qualifiers to tables and forces an alias.
- Ensure each column is unambiguous and expand stars.
```sql
SELECT * FROM x;
SELECT "db"."x" AS "x";
```
### simplify
Boolean and math simplification. Check out all the [test cases](https://github.com/tobymao/sqlglot/blob/main/tests/fixtures/optimizer/simplify.sql).
```sql
((NOT FALSE) AND (x = x)) AND (TRUE OR 1 <> 3);
x = x;
1 + 1;
2;
```
### normalize
Attempts to convert all predicates into [conjunctive normal form](https://en.wikipedia.org/wiki/Conjunctive_normal_form).
```sql
-- DNF
(A AND B) OR (B AND C AND D);
-- CNF
(A OR C) AND (A OR D) AND B;
```
### unnest\_subqueries
Converts subqueries in predicates into joins.
```sql
-- The subquery can be converted into a left join
SELECT *
FROM x AS x
WHERE (
SELECT y.a AS a
FROM y AS y
WHERE x.a = y.a
) = 1;
SELECT *
FROM x AS x
LEFT JOIN (
SELECT y.a AS a
FROM y AS y
WHERE TRUE
GROUP BY y.a
) AS "_u_0"
ON x.a = "_u_0".a
WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)
```
### pushdown_predicates
Push down filters into the innermost query.
```sql
SELECT *
FROM (
SELECT *
FROM x AS x
) AS y
WHERE y.a = 1;
SELECT *
FROM (
SELECT *
FROM x AS x
WHERE y.a = 1
) AS y WHERE TRUE
```
### annotate_types
Infer all types throughout the AST given schema information and function type definitions.
## Planning
After the SQL AST has been "optimized", it's much easier to [convert into a logical plan](https://github.com/tobymao/sqlglot/blob/main/sqlglot/planner.py). The AST is traversed and converted into a [DAG](https://en.wikipedia.org/wiki/Directed_acyclic_graph) consisting of one of five steps. The different steps are:
### Scan
Selects columns from a table, applies projections, and finally filters the table.
### Sort
Sorts a table for order by expressions.
### Set
Applies the operators union/union all/except/intersect.
### Aggregate
Applies an aggregation/group by.
### Join
Joins multiple tables together.
![Planner Output](python_sql_engine_images/planner.png)
The logical plan is quite simple and contains the information required to convert it into a physical plan (execution).
## Executing
Finally, we can actually execute the SQL query. The [Python engine](https://github.com/tobymao/sqlglot/blob/main/sqlglot/executor/python.py) is not fast, but it's very small (~400 LOC)! It iterates the DAG with a queue and runs each step, passing each intermediary table to the next step.
In order to keep things simple, it evaluates expressions with `eval`. Because SQLGlot was built primarily to be a transpiler, it was simple to create a "Python SQL" dialect. So a SQL expression `x + 1` can just be converted into `scope['x'] + 1`.
![Executor Output](python_sql_engine_images/executor.png)
## What's next
SQLGlot's main focus will always be on parsing/transpiling, but I plan to continue development on the execution engine. I'd like to pass [TPC-DS](https://www.tpc.org/tpcds/). If someone doesn't beat me to it, I may even take a stab at writing a Pandas/Arrow execution engine.
I'm hoping that over time, SQLGlot will spark the Python SQL ecosystem just like Calcite has for Java.
## Special thanks
SQLGlot would not be what it is without it's core contributors. In particular, the execution engine would not exist without [Barak Alon](https://github.com/barakalon) and [George Sittas](https://github.com/GeorgeSittas).
## Get in touch
If you'd like to chat more about SQLGlot, please join my [Slack Channel](https://join.slack.com/t/tobiko-data/shared_invite/zt-1ma66d79v-a4dbf4DUpLAQJ8ptQrJygg)!

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 154 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 121 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 333 KiB

View file

@ -1,8 +0,0 @@
#!/bin/bash -e
[[ -z "${GITHUB_ACTIONS}" ]] && RETURN_ERROR_CODE='' || RETURN_ERROR_CODE='--check'
TARGETS="sqlglot/ tests/"
python -m mypy $TARGETS
python -m autoflake -i -r ${RETURN_ERROR_CODE} $TARGETS
python -m isort $TARGETS
python -m black --line-length 100 ${RETURN_ERROR_CODE} $TARGETS
python -m unittest

View file

@ -22,6 +22,20 @@ setup(
license="MIT", license="MIT",
packages=find_packages(include=["sqlglot", "sqlglot.*"]), packages=find_packages(include=["sqlglot", "sqlglot.*"]),
package_data={"sqlglot": ["py.typed"]}, package_data={"sqlglot": ["py.typed"]},
extras_require={
"dev": [
"autoflake",
"black",
"duckdb",
"isort",
"mypy",
"pandas",
"pyspark",
"python-dateutil",
"pdoc",
"pre-commit",
],
},
classifiers=[ classifiers=[
"Development Status :: 5 - Production/Stable", "Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers", "Intended Audience :: Developers",

View file

@ -1,4 +1,6 @@
"""## Python SQL parser, transpiler and optimizer.""" """
.. include:: ../README.md
"""
from __future__ import annotations from __future__ import annotations
@ -30,7 +32,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType from sqlglot.tokens import Tokenizer, TokenType
__version__ = "10.2.9" __version__ = "10.4.2"
pretty = False pretty = False

View file

@ -1,9 +1,15 @@
import argparse import argparse
import sys
import sqlglot import sqlglot
parser = argparse.ArgumentParser(description="Transpile SQL") parser = argparse.ArgumentParser(description="Transpile SQL")
parser.add_argument("sql", metavar="sql", type=str, help="SQL string to transpile") parser.add_argument(
"sql",
metavar="sql",
type=str,
help="SQL statement(s) to transpile, or - to parse stdin.",
)
parser.add_argument( parser.add_argument(
"--read", "--read",
dest="read", dest="read",
@ -48,14 +54,20 @@ parser.add_argument(
args = parser.parse_args() args = parser.parse_args()
error_level = sqlglot.ErrorLevel[args.error_level.upper()] error_level = sqlglot.ErrorLevel[args.error_level.upper()]
sql = sys.stdin.read() if args.sql == "-" else args.sql
if args.parse: if args.parse:
sqls = [ sqls = [
repr(expression) repr(expression)
for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level) for expression in sqlglot.parse(
sql,
read=args.read,
error_level=error_level,
)
] ]
else: else:
sqls = sqlglot.transpile( sqls = sqlglot.transpile(
args.sql, sql,
read=args.read, read=args.read,
write=args.write, write=args.write,
identify=args.identify, identify=args.identify,

View file

@ -0,0 +1,3 @@
"""
.. include:: ./README.md
"""

View file

@ -9,18 +9,8 @@ if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.column import Column from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.types import StructType from sqlglot.dataframe.sql.types import StructType
ColumnLiterals = t.TypeVar( ColumnLiterals = t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
"ColumnLiterals", ColumnOrName = t.Union[Column, str]
bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime], ColumnOrLiteral = t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
) SchemaInput = t.Union[str, t.List[str], StructType, t.Dict[str, t.Optional[str]]]
ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str]) OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert]
ColumnOrLiteral = t.TypeVar(
"ColumnOrLiteral",
bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
)
SchemaInput = t.TypeVar(
"SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]]
)
OutputExpressionContainer = t.TypeVar(
"OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert]
)

View file

@ -634,7 +634,7 @@ class DataFrame:
all_columns = self._get_outer_select_columns(new_df.expression) all_columns = self._get_outer_select_columns(new_df.expression)
all_column_mapping = {column.alias_or_name: column for column in all_columns} all_column_mapping = {column.alias_or_name: column for column in all_columns}
if isinstance(value, dict): if isinstance(value, dict):
values = value.values() values = list(value.values())
columns = self._ensure_and_normalize_cols(list(value)) columns = self._ensure_and_normalize_cols(list(value))
if not columns: if not columns:
columns = self._ensure_and_normalize_cols(subset) if subset else all_columns columns = self._ensure_and_normalize_cols(subset) if subset else all_columns

View file

@ -1,11 +1,15 @@
"""Supports BigQuery Standard SQL."""
from __future__ import annotations from __future__ import annotations
from sqlglot import exp, generator, parser, tokens from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import ( from sqlglot.dialects.dialect import (
Dialect, Dialect,
datestrtodate_sql,
inline_array_sql, inline_array_sql,
no_ilike_sql, no_ilike_sql,
rename_func, rename_func,
timestrtotime_sql,
) )
from sqlglot.helper import seq_get from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
@ -120,13 +124,12 @@ class BigQuery(Dialect):
"NOT DETERMINISTIC": TokenType.VOLATILE, "NOT DETERMINISTIC": TokenType.VOLATILE,
"QUALIFY": TokenType.QUALIFY, "QUALIFY": TokenType.QUALIFY,
"UNKNOWN": TokenType.NULL, "UNKNOWN": TokenType.NULL,
"WINDOW": TokenType.WINDOW,
} }
KEYWORDS.pop("DIV") KEYWORDS.pop("DIV")
class Parser(parser.Parser): class Parser(parser.Parser):
FUNCTIONS = { FUNCTIONS = {
**parser.Parser.FUNCTIONS, **parser.Parser.FUNCTIONS, # type: ignore
"DATE_TRUNC": _date_trunc, "DATE_TRUNC": _date_trunc,
"DATE_ADD": _date_add(exp.DateAdd), "DATE_ADD": _date_add(exp.DateAdd),
"DATETIME_ADD": _date_add(exp.DatetimeAdd), "DATETIME_ADD": _date_add(exp.DatetimeAdd),
@ -144,31 +147,33 @@ class BigQuery(Dialect):
} }
FUNCTION_PARSERS = { FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS, **parser.Parser.FUNCTION_PARSERS, # type: ignore
"ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]), "ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]),
} }
FUNCTION_PARSERS.pop("TRIM") FUNCTION_PARSERS.pop("TRIM")
NO_PAREN_FUNCTIONS = { NO_PAREN_FUNCTIONS = {
**parser.Parser.NO_PAREN_FUNCTIONS, **parser.Parser.NO_PAREN_FUNCTIONS, # type: ignore
TokenType.CURRENT_DATETIME: exp.CurrentDatetime, TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
TokenType.CURRENT_TIME: exp.CurrentTime, TokenType.CURRENT_TIME: exp.CurrentTime,
} }
NESTED_TYPE_TOKENS = { NESTED_TYPE_TOKENS = {
*parser.Parser.NESTED_TYPE_TOKENS, *parser.Parser.NESTED_TYPE_TOKENS, # type: ignore
TokenType.TABLE, TokenType.TABLE,
} }
class Generator(generator.Generator): class Generator(generator.Generator):
TRANSFORMS = { TRANSFORMS = {
**generator.Generator.TRANSFORMS, **generator.Generator.TRANSFORMS, # type: ignore
exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.DateAdd: _date_add_sql("DATE", "ADD"), exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"), exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"), exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"), exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})", exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.DateStrToDate: datestrtodate_sql,
exp.GroupConcat: rename_func("STRING_AGG"),
exp.ILike: no_ilike_sql, exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"), exp.IntDiv: rename_func("DIV"),
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
@ -176,6 +181,7 @@ class BigQuery(Dialect):
exp.TimeSub: _date_add_sql("TIME", "SUB"), exp.TimeSub: _date_add_sql("TIME", "SUB"),
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
exp.VariancePop: rename_func("VAR_POP"), exp.VariancePop: rename_func("VAR_POP"),
exp.Values: _derived_table_values_to_unnest, exp.Values: _derived_table_values_to_unnest,
exp.ReturnsProperty: _returnsproperty_sql, exp.ReturnsProperty: _returnsproperty_sql,
@ -188,7 +194,7 @@ class BigQuery(Dialect):
} }
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "INT64", exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.SMALLINT: "INT64", exp.DataType.Type.SMALLINT: "INT64",
exp.DataType.Type.INT: "INT64", exp.DataType.Type.INT: "INT64",

View file

@ -35,13 +35,13 @@ class ClickHouse(Dialect):
class Parser(parser.Parser): class Parser(parser.Parser):
FUNCTIONS = { FUNCTIONS = {
**parser.Parser.FUNCTIONS, **parser.Parser.FUNCTIONS, # type: ignore
"MAP": parse_var_map, "MAP": parse_var_map,
} }
JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
def _parse_table(self, schema=False): def _parse_table(self, schema=False):
this = super()._parse_table(schema) this = super()._parse_table(schema)
@ -55,7 +55,7 @@ class ClickHouse(Dialect):
STRUCT_DELIMITER = ("(", ")") STRUCT_DELIMITER = ("(", ")")
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.NULLABLE: "Nullable", exp.DataType.Type.NULLABLE: "Nullable",
exp.DataType.Type.DATETIME: "DateTime64", exp.DataType.Type.DATETIME: "DateTime64",
exp.DataType.Type.MAP: "Map", exp.DataType.Type.MAP: "Map",
@ -70,7 +70,7 @@ class ClickHouse(Dialect):
} }
TRANSFORMS = { TRANSFORMS = {
**generator.Generator.TRANSFORMS, **generator.Generator.TRANSFORMS, # type: ignore
exp.Array: inline_array_sql, exp.Array: inline_array_sql,
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",

View file

@ -198,7 +198,7 @@ class Dialect(metaclass=_Dialect):
def rename_func(name): def rename_func(name):
def _rename(self, expression): def _rename(self, expression):
args = flatten(expression.args.values()) args = flatten(expression.args.values())
return f"{name}({self.format_args(*args)})" return f"{self.normalize_func(name)}({self.format_args(*args)})"
return _rename return _rename
@ -217,11 +217,11 @@ def if_sql(self, expression):
def arrow_json_extract_sql(self, expression): def arrow_json_extract_sql(self, expression):
return f"{self.sql(expression, 'this')}->{self.sql(expression, 'path')}" return self.binary(expression, "->")
def arrow_json_extract_scalar_sql(self, expression): def arrow_json_extract_scalar_sql(self, expression):
return f"{self.sql(expression, 'this')}->>{self.sql(expression, 'path')}" return self.binary(expression, "->>")
def inline_array_sql(self, expression): def inline_array_sql(self, expression):
@ -373,3 +373,11 @@ def strposition_to_local_sql(self, expression):
expression.args.get("substr"), expression.this, expression.args.get("position") expression.args.get("substr"), expression.this, expression.args.get("position")
) )
return f"LOCATE({args})" return f"LOCATE({args})"
def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str:
return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str:
return f"CAST({self.sql(expression, 'this')} AS DATE)"

View file

@ -6,13 +6,14 @@ from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import ( from sqlglot.dialects.dialect import (
Dialect, Dialect,
create_with_partitions_sql, create_with_partitions_sql,
datestrtodate_sql,
format_time_lambda, format_time_lambda,
no_pivot_sql, no_pivot_sql,
no_trycast_sql, no_trycast_sql,
rename_func, rename_func,
str_position_sql, str_position_sql,
timestrtotime_sql,
) )
from sqlglot.dialects.postgres import _lateral_sql
def _to_timestamp(args): def _to_timestamp(args):
@ -117,14 +118,14 @@ class Drill(Dialect):
STRICT_CAST = False STRICT_CAST = False
FUNCTIONS = { FUNCTIONS = {
**parser.Parser.FUNCTIONS, **parser.Parser.FUNCTIONS, # type: ignore
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"), "TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
} }
class Generator(generator.Generator): class Generator(generator.Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER", exp.DataType.Type.SMALLINT: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER", exp.DataType.Type.TINYINT: "INTEGER",
@ -139,14 +140,13 @@ class Drill(Dialect):
ROOT_PROPERTIES = {exp.PartitionedByProperty} ROOT_PROPERTIES = {exp.PartitionedByProperty}
TRANSFORMS = { TRANSFORMS = {
**generator.Generator.TRANSFORMS, **generator.Generator.TRANSFORMS, # type: ignore
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.Lateral: _lateral_sql,
exp.ArrayContains: rename_func("REPEATED_CONTAINS"), exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
exp.ArraySize: rename_func("REPEATED_COUNT"), exp.ArraySize: rename_func("REPEATED_COUNT"),
exp.Create: create_with_partitions_sql, exp.Create: create_with_partitions_sql,
exp.DateAdd: _date_add_sql("ADD"), exp.DateAdd: _date_add_sql("ADD"),
exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("SUB"), exp.DateSub: _date_add_sql("SUB"),
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)", exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})", exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})",
@ -160,7 +160,7 @@ class Drill(Dialect):
exp.StrToDate: _str_to_date, exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),

View file

@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
approx_count_distinct_sql, approx_count_distinct_sql,
arrow_json_extract_scalar_sql, arrow_json_extract_scalar_sql,
arrow_json_extract_sql, arrow_json_extract_sql,
datestrtodate_sql,
format_time_lambda, format_time_lambda,
no_pivot_sql, no_pivot_sql,
no_properties_sql, no_properties_sql,
@ -13,6 +14,7 @@ from sqlglot.dialects.dialect import (
no_tablesample_sql, no_tablesample_sql,
rename_func, rename_func,
str_position_sql, str_position_sql,
timestrtotime_sql,
) )
from sqlglot.helper import seq_get from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
@ -83,11 +85,12 @@ class DuckDB(Dialect):
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
":=": TokenType.EQ, ":=": TokenType.EQ,
"CHARACTER VARYING": TokenType.VARCHAR,
} }
class Parser(parser.Parser): class Parser(parser.Parser):
FUNCTIONS = { FUNCTIONS = {
**parser.Parser.FUNCTIONS, **parser.Parser.FUNCTIONS, # type: ignore
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"ARRAY_LENGTH": exp.ArraySize.from_arg_list, "ARRAY_LENGTH": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list, "ARRAY_SORT": exp.SortArray.from_arg_list,
@ -119,16 +122,18 @@ class DuckDB(Dialect):
STRUCT_DELIMITER = ("(", ")") STRUCT_DELIMITER = ("(", ")")
TRANSFORMS = { TRANSFORMS = {
**generator.Generator.TRANSFORMS, **generator.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: approx_count_distinct_sql, exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: rename_func("LIST_VALUE"), exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
if isinstance(seq_get(e.expressions, 0), exp.Select)
else rename_func("LIST_VALUE")(self, e),
exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArraySort: _array_sort_sql, exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"), exp.ArraySum: rename_func("LIST_SUM"),
exp.DataType: _datatype_sql, exp.DataType: _datatype_sql,
exp.DateAdd: _date_add, exp.DateAdd: _date_add,
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.format_args(e.args.get("unit") or "'day'", e.expression, e.this)})""", exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.format_args(e.args.get("unit") or "'day'", e.expression, e.this)})""",
exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)", exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)", exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)",
exp.Explode: rename_func("UNNEST"), exp.Explode: rename_func("UNNEST"),
@ -136,6 +141,7 @@ class DuckDB(Dialect):
exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: arrow_json_extract_sql, exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Pivot: no_pivot_sql, exp.Pivot: no_pivot_sql,
exp.Properties: no_properties_sql, exp.Properties: no_properties_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"), exp.RegexpLike: rename_func("REGEXP_MATCHES"),
@ -150,7 +156,7 @@ class DuckDB(Dialect):
exp.Struct: _struct_pack_sql, exp.Struct: _struct_pack_sql,
exp.TableSample: no_tablesample_sql, exp.TableSample: no_tablesample_sql,
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))", exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))",
exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("EPOCH"), exp.TimeToUnix: rename_func("EPOCH"),
@ -163,7 +169,7 @@ class DuckDB(Dialect):
} }
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.VARCHAR: "TEXT", exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.NVARCHAR: "TEXT", exp.DataType.Type.NVARCHAR: "TEXT",
} }

View file

@ -15,6 +15,7 @@ from sqlglot.dialects.dialect import (
rename_func, rename_func,
strposition_to_local_sql, strposition_to_local_sql,
struct_extract_sql, struct_extract_sql,
timestrtotime_sql,
var_map_sql, var_map_sql,
) )
from sqlglot.helper import seq_get from sqlglot.helper import seq_get
@ -197,7 +198,7 @@ class Hive(Dialect):
STRICT_CAST = False STRICT_CAST = False
FUNCTIONS = { FUNCTIONS = {
**parser.Parser.FUNCTIONS, **parser.Parser.FUNCTIONS, # type: ignore
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list, "COLLECT_LIST": exp.ArrayAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd( "DATE_ADD": lambda args: exp.TsOrDsAdd(
@ -217,7 +218,12 @@ class Hive(Dialect):
), ),
unit=exp.Literal.string("DAY"), unit=exp.Literal.string("DAY"),
), ),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"), "DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")(
[
exp.TimeStrToTime(this=seq_get(args, 0)),
seq_get(args, 1),
]
),
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list, "GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
@ -240,7 +246,7 @@ class Hive(Dialect):
} }
PROPERTY_PARSERS = { PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS, **parser.Parser.PROPERTY_PARSERS, # type: ignore
TokenType.SERDE_PROPERTIES: lambda self: exp.SerdeProperties( TokenType.SERDE_PROPERTIES: lambda self: exp.SerdeProperties(
expressions=self._parse_wrapped_csv(self._parse_property) expressions=self._parse_wrapped_csv(self._parse_property)
), ),
@ -248,14 +254,14 @@ class Hive(Dialect):
class Generator(generator.Generator): class Generator(generator.Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TEXT: "STRING", exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.DATETIME: "TIMESTAMP", exp.DataType.Type.DATETIME: "TIMESTAMP",
exp.DataType.Type.VARBINARY: "BINARY", exp.DataType.Type.VARBINARY: "BINARY",
} }
TRANSFORMS = { TRANSFORMS = {
**generator.Generator.TRANSFORMS, **generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore
exp.Property: _property_sql, exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql, exp.ApproxDistinct: approx_count_distinct_sql,
@ -294,7 +300,7 @@ class Hive(Dialect):
exp.StructExtract: struct_extract_sql, exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}", exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
exp.TimeStrToDate: rename_func("TO_DATE"), exp.TimeStrToDate: rename_func("TO_DATE"),
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: _time_to_str, exp.TimeToStr: _time_to_str,
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),

View file

@ -161,8 +161,6 @@ class MySQL(Dialect):
"_UCS2": TokenType.INTRODUCER, "_UCS2": TokenType.INTRODUCER,
"_UJIS": TokenType.INTRODUCER, "_UJIS": TokenType.INTRODUCER,
# https://dev.mysql.com/doc/refman/8.0/en/string-literals.html # https://dev.mysql.com/doc/refman/8.0/en/string-literals.html
"N": TokenType.INTRODUCER,
"n": TokenType.INTRODUCER,
"_UTF8": TokenType.INTRODUCER, "_UTF8": TokenType.INTRODUCER,
"_UTF16": TokenType.INTRODUCER, "_UTF16": TokenType.INTRODUCER,
"_UTF16LE": TokenType.INTRODUCER, "_UTF16LE": TokenType.INTRODUCER,
@ -175,10 +173,10 @@ class MySQL(Dialect):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW} COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
class Parser(parser.Parser): class Parser(parser.Parser):
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA} FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA} # type: ignore
FUNCTIONS = { FUNCTIONS = {
**parser.Parser.FUNCTIONS, **parser.Parser.FUNCTIONS, # type: ignore
"DATE_ADD": _date_add(exp.DateAdd), "DATE_ADD": _date_add(exp.DateAdd),
"DATE_SUB": _date_add(exp.DateSub), "DATE_SUB": _date_add(exp.DateSub),
"STR_TO_DATE": _str_to_date, "STR_TO_DATE": _str_to_date,
@ -190,7 +188,7 @@ class MySQL(Dialect):
} }
FUNCTION_PARSERS = { FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS, **parser.Parser.FUNCTION_PARSERS, # type: ignore
"GROUP_CONCAT": lambda self: self.expression( "GROUP_CONCAT": lambda self: self.expression(
exp.GroupConcat, exp.GroupConcat,
this=self._parse_lambda(), this=self._parse_lambda(),
@ -199,12 +197,12 @@ class MySQL(Dialect):
} }
PROPERTY_PARSERS = { PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS, **parser.Parser.PROPERTY_PARSERS, # type: ignore
TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty), TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty),
} }
STATEMENT_PARSERS = { STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS, **parser.Parser.STATEMENT_PARSERS, # type: ignore
TokenType.SHOW: lambda self: self._parse_show(), TokenType.SHOW: lambda self: self._parse_show(),
TokenType.SET: lambda self: self._parse_set(), TokenType.SET: lambda self: self._parse_set(),
} }
@ -429,7 +427,7 @@ class MySQL(Dialect):
NULL_ORDERING_SUPPORTED = False NULL_ORDERING_SUPPORTED = False
TRANSFORMS = { TRANSFORMS = {
**generator.Generator.TRANSFORMS, **generator.Generator.TRANSFORMS, # type: ignore
exp.CurrentDate: no_paren_current_date_sql, exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ILike: no_ilike_sql, exp.ILike: no_ilike_sql,

View file

@ -39,13 +39,13 @@ class Oracle(Dialect):
class Parser(parser.Parser): class Parser(parser.Parser):
FUNCTIONS = { FUNCTIONS = {
**parser.Parser.FUNCTIONS, **parser.Parser.FUNCTIONS, # type: ignore
"DECODE": exp.Matches.from_arg_list, "DECODE": exp.Matches.from_arg_list,
} }
class Generator(generator.Generator): class Generator(generator.Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "NUMBER", exp.DataType.Type.TINYINT: "NUMBER",
exp.DataType.Type.SMALLINT: "NUMBER", exp.DataType.Type.SMALLINT: "NUMBER",
exp.DataType.Type.INT: "NUMBER", exp.DataType.Type.INT: "NUMBER",
@ -60,7 +60,7 @@ class Oracle(Dialect):
} }
TRANSFORMS = { TRANSFORMS = {
**generator.Generator.TRANSFORMS, **generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore
exp.ILike: no_ilike_sql, exp.ILike: no_ilike_sql,
exp.Limit: _limit_sql, exp.Limit: _limit_sql,

View file

@ -11,9 +11,19 @@ from sqlglot.dialects.dialect import (
no_trycast_sql, no_trycast_sql,
str_position_sql, str_position_sql,
) )
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
from sqlglot.transforms import delegate, preprocess from sqlglot.transforms import delegate, preprocess
DATE_DIFF_FACTOR = {
"MICROSECOND": " * 1000000",
"MILLISECOND": " * 1000",
"SECOND": "",
"MINUTE": " / 60",
"HOUR": " / 3600",
"DAY": " / 86400",
}
def _date_add_sql(kind): def _date_add_sql(kind):
def func(self, expression): def func(self, expression):
@ -34,16 +44,30 @@ def _date_add_sql(kind):
return func return func
def _lateral_sql(self, expression): def _date_diff_sql(self, expression):
this = self.sql(expression, "this") unit = expression.text("unit").upper()
if isinstance(expression.this, exp.Subquery): factor = DATE_DIFF_FACTOR.get(unit)
return f"LATERAL{self.sep()}{this}"
alias = expression.args["alias"] end = f"CAST({expression.this} AS TIMESTAMP)"
table = alias.name start = f"CAST({expression.expression} AS TIMESTAMP)"
table = f" {table}" if table else table
columns = self.expressions(alias, key="columns", flat=True) if factor is not None:
columns = f" AS {columns}" if columns else "" return f"CAST(EXTRACT(epoch FROM {end} - {start}){factor} AS BIGINT)"
return f"LATERAL{self.sep()}{this}{table}{columns}"
age = f"AGE({end}, {start})"
if unit == "WEEK":
extract = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7"
elif unit == "MONTH":
extract = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})"
elif unit == "QUARTER":
extract = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3"
elif unit == "YEAR":
extract = f"EXTRACT(year FROM {age})"
else:
self.unsupported(f"Unsupported DATEDIFF unit {unit}")
return f"CAST({extract} AS BIGINT)"
def _substring_sql(self, expression): def _substring_sql(self, expression):
@ -141,7 +165,7 @@ def _serial_to_generated(expression):
def _to_timestamp(args): def _to_timestamp(args):
# TO_TIMESTAMP accepts either a single double argument or (text, text) # TO_TIMESTAMP accepts either a single double argument or (text, text)
if len(args) == 1 and args[0].is_number: if len(args) == 1:
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE
return exp.UnixToTime.from_arg_list(args) return exp.UnixToTime.from_arg_list(args)
# https://www.postgresql.org/docs/current/functions-formatting.html # https://www.postgresql.org/docs/current/functions-formatting.html
@ -211,11 +235,16 @@ class Postgres(Dialect):
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"~~": TokenType.LIKE,
"~~*": TokenType.ILIKE,
"~*": TokenType.IRLIKE,
"~": TokenType.RLIKE,
"ALWAYS": TokenType.ALWAYS, "ALWAYS": TokenType.ALWAYS,
"BEGIN": TokenType.COMMAND, "BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN, "BEGIN TRANSACTION": TokenType.BEGIN,
"BIGSERIAL": TokenType.BIGSERIAL, "BIGSERIAL": TokenType.BIGSERIAL,
"BY DEFAULT": TokenType.BY_DEFAULT, "BY DEFAULT": TokenType.BY_DEFAULT,
"CHARACTER VARYING": TokenType.VARCHAR,
"COMMENT ON": TokenType.COMMAND, "COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND, "DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND, "DO": TokenType.COMMAND,
@ -233,6 +262,7 @@ class Postgres(Dialect):
"SMALLSERIAL": TokenType.SMALLSERIAL, "SMALLSERIAL": TokenType.SMALLSERIAL,
"TEMP": TokenType.TEMPORARY, "TEMP": TokenType.TEMPORARY,
"UUID": TokenType.UUID, "UUID": TokenType.UUID,
"CSTRING": TokenType.PSEUDO_TYPE,
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES}, **{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES}, **{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
} }
@ -244,17 +274,16 @@ class Postgres(Dialect):
class Parser(parser.Parser): class Parser(parser.Parser):
STRICT_CAST = False STRICT_CAST = False
LATERAL_FUNCTION_AS_VIEW = True
FUNCTIONS = { FUNCTIONS = {
**parser.Parser.FUNCTIONS, **parser.Parser.FUNCTIONS, # type: ignore
"TO_TIMESTAMP": _to_timestamp, "TO_TIMESTAMP": _to_timestamp,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
} }
class Generator(generator.Generator): class Generator(generator.Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "SMALLINT", exp.DataType.Type.TINYINT: "SMALLINT",
exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
@ -264,7 +293,7 @@ class Postgres(Dialect):
} }
TRANSFORMS = { TRANSFORMS = {
**generator.Generator.TRANSFORMS, **generator.Generator.TRANSFORMS, # type: ignore
exp.ColumnDef: preprocess( exp.ColumnDef: preprocess(
[ [
_auto_increment_to_serial, _auto_increment_to_serial,
@ -274,13 +303,16 @@ class Postgres(Dialect):
), ),
exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}", exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
exp.JSONBExtractScalar: lambda self, e: f"{self.sql(e, 'this')}#>>{self.sql(e, 'path')}", exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"),
exp.JSONBContains: lambda self, e: self.binary(e, "?"),
exp.CurrentDate: no_paren_current_date_sql, exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: _date_add_sql("+"), exp.DateAdd: _date_add_sql("+"),
exp.DateSub: _date_add_sql("-"), exp.DateSub: _date_add_sql("-"),
exp.Lateral: _lateral_sql, exp.DateDiff: _date_diff_sql,
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
exp.StrPosition: str_position_sql, exp.StrPosition: str_position_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Substring: _substring_sql, exp.Substring: _substring_sql,
@ -291,5 +323,7 @@ class Postgres(Dialect):
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
exp.DataType: _datatype_sql, exp.DataType: _datatype_sql,
exp.GroupConcat: _string_agg_sql, exp.GroupConcat: _string_agg_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
if isinstance(seq_get(e.expressions, 0), exp.Select)
else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]",
} }

View file

@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
rename_func, rename_func,
str_position_sql, str_position_sql,
struct_extract_sql, struct_extract_sql,
timestrtotime_sql,
) )
from sqlglot.dialects.mysql import MySQL from sqlglot.dialects.mysql import MySQL
from sqlglot.errors import UnsupportedError from sqlglot.errors import UnsupportedError
@ -38,10 +39,6 @@ def _datatype_sql(self, expression):
return sql return sql
def _date_parse_sql(self, expression):
return f"DATE_PARSE({self.sql(expression, 'this')}, '%Y-%m-%d %H:%i:%s')"
def _explode_to_unnest_sql(self, expression): def _explode_to_unnest_sql(self, expression):
if isinstance(expression.this, (exp.Explode, exp.Posexplode)): if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
return self.sql( return self.sql(
@ -137,7 +134,7 @@ class Presto(Dialect):
class Parser(parser.Parser): class Parser(parser.Parser):
FUNCTIONS = { FUNCTIONS = {
**parser.Parser.FUNCTIONS, **parser.Parser.FUNCTIONS, # type: ignore
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"CARDINALITY": exp.ArraySize.from_arg_list, "CARDINALITY": exp.ArraySize.from_arg_list,
"CONTAINS": exp.ArrayContains.from_arg_list, "CONTAINS": exp.ArrayContains.from_arg_list,
@ -174,7 +171,7 @@ class Presto(Dialect):
ROOT_PROPERTIES = {exp.SchemaCommentProperty} ROOT_PROPERTIES = {exp.SchemaCommentProperty}
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.BINARY: "VARBINARY", exp.DataType.Type.BINARY: "VARBINARY",
@ -184,7 +181,7 @@ class Presto(Dialect):
} }
TRANSFORMS = { TRANSFORMS = {
**generator.Generator.TRANSFORMS, **generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql, exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
@ -224,8 +221,8 @@ class Presto(Dialect):
exp.StructExtract: struct_extract_sql, exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'", exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.TimeStrToDate: _date_parse_sql, exp.TimeStrToDate: timestrtotime_sql,
exp.TimeStrToTime: _date_parse_sql, exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))", exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",
exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("TO_UNIXTIME"), exp.TimeToUnix: rename_func("TO_UNIXTIME"),

View file

@ -36,7 +36,6 @@ class Redshift(Postgres):
"TIMETZ": TokenType.TIMESTAMPTZ, "TIMETZ": TokenType.TIMESTAMPTZ,
"UNLOAD": TokenType.COMMAND, "UNLOAD": TokenType.COMMAND,
"VARBYTE": TokenType.VARBINARY, "VARBYTE": TokenType.VARBINARY,
"SIMILAR TO": TokenType.SIMILAR_TO,
} }
class Generator(Postgres.Generator): class Generator(Postgres.Generator):

View file

@ -3,13 +3,15 @@ from __future__ import annotations
from sqlglot import exp, generator, parser, tokens from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import ( from sqlglot.dialects.dialect import (
Dialect, Dialect,
datestrtodate_sql,
format_time_lambda, format_time_lambda,
inline_array_sql, inline_array_sql,
rename_func, rename_func,
timestrtotime_sql,
var_map_sql, var_map_sql,
) )
from sqlglot.expressions import Literal from sqlglot.expressions import Literal
from sqlglot.helper import seq_get from sqlglot.helper import flatten, seq_get
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
@ -183,7 +185,7 @@ class Snowflake(Dialect):
class Tokenizer(tokens.Tokenizer): class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"] QUOTES = ["'", "$$"]
ESCAPES = ["\\"] ESCAPES = ["\\", "'"]
SINGLE_TOKENS = { SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS, **tokens.Tokenizer.SINGLE_TOKENS,
@ -206,9 +208,10 @@ class Snowflake(Dialect):
CREATE_TRANSIENT = True CREATE_TRANSIENT = True
TRANSFORMS = { TRANSFORMS = {
**generator.Generator.TRANSFORMS, **generator.Generator.TRANSFORMS, # type: ignore
exp.Array: inline_array_sql, exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql, exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"), exp.If: rename_func("IFF"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
@ -218,13 +221,14 @@ class Snowflake(Dialect):
exp.Matches: rename_func("DECODE"), exp.Matches: rename_func("DECODE"),
exp.StrPosition: rename_func("POSITION"), exp.StrPosition: rename_func("POSITION"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})", exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
exp.UnixToTime: _unix_to_time_sql, exp.UnixToTime: _unix_to_time_sql,
} }
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
} }
@ -246,3 +250,47 @@ class Snowflake(Dialect):
if not expression.args.get("distinct", False): if not expression.args.get("distinct", False):
self.unsupported("INTERSECT with All is not supported in Snowflake") self.unsupported("INTERSECT with All is not supported in Snowflake")
return super().intersect_op(expression) return super().intersect_op(expression)
def values_sql(self, expression: exp.Values) -> str:
"""Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted.
We also want to make sure that after we find matches where we need to unquote a column that we prevent users
from adding quotes to the column by using the `identify` argument when generating the SQL.
"""
alias = expression.args.get("alias")
if alias and alias.args.get("columns"):
expression = expression.transform(
lambda node: exp.Identifier(**{**node.args, "quoted": False})
if isinstance(node, exp.Identifier)
and isinstance(node.parent, exp.TableAlias)
and node.arg_key == "columns"
else node,
)
return self.no_identify(lambda: super(self.__class__, self).values_sql(expression))
return super().values_sql(expression)
def select_sql(self, expression: exp.Select) -> str:
"""Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also
that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need
to unquote a column that we prevent users from adding quotes to the column by using the `identify` argument when
generating the SQL.
Note: We make an assumption that any columns referenced in a VALUES expression should be unquoted throughout the
expression. This might not be true in a case where the same column name can be sourced from another table that can
properly quote but should be true in most cases.
"""
values_expressions = expression.find_all(exp.Values)
values_identifiers = set(
flatten(
v.args.get("alias", exp.Alias()).args.get("columns", [])
for v in values_expressions
)
)
if values_identifiers:
expression = expression.transform(
lambda node: exp.Identifier(**{**node.args, "quoted": False})
if isinstance(node, exp.Identifier) and node in values_identifiers
else node,
)
return self.no_identify(lambda: super(self.__class__, self).select_sql(expression))
return super().select_sql(expression)

View file

@ -76,7 +76,7 @@ class Spark(Hive):
} }
FUNCTION_PARSERS = { FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS, **parser.Parser.FUNCTION_PARSERS, # type: ignore
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
@ -87,6 +87,16 @@ class Spark(Hive):
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"), "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
} }
def _parse_add_column(self):
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
def _parse_drop_column(self):
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
exp.Drop,
this=self._parse_schema(),
kind="COLUMNS",
)
class Generator(Hive.Generator): class Generator(Hive.Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING, # type: ignore **Hive.Generator.TYPE_MAPPING, # type: ignore

View file

@ -42,13 +42,13 @@ class SQLite(Dialect):
class Parser(parser.Parser): class Parser(parser.Parser):
FUNCTIONS = { FUNCTIONS = {
**parser.Parser.FUNCTIONS, **parser.Parser.FUNCTIONS, # type: ignore
"EDITDIST3": exp.Levenshtein.from_arg_list, "EDITDIST3": exp.Levenshtein.from_arg_list,
} }
class Generator(generator.Generator): class Generator(generator.Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BOOLEAN: "INTEGER", exp.DataType.Type.BOOLEAN: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER", exp.DataType.Type.TINYINT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER", exp.DataType.Type.SMALLINT: "INTEGER",
@ -70,7 +70,7 @@ class SQLite(Dialect):
} }
TRANSFORMS = { TRANSFORMS = {
**generator.Generator.TRANSFORMS, **generator.Generator.TRANSFORMS, # type: ignore
exp.ILike: no_ilike_sql, exp.ILike: no_ilike_sql,
exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql,

View file

@ -8,7 +8,7 @@ from sqlglot.dialects.mysql import MySQL
class StarRocks(MySQL): class StarRocks(MySQL):
class Generator(MySQL.Generator): # type: ignore class Generator(MySQL.Generator): # type: ignore
TYPE_MAPPING = { TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING, **MySQL.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TEXT: "STRING", exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.TIMESTAMP: "DATETIME", exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "DATETIME", exp.DataType.Type.TIMESTAMPTZ: "DATETIME",

View file

@ -30,7 +30,7 @@ class Tableau(Dialect):
class Parser(parser.Parser): class Parser(parser.Parser):
FUNCTIONS = { FUNCTIONS = {
**parser.Parser.FUNCTIONS, **parser.Parser.FUNCTIONS, # type: ignore
"IFNULL": exp.Coalesce.from_arg_list, "IFNULL": exp.Coalesce.from_arg_list,
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)), "COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
} }

View file

@ -224,11 +224,7 @@ class TSQL(Dialect):
class Tokenizer(tokens.Tokenizer): class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]")] IDENTIFIERS = ['"', ("[", "]")]
QUOTES = [ QUOTES = ["'", '"']
(prefix + quote, quote) if prefix else quote
for quote in ["'", '"']
for prefix in ["", "n", "N"]
]
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
@ -253,7 +249,7 @@ class TSQL(Dialect):
class Parser(parser.Parser): class Parser(parser.Parser):
FUNCTIONS = { FUNCTIONS = {
**parser.Parser.FUNCTIONS, **parser.Parser.FUNCTIONS, # type: ignore
"CHARINDEX": exp.StrPosition.from_arg_list, "CHARINDEX": exp.StrPosition.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list, "ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
@ -314,7 +310,7 @@ class TSQL(Dialect):
class Generator(generator.Generator): class Generator(generator.Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BOOLEAN: "BIT", exp.DataType.Type.BOOLEAN: "BIT",
exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC", exp.DataType.Type.DECIMAL: "NUMERIC",

View file

@ -1,3 +1,7 @@
"""
.. include:: ../posts/sql_diff.md
"""
from __future__ import annotations from __future__ import annotations
import typing as t import typing as t

View file

@ -29,10 +29,10 @@ class Context:
self._table: t.Optional[Table] = None self._table: t.Optional[Table] = None
self.range_readers = {name: table.range_reader for name, table in self.tables.items()} self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
self.row_readers = {name: table.reader for name, table in tables.items()} self.row_readers = {name: table.reader for name, table in tables.items()}
self.env = {**(env or {}), "scope": self.row_readers} self.env = {**ENV, **(env or {}), "scope": self.row_readers}
def eval(self, code): def eval(self, code):
return eval(code, ENV, self.env) return eval(code, self.env)
def eval_tuple(self, codes): def eval_tuple(self, codes):
return tuple(self.eval(code) for code in codes) return tuple(self.eval(code) for code in codes)

View file

@ -127,14 +127,16 @@ def interval(this, unit):
ENV = { ENV = {
"exp": exp, "exp": exp,
# aggs # aggs
"SUM": filter_nulls(sum), "ARRAYAGG": list,
"AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
"COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False), "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False),
"MAX": filter_nulls(max), "MAX": filter_nulls(max),
"MIN": filter_nulls(min), "MIN": filter_nulls(min),
"SUM": filter_nulls(sum),
# scalar functions # scalar functions
"ABS": null_if_any(lambda this: abs(this)), "ABS": null_if_any(lambda this: abs(this)),
"ADD": null_if_any(lambda e, this: e + this), "ADD": null_if_any(lambda e, this: e + this),
"ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)),
"BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high), "BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
"BITWISEAND": null_if_any(lambda this, e: this & e), "BITWISEAND": null_if_any(lambda this, e: this & e),
"BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e), "BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),

View file

@ -394,6 +394,18 @@ def _case_sql(self, expression):
return chain return chain
def _lambda_sql(self, e: exp.Lambda) -> str:
names = {e.name.lower() for e in e.expressions}
e = e.transform(
lambda n: exp.Var(this=n.name)
if isinstance(n, exp.Identifier) and n.name.lower() in names
else n
)
return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}"
class Python(Dialect): class Python(Dialect):
class Tokenizer(tokens.Tokenizer): class Tokenizer(tokens.Tokenizer):
ESCAPES = ["\\"] ESCAPES = ["\\"]
@ -414,6 +426,7 @@ class Python(Dialect):
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})", exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})", exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
exp.Is: lambda self, e: self.binary(e, "is"), exp.Is: lambda self, e: self.binary(e, "is"),
exp.Lambda: _lambda_sql,
exp.Not: lambda self, e: f"not {self.sql(e.this)}", exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None", exp.Null: lambda *_: "None",
exp.Or: lambda self, e: self.binary(e, "or"), exp.Or: lambda self, e: self.binary(e, "or"),

View file

@ -1,6 +1,11 @@
"""
.. include:: ../pdoc/docs/expressions.md
"""
from __future__ import annotations from __future__ import annotations
import datetime import datetime
import math
import numbers import numbers
import re import re
import typing as t import typing as t
@ -682,6 +687,10 @@ class CharacterSet(Expression):
class With(Expression): class With(Expression):
arg_types = {"expressions": True, "recursive": False} arg_types = {"expressions": True, "recursive": False}
@property
def recursive(self) -> bool:
return bool(self.args.get("recursive"))
class WithinGroup(Expression): class WithinGroup(Expression):
arg_types = {"this": True, "expression": False} arg_types = {"this": True, "expression": False}
@ -724,6 +733,18 @@ class ColumnDef(Expression):
"this": True, "this": True,
"kind": True, "kind": True,
"constraints": False, "constraints": False,
"exists": False,
}
class AlterColumn(Expression):
arg_types = {
"this": True,
"dtype": False,
"collate": False,
"using": False,
"default": False,
"drop": False,
} }
@ -877,6 +898,11 @@ class Introducer(Expression):
arg_types = {"this": True, "expression": True} arg_types = {"this": True, "expression": True}
# national char, like n'utf8'
class National(Expression):
pass
class LoadData(Expression): class LoadData(Expression):
arg_types = { arg_types = {
"this": True, "this": True,
@ -894,7 +920,7 @@ class Partition(Expression):
class Fetch(Expression): class Fetch(Expression):
arg_types = {"direction": False, "count": True} arg_types = {"direction": False, "count": False}
class Group(Expression): class Group(Expression):
@ -1316,7 +1342,7 @@ QUERY_MODIFIERS = {
"group": False, "group": False,
"having": False, "having": False,
"qualify": False, "qualify": False,
"window": False, "windows": False,
"distribute": False, "distribute": False,
"sort": False, "sort": False,
"cluster": False, "cluster": False,
@ -1353,7 +1379,7 @@ class Union(Subqueryable):
Example: Example:
>>> select("1").union(select("1")).limit(1).sql() >>> select("1").union(select("1")).limit(1).sql()
'SELECT * FROM (SELECT 1 UNION SELECT 1) AS "_l_0" LIMIT 1' 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1'
Args: Args:
expression (str | int | Expression): the SQL code string to parse. expression (str | int | Expression): the SQL code string to parse.
@ -1889,6 +1915,18 @@ class Select(Subqueryable):
**opts, **opts,
) )
def window(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
return _apply_list_builder(
*expressions,
instance=self,
arg="windows",
append=append,
into=Window,
dialect=dialect,
copy=copy,
**opts,
)
def distinct(self, distinct=True, copy=True) -> Select: def distinct(self, distinct=True, copy=True) -> Select:
""" """
Set the OFFSET expression. Set the OFFSET expression.
@ -2140,6 +2178,11 @@ class DataType(Expression):
) )
# https://www.postgresql.org/docs/15/datatype-pseudo.html
class PseudoType(Expression):
pass
class StructKwarg(Expression): class StructKwarg(Expression):
arg_types = {"this": True, "expression": True} arg_types = {"this": True, "expression": True}
@ -2167,18 +2210,26 @@ class Command(Expression):
arg_types = {"this": True, "expression": False} arg_types = {"this": True, "expression": False}
class Transaction(Command): class Transaction(Expression):
arg_types = {"this": False, "modes": False} arg_types = {"this": False, "modes": False}
class Commit(Command): class Commit(Expression):
arg_types = {"chain": False} arg_types = {"chain": False}
class Rollback(Command): class Rollback(Expression):
arg_types = {"savepoint": False} arg_types = {"savepoint": False}
class AlterTable(Expression):
arg_types = {
"this": True,
"actions": True,
"exists": False,
}
# Binary expressions like (ADD a b) # Binary expressions like (ADD a b)
class Binary(Expression): class Binary(Expression):
arg_types = {"this": True, "expression": True} arg_types = {"this": True, "expression": True}
@ -2312,6 +2363,10 @@ class SimilarTo(Binary, Predicate):
pass pass
class Slice(Binary):
arg_types = {"this": False, "expression": False}
class Sub(Binary): class Sub(Binary):
pass pass
@ -2392,7 +2447,7 @@ class TimeUnit(Expression):
class Interval(TimeUnit): class Interval(TimeUnit):
arg_types = {"this": True, "unit": False} arg_types = {"this": False, "unit": False}
class IgnoreNulls(Expression): class IgnoreNulls(Expression):
@ -2730,8 +2785,11 @@ class Initcap(Func):
pass pass
class JSONExtract(Func): class JSONBContains(Binary):
arg_types = {"this": True, "path": True} _sql_names = ["JSONB_CONTAINS"]
class JSONExtract(Binary, Func):
_sql_names = ["JSON_EXTRACT"] _sql_names = ["JSON_EXTRACT"]
@ -2776,6 +2834,10 @@ class Log10(Func):
pass pass
class LogicalOr(AggFunc):
_sql_names = ["LOGICAL_OR", "BOOL_OR"]
class Lower(Func): class Lower(Func):
_sql_names = ["LOWER", "LCASE"] _sql_names = ["LOWER", "LCASE"]
@ -2846,6 +2908,10 @@ class RegexpLike(Func):
arg_types = {"this": True, "expression": True, "flag": False} arg_types = {"this": True, "expression": True, "flag": False}
class RegexpILike(Func):
arg_types = {"this": True, "expression": True, "flag": False}
class RegexpSplit(Func): class RegexpSplit(Func):
arg_types = {"this": True, "expression": True} arg_types = {"this": True, "expression": True}
@ -3388,11 +3454,17 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts) -> U
], ],
) )
if from_: if from_:
update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts)) update.set(
"from",
maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts),
)
if isinstance(where, Condition): if isinstance(where, Condition):
where = Where(this=where) where = Where(this=where)
if where: if where:
update.set("where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts)) update.set(
"where",
maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts),
)
return update return update
@ -3522,7 +3594,7 @@ def paren(expression) -> Paren:
return Paren(this=expression) return Paren(this=expression)
SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$") SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
def to_identifier(alias, quoted=None) -> t.Optional[Identifier]: def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
@ -3724,6 +3796,8 @@ def convert(value) -> Expression:
return Boolean(this=value) return Boolean(this=value)
if isinstance(value, str): if isinstance(value, str):
return Literal.string(value) return Literal.string(value)
if isinstance(value, float) and math.isnan(value):
return NULL
if isinstance(value, numbers.Number): if isinstance(value, numbers.Number):
return Literal.number(value) return Literal.number(value)
if isinstance(value, tuple): if isinstance(value, tuple):
@ -3732,11 +3806,13 @@ def convert(value) -> Expression:
return Array(expressions=[convert(v) for v in value]) return Array(expressions=[convert(v) for v in value])
if isinstance(value, dict): if isinstance(value, dict):
return Map( return Map(
keys=[convert(k) for k in value.keys()], keys=[convert(k) for k in value],
values=[convert(v) for v in value.values()], values=[convert(v) for v in value.values()],
) )
if isinstance(value, datetime.datetime): if isinstance(value, datetime.datetime):
datetime_literal = Literal.string(value.strftime("%Y-%m-%d %H:%M:%S.%f%z")) datetime_literal = Literal.string(
(value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
)
return TimeStrToTime(this=datetime_literal) return TimeStrToTime(this=datetime_literal)
if isinstance(value, datetime.date): if isinstance(value, datetime.date):
date_literal = Literal.string(value.strftime("%Y-%m-%d")) date_literal = Literal.string(value.strftime("%Y-%m-%d"))

View file

@ -361,10 +361,11 @@ class Generator:
column = self.sql(expression, "this") column = self.sql(expression, "this")
kind = self.sql(expression, "kind") kind = self.sql(expression, "kind")
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True) constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
exists = "IF NOT EXISTS " if expression.args.get("exists") else ""
if not constraints: if not constraints:
return f"{column} {kind}" return f"{exists}{column} {kind}"
return f"{column} {kind} {constraints}" return f"{exists}{column} {kind} {constraints}"
def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str: def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
this = self.sql(expression, "this") this = self.sql(expression, "this")
@ -549,6 +550,9 @@ class Generator:
text = f"{self.identifier_start}{text}{self.identifier_end}" text = f"{self.identifier_start}{text}{self.identifier_end}"
return text return text
def national_sql(self, expression: exp.National) -> str:
return f"N{self.sql(expression, 'this')}"
def partition_sql(self, expression: exp.Partition) -> str: def partition_sql(self, expression: exp.Partition) -> str:
keys = csv( keys = csv(
*[ *[
@ -633,6 +637,9 @@ class Generator:
def introducer_sql(self, expression: exp.Introducer) -> str: def introducer_sql(self, expression: exp.Introducer) -> str:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
def pseudotype_sql(self, expression: exp.PseudoType) -> str:
return expression.name.upper()
def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str: def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str:
fields = expression.args.get("fields") fields = expression.args.get("fields")
fields = f" FIELDS TERMINATED BY {fields}" if fields else "" fields = f" FIELDS TERMINATED BY {fields}" if fields else ""
@ -793,19 +800,17 @@ class Generator:
if isinstance(expression.this, exp.Subquery): if isinstance(expression.this, exp.Subquery):
return f"LATERAL {this}" return f"LATERAL {this}"
alias = expression.args["alias"]
table = alias.name
columns = self.expressions(alias, key="columns", flat=True)
if expression.args.get("view"): if expression.args.get("view"):
table = f" {table}" if table else table alias = expression.args["alias"]
columns = self.expressions(alias, key="columns", flat=True)
table = f" {alias.name}" if alias.name else ""
columns = f" AS {columns}" if columns else "" columns = f" AS {columns}" if columns else ""
op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}") op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}")
return f"{op_sql}{self.sep()}{this}{table}{columns}" return f"{op_sql}{self.sep()}{this}{table}{columns}"
table = f" AS {table}" if table else table alias = self.sql(expression, "alias")
columns = f"({columns})" if columns else "" alias = f" AS {alias}" if alias else ""
return f"LATERAL {this}{table}{columns}" return f"LATERAL {this}{alias}"
def limit_sql(self, expression: exp.Limit) -> str: def limit_sql(self, expression: exp.Limit) -> str:
this = self.sql(expression, "this") this = self.sql(expression, "this")
@ -891,13 +896,15 @@ class Generator:
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
return csv( return csv(
*sqls, *sqls,
*[self.sql(sql) for sql in expression.args.get("joins", [])], *[self.sql(sql) for sql in expression.args.get("joins") or []],
*[self.sql(sql) for sql in expression.args.get("laterals", [])], *[self.sql(sql) for sql in expression.args.get("laterals") or []],
self.sql(expression, "where"), self.sql(expression, "where"),
self.sql(expression, "group"), self.sql(expression, "group"),
self.sql(expression, "having"), self.sql(expression, "having"),
self.sql(expression, "qualify"), self.sql(expression, "qualify"),
self.sql(expression, "window"), self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True)
if expression.args.get("windows")
else "",
self.sql(expression, "distribute"), self.sql(expression, "distribute"),
self.sql(expression, "sort"), self.sql(expression, "sort"),
self.sql(expression, "cluster"), self.sql(expression, "cluster"),
@ -1008,11 +1015,7 @@ class Generator:
spec_sql = " " + self.window_spec_sql(spec) if spec else "" spec_sql = " " + self.window_spec_sql(spec) if spec else ""
alias = self.sql(expression, "alias") alias = self.sql(expression, "alias")
this = f"{this} {'AS' if expression.arg_key == 'windows' else 'OVER'}"
if expression.arg_key == "window":
this = this = f"{self.seg('WINDOW')} {this} AS"
else:
this = f"{this} OVER"
if not partition and not order and not spec and alias: if not partition and not order and not spec and alias:
return f"{this} {alias}" return f"{this} {alias}"
@ -1141,9 +1144,11 @@ class Generator:
return f"(SELECT {self.sql(unnest)})" return f"(SELECT {self.sql(unnest)})"
def interval_sql(self, expression: exp.Interval) -> str: def interval_sql(self, expression: exp.Interval) -> str:
this = self.sql(expression, "this")
this = f" {this}" if this else ""
unit = self.sql(expression, "unit") unit = self.sql(expression, "unit")
unit = f" {unit}" if unit else "" unit = f" {unit}" if unit else ""
return f"INTERVAL {self.sql(expression, 'this')}{unit}" return f"INTERVAL{this}{unit}"
def reference_sql(self, expression: exp.Reference) -> str: def reference_sql(self, expression: exp.Reference) -> str:
this = self.sql(expression, "this") this = self.sql(expression, "this")
@ -1245,6 +1250,43 @@ class Generator:
savepoint = f" TO {savepoint}" if savepoint else "" savepoint = f" TO {savepoint}" if savepoint else ""
return f"ROLLBACK{savepoint}" return f"ROLLBACK{savepoint}"
def altercolumn_sql(self, expression: exp.AlterColumn) -> str:
this = self.sql(expression, "this")
dtype = self.sql(expression, "dtype")
if dtype:
collate = self.sql(expression, "collate")
collate = f" COLLATE {collate}" if collate else ""
using = self.sql(expression, "using")
using = f" USING {using}" if using else ""
return f"ALTER COLUMN {this} TYPE {dtype}{collate}{using}"
default = self.sql(expression, "default")
if default:
return f"ALTER COLUMN {this} SET DEFAULT {default}"
if not expression.args.get("drop"):
self.unsupported("Unsupported ALTER COLUMN syntax")
return f"ALTER COLUMN {this} DROP DEFAULT"
def altertable_sql(self, expression: exp.AlterTable) -> str:
actions = expression.args["actions"]
if isinstance(actions[0], exp.ColumnDef):
actions = self.expressions(expression, "actions", prefix="ADD COLUMN ")
elif isinstance(actions[0], exp.Schema):
actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ")
elif isinstance(actions[0], exp.Drop):
actions = self.expressions(expression, "actions")
elif isinstance(actions[0], exp.AlterColumn):
actions = self.sql(actions[0])
else:
self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}")
exists = " IF EXISTS" if expression.args.get("exists") else ""
return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}"
def distinct_sql(self, expression: exp.Distinct) -> str: def distinct_sql(self, expression: exp.Distinct) -> str:
this = self.expressions(expression, flat=True) this = self.expressions(expression, flat=True)
this = f" {this}" if this else "" this = f" {this}" if this else ""
@ -1327,6 +1369,9 @@ class Generator:
def or_sql(self, expression: exp.Or) -> str: def or_sql(self, expression: exp.Or) -> str:
return self.connector_sql(expression, "OR") return self.connector_sql(expression, "OR")
def slice_sql(self, expression: exp.Slice) -> str:
return self.binary(expression, ":")
def sub_sql(self, expression: exp.Sub) -> str: def sub_sql(self, expression: exp.Sub) -> str:
return self.binary(expression, "-") return self.binary(expression, "-")
@ -1369,6 +1414,7 @@ class Generator:
flat: bool = False, flat: bool = False,
indent: bool = True, indent: bool = True,
sep: str = ", ", sep: str = ", ",
prefix: str = "",
) -> str: ) -> str:
expressions = expression.args.get(key or "expressions") expressions = expression.args.get(key or "expressions")
@ -1391,11 +1437,13 @@ class Generator:
if self.pretty: if self.pretty:
if self._leading_comma: if self._leading_comma:
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comments}") result_sqls.append(f"{sep if i > 0 else pad}{prefix}{sql}{comments}")
else: else:
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}") result_sqls.append(
f"{prefix}{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}"
)
else: else:
result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}") result_sqls.append(f"{prefix}{sql}{comments}{sep if i + 1 < num_sqls else ''}")
result_sql = "\n".join(result_sqls) if self.pretty else "".join(result_sqls) result_sql = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
return self.indent(result_sql, skip_first=False) if indent else result_sql return self.indent(result_sql, skip_first=False) if indent else result_sql

View file

@ -18,6 +18,9 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
expression = coerce_type(expression) expression = coerce_type(expression)
expression = remove_redundant_casts(expression) expression = remove_redundant_casts(expression)
if isinstance(expression, exp.Identifier):
expression.set("quoted", True)
return expression return expression

View file

@ -129,10 +129,23 @@ def join_condition(join):
""" """
name = join.this.alias_or_name name = join.this.alias_or_name
on = (join.args.get("on") or exp.true()).copy() on = (join.args.get("on") or exp.true()).copy()
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
source_key = [] source_key = []
join_key = [] join_key = []
def extract_condition(condition):
left, right = condition.unnest_operands()
left_tables = exp.column_table_names(left)
right_tables = exp.column_table_names(right)
if name in left_tables and name not in right_tables:
join_key.append(left)
source_key.append(right)
condition.replace(exp.true())
elif name in right_tables and name not in left_tables:
join_key.append(right)
source_key.append(left)
condition.replace(exp.true())
# find the join keys # find the join keys
# SELECT # SELECT
# FROM x # FROM x
@ -141,20 +154,30 @@ def join_condition(join):
# #
# should pull y.b as the join key and x.a as the source key # should pull y.b as the join key and x.a as the source key
if normalized(on): if normalized(on):
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
for condition in on.flatten(): for condition in on.flatten():
if isinstance(condition, exp.EQ): if isinstance(condition, exp.EQ):
left, right = condition.unnest_operands() extract_condition(condition)
left_tables = exp.column_table_names(left) elif normalized(on, dnf=True):
right_tables = exp.column_table_names(right) conditions = None
if name in left_tables and name not in right_tables: for condition in on.flatten():
join_key.append(left) parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)]
source_key.append(right) if conditions is None:
condition.replace(exp.true()) conditions = parts
elif name in right_tables and name not in left_tables: else:
join_key.append(right) temp = []
source_key.append(left) for p in parts:
condition.replace(exp.true()) cs = [c for c in conditions if p == c]
if cs:
temp.append(p)
temp.extend(cs)
conditions = temp
for condition in conditions:
extract_condition(condition)
on = simplify(on) on = simplify(on)
remaining_condition = None if on == exp.true() else on remaining_condition = None if on == exp.true() else on

View file

@ -58,7 +58,9 @@ def eliminate_subqueries(expression):
existing_ctes = {} existing_ctes = {}
with_ = root.expression.args.get("with") with_ = root.expression.args.get("with")
recursive = False
if with_: if with_:
recursive = with_.args.get("recursive")
for cte in with_.expressions: for cte in with_.expressions:
existing_ctes[cte.this] = cte.alias existing_ctes[cte.this] = cte.alias
new_ctes = [] new_ctes = []
@ -88,7 +90,7 @@ def eliminate_subqueries(expression):
new_ctes.append(new_cte) new_ctes.append(new_cte)
if new_ctes: if new_ctes:
expression.set("with", exp.With(expressions=new_ctes)) expression.set("with", exp.With(expressions=new_ctes, recursive=recursive))
return expression return expression

View file

@ -69,8 +69,9 @@ def _predicate_lengths(expression, dnf):
left, right = expression.args.values() left, right = expression.args.values()
if isinstance(expression, exp.And if dnf else exp.Or): if isinstance(expression, exp.And if dnf else exp.Or):
x = [a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)] return [
return x a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
]
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)

View file

@ -14,7 +14,6 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.quote_identities import quote_identities
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
RULES = ( RULES = (
@ -34,7 +33,6 @@ RULES = (
eliminate_ctes, eliminate_ctes,
annotate_types, annotate_types,
canonicalize, canonicalize,
quote_identities,
) )

View file

@ -27,7 +27,14 @@ def pushdown_predicates(expression):
select = scope.expression select = scope.expression
where = select.args.get("where") where = select.args.get("where")
if where: if where:
pushdown(where.this, scope.selected_sources, scope_ref_count) selected_sources = scope.selected_sources
# a right join can only push down to itself and not the source FROM table
for k, (node, source) in selected_sources.items():
parent = node.find_ancestor(exp.Join, exp.From)
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
selected_sources = {k: (node, source)}
break
pushdown(where.this, selected_sources, scope_ref_count)
# joins should only pushdown into itself, not to other joins # joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself # so we limit the selected sources to only itself
@ -148,10 +155,13 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
# a node can reference a CTE which should be pushed down # a node can reference a CTE which should be pushed down
if isinstance(node, exp.From) and not isinstance(source, exp.Table): if isinstance(node, exp.From) and not isinstance(source, exp.Table):
with_ = source.parent.expression.args.get("with")
if with_ and with_.recursive:
return {}
node = source.expression node = source.expression
if isinstance(node, exp.Join): if isinstance(node, exp.Join):
if node.side: if node.side and node.side != "RIGHT":
return {} return {}
nodes[table] = node nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1: elif isinstance(node, exp.Select) and len(tables) == 1:

View file

@ -6,7 +6,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
# Sentinel value that means an outer query selecting ALL columns # Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object() SELECT_ALL = object()
# SELECTION TO USE IF SELECTION LIST IS EMPTY # Selection to use if selection list is empty
DEFAULT_SELECTION = alias("1", "_") DEFAULT_SELECTION = alias("1", "_")
@ -91,7 +91,7 @@ def _remove_unused_selections(scope, parent_selections):
# If there are no remaining selections, just select a single constant # If there are no remaining selections, just select a single constant
if not new_selections: if not new_selections:
new_selections.append(DEFAULT_SELECTION) new_selections.append(DEFAULT_SELECTION.copy())
scope.expression.set("expressions", new_selections) scope.expression.set("expressions", new_selections)
return removed_indexes return removed_indexes
@ -102,5 +102,5 @@ def _remove_indexed_selections(scope, indexes_to_remove):
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
] ]
if not new_selections: if not new_selections:
new_selections.append(DEFAULT_SELECTION) new_selections.append(DEFAULT_SELECTION.copy())
scope.expression.set("expressions", new_selections) scope.expression.set("expressions", new_selections)

View file

@ -311,6 +311,9 @@ def _qualify_outputs(scope):
alias_ = alias(exp.column(""), alias=selection.name) alias_ = alias(exp.column(""), alias=selection.name)
alias_.set("this", selection) alias_.set("this", selection)
selection = alias_ selection = alias_
elif isinstance(selection, exp.Subquery):
if not selection.alias:
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
elif not isinstance(selection, exp.Alias): elif not isinstance(selection, exp.Alias):
alias_ = alias(exp.column(""), f"_col_{i}") alias_ = alias(exp.column(""), f"_col_{i}")
alias_.set("this", selection) alias_.set("this", selection)

View file

@ -1,25 +0,0 @@
from sqlglot import exp
def quote_identities(expression):
"""
Rewrite sqlglot AST to ensure all identities are quoted.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT x.a AS a FROM db.x")
>>> quote_identities(expression).sql()
'SELECT "x"."a" AS "a" FROM "db"."x"'
Args:
expression (sqlglot.Expression): expression to quote
Returns:
sqlglot.Expression: quoted expression
"""
def qualify(node):
if isinstance(node, exp.Identifier):
node.set("quoted", True)
return node
return expression.transform(qualify, copy=False)

View file

@ -511,9 +511,20 @@ def _traverse_union(scope):
def _traverse_derived_tables(derived_tables, scope, scope_type): def _traverse_derived_tables(derived_tables, scope, scope_type):
sources = {} sources = {}
is_cte = scope_type == ScopeType.CTE
for derived_table in derived_tables: for derived_table in derived_tables:
top = None recursive_scope = None
# if the scope is a recursive cte, it must be in the form of
# base_case UNION recursive. thus the recursive scope is the first
# section of the union.
if is_cte and scope.expression.args["with"].recursive:
union = derived_table.this
if isinstance(union, exp.Union):
recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
for child_scope in _traverse_scope( for child_scope in _traverse_scope(
scope.branch( scope.branch(
derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this, derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
@ -523,16 +534,23 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
) )
): ):
yield child_scope yield child_scope
top = child_scope
# Tables without aliases will be set as "" # Tables without aliases will be set as ""
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather, # Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins. # the latest one wins.
sources[derived_table.alias] = child_scope alias = derived_table.alias
if scope_type == ScopeType.CTE: sources[alias] = child_scope
scope.cte_scopes.append(top)
if recursive_scope:
child_scope.add_source(alias, recursive_scope)
# append the final child_scope yielded
if is_cte:
scope.cte_scopes.append(child_scope)
else: else:
scope.derived_table_scopes.append(top) scope.derived_table_scopes.append(child_scope)
scope.sources.update(sources) scope.sources.update(sources)

View file

@ -16,7 +16,7 @@ def unnest_subqueries(expression):
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql() >>> unnest_subqueries(expression).sql()
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\ 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)' AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)'
Args: Args:
expression (sqlglot.Expression): expression to unnest expression (sqlglot.Expression): expression to unnest
@ -97,8 +97,8 @@ def decorrelate(select, parent_select, external_columns, sequence):
table_alias = _alias(sequence) table_alias = _alias(sequence)
keys = [] keys = []
# for all external columns in the where statement, # for all external columns in the where statement, find the relevant predicate
# split out the relevant data to convert it into a join # keys to convert it into a join
for column in external_columns: for column in external_columns:
if column.find_ancestor(exp.Where) is not where: if column.find_ancestor(exp.Where) is not where:
return return
@ -122,6 +122,10 @@ def decorrelate(select, parent_select, external_columns, sequence):
if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
return return
is_subquery_projection = any(
node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
)
value = select.selects[0] value = select.selects[0]
key_aliases = {} key_aliases = {}
group_by = [] group_by = []
@ -142,9 +146,14 @@ def decorrelate(select, parent_select, external_columns, sequence):
parent_predicate = select.find_ancestor(exp.Predicate) parent_predicate = select.find_ancestor(exp.Predicate)
# if the value of the subquery is not an agg or a key, we need to collect it into an array # if the value of the subquery is not an agg or a key, we need to collect it into an array
# so that it can be grouped # so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
if not value.find(exp.AggFunc) and value.this not in group_by: if not value.find(exp.AggFunc) and value.this not in group_by:
select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False) select.select(
exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
append=False,
copy=False,
)
# exists queries should not have any selects as it only checks if there are any rows # exists queries should not have any selects as it only checks if there are any rows
# all selects will be added by the optimizer and only used for join keys # all selects will be added by the optimizer and only used for join keys
@ -158,7 +167,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
if isinstance(parent_predicate, exp.Exists) or key != value.this: if isinstance(parent_predicate, exp.Exists) or key != value.this:
select.select(f"{key} AS {alias}", copy=False) select.select(f"{key} AS {alias}", copy=False)
else: else:
select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False) select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
alias = exp.column(value.alias, table_alias) alias = exp.column(value.alias, table_alias)
other = _other_operand(parent_predicate) other = _other_operand(parent_predicate)
@ -186,12 +195,18 @@ def decorrelate(select, parent_select, external_columns, sequence):
f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
) )
else: else:
if is_subquery_projection:
alias = exp.alias_(alias, select.parent.alias)
select.parent.replace(alias) select.parent.replace(alias)
for key, column, predicate in keys: for key, column, predicate in keys:
predicate.replace(exp.true()) predicate.replace(exp.true())
nested = exp.column(key_aliases[key], table_alias) nested = exp.column(key_aliases[key], table_alias)
if is_subquery_projection:
key.replace(nested)
continue
if key in group_by: if key in group_by:
key.replace(nested) key.replace(nested)
parent_predicate = _replace( parent_predicate = _replace(

View file

@ -5,7 +5,7 @@ import typing as t
from sqlglot import exp from sqlglot import exp
from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
from sqlglot.helper import apply_index_offset, ensure_collection, seq_get from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get
from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import in_trie, new_trie from sqlglot.trie import in_trie, new_trie
@ -117,6 +117,7 @@ class Parser(metaclass=_Parser):
TokenType.GEOMETRY, TokenType.GEOMETRY,
TokenType.HLLSKETCH, TokenType.HLLSKETCH,
TokenType.HSTORE, TokenType.HSTORE,
TokenType.PSEUDO_TYPE,
TokenType.SUPER, TokenType.SUPER,
TokenType.SERIAL, TokenType.SERIAL,
TokenType.SMALLSERIAL, TokenType.SMALLSERIAL,
@ -153,6 +154,7 @@ class Parser(metaclass=_Parser):
TokenType.CACHE, TokenType.CACHE,
TokenType.CASCADE, TokenType.CASCADE,
TokenType.COLLATE, TokenType.COLLATE,
TokenType.COLUMN,
TokenType.COMMAND, TokenType.COMMAND,
TokenType.COMMIT, TokenType.COMMIT,
TokenType.COMPOUND, TokenType.COMPOUND,
@ -169,6 +171,7 @@ class Parser(metaclass=_Parser):
TokenType.ESCAPE, TokenType.ESCAPE,
TokenType.FALSE, TokenType.FALSE,
TokenType.FIRST, TokenType.FIRST,
TokenType.FILTER,
TokenType.FOLLOWING, TokenType.FOLLOWING,
TokenType.FORMAT, TokenType.FORMAT,
TokenType.FUNCTION, TokenType.FUNCTION,
@ -188,6 +191,7 @@ class Parser(metaclass=_Parser):
TokenType.MERGE, TokenType.MERGE,
TokenType.NATURAL, TokenType.NATURAL,
TokenType.NEXT, TokenType.NEXT,
TokenType.OFFSET,
TokenType.ONLY, TokenType.ONLY,
TokenType.OPTIONS, TokenType.OPTIONS,
TokenType.ORDINALITY, TokenType.ORDINALITY,
@ -222,12 +226,18 @@ class Parser(metaclass=_Parser):
TokenType.PROPERTIES, TokenType.PROPERTIES,
TokenType.PROCEDURE, TokenType.PROCEDURE,
TokenType.VOLATILE, TokenType.VOLATILE,
TokenType.WINDOW,
*SUBQUERY_PREDICATES, *SUBQUERY_PREDICATES,
*TYPE_TOKENS, *TYPE_TOKENS,
*NO_PAREN_FUNCTIONS, *NO_PAREN_FUNCTIONS,
} }
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY} TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
TokenType.APPLY,
TokenType.NATURAL,
TokenType.OFFSET,
TokenType.WINDOW,
}
UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET} UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET}
@ -257,6 +267,7 @@ class Parser(metaclass=_Parser):
TokenType.TABLE, TokenType.TABLE,
TokenType.TIMESTAMP, TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ, TokenType.TIMESTAMPTZ,
TokenType.WINDOW,
*TYPE_TOKENS, *TYPE_TOKENS,
*SUBQUERY_PREDICATES, *SUBQUERY_PREDICATES,
} }
@ -351,22 +362,27 @@ class Parser(metaclass=_Parser):
TokenType.ARROW: lambda self, this, path: self.expression( TokenType.ARROW: lambda self, this, path: self.expression(
exp.JSONExtract, exp.JSONExtract,
this=this, this=this,
path=path, expression=path,
), ),
TokenType.DARROW: lambda self, this, path: self.expression( TokenType.DARROW: lambda self, this, path: self.expression(
exp.JSONExtractScalar, exp.JSONExtractScalar,
this=this, this=this,
path=path, expression=path,
), ),
TokenType.HASH_ARROW: lambda self, this, path: self.expression( TokenType.HASH_ARROW: lambda self, this, path: self.expression(
exp.JSONBExtract, exp.JSONBExtract,
this=this, this=this,
path=path, expression=path,
), ),
TokenType.DHASH_ARROW: lambda self, this, path: self.expression( TokenType.DHASH_ARROW: lambda self, this, path: self.expression(
exp.JSONBExtractScalar, exp.JSONBExtractScalar,
this=this, this=this,
path=path, expression=path,
),
TokenType.PLACEHOLDER: lambda self, this, key: self.expression(
exp.JSONBContains,
this=this,
expression=key,
), ),
} }
@ -392,25 +408,27 @@ class Parser(metaclass=_Parser):
exp.Ordered: lambda self: self._parse_ordered(), exp.Ordered: lambda self: self._parse_ordered(),
exp.Having: lambda self: self._parse_having(), exp.Having: lambda self: self._parse_having(),
exp.With: lambda self: self._parse_with(), exp.With: lambda self: self._parse_with(),
exp.Window: lambda self: self._parse_named_window(),
"JOIN_TYPE": lambda self: self._parse_join_side_and_kind(), "JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
} }
STATEMENT_PARSERS = { STATEMENT_PARSERS = {
TokenType.ALTER: lambda self: self._parse_alter(),
TokenType.BEGIN: lambda self: self._parse_transaction(),
TokenType.CACHE: lambda self: self._parse_cache(),
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.CREATE: lambda self: self._parse_create(), TokenType.CREATE: lambda self: self._parse_create(),
TokenType.DELETE: lambda self: self._parse_delete(),
TokenType.DESCRIBE: lambda self: self._parse_describe(), TokenType.DESCRIBE: lambda self: self._parse_describe(),
TokenType.DROP: lambda self: self._parse_drop(), TokenType.DROP: lambda self: self._parse_drop(),
TokenType.END: lambda self: self._parse_commit_or_rollback(),
TokenType.INSERT: lambda self: self._parse_insert(), TokenType.INSERT: lambda self: self._parse_insert(),
TokenType.LOAD_DATA: lambda self: self._parse_load_data(), TokenType.LOAD_DATA: lambda self: self._parse_load_data(),
TokenType.UPDATE: lambda self: self._parse_update(),
TokenType.DELETE: lambda self: self._parse_delete(),
TokenType.CACHE: lambda self: self._parse_cache(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
TokenType.BEGIN: lambda self: self._parse_transaction(),
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.END: lambda self: self._parse_commit_or_rollback(),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
TokenType.MERGE: lambda self: self._parse_merge(), TokenType.MERGE: lambda self: self._parse_merge(),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
TokenType.UPDATE: lambda self: self._parse_update(),
TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
} }
UNARY_PARSERS = { UNARY_PARSERS = {
@ -441,6 +459,7 @@ class Parser(metaclass=_Parser):
TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text), TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text), TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token), TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
TokenType.NATIONAL: lambda self, token: self._parse_national(token),
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(), TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
} }
@ -454,6 +473,9 @@ class Parser(metaclass=_Parser):
TokenType.ILIKE: lambda self, this: self._parse_escape( TokenType.ILIKE: lambda self, this: self._parse_escape(
self.expression(exp.ILike, this=this, expression=self._parse_bitwise()) self.expression(exp.ILike, this=this, expression=self._parse_bitwise())
), ),
TokenType.IRLIKE: lambda self, this: self.expression(
exp.RegexpILike, this=this, expression=self._parse_bitwise()
),
TokenType.RLIKE: lambda self, this: self.expression( TokenType.RLIKE: lambda self, this: self.expression(
exp.RegexpLike, this=this, expression=self._parse_bitwise() exp.RegexpLike, this=this, expression=self._parse_bitwise()
), ),
@ -535,8 +557,7 @@ class Parser(metaclass=_Parser):
"group": lambda self: self._parse_group(), "group": lambda self: self._parse_group(),
"having": lambda self: self._parse_having(), "having": lambda self: self._parse_having(),
"qualify": lambda self: self._parse_qualify(), "qualify": lambda self: self._parse_qualify(),
"window": lambda self: self._match(TokenType.WINDOW) "windows": lambda self: self._parse_window_clause(),
and self._parse_window(self._parse_id_var(), alias=True),
"distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute), "distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute),
"sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort), "sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
"cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster), "cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
@ -551,18 +572,18 @@ class Parser(metaclass=_Parser):
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
CREATABLES = { CREATABLES = {
TokenType.TABLE, TokenType.COLUMN,
TokenType.VIEW,
TokenType.FUNCTION, TokenType.FUNCTION,
TokenType.INDEX, TokenType.INDEX,
TokenType.PROCEDURE, TokenType.PROCEDURE,
TokenType.SCHEMA, TokenType.SCHEMA,
TokenType.TABLE,
TokenType.VIEW,
} }
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
STRICT_CAST = True STRICT_CAST = True
LATERAL_FUNCTION_AS_VIEW = False
__slots__ = ( __slots__ = (
"error_level", "error_level",
@ -782,13 +803,16 @@ class Parser(metaclass=_Parser):
self._parse_query_modifiers(expression) self._parse_query_modifiers(expression)
return expression return expression
def _parse_drop(self): def _parse_drop(self, default_kind=None):
temporary = self._match(TokenType.TEMPORARY) temporary = self._match(TokenType.TEMPORARY)
materialized = self._match(TokenType.MATERIALIZED) materialized = self._match(TokenType.MATERIALIZED)
kind = self._match_set(self.CREATABLES) and self._prev.text kind = self._match_set(self.CREATABLES) and self._prev.text
if not kind: if not kind:
self.raise_error(f"Expected {self.CREATABLES}") if default_kind:
return kind = default_kind
else:
self.raise_error(f"Expected {self.CREATABLES}")
return
return self.expression( return self.expression(
exp.Drop, exp.Drop,
@ -876,7 +900,7 @@ class Parser(metaclass=_Parser):
) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False) ) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False)
if assignment: if assignment:
key = self._parse_var() or self._parse_string() key = self._parse_var_or_string()
self._match(TokenType.EQ) self._match(TokenType.EQ)
return self.expression(exp.Property, this=key, value=self._parse_column()) return self.expression(exp.Property, this=key, value=self._parse_column())
@ -1152,18 +1176,32 @@ class Parser(metaclass=_Parser):
elif (table or nested) and self._match(TokenType.L_PAREN): elif (table or nested) and self._match(TokenType.L_PAREN):
this = self._parse_table() if table else self._parse_select(nested=True) this = self._parse_table() if table else self._parse_select(nested=True)
self._parse_query_modifiers(this) self._parse_query_modifiers(this)
this = self._parse_set_operations(this)
self._match_r_paren() self._match_r_paren()
this = self._parse_subquery(this) # early return so that subquery unions aren't parsed again
# SELECT * FROM (SELECT 1) UNION ALL SELECT 1
# Union ALL should be a property of the top select node, not the subquery
return self._parse_subquery(this)
elif self._match(TokenType.VALUES): elif self._match(TokenType.VALUES):
if self._curr.token_type == TokenType.L_PAREN:
# We don't consume the left paren because it's consumed in _parse_value
expressions = self._parse_csv(self._parse_value)
else:
# In presto we can have VALUES 1, 2 which results in 1 column & 2 rows.
# Source: https://prestodb.io/docs/current/sql/values.html
expressions = self._parse_csv(
lambda: self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
)
this = self.expression( this = self.expression(
exp.Values, exp.Values,
expressions=self._parse_csv(self._parse_value), expressions=expressions,
alias=self._parse_table_alias(), alias=self._parse_table_alias(),
) )
else: else:
this = None this = None
return self._parse_set_operations(this) if this else None return self._parse_set_operations(this)
def _parse_with(self, skip_with_token=False): def _parse_with(self, skip_with_token=False):
if not skip_with_token and not self._match(TokenType.WITH): if not skip_with_token and not self._match(TokenType.WITH):
@ -1201,11 +1239,12 @@ class Parser(metaclass=_Parser):
alias = self._parse_id_var( alias = self._parse_id_var(
any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS
) )
columns = None
if self._match(TokenType.L_PAREN): if self._match(TokenType.L_PAREN):
columns = self._parse_csv(lambda: self._parse_id_var(any_token)) columns = self._parse_csv(lambda: self._parse_column_def(self._parse_id_var()))
self._match_r_paren() self._match_r_paren()
else:
columns = None
if not alias and not columns: if not alias and not columns:
return None return None
@ -1295,26 +1334,19 @@ class Parser(metaclass=_Parser):
expression=self._parse_function() or self._parse_id_var(any_token=False), expression=self._parse_function() or self._parse_id_var(any_token=False),
) )
columns = None if view:
table_alias = None table = self._parse_id_var(any_token=False)
if view or self.LATERAL_FUNCTION_AS_VIEW: columns = self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else []
table_alias = self._parse_id_var(any_token=False) table_alias = self.expression(exp.TableAlias, this=table, columns=columns)
if self._match(TokenType.ALIAS):
columns = self._parse_csv(self._parse_id_var)
else: else:
self._match(TokenType.ALIAS) table_alias = self._parse_table_alias()
table_alias = self._parse_id_var(any_token=False)
if self._match(TokenType.L_PAREN):
columns = self._parse_csv(self._parse_id_var)
self._match_r_paren()
expression = self.expression( expression = self.expression(
exp.Lateral, exp.Lateral,
this=this, this=this,
view=view, view=view,
outer=outer, outer=outer,
alias=self.expression(exp.TableAlias, this=table_alias, columns=columns), alias=table_alias,
) )
if outer_apply or cross_apply: if outer_apply or cross_apply:
@ -1693,6 +1725,9 @@ class Parser(metaclass=_Parser):
if negate: if negate:
this = self.expression(exp.Not, this=this) this = self.expression(exp.Not, this=this)
if self._match(TokenType.IS):
this = self._parse_is(this)
return this return this
def _parse_is(self, this): def _parse_is(self, this):
@ -1796,6 +1831,10 @@ class Parser(metaclass=_Parser):
return None return None
type_token = self._prev.token_type type_token = self._prev.token_type
if type_token == TokenType.PSEUDO_TYPE:
return self.expression(exp.PseudoType, this=self._prev.text)
nested = type_token in self.NESTED_TYPE_TOKENS nested = type_token in self.NESTED_TYPE_TOKENS
is_struct = type_token == TokenType.STRUCT is_struct = type_token == TokenType.STRUCT
expressions = None expressions = None
@ -1851,6 +1890,8 @@ class Parser(metaclass=_Parser):
if value is None: if value is None:
value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions) value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
elif type_token == TokenType.INTERVAL:
value = self.expression(exp.Interval, unit=self._parse_var())
if maybe_func and check_func: if maybe_func and check_func:
index2 = self._index index2 = self._index
@ -1924,7 +1965,16 @@ class Parser(metaclass=_Parser):
def _parse_primary(self): def _parse_primary(self):
if self._match_set(self.PRIMARY_PARSERS): if self._match_set(self.PRIMARY_PARSERS):
return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev) token_type = self._prev.token_type
primary = self.PRIMARY_PARSERS[token_type](self, self._prev)
if token_type == TokenType.STRING:
expressions = [primary]
while self._match(TokenType.STRING):
expressions.append(exp.Literal.string(self._prev.text))
if len(expressions) > 1:
return self.expression(exp.Concat, expressions=expressions)
return primary
if self._match_pair(TokenType.DOT, TokenType.NUMBER): if self._match_pair(TokenType.DOT, TokenType.NUMBER):
return exp.Literal.number(f"0.{self._prev.text}") return exp.Literal.number(f"0.{self._prev.text}")
@ -2027,6 +2077,9 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Identifier, this=token.text) return self.expression(exp.Identifier, this=token.text)
def _parse_national(self, token):
return self.expression(exp.National, this=exp.Literal.string(token.text))
def _parse_session_parameter(self): def _parse_session_parameter(self):
kind = None kind = None
this = self._parse_id_var() or self._parse_primary() this = self._parse_id_var() or self._parse_primary()
@ -2051,7 +2104,9 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.L_PAREN): if self._match(TokenType.L_PAREN):
expressions = self._parse_csv(self._parse_id_var) expressions = self._parse_csv(self._parse_id_var)
self._match(TokenType.R_PAREN)
if not self._match(TokenType.R_PAREN):
self._retreat(index)
else: else:
expressions = [self._parse_id_var()] expressions = [self._parse_id_var()]
@ -2065,14 +2120,14 @@ class Parser(metaclass=_Parser):
exp.Distinct, expressions=self._parse_csv(self._parse_conjunction) exp.Distinct, expressions=self._parse_csv(self._parse_conjunction)
) )
else: else:
this = self._parse_conjunction() this = self._parse_select_or_expression()
if self._match(TokenType.IGNORE_NULLS): if self._match(TokenType.IGNORE_NULLS):
this = self.expression(exp.IgnoreNulls, this=this) this = self.expression(exp.IgnoreNulls, this=this)
else: else:
self._match(TokenType.RESPECT_NULLS) self._match(TokenType.RESPECT_NULLS)
return self._parse_alias(self._parse_limit(self._parse_order(this))) return self._parse_limit(self._parse_order(this))
def _parse_schema(self, this=None): def _parse_schema(self, this=None):
index = self._index index = self._index
@ -2081,7 +2136,8 @@ class Parser(metaclass=_Parser):
return this return this
args = self._parse_csv( args = self._parse_csv(
lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True)) lambda: self._parse_constraint()
or self._parse_column_def(self._parse_field(any_token=True))
) )
self._match_r_paren() self._match_r_paren()
return self.expression(exp.Schema, this=this, expressions=args) return self.expression(exp.Schema, this=this, expressions=args)
@ -2120,7 +2176,7 @@ class Parser(metaclass=_Parser):
elif self._match(TokenType.ENCODE): elif self._match(TokenType.ENCODE):
kind = self.expression(exp.EncodeColumnConstraint, this=self._parse_var()) kind = self.expression(exp.EncodeColumnConstraint, this=self._parse_var())
elif self._match(TokenType.DEFAULT): elif self._match(TokenType.DEFAULT):
kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction()) kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_bitwise())
elif self._match_pair(TokenType.NOT, TokenType.NULL): elif self._match_pair(TokenType.NOT, TokenType.NULL):
kind = exp.NotNullColumnConstraint() kind = exp.NotNullColumnConstraint()
elif self._match(TokenType.NULL): elif self._match(TokenType.NULL):
@ -2211,7 +2267,10 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.L_BRACKET): if not self._match(TokenType.L_BRACKET):
return this return this
expressions = self._parse_csv(self._parse_conjunction) if self._match(TokenType.COLON):
expressions = [self.expression(exp.Slice, expression=self._parse_conjunction())]
else:
expressions = self._parse_csv(lambda: self._parse_slice(self._parse_conjunction()))
if not this or this.name.upper() == "ARRAY": if not this or this.name.upper() == "ARRAY":
this = self.expression(exp.Array, expressions=expressions) this = self.expression(exp.Array, expressions=expressions)
@ -2225,6 +2284,11 @@ class Parser(metaclass=_Parser):
this.comments = self._prev_comments this.comments = self._prev_comments
return self._parse_bracket(this) return self._parse_bracket(this)
def _parse_slice(self, this):
if self._match(TokenType.COLON):
return self.expression(exp.Slice, this=this, expression=self._parse_conjunction())
return this
def _parse_case(self): def _parse_case(self):
ifs = [] ifs = []
default = None default = None
@ -2386,6 +2450,12 @@ class Parser(metaclass=_Parser):
collation=collation, collation=collation,
) )
def _parse_window_clause(self):
return self._match(TokenType.WINDOW) and self._parse_csv(self._parse_named_window)
def _parse_named_window(self):
return self._parse_window(self._parse_id_var(), alias=True)
def _parse_window(self, this, alias=False): def _parse_window(self, this, alias=False):
if self._match(TokenType.FILTER): if self._match(TokenType.FILTER):
where = self._parse_wrapped(self._parse_where) where = self._parse_wrapped(self._parse_where)
@ -2501,11 +2571,9 @@ class Parser(metaclass=_Parser):
if identifier: if identifier:
return identifier return identifier
if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS: if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS):
self._advance() return exp.Identifier(this=self._prev.text, quoted=False)
elif not self._match_set(tokens or self.ID_VAR_TOKENS): return None
return None
return exp.Identifier(this=self._prev.text, quoted=False)
def _parse_string(self): def _parse_string(self):
if self._match(TokenType.STRING): if self._match(TokenType.STRING):
@ -2522,11 +2590,17 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Identifier, this=self._prev.text, quoted=True) return self.expression(exp.Identifier, this=self._prev.text, quoted=True)
return self._parse_placeholder() return self._parse_placeholder()
def _parse_var(self): def _parse_var(self, any_token=False):
if self._match(TokenType.VAR): if (any_token and self._advance_any()) or self._match(TokenType.VAR):
return self.expression(exp.Var, this=self._prev.text) return self.expression(exp.Var, this=self._prev.text)
return self._parse_placeholder() return self._parse_placeholder()
def _advance_any(self):
if self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
self._advance()
return self._prev
return None
def _parse_var_or_string(self): def _parse_var_or_string(self):
return self._parse_var() or self._parse_string() return self._parse_var() or self._parse_string()
@ -2551,8 +2625,9 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.PLACEHOLDER): if self._match(TokenType.PLACEHOLDER):
return self.expression(exp.Placeholder) return self.expression(exp.Placeholder)
elif self._match(TokenType.COLON): elif self._match(TokenType.COLON):
self._advance() if self._match_set((TokenType.NUMBER, TokenType.VAR)):
return self.expression(exp.Placeholder, this=self._prev.text) return self.expression(exp.Placeholder, this=self._prev.text)
self._advance(-1)
return None return None
def _parse_except(self): def _parse_except(self):
@ -2647,6 +2722,54 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Rollback, savepoint=savepoint) return self.expression(exp.Rollback, savepoint=savepoint)
return self.expression(exp.Commit, chain=chain) return self.expression(exp.Commit, chain=chain)
def _parse_add_column(self):
if not self._match_text_seq("ADD"):
return None
self._match(TokenType.COLUMN)
exists_column = self._parse_exists(not_=True)
expression = self._parse_column_def(self._parse_field(any_token=True))
expression.set("exists", exists_column)
return expression
def _parse_drop_column(self):
return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN")
def _parse_alter(self):
if not self._match(TokenType.TABLE):
return None
exists = self._parse_exists()
this = self._parse_table(schema=True)
actions = None
if self._match_text_seq("ADD", advance=False):
actions = self._parse_csv(self._parse_add_column)
elif self._match_text_seq("DROP", advance=False):
actions = self._parse_csv(self._parse_drop_column)
elif self._match_text_seq("ALTER"):
self._match(TokenType.COLUMN)
column = self._parse_field(any_token=True)
if self._match_pair(TokenType.DROP, TokenType.DEFAULT):
actions = self.expression(exp.AlterColumn, this=column, drop=True)
elif self._match_pair(TokenType.SET, TokenType.DEFAULT):
actions = self.expression(
exp.AlterColumn, this=column, default=self._parse_conjunction()
)
else:
self._match_text_seq("SET", "DATA")
actions = self.expression(
exp.AlterColumn,
this=column,
dtype=self._match_text_seq("TYPE") and self._parse_types(),
collate=self._match(TokenType.COLLATE) and self._parse_term(),
using=self._match(TokenType.USING) and self._parse_conjunction(),
)
actions = ensure_list(actions)
return self.expression(exp.AlterTable, this=this, exists=exists, actions=actions)
def _parse_show(self): def _parse_show(self):
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
if parser: if parser:
@ -2782,7 +2905,7 @@ class Parser(metaclass=_Parser):
return True return True
return False return False
def _match_text_seq(self, *texts): def _match_text_seq(self, *texts, advance=True):
index = self._index index = self._index
for text in texts: for text in texts:
if self._curr and self._curr.text.upper() == text: if self._curr and self._curr.text.upper() == text:
@ -2790,6 +2913,10 @@ class Parser(metaclass=_Parser):
else: else:
self._retreat(index) self._retreat(index)
return False return False
if not advance:
self._retreat(index)
return True return True
def _replace_columns_with_dots(self, this): def _replace_columns_with_dots(self, this):

View file

@ -160,9 +160,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
super().__init__(schema) super().__init__(schema)
self.visible = visible or {} self.visible = visible or {}
self.dialect = dialect self.dialect = dialect
self._type_mapping_cache: t.Dict[str, exp.DataType] = { self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
"STR": exp.DataType.build("text"),
}
@classmethod @classmethod
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:

View file

@ -48,6 +48,7 @@ class TokenType(AutoName):
DOLLAR = auto() DOLLAR = auto()
PARAMETER = auto() PARAMETER = auto()
SESSION_PARAMETER = auto() SESSION_PARAMETER = auto()
NATIONAL = auto()
BLOCK_START = auto() BLOCK_START = auto()
BLOCK_END = auto() BLOCK_END = auto()
@ -111,6 +112,7 @@ class TokenType(AutoName):
# keywords # keywords
ALIAS = auto() ALIAS = auto()
ALTER = auto()
ALWAYS = auto() ALWAYS = auto()
ALL = auto() ALL = auto()
ANTI = auto() ANTI = auto()
@ -196,6 +198,7 @@ class TokenType(AutoName):
INTERVAL = auto() INTERVAL = auto()
INTO = auto() INTO = auto()
INTRODUCER = auto() INTRODUCER = auto()
IRLIKE = auto()
IS = auto() IS = auto()
ISNULL = auto() ISNULL = auto()
JOIN = auto() JOIN = auto()
@ -241,6 +244,7 @@ class TokenType(AutoName):
PRIMARY_KEY = auto() PRIMARY_KEY = auto()
PROCEDURE = auto() PROCEDURE = auto()
PROPERTIES = auto() PROPERTIES = auto()
PSEUDO_TYPE = auto()
QUALIFY = auto() QUALIFY = auto()
QUOTE = auto() QUOTE = auto()
RANGE = auto() RANGE = auto()
@ -346,7 +350,11 @@ class _Tokenizer(type):
def __new__(cls, clsname, bases, attrs): # type: ignore def __new__(cls, clsname, bases, attrs): # type: ignore
klass = super().__new__(cls, clsname, bases, attrs) klass = super().__new__(cls, clsname, bases, attrs)
klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES) klass._QUOTES = {
f"{prefix}{s}": e
for s, e in cls._delimeter_list_to_dict(klass.QUOTES).items()
for prefix in (("",) if s[0].isalpha() else ("", "n", "N"))
}
klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS) klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS)
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS) klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS) klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
@ -470,6 +478,7 @@ class Tokenizer(metaclass=_Tokenizer):
"CHECK": TokenType.CHECK, "CHECK": TokenType.CHECK,
"CLUSTER BY": TokenType.CLUSTER_BY, "CLUSTER BY": TokenType.CLUSTER_BY,
"COLLATE": TokenType.COLLATE, "COLLATE": TokenType.COLLATE,
"COLUMN": TokenType.COLUMN,
"COMMENT": TokenType.SCHEMA_COMMENT, "COMMENT": TokenType.SCHEMA_COMMENT,
"COMMIT": TokenType.COMMIT, "COMMIT": TokenType.COMMIT,
"COMPOUND": TokenType.COMPOUND, "COMPOUND": TokenType.COMPOUND,
@ -587,6 +596,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SEMI": TokenType.SEMI, "SEMI": TokenType.SEMI,
"SET": TokenType.SET, "SET": TokenType.SET,
"SHOW": TokenType.SHOW, "SHOW": TokenType.SHOW,
"SIMILAR TO": TokenType.SIMILAR_TO,
"SOME": TokenType.SOME, "SOME": TokenType.SOME,
"SORTKEY": TokenType.SORTKEY, "SORTKEY": TokenType.SORTKEY,
"SORT BY": TokenType.SORT_BY, "SORT BY": TokenType.SORT_BY,
@ -614,6 +624,7 @@ class Tokenizer(metaclass=_Tokenizer):
"VOLATILE": TokenType.VOLATILE, "VOLATILE": TokenType.VOLATILE,
"WHEN": TokenType.WHEN, "WHEN": TokenType.WHEN,
"WHERE": TokenType.WHERE, "WHERE": TokenType.WHERE,
"WINDOW": TokenType.WINDOW,
"WITH": TokenType.WITH, "WITH": TokenType.WITH,
"WITH TIME ZONE": TokenType.WITH_TIME_ZONE, "WITH TIME ZONE": TokenType.WITH_TIME_ZONE,
"WITH LOCAL TIME ZONE": TokenType.WITH_LOCAL_TIME_ZONE, "WITH LOCAL TIME ZONE": TokenType.WITH_LOCAL_TIME_ZONE,
@ -652,6 +663,7 @@ class Tokenizer(metaclass=_Tokenizer):
"VARCHAR2": TokenType.VARCHAR, "VARCHAR2": TokenType.VARCHAR,
"NVARCHAR": TokenType.NVARCHAR, "NVARCHAR": TokenType.NVARCHAR,
"NVARCHAR2": TokenType.NVARCHAR, "NVARCHAR2": TokenType.NVARCHAR,
"STR": TokenType.TEXT,
"STRING": TokenType.TEXT, "STRING": TokenType.TEXT,
"TEXT": TokenType.TEXT, "TEXT": TokenType.TEXT,
"CLOB": TokenType.TEXT, "CLOB": TokenType.TEXT,
@ -667,7 +679,16 @@ class Tokenizer(metaclass=_Tokenizer):
"UNIQUE": TokenType.UNIQUE, "UNIQUE": TokenType.UNIQUE,
"STRUCT": TokenType.STRUCT, "STRUCT": TokenType.STRUCT,
"VARIANT": TokenType.VARIANT, "VARIANT": TokenType.VARIANT,
"ALTER": TokenType.COMMAND, "ALTER": TokenType.ALTER,
"ALTER AGGREGATE": TokenType.COMMAND,
"ALTER DEFAULT": TokenType.COMMAND,
"ALTER DOMAIN": TokenType.COMMAND,
"ALTER ROLE": TokenType.COMMAND,
"ALTER RULE": TokenType.COMMAND,
"ALTER SEQUENCE": TokenType.COMMAND,
"ALTER TYPE": TokenType.COMMAND,
"ALTER USER": TokenType.COMMAND,
"ALTER VIEW": TokenType.COMMAND,
"ANALYZE": TokenType.COMMAND, "ANALYZE": TokenType.COMMAND,
"CALL": TokenType.COMMAND, "CALL": TokenType.COMMAND,
"EXPLAIN": TokenType.COMMAND, "EXPLAIN": TokenType.COMMAND,
@ -967,7 +988,7 @@ class Tokenizer(metaclass=_Tokenizer):
text = self._extract_string(quote_end) text = self._extract_string(quote_end)
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore
text = text.replace("\\\\", "\\") if self._replace_backslash else text text = text.replace("\\\\", "\\") if self._replace_backslash else text
self._add(TokenType.STRING, text) self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text)
return True return True
# X'1234, b'0110', E'\\\\\' etc. # X'1234, b'0110', E'\\\\\' etc.

View file

@ -150,8 +150,8 @@ class TestDataframeColumn(unittest.TestCase):
F.col("cola").between(datetime.date(2022, 1, 1), datetime.date(2022, 3, 1)).sql(), F.col("cola").between(datetime.date(2022, 1, 1), datetime.date(2022, 3, 1)).sql(),
) )
self.assertEqual( self.assertEqual(
"cola BETWEEN CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP) " "cola BETWEEN CAST('2022-01-01T01:01:01+00:00' AS TIMESTAMP) "
"AND CAST('2022-03-01 01:01:01.000000' AS TIMESTAMP)", "AND CAST('2022-03-01T01:01:01+00:00' AS TIMESTAMP)",
F.col("cola") F.col("cola")
.between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)) .between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1))
.sql(), .sql(),

View file

@ -30,7 +30,7 @@ class TestFunctions(unittest.TestCase):
test_date = SF.lit(datetime.date(2022, 1, 1)) test_date = SF.lit(datetime.date(2022, 1, 1))
self.assertEqual("TO_DATE('2022-01-01')", test_date.sql()) self.assertEqual("TO_DATE('2022-01-01')", test_date.sql())
test_datetime = SF.lit(datetime.datetime(2022, 1, 1, 1, 1, 1)) test_datetime = SF.lit(datetime.datetime(2022, 1, 1, 1, 1, 1))
self.assertEqual("CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP)", test_datetime.sql()) self.assertEqual("CAST('2022-01-01T01:01:01+00:00' AS TIMESTAMP)", test_datetime.sql())
test_dict = SF.lit({"cola": 1, "colb": "test"}) test_dict = SF.lit({"cola": 1, "colb": "test"})
self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql()) self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql())
@ -52,7 +52,7 @@ class TestFunctions(unittest.TestCase):
test_date = SF.col(datetime.date(2022, 1, 1)) test_date = SF.col(datetime.date(2022, 1, 1))
self.assertEqual("TO_DATE('2022-01-01')", test_date.sql()) self.assertEqual("TO_DATE('2022-01-01')", test_date.sql())
test_datetime = SF.col(datetime.datetime(2022, 1, 1, 1, 1, 1)) test_datetime = SF.col(datetime.datetime(2022, 1, 1, 1, 1, 1))
self.assertEqual("CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP)", test_datetime.sql()) self.assertEqual("CAST('2022-01-01T01:01:01+00:00' AS TIMESTAMP)", test_datetime.sql())
test_dict = SF.col({"cola": 1, "colb": "test"}) test_dict = SF.col({"cola": 1, "colb": "test"})
self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql()) self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql())

View file

@ -318,3 +318,9 @@ class TestBigQuery(Validator):
self.validate_identity( self.validate_identity(
"CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t" "CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t"
) )
def test_group_concat(self):
self.validate_all(
"SELECT a, GROUP_CONCAT(b) FROM table GROUP BY a",
write={"bigquery": "SELECT a, STRING_AGG(b) FROM table GROUP BY a"},
)

View file

@ -12,6 +12,76 @@ class TestDatabricks(Validator):
"databricks": "SELECT DATEDIFF(year, 'start', 'end')", "databricks": "SELECT DATEDIFF(year, 'start', 'end')",
}, },
) )
self.validate_all(
"SELECT DATEDIFF(microsecond, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(microsecond, 'start', 'end')",
"postgres": "SELECT CAST(EXTRACT(epoch FROM CAST('end' AS TIMESTAMP) - CAST('start' AS TIMESTAMP)) * 1000000 AS BIGINT)",
},
)
self.validate_all(
"SELECT DATEDIFF(millisecond, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(millisecond, 'start', 'end')",
"postgres": "SELECT CAST(EXTRACT(epoch FROM CAST('end' AS TIMESTAMP) - CAST('start' AS TIMESTAMP)) * 1000 AS BIGINT)",
},
)
self.validate_all(
"SELECT DATEDIFF(second, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(second, 'start', 'end')",
"postgres": "SELECT CAST(EXTRACT(epoch FROM CAST('end' AS TIMESTAMP) - CAST('start' AS TIMESTAMP)) AS BIGINT)",
},
)
self.validate_all(
"SELECT DATEDIFF(minute, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(minute, 'start', 'end')",
"postgres": "SELECT CAST(EXTRACT(epoch FROM CAST('end' AS TIMESTAMP) - CAST('start' AS TIMESTAMP)) / 60 AS BIGINT)",
},
)
self.validate_all(
"SELECT DATEDIFF(hour, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(hour, 'start', 'end')",
"postgres": "SELECT CAST(EXTRACT(epoch FROM CAST('end' AS TIMESTAMP) - CAST('start' AS TIMESTAMP)) / 3600 AS BIGINT)",
},
)
self.validate_all(
"SELECT DATEDIFF(day, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(day, 'start', 'end')",
"postgres": "SELECT CAST(EXTRACT(epoch FROM CAST('end' AS TIMESTAMP) - CAST('start' AS TIMESTAMP)) / 86400 AS BIGINT)",
},
)
self.validate_all(
"SELECT DATEDIFF(week, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(week, 'start', 'end')",
"postgres": "SELECT CAST(EXTRACT(year FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) * 48 + EXTRACT(month FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) * 4 + EXTRACT(day FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) / 7 AS BIGINT)",
},
)
self.validate_all(
"SELECT DATEDIFF(month, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(month, 'start', 'end')",
"postgres": "SELECT CAST(EXTRACT(year FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) * 12 + EXTRACT(month FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) AS BIGINT)",
},
)
self.validate_all(
"SELECT DATEDIFF(quarter, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(quarter, 'start', 'end')",
"postgres": "SELECT CAST(EXTRACT(year FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) * 4 + EXTRACT(month FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) / 3 AS BIGINT)",
},
)
self.validate_all(
"SELECT DATEDIFF(year, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(year, 'start', 'end')",
"postgres": "SELECT CAST(EXTRACT(year FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) AS BIGINT)",
},
)
def test_add_date(self): def test_add_date(self):
self.validate_all( self.validate_all(

View file

@ -333,7 +333,7 @@ class TestDialect(Validator):
"drill": "CAST('2020-01-01' AS DATE)", "drill": "CAST('2020-01-01' AS DATE)",
"duckdb": "CAST('2020-01-01' AS DATE)", "duckdb": "CAST('2020-01-01' AS DATE)",
"hive": "TO_DATE('2020-01-01')", "hive": "TO_DATE('2020-01-01')",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')", "presto": "CAST('2020-01-01' AS TIMESTAMP)",
"starrocks": "TO_DATE('2020-01-01')", "starrocks": "TO_DATE('2020-01-01')",
}, },
) )
@ -343,7 +343,7 @@ class TestDialect(Validator):
"drill": "CAST('2020-01-01' AS TIMESTAMP)", "drill": "CAST('2020-01-01' AS TIMESTAMP)",
"duckdb": "CAST('2020-01-01' AS TIMESTAMP)", "duckdb": "CAST('2020-01-01' AS TIMESTAMP)",
"hive": "CAST('2020-01-01' AS TIMESTAMP)", "hive": "CAST('2020-01-01' AS TIMESTAMP)",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')", "presto": "CAST('2020-01-01' AS TIMESTAMP)",
}, },
) )
self.validate_all( self.validate_all(
@ -723,23 +723,23 @@ class TestDialect(Validator):
read={ read={
"postgres": "x->'y'", "postgres": "x->'y'",
"presto": "JSON_EXTRACT(x, 'y')", "presto": "JSON_EXTRACT(x, 'y')",
"starrocks": "x->'y'", "starrocks": "x -> 'y'",
}, },
write={ write={
"oracle": "JSON_EXTRACT(x, 'y')", "oracle": "JSON_EXTRACT(x, 'y')",
"postgres": "x->'y'", "postgres": "x -> 'y'",
"presto": "JSON_EXTRACT(x, 'y')", "presto": "JSON_EXTRACT(x, 'y')",
"starrocks": "x->'y'", "starrocks": "x -> 'y'",
}, },
) )
self.validate_all( self.validate_all(
"JSON_EXTRACT_SCALAR(x, 'y')", "JSON_EXTRACT_SCALAR(x, 'y')",
read={ read={
"postgres": "x->>'y'", "postgres": "x ->> 'y'",
"presto": "JSON_EXTRACT_SCALAR(x, 'y')", "presto": "JSON_EXTRACT_SCALAR(x, 'y')",
}, },
write={ write={
"postgres": "x->>'y'", "postgres": "x ->> 'y'",
"presto": "JSON_EXTRACT_SCALAR(x, 'y')", "presto": "JSON_EXTRACT_SCALAR(x, 'y')",
}, },
) )
@ -749,7 +749,7 @@ class TestDialect(Validator):
"postgres": "x#>'y'", "postgres": "x#>'y'",
}, },
write={ write={
"postgres": "x#>'y'", "postgres": "x #> 'y'",
}, },
) )
self.validate_all( self.validate_all(
@ -758,7 +758,7 @@ class TestDialect(Validator):
"postgres": "x#>>'y'", "postgres": "x#>>'y'",
}, },
write={ write={
"postgres": "x#>>'y'", "postgres": "x #>> 'y'",
}, },
) )

View file

@ -59,7 +59,7 @@ class TestDuckDB(Validator):
"TO_TIMESTAMP(x)", "TO_TIMESTAMP(x)",
write={ write={
"duckdb": "CAST(x AS TIMESTAMP)", "duckdb": "CAST(x AS TIMESTAMP)",
"presto": "DATE_PARSE(x, '%Y-%m-%d %H:%i:%s')", "presto": "CAST(x AS TIMESTAMP)",
"hive": "CAST(x AS TIMESTAMP)", "hive": "CAST(x AS TIMESTAMP)",
}, },
) )
@ -302,3 +302,20 @@ class TestDuckDB(Validator):
read="duckdb", read="duckdb",
unsupported_level=ErrorLevel.IMMEDIATE, unsupported_level=ErrorLevel.IMMEDIATE,
) )
def test_array(self):
self.validate_identity("ARRAY(SELECT id FROM t)")
def test_cast(self):
self.validate_all(
"123::CHARACTER VARYING",
write={
"duckdb": "CAST(123 AS TEXT)",
},
)
def test_bool_or(self):
self.validate_all(
"SELECT a, LOGICAL_OR(b) FROM table GROUP BY a",
write={"duckdb": "SELECT a, BOOL_OR(b) FROM table GROUP BY a"},
)

View file

@ -268,10 +268,10 @@ class TestHive(Validator):
self.validate_all( self.validate_all(
"DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')", "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')",
write={ write={
"duckdb": "STRFTIME('2020-01-01', '%Y-%m-%d %H:%M:%S')", "duckdb": "STRFTIME(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %H:%M:%S')",
"presto": "DATE_FORMAT('2020-01-01', '%Y-%m-%d %H:%i:%S')", "presto": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %H:%i:%S')",
"hive": "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')", "hive": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss')",
"spark": "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')", "spark": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss')",
}, },
) )
self.validate_all( self.validate_all(

View file

@ -91,12 +91,12 @@ class TestMySQL(Validator):
}, },
) )
self.validate_all( self.validate_all(
"N 'some text'", "N'some text'",
read={ read={
"mysql": "N'some text'", "mysql": "n'some text'",
}, },
write={ write={
"mysql": "N 'some text'", "mysql": "N'some text'",
}, },
) )
self.validate_all( self.validate_all(

View file

@ -3,6 +3,7 @@ from tests.dialects.test_dialect import Validator
class TestPostgres(Validator): class TestPostgres(Validator):
maxDiff = None
dialect = "postgres" dialect = "postgres"
def test_ddl(self): def test_ddl(self):
@ -94,6 +95,7 @@ class TestPostgres(Validator):
self.validate_identity("COMMENT ON TABLE mytable IS 'this'") self.validate_identity("COMMENT ON TABLE mytable IS 'this'")
self.validate_identity("SELECT e'\\xDEADBEEF'") self.validate_identity("SELECT e'\\xDEADBEEF'")
self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)") self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
self.validate_identity("""SELECT * FROM JSON_TO_RECORDSET(z) AS y("rank" INT)""")
self.validate_all( self.validate_all(
"END WORK AND NO CHAIN", "END WORK AND NO CHAIN",
@ -112,6 +114,14 @@ class TestPostgres(Validator):
"spark": "CREATE TABLE x (a UUID, b BINARY)", "spark": "CREATE TABLE x (a UUID, b BINARY)",
}, },
) )
self.validate_all(
"123::CHARACTER VARYING",
write={"postgres": "CAST(123 AS VARCHAR)"},
)
self.validate_all(
"TO_TIMESTAMP(123::DOUBLE PRECISION)",
write={"postgres": "TO_TIMESTAMP(CAST(123 AS DOUBLE PRECISION))"},
)
self.validate_identity( self.validate_identity(
"CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)" "CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)"
@ -193,15 +203,21 @@ class TestPostgres(Validator):
}, },
) )
self.validate_all( self.validate_all(
"SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL", "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) pname ON TRUE WHERE pname IS NULL",
read={ write={
"postgres": "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL", "postgres": "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL",
}, },
) )
self.validate_all( self.validate_all(
"SELECT p1.id, p2.id, v1, v2 FROM polygons AS p1, polygons AS p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id <> p2.id", "SELECT p1.id, p2.id, v1, v2 FROM polygons AS p1, polygons AS p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id <> p2.id",
read={ write={
"postgres": "SELECT p1.id, p2.id, v1, v2 FROM polygons p1, polygons p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id != p2.id", "postgres": "SELECT p1.id, p2.id, v1, v2 FROM polygons AS p1, polygons AS p2, LATERAL VERTICES(p1.poly) AS v1, LATERAL VERTICES(p2.poly) AS v2 WHERE (v1 <-> v2) < 10 AND p1.id <> p2.id",
},
)
self.validate_all(
"SELECT * FROM r CROSS JOIN LATERAL unnest(array(1)) AS s(location)",
write={
"postgres": "SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)",
}, },
) )
self.validate_all( self.validate_all(
@ -218,35 +234,46 @@ class TestPostgres(Validator):
) )
self.validate_all( self.validate_all(
"'[1,2,3]'::json->2", "'[1,2,3]'::json->2",
write={"postgres": "CAST('[1,2,3]' AS JSON)->'2'"}, write={"postgres": "CAST('[1,2,3]' AS JSON) -> '2'"},
) )
self.validate_all( self.validate_all(
"""'{"a":1,"b":2}'::json->'b'""", """'{"a":1,"b":2}'::json->'b'""",
write={"postgres": """CAST('{"a":1,"b":2}' AS JSON)->'b'"""}, write={"postgres": """CAST('{"a":1,"b":2}' AS JSON) -> 'b'"""},
) )
self.validate_all( self.validate_all(
"""'{"x": {"y": 1}}'::json->'x'->'y'""", """'{"x": {"y": 1}}'::json->'x'->'y'""",
write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""}, write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON) -> 'x' -> 'y'"""},
) )
self.validate_all( self.validate_all(
"""'{"x": {"y": 1}}'::json->'x'::json->'y'""", """'{"x": {"y": 1}}'::json->'x'::json->'y'""",
write={"postgres": """CAST(CAST('{"x": {"y": 1}}' AS JSON)->'x' AS JSON)->'y'"""}, write={"postgres": """CAST(CAST('{"x": {"y": 1}}' AS JSON) -> 'x' AS JSON) -> 'y'"""},
) )
self.validate_all( self.validate_all(
"""'[1,2,3]'::json->>2""", """'[1,2,3]'::json->>2""",
write={"postgres": "CAST('[1,2,3]' AS JSON)->>'2'"}, write={"postgres": "CAST('[1,2,3]' AS JSON) ->> '2'"},
) )
self.validate_all( self.validate_all(
"""'{"a":1,"b":2}'::json->>'b'""", """'{"a":1,"b":2}'::json->>'b'""",
write={"postgres": """CAST('{"a":1,"b":2}' AS JSON)->>'b'"""}, write={"postgres": """CAST('{"a":1,"b":2}' AS JSON) ->> 'b'"""},
) )
self.validate_all( self.validate_all(
"""'{"a":[1,2,3],"b":[4,5,6]}'::json#>'{a,2}'""", """'{"a":[1,2,3],"b":[4,5,6]}'::json#>'{a,2}'""",
write={"postgres": """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON)#>'{a,2}'"""}, write={"postgres": """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON) #> '{a,2}'"""},
) )
self.validate_all( self.validate_all(
"""'{"a":[1,2,3],"b":[4,5,6]}'::json#>>'{a,2}'""", """'{"a":[1,2,3],"b":[4,5,6]}'::json#>>'{a,2}'""",
write={"postgres": """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON)#>>'{a,2}'"""}, write={"postgres": """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON) #>> '{a,2}'"""},
)
self.validate_all(
"""SELECT JSON_ARRAY_ELEMENTS((foo->'sections')::JSON) AS sections""",
write={
"postgres": """SELECT JSON_ARRAY_ELEMENTS(CAST((foo -> 'sections') AS JSON)) AS sections""",
"presto": """SELECT JSON_ARRAY_ELEMENTS(CAST((JSON_EXTRACT(foo, 'sections')) AS JSON)) AS sections""",
},
)
self.validate_all(
"""x ? 'x'""",
write={"postgres": "x ? 'x'"},
) )
self.validate_all( self.validate_all(
"SELECT $$a$$", "SELECT $$a$$",
@ -260,3 +287,49 @@ class TestPostgres(Validator):
"UPDATE MYTABLE T1 SET T1.COL = 13", "UPDATE MYTABLE T1 SET T1.COL = 13",
write={"postgres": "UPDATE MYTABLE AS T1 SET T1.COL = 13"}, write={"postgres": "UPDATE MYTABLE AS T1 SET T1.COL = 13"},
) )
self.validate_identity("x ~ 'y'")
self.validate_identity("x ~* 'y'")
self.validate_all(
"x !~ 'y'",
write={"postgres": "NOT x ~ 'y'"},
)
self.validate_all(
"x !~* 'y'",
write={"postgres": "NOT x ~* 'y'"},
)
self.validate_all(
"x ~~ 'y'",
write={"postgres": "x LIKE 'y'"},
)
self.validate_all(
"x ~~* 'y'",
write={"postgres": "x ILIKE 'y'"},
)
self.validate_all(
"x !~~ 'y'",
write={"postgres": "NOT x LIKE 'y'"},
)
self.validate_all(
"x !~~* 'y'",
write={"postgres": "NOT x ILIKE 'y'"},
)
self.validate_all(
"'45 days'::interval day",
write={"postgres": "CAST('45 days' AS INTERVAL day)"},
)
self.validate_all(
"'x' 'y' 'z'",
write={"postgres": "CONCAT('x', 'y', 'z')"},
)
self.validate_identity("SELECT ARRAY(SELECT 1)")
self.validate_all(
"x::cstring",
write={"postgres": "CAST(x AS CSTRING)"},
)
self.validate_identity(
"SELECT SUM(x) OVER a, SUM(y) OVER b FROM c WINDOW a AS (PARTITION BY d), b AS (PARTITION BY e)"
)

View file

@ -53,7 +53,7 @@ class TestRedshift(Validator):
self.validate_all( self.validate_all(
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC", "SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
write={ write={
"redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1', "redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1',
}, },
) )
self.validate_all( self.validate_all(

View file

@ -6,6 +6,12 @@ class TestSnowflake(Validator):
dialect = "snowflake" dialect = "snowflake"
def test_snowflake(self): def test_snowflake(self):
self.validate_all(
"SELECT * FROM xxx WHERE col ilike '%Don''t%'",
write={
"snowflake": "SELECT * FROM xxx WHERE col ILIKE '%Don\\'t%'",
},
)
self.validate_all( self.validate_all(
'x:a:"b c"', 'x:a:"b c"',
write={ write={
@ -509,3 +515,11 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f, LATERA
"snowflake": "SELECT 1 MINUS SELECT 1", "snowflake": "SELECT 1 MINUS SELECT 1",
}, },
) )
def test_values(self):
self.validate_all(
'SELECT c0, c1 FROM (VALUES (1, 2), (3, 4)) AS "t0"(c0, c1)',
read={
"spark": "SELECT `c0`, `c1` FROM (VALUES (1, 2), (3, 4)) AS `t0`(`c0`, `c1`)",
},
)

View file

@ -101,6 +101,18 @@ TBLPROPERTIES (
"spark": "CACHE TABLE testCache OPTIONS('storageLevel' = 'DISK_ONLY') AS SELECT * FROM testData" "spark": "CACHE TABLE testCache OPTIONS('storageLevel' = 'DISK_ONLY') AS SELECT * FROM testData"
}, },
) )
self.validate_all(
"ALTER TABLE StudentInfo ADD COLUMNS (LastName STRING, DOB TIMESTAMP)",
write={
"spark": "ALTER TABLE StudentInfo ADD COLUMNS (LastName STRING, DOB TIMESTAMP)",
},
)
self.validate_all(
"ALTER TABLE StudentInfo DROP COLUMNS (LastName, DOB)",
write={
"spark": "ALTER TABLE StudentInfo DROP COLUMNS (LastName, DOB)",
},
)
def test_to_date(self): def test_to_date(self):
self.validate_all( self.validate_all(

View file

@ -431,11 +431,11 @@ class TestTSQL(Validator):
def test_string(self): def test_string(self):
self.validate_all( self.validate_all(
"SELECT N'test'", "SELECT N'test'",
write={"spark": "SELECT 'test'"}, write={"spark": "SELECT N'test'"},
) )
self.validate_all( self.validate_all(
"SELECT n'test'", "SELECT n'test'",
write={"spark": "SELECT 'test'"}, write={"spark": "SELECT N'test'"},
) )
self.validate_all( self.validate_all(
"SELECT '''test'''", "SELECT '''test'''",

View file

@ -17,6 +17,7 @@ SUM(CASE WHEN x > 1 THEN 1 ELSE 0 END) / y
'\x' '\x'
"x" "x"
"" ""
N'abc'
x x
x % 1 x % 1
x < 1 x < 1
@ -33,6 +34,10 @@ x << 1
x >> 1 x >> 1
x >> 1 | 1 & 1 ^ 1 x >> 1 | 1 & 1 ^ 1
x || y x || y
x[ : ]
x[1 : ]
x[1 : 2]
x[-4 : -1]
1 - -1 1 - -1
- -5 - -5
dec.x + y dec.x + y
@ -62,6 +67,8 @@ x BETWEEN 'a' || b AND 'c' || d
NOT x IS NULL NOT x IS NULL
x IS TRUE x IS TRUE
x IS FALSE x IS FALSE
x IS TRUE IS TRUE
x LIKE y IS TRUE
time time
zone zone
ARRAY<TEXT> ARRAY<TEXT>
@ -93,10 +100,11 @@ x LIKE '%y%' ESCAPE '\'
x ILIKE '%y%' ESCAPE '\' x ILIKE '%y%' ESCAPE '\'
1 AS escape 1 AS escape
INTERVAL '1' day INTERVAL '1' day
INTERVAL '1' month INTERVAL '1' MONTH
INTERVAL '1 day' INTERVAL '1 day'
INTERVAL 2 months INTERVAL 2 months
INTERVAL 1 + 3 days INTERVAL 1 + 3 DAYS
CAST('45' AS INTERVAL DAYS)
TIMESTAMP_DIFF(CURRENT_TIMESTAMP(), 1, DAY) TIMESTAMP_DIFF(CURRENT_TIMESTAMP(), 1, DAY)
DATETIME_DIFF(CURRENT_DATE, 1, DAY) DATETIME_DIFF(CURRENT_DATE, 1, DAY)
QUANTILE(x, 0.5) QUANTILE(x, 0.5)
@ -144,6 +152,7 @@ SELECT 1 AS count FROM test
SELECT 1 AS comment FROM test SELECT 1 AS comment FROM test
SELECT 1 AS numeric FROM test SELECT 1 AS numeric FROM test
SELECT 1 AS number FROM test SELECT 1 AS number FROM test
SELECT COALESCE(offset, 1)
SELECT t.count SELECT t.count
SELECT DISTINCT x FROM test SELECT DISTINCT x FROM test
SELECT DISTINCT x, y FROM test SELECT DISTINCT x, y FROM test
@ -196,6 +205,7 @@ SELECT JSON_EXTRACT_SCALAR(x, '$.name')
SELECT x LIKE '%x%' FROM test SELECT x LIKE '%x%' FROM test
SELECT * FROM test LIMIT 100 SELECT * FROM test LIMIT 100
SELECT * FROM test LIMIT 100 OFFSET 200 SELECT * FROM test LIMIT 100 OFFSET 200
SELECT * FROM test FETCH FIRST ROWS ONLY
SELECT * FROM test FETCH FIRST 1 ROWS ONLY SELECT * FROM test FETCH FIRST 1 ROWS ONLY
SELECT * FROM test FETCH NEXT 1 ROWS ONLY SELECT * FROM test FETCH NEXT 1 ROWS ONLY
SELECT (1 > 2) AS x FROM test SELECT (1 > 2) AS x FROM test
@ -460,6 +470,7 @@ CREATE TABLE z (end INT)
CREATE TABLE z (a ARRAY<TEXT>, b MAP<TEXT, DOUBLE>, c DECIMAL(5, 3)) CREATE TABLE z (a ARRAY<TEXT>, b MAP<TEXT, DOUBLE>, c DECIMAL(5, 3))
CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECIMAL(5, 3)) CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECIMAL(5, 3))
CREATE TABLE z (a INT(11) DEFAULT UUID()) CREATE TABLE z (a INT(11) DEFAULT UUID())
CREATE TABLE z (n INT DEFAULT 0 NOT NULL)
CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id') CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id')
CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1) CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1)
CREATE TABLE z (a INT(11) NOT NULL DEFAULT -1) CREATE TABLE z (a INT(11) NOT NULL DEFAULT -1)
@ -511,7 +522,13 @@ INSERT OVERWRITE TABLE a.b PARTITION(ds) SELECT x FROM y
INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD') SELECT x FROM y INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD') SELECT x FROM y
INSERT OVERWRITE TABLE a.b PARTITION(ds, hour) SELECT x FROM y INSERT OVERWRITE TABLE a.b PARTITION(ds, hour) SELECT x FROM y
INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD', hour='hh') SELECT x FROM y INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD', hour='hh') SELECT x FROM y
ALTER AGGREGATE bla(foo) OWNER TO CURRENT_USER
ALTER RULE foo ON bla RENAME TO baz
ALTER ROLE CURRENT_USER WITH REPLICATION
ALTER SEQUENCE IF EXISTS baz RESTART WITH boo
ALTER TYPE electronic_mail RENAME TO email ALTER TYPE electronic_mail RENAME TO email
ALTER VIEW foo ALTER COLUMN bla SET DEFAULT 'NOT SET'
ALTER DOMAIN foo VALIDATE CONSTRAINT bla
ANALYZE a.y ANALYZE a.y
DELETE FROM x WHERE y > 1 DELETE FROM x WHERE y > 1
DELETE FROM y DELETE FROM y
@ -596,3 +613,17 @@ SELECT x AS INTO FROM bla
SELECT * INTO newevent FROM event SELECT * INTO newevent FROM event
SELECT * INTO TEMPORARY newevent FROM event SELECT * INTO TEMPORARY newevent FROM event
SELECT * INTO UNLOGGED newevent FROM event SELECT * INTO UNLOGGED newevent FROM event
ALTER TABLE integers ADD COLUMN k INT
ALTER TABLE integers ADD COLUMN IF NOT EXISTS k INT
ALTER TABLE IF EXISTS integers ADD COLUMN k INT
ALTER TABLE integers ADD COLUMN l INT DEFAULT 10
ALTER TABLE measurements ADD COLUMN mtime TIMESTAMPTZ DEFAULT NOW()
ALTER TABLE integers DROP COLUMN k
ALTER TABLE integers DROP COLUMN IF EXISTS k
ALTER TABLE integers DROP COLUMN k CASCADE
ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR
ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR USING CONCAT(i, '_', j)
ALTER TABLE integers ALTER COLUMN i SET DEFAULT 10
ALTER TABLE integers ALTER COLUMN i DROP DEFAULT
ALTER TABLE mydataset.mytable DROP COLUMN A, DROP COLUMN IF EXISTS B
ALTER TABLE mydataset.mytable ADD COLUMN A TEXT, ADD COLUMN IF NOT EXISTS B INT

View file

@ -1,11 +1,11 @@
SELECT w.d + w.e AS c FROM w AS w; SELECT w.d + w.e AS c FROM w AS w;
SELECT CONCAT(w.d, w.e) AS c FROM w AS w; SELECT CONCAT("w"."d", "w"."e") AS "c" FROM "w" AS "w";
SELECT CAST(w.d AS DATE) > w.e AS a FROM w AS w; SELECT CAST(w.d AS DATE) > w.e AS a FROM w AS w;
SELECT CAST(w.d AS DATE) > CAST(w.e AS DATE) AS a FROM w AS w; SELECT CAST("w"."d" AS DATE) > CAST("w"."e" AS DATE) AS "a" FROM "w" AS "w";
SELECT CAST(1 AS VARCHAR) AS a FROM w AS w; SELECT CAST(1 AS VARCHAR) AS a FROM w AS w;
SELECT CAST(1 AS VARCHAR) AS a FROM w AS w; SELECT CAST(1 AS VARCHAR) AS "a" FROM "w" AS "w";
SELECT CAST(1 + 3.2 AS DOUBLE) AS a FROM w AS w; SELECT CAST(1 + 3.2 AS DOUBLE) AS a FROM w AS w;
SELECT 1 + 3.2 AS a FROM w AS w; SELECT 1 + 3.2 AS "a" FROM "w" AS "w";

View file

@ -291,3 +291,81 @@ SELECT a1 FROM cte1;
SELECT SELECT
"x"."a" AS "a1" "x"."a" AS "a1"
FROM "x" AS "x"; FROM "x" AS "x";
# title: recursive cte
WITH RECURSIVE cte1 AS (
SELECT *
FROM (
SELECT 1 AS a, 2 AS b
) base
CROSS JOIN (SELECT 3 c) y
UNION ALL
SELECT *
FROM cte1
WHERE a < 1
)
SELECT *
FROM cte1;
WITH RECURSIVE "base" AS (
SELECT
1 AS "a",
2 AS "b"
), "y" AS (
SELECT
3 AS "c"
), "cte1" AS (
SELECT
"base"."a" AS "a",
"base"."b" AS "b",
"y"."c" AS "c"
FROM "base" AS "base"
CROSS JOIN "y" AS "y"
UNION ALL
SELECT
"cte1"."a" AS "a",
"cte1"."b" AS "b",
"cte1"."c" AS "c"
FROM "cte1"
WHERE
"cte1"."a" < 1
)
SELECT
"cte1"."a" AS "a",
"cte1"."b" AS "b",
"cte1"."c" AS "c"
FROM "cte1";
# title: right join should not push down to from
SELECT x.a, y.b
FROM x
RIGHT JOIN y
ON x.a = y.b
WHERE x.b = 1;
SELECT
"x"."a" AS "a",
"y"."b" AS "b"
FROM "x" AS "x"
RIGHT JOIN "y" AS "y"
ON "x"."a" = "y"."b"
WHERE
"x"."b" = 1;
# title: right join can push down to itself
SELECT x.a, y.b
FROM x
RIGHT JOIN y
ON x.a = y.b
WHERE y.b = 1;
WITH "y_2" AS (
SELECT
"y"."b" AS "b"
FROM "y" AS "y"
WHERE
"y"."b" = 1
)
SELECT
"x"."a" AS "a",
"y"."b" AS "b"
FROM "x" AS "x"
RIGHT JOIN "y_2" AS "y"
ON "x"."a" = "y"."b";

View file

@ -1,32 +1,32 @@
SELECT a FROM (SELECT * FROM x); SELECT a FROM (SELECT * FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0"; SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0;
SELECT 1 FROM (SELECT * FROM x) WHERE b = 2; SELECT 1 FROM (SELECT * FROM x) WHERE b = 2;
SELECT 1 AS "_col_0" FROM (SELECT x.b AS b FROM x AS x) AS "_q_0" WHERE "_q_0".b = 2; SELECT 1 AS _col_0 FROM (SELECT x.b AS b FROM x AS x) AS _q_0 WHERE _q_0.b = 2;
SELECT (SELECT c FROM y WHERE q.b = y.b) FROM (SELECT * FROM x) AS q; SELECT (SELECT c FROM y WHERE q.b = y.b) FROM (SELECT * FROM x) AS q;
SELECT (SELECT y.c AS c FROM y AS y WHERE q.b = y.b) AS "_col_0" FROM (SELECT x.b AS b FROM x AS x) AS q; SELECT (SELECT y.c AS c FROM y AS y WHERE q.b = y.b) AS _col_0 FROM (SELECT x.b AS b FROM x AS x) AS q;
SELECT a FROM x JOIN (SELECT b, c FROM y) AS z ON x.b = z.b; SELECT a FROM x JOIN (SELECT b, c FROM y) AS z ON x.b = z.b;
SELECT x.a AS a FROM x AS x JOIN (SELECT y.b AS b FROM y AS y) AS z ON x.b = z.b; SELECT x.a AS a FROM x AS x JOIN (SELECT y.b AS b FROM y AS y) AS z ON x.b = z.b;
SELECT x1.a FROM (SELECT * FROM x) AS x1, (SELECT * FROM x) AS x2; SELECT x1.a FROM (SELECT * FROM x) AS x1, (SELECT * FROM x) AS x2;
SELECT x1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS x1, (SELECT 1 AS "_" FROM x AS x) AS x2; SELECT x1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS x1, (SELECT 1 AS _ FROM x AS x) AS x2;
SELECT x1.a FROM (SELECT * FROM x) AS x1, (SELECT * FROM x) AS x2; SELECT x1.a FROM (SELECT * FROM x) AS x1, (SELECT * FROM x) AS x2;
SELECT x1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS x1, (SELECT 1 AS "_" FROM x AS x) AS x2; SELECT x1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS x1, (SELECT 1 AS _ FROM x AS x) AS x2;
SELECT a FROM (SELECT DISTINCT a, b FROM x); SELECT a FROM (SELECT DISTINCT a, b FROM x);
SELECT "_q_0".a AS a FROM (SELECT DISTINCT x.a AS a, x.b AS b FROM x AS x) AS "_q_0"; SELECT _q_0.a AS a FROM (SELECT DISTINCT x.a AS a, x.b AS b FROM x AS x) AS _q_0;
SELECT a FROM (SELECT a, b FROM x UNION ALL SELECT a, b FROM x); SELECT a FROM (SELECT a, b FROM x UNION ALL SELECT a, b FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION ALL SELECT x.a AS a FROM x AS x) AS "_q_0"; SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x UNION ALL SELECT x.a AS a FROM x AS x) AS _q_0;
WITH t1 AS (SELECT x.a AS a, x.b AS b FROM x UNION ALL SELECT z.b AS b, z.c AS c FROM z) SELECT a, b FROM t1; WITH t1 AS (SELECT x.a AS a, x.b AS b FROM x UNION ALL SELECT z.b AS b, z.c AS c FROM z) SELECT a, b FROM t1;
WITH t1 AS (SELECT x.a AS a, x.b AS b FROM x AS x UNION ALL SELECT z.b AS b, z.c AS c FROM z AS z) SELECT t1.a AS a, t1.b AS b FROM t1; WITH t1 AS (SELECT x.a AS a, x.b AS b FROM x AS x UNION ALL SELECT z.b AS b, z.c AS c FROM z AS z) SELECT t1.a AS a, t1.b AS b FROM t1;
SELECT a FROM (SELECT a, b FROM x UNION SELECT a, b FROM x); SELECT a FROM (SELECT a, b FROM x UNION SELECT a, b FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x UNION SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0"; SELECT _q_0.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x UNION SELECT x.a AS a, x.b AS b FROM x AS x) AS _q_0;
WITH y AS (SELECT * FROM x) SELECT a FROM y; WITH y AS (SELECT * FROM x) SELECT a FROM y;
WITH y AS (SELECT x.a AS a FROM x AS x) SELECT y.a AS a FROM y; WITH y AS (SELECT x.a AS a FROM x AS x) SELECT y.a AS a FROM y;
@ -38,10 +38,10 @@ WITH z AS (SELECT * FROM x) SELECT a FROM z UNION SELECT a FROM z;
WITH z AS (SELECT x.a AS a FROM x AS x) SELECT z.a AS a FROM z UNION SELECT z.a AS a FROM z; WITH z AS (SELECT x.a AS a FROM x AS x) SELECT z.a AS a FROM z UNION SELECT z.a AS a FROM z;
SELECT b FROM (SELECT a, SUM(b) AS b FROM x GROUP BY a); SELECT b FROM (SELECT a, SUM(b) AS b FROM x GROUP BY a);
SELECT "_q_0".b AS b FROM (SELECT SUM(x.b) AS b FROM x AS x GROUP BY x.a) AS "_q_0"; SELECT _q_0.b AS b FROM (SELECT SUM(x.b) AS b FROM x AS x GROUP BY x.a) AS _q_0;
SELECT b FROM (SELECT a, SUM(b) AS b FROM x ORDER BY a); SELECT b FROM (SELECT a, SUM(b) AS b FROM x ORDER BY a);
SELECT "_q_0".b AS b FROM (SELECT x.a AS a, SUM(x.b) AS b FROM x AS x ORDER BY a) AS "_q_0"; SELECT _q_0.b AS b FROM (SELECT x.a AS a, SUM(x.b) AS b FROM x AS x ORDER BY a) AS _q_0;
SELECT x FROM (VALUES(1, 2)) AS q(x, y); SELECT x FROM (VALUES(1, 2)) AS q(x, y);
SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y); SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y);

View file

@ -21,15 +21,15 @@ SELECT x.a AS b FROM x AS x;
# execute: false # execute: false
SELECT 1, 2 FROM x; SELECT 1, 2 FROM x;
SELECT 1 AS "_col_0", 2 AS "_col_1" FROM x AS x; SELECT 1 AS _col_0, 2 AS _col_1 FROM x AS x;
# execute: false # execute: false
SELECT a + b FROM x; SELECT a + b FROM x;
SELECT x.a + x.b AS "_col_0" FROM x AS x; SELECT x.a + x.b AS _col_0 FROM x AS x;
# execute: false # execute: false
SELECT a, SUM(b) FROM x WHERE a > 1 AND b > 1 GROUP BY a; SELECT a, SUM(b) FROM x WHERE a > 1 AND b > 1 GROUP BY a;
SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 AND x.b > 1 GROUP BY x.a; SELECT x.a AS a, SUM(x.b) AS _col_1 FROM x AS x WHERE x.a > 1 AND x.b > 1 GROUP BY x.a;
SELECT SUM(a) AS c FROM x HAVING SUM(a) > 3; SELECT SUM(a) AS c FROM x HAVING SUM(a) > 3;
SELECT SUM(x.a) AS c FROM x AS x HAVING SUM(x.a) > 3; SELECT SUM(x.a) AS c FROM x AS x HAVING SUM(x.a) > 3;
@ -59,7 +59,7 @@ SELECT SUM(x.a) AS c, SUM(x.b) AS d FROM x AS x ORDER BY SUM(x.a), SUM(x.b);
# execute: false # execute: false
SELECT SUM(a), SUM(b) AS c FROM x ORDER BY 1, 2; SELECT SUM(a), SUM(b) AS c FROM x ORDER BY 1, 2;
SELECT SUM(x.a) AS "_col_0", SUM(x.b) AS c FROM x AS x ORDER BY SUM(x.a), SUM(x.b); SELECT SUM(x.a) AS _col_0, SUM(x.b) AS c FROM x AS x ORDER BY SUM(x.a), SUM(x.b);
SELECT a AS j, b FROM x GROUP BY j, b; SELECT a AS j, b FROM x GROUP BY j, b;
SELECT x.a AS j, x.b AS b FROM x AS x GROUP BY x.a, x.b; SELECT x.a AS j, x.b AS b FROM x AS x GROUP BY x.a, x.b;
@ -72,7 +72,7 @@ SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a, x.b;
# execute: false # execute: false
SELECT DATE(a), DATE(b) AS c FROM x GROUP BY 1, 2; SELECT DATE(a), DATE(b) AS c FROM x GROUP BY 1, 2;
SELECT DATE(x.a) AS "_col_0", DATE(x.b) AS c FROM x AS x GROUP BY DATE(x.a), DATE(x.b); SELECT DATE(x.a) AS _col_0, DATE(x.b) AS c FROM x AS x GROUP BY DATE(x.a), DATE(x.b);
SELECT SUM(x.a) AS c FROM x JOIN y ON x.b = y.b GROUP BY c; SELECT SUM(x.a) AS c FROM x JOIN y ON x.b = y.b GROUP BY c;
SELECT SUM(x.a) AS c FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY y.c; SELECT SUM(x.a) AS c FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY y.c;
@ -130,10 +130,10 @@ SELECT a FROM (SELECT a FROM x AS x) y;
SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y; SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y;
SELECT a FROM (SELECT a AS a FROM x); SELECT a FROM (SELECT a AS a FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0"; SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0;
SELECT a FROM (SELECT a FROM (SELECT a FROM x)); SELECT a FROM (SELECT a FROM (SELECT a FROM x));
SELECT "_q_1".a AS a FROM (SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0") AS "_q_1"; SELECT _q_1.a AS a FROM (SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0) AS _q_1;
SELECT x.a FROM x AS x JOIN (SELECT * FROM x) AS y ON x.a = y.a; SELECT x.a FROM x AS x JOIN (SELECT * FROM x) AS y ON x.a = y.a;
SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a, x.b AS b FROM x AS x) AS y ON x.a = y.a; SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a, x.b AS b FROM x AS x) AS y ON x.a = y.a;
@ -157,7 +157,7 @@ SELECT a FROM x UNION SELECT a FROM x UNION SELECT a FROM x;
SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x; SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x;
SELECT a FROM (SELECT a FROM x UNION SELECT a FROM x); SELECT a FROM (SELECT a FROM x UNION SELECT a FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x) AS "_q_0"; SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x) AS _q_0;
-------------------------------------- --------------------------------------
-- Subqueries -- Subqueries
@ -167,10 +167,10 @@ SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT y.c AS c FROM y AS y);
# execute: false # execute: false
SELECT (SELECT c FROM y) FROM x; SELECT (SELECT c FROM y) FROM x;
SELECT (SELECT y.c AS c FROM y AS y) AS "_col_0" FROM x AS x; SELECT (SELECT y.c AS c FROM y AS y) AS _col_0 FROM x AS x;
SELECT a FROM (SELECT a FROM x) WHERE a IN (SELECT b FROM (SELECT b FROM y)); SELECT a FROM (SELECT a FROM x) WHERE a IN (SELECT b FROM (SELECT b FROM y));
SELECT "_q_1".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_1" WHERE "_q_1".a IN (SELECT "_q_0".b AS b FROM (SELECT y.b AS b FROM y AS y) AS "_q_0"); SELECT _q_1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_1 WHERE _q_1.a IN (SELECT _q_0.b AS b FROM (SELECT y.b AS b FROM y AS y) AS _q_0);
-------------------------------------- --------------------------------------
-- Correlated subqueries -- Correlated subqueries
@ -215,10 +215,10 @@ SELECT x.*, y.* FROM x JOIN y ON x.b = y.b;
SELECT x.a AS a, x.b AS b, y.b AS b, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; SELECT x.a AS a, x.b AS b, y.b AS b, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
SELECT a FROM (SELECT * FROM x); SELECT a FROM (SELECT * FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0"; SELECT _q_0.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x) AS _q_0;
SELECT * FROM (SELECT a FROM x); SELECT * FROM (SELECT a FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0"; SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0;
-------------------------------------- --------------------------------------
-- CTEs -- CTEs

View file

@ -11,10 +11,10 @@ SELECT x.b AS b FROM x AS x;
-- Derived tables -- Derived tables
-------------------------------------- --------------------------------------
SELECT x.a FROM x AS x JOIN (SELECT * FROM x); SELECT x.a FROM x AS x JOIN (SELECT * FROM x);
SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a FROM x AS x) AS "_q_0"; SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a FROM x AS x) AS _q_0;
SELECT x.b FROM x AS x JOIN (SELECT b FROM x); SELECT x.b FROM x AS x JOIN (SELECT b FROM x);
SELECT x.b AS b FROM x AS x JOIN (SELECT x.b AS b FROM x AS x) AS "_q_0"; SELECT x.b AS b FROM x AS x JOIN (SELECT x.b AS b FROM x AS x) AS _q_0;
-------------------------------------- --------------------------------------
-- Expand * -- Expand *
@ -29,7 +29,7 @@ SELECT * FROM y JOIN z ON y.c = z.c;
SELECT y.b AS b, z.b AS b FROM y AS y JOIN z AS z ON y.c = z.c; SELECT y.b AS b, z.b AS b FROM y AS y JOIN z AS z ON y.c = z.c;
SELECT a FROM (SELECT * FROM x); SELECT a FROM (SELECT * FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0"; SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0;
SELECT * FROM (SELECT a FROM x); SELECT * FROM (SELECT a FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0"; SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0;

View file

@ -30,14 +30,14 @@ CROSS JOIN (
SELECT SELECT
SUM(y.a) AS a SUM(y.a) AS a
FROM y FROM y
) AS "_u_0" ) AS _u_0
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
y.a AS a y.a AS a
FROM y FROM y
GROUP BY GROUP BY
y.a y.a
) AS "_u_1" ) AS _u_1
ON x.a = "_u_1"."a" ON x.a = "_u_1"."a"
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
@ -45,7 +45,7 @@ LEFT JOIN (
FROM y FROM y
GROUP BY GROUP BY
y.b y.b
) AS "_u_2" ) AS _u_2
ON x.a = "_u_2"."b" ON x.a = "_u_2"."b"
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
@ -53,7 +53,7 @@ LEFT JOIN (
FROM y FROM y
GROUP BY GROUP BY
y.a y.a
) AS "_u_3" ) AS _u_3
ON x.a = "_u_3"."a" ON x.a = "_u_3"."a"
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
@ -64,8 +64,8 @@ LEFT JOIN (
TRUE TRUE
GROUP BY GROUP BY
y.a y.a
) AS "_u_4" ) AS _u_4
ON x.a = "_u_4"."_u_5" ON x.a = _u_4._u_5
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
SUM(y.b) AS b, SUM(y.b) AS b,
@ -75,8 +75,8 @@ LEFT JOIN (
TRUE TRUE
GROUP BY GROUP BY
y.a y.a
) AS "_u_6" ) AS _u_6
ON x.a = "_u_6"."_u_7" ON x.a = _u_6._u_7
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
y.a AS a y.a AS a
@ -85,8 +85,8 @@ LEFT JOIN (
TRUE TRUE
GROUP BY GROUP BY
y.a y.a
) AS "_u_8" ) AS _u_8
ON "_u_8".a = x.a ON _u_8.a = x.a
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
y.a AS a y.a AS a
@ -95,8 +95,8 @@ LEFT JOIN (
TRUE TRUE
GROUP BY GROUP BY
y.a y.a
) AS "_u_9" ) AS _u_9
ON "_u_9".a = x.a ON _u_9.a = x.a
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
ARRAY_AGG(y.a) AS a, ARRAY_AGG(y.a) AS a,
@ -106,8 +106,8 @@ LEFT JOIN (
TRUE TRUE
GROUP BY GROUP BY
y.b y.b
) AS "_u_10" ) AS _u_10
ON "_u_10"."_u_11" = x.a ON _u_10._u_11 = x.a
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
SUM(y.a) AS a, SUM(y.a) AS a,
@ -118,8 +118,8 @@ LEFT JOIN (
TRUE AND TRUE AND TRUE TRUE AND TRUE AND TRUE
GROUP BY GROUP BY
y.a y.a
) AS "_u_12" ) AS _u_12
ON "_u_12"."_u_13" = x.a AND "_u_12"."_u_13" = x.b ON _u_12._u_13 = x.a AND _u_12._u_13 = x.b
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
y.a AS a y.a AS a
@ -128,38 +128,38 @@ LEFT JOIN (
TRUE TRUE
GROUP BY GROUP BY
y.a y.a
) AS "_u_15" ) AS _u_15
ON x.a = "_u_15".a ON x.a = _u_15.a
WHERE WHERE
x.a = "_u_0".a x.a = _u_0.a
AND NOT "_u_1"."a" IS NULL AND NOT "_u_1"."a" IS NULL
AND NOT "_u_2"."b" IS NULL AND NOT "_u_2"."b" IS NULL
AND NOT "_u_3"."a" IS NULL AND NOT "_u_3"."a" IS NULL
AND ( AND (
x.a = "_u_4".b AND NOT "_u_4"."_u_5" IS NULL x.a = _u_4.b AND NOT _u_4._u_5 IS NULL
) )
AND ( AND (
x.a > "_u_6".b AND NOT "_u_6"."_u_7" IS NULL x.a > _u_6.b AND NOT _u_6._u_7 IS NULL
) )
AND ( AND (
None = "_u_8".a AND NOT "_u_8".a IS NULL None = _u_8.a AND NOT _u_8.a IS NULL
) )
AND NOT ( AND NOT (
x.a = "_u_9".a AND NOT "_u_9".a IS NULL x.a = _u_9.a AND NOT _u_9.a IS NULL
) )
AND ( AND (
ARRAY_ANY("_u_10".a, _x -> _x = x.a) AND NOT "_u_10"."_u_11" IS NULL ARRAY_ANY(_u_10.a, _x -> _x = x.a) AND NOT _u_10._u_11 IS NULL
) )
AND ( AND (
( (
( (
x.a < "_u_12".a AND NOT "_u_12"."_u_13" IS NULL x.a < _u_12.a AND NOT _u_12._u_13 IS NULL
) AND NOT "_u_12"."_u_13" IS NULL ) AND NOT _u_12._u_13 IS NULL
) )
AND ARRAY_ANY("_u_12"."_u_14", "_x" -> "_x" <> x.d) AND ARRAY_ANY(_u_12._u_14, "_x" -> _x <> x.d)
) )
AND ( AND (
NOT "_u_15".a IS NULL AND NOT "_u_15".a IS NULL NOT _u_15.a IS NULL AND NOT _u_15.a IS NULL
) )
AND x.a IN ( AND x.a IN (
SELECT SELECT

View file

@ -481,6 +481,19 @@ class TestBuild(unittest.TestCase):
), ),
(lambda: exp.delete("y", where="x > 1"), "DELETE FROM y WHERE x > 1"), (lambda: exp.delete("y", where="x > 1"), "DELETE FROM y WHERE x > 1"),
(lambda: exp.delete("y", where=exp.and_("x > 1")), "DELETE FROM y WHERE x > 1"), (lambda: exp.delete("y", where=exp.and_("x > 1")), "DELETE FROM y WHERE x > 1"),
(
lambda: select("AVG(a) OVER b")
.from_("table")
.window("b AS (PARTITION BY c ORDER BY d)"),
"SELECT AVG(a) OVER b FROM table WINDOW b AS (PARTITION BY c ORDER BY d)",
),
(
lambda: select("AVG(a) OVER b", "MIN(c) OVER d")
.from_("table")
.window("b AS (PARTITION BY e ORDER BY f)")
.window("d AS (PARTITION BY g ORDER BY h)"),
"SELECT AVG(a) OVER b, MIN(c) OVER d FROM table WINDOW b AS (PARTITION BY e ORDER BY f), d AS (PARTITION BY g ORDER BY h)",
),
]: ]:
with self.subTest(sql): with self.subTest(sql):
self.assertEqual(expression().sql(dialect[0] if dialect else None), sql) self.assertEqual(expression().sql(dialect[0] if dialect else None), sql)

View file

@ -74,7 +74,7 @@ class TestExecutor(unittest.TestCase):
) )
return expression return expression
for i, (sql, _) in enumerate(self.sqls[0:18]): for i, (sql, _) in enumerate(self.sqls):
with self.subTest(f"tpch-h {i + 1}"): with self.subTest(f"tpch-h {i + 1}"):
a = self.cached_execute(sql) a = self.cached_execute(sql)
sql = parse_one(sql).transform(to_csv).sql(pretty=True) sql = parse_one(sql).transform(to_csv).sql(pretty=True)

View file

@ -1,4 +1,5 @@
import datetime import datetime
import math
import unittest import unittest
from sqlglot import alias, exp, parse_one from sqlglot import alias, exp, parse_one
@ -491,7 +492,7 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(alias("foo", "bar-1").sql(), 'foo AS "bar-1"') self.assertEqual(alias("foo", "bar-1").sql(), 'foo AS "bar-1"')
self.assertEqual(alias("foo", "bar_1").sql(), "foo AS bar_1") self.assertEqual(alias("foo", "bar_1").sql(), "foo AS bar_1")
self.assertEqual(alias("foo * 2", "2bar").sql(), 'foo * 2 AS "2bar"') self.assertEqual(alias("foo * 2", "2bar").sql(), 'foo * 2 AS "2bar"')
self.assertEqual(alias('"foo"', "_bar").sql(), '"foo" AS "_bar"') self.assertEqual(alias('"foo"', "_bar").sql(), '"foo" AS _bar')
self.assertEqual(alias("foo", "bar", quoted=True).sql(), 'foo AS "bar"') self.assertEqual(alias("foo", "bar", quoted=True).sql(), 'foo AS "bar"')
def test_unit(self): def test_unit(self):
@ -503,6 +504,8 @@ class TestExpressions(unittest.TestCase):
def test_identifier(self): def test_identifier(self):
self.assertTrue(exp.to_identifier('"x"').quoted) self.assertTrue(exp.to_identifier('"x"').quoted)
self.assertFalse(exp.to_identifier("x").quoted) self.assertFalse(exp.to_identifier("x").quoted)
self.assertTrue(exp.to_identifier("foo ").quoted)
self.assertFalse(exp.to_identifier("_x").quoted)
def test_function_normalizer(self): def test_function_normalizer(self):
self.assertEqual(parse_one("HELLO()").sql(normalize_functions="lower"), "hello()") self.assertEqual(parse_one("HELLO()").sql(normalize_functions="lower"), "hello()")
@ -549,14 +552,15 @@ class TestExpressions(unittest.TestCase):
([1, "2", None], "ARRAY(1, '2', NULL)"), ([1, "2", None], "ARRAY(1, '2', NULL)"),
({"x": None}, "MAP('x', NULL)"), ({"x": None}, "MAP('x', NULL)"),
( (
datetime.datetime(2022, 10, 1, 1, 1, 1), datetime.datetime(2022, 10, 1, 1, 1, 1, 1),
"TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')", "TIME_STR_TO_TIME('2022-10-01T01:01:01.000001+00:00')",
), ),
( (
datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc), datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc),
"TIME_STR_TO_TIME('2022-10-01 01:01:01.000000+0000')", "TIME_STR_TO_TIME('2022-10-01T01:01:01+00:00')",
), ),
(datetime.date(2022, 10, 1), "DATE_STR_TO_DATE('2022-10-01')"), (datetime.date(2022, 10, 1), "DATE_STR_TO_DATE('2022-10-01')"),
(math.nan, "NULL"),
]: ]:
with self.subTest(value): with self.subTest(value):
self.assertEqual(exp.convert(value).sql(), expected) self.assertEqual(exp.convert(value).sql(), expected)

View file

@ -164,9 +164,6 @@ class TestOptimizer(unittest.TestCase):
with self.assertRaises(OptimizeError): with self.assertRaises(OptimizeError):
optimizer.qualify_columns.qualify_columns(parse_one(sql), schema=self.schema) optimizer.qualify_columns.qualify_columns(parse_one(sql), schema=self.schema)
def test_quote_identities(self):
self.check_file("quote_identities", optimizer.quote_identities.quote_identities)
def test_lower_identities(self): def test_lower_identities(self):
self.check_file("lower_identities", optimizer.lower_identities.lower_identities) self.check_file("lower_identities", optimizer.lower_identities.lower_identities)
@ -555,3 +552,29 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), schema=schema parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), schema=schema
) )
self.assertEqual(expression.expressions[0].type.this, target_type) self.assertEqual(expression.expressions[0].type.this, target_type)
def test_recursive_cte(self):
query = parse_one(
"""
with recursive t(n) AS
(
select 1
union all
select n + 1
FROM t
where n < 3
), y AS (
select n
FROM t
union all
select n + 1
FROM y
where n < 2
)
select * from y
"""
)
scope_t, scope_y = build_scope(query).cte_scopes
self.assertEqual(set(scope_t.cte_sources), {"t"})
self.assertEqual(set(scope_y.cte_sources), {"t", "y"})

View file

@ -76,6 +76,9 @@ class TestParser(unittest.TestCase):
tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)] tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)]
self.assertEqual(tables, ["a", "b.c", "d"]) self.assertEqual(tables, ["a", "b.c", "d"])
def test_union_order(self):
self.assertIsInstance(parse_one("SELECT * FROM (SELECT 1) UNION SELECT 2"), exp.Union)
def test_select(self): def test_select(self):
self.assertIsNotNone(parse_one("select 1 natural")) self.assertIsNotNone(parse_one("select 1 natural"))
self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"]) self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"])

View file

@ -40,17 +40,17 @@ class TestTime(unittest.TestCase):
self.validate( self.validate(
eliminate_distinct_on, eliminate_distinct_on,
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC", "SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1', 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1',
) )
self.validate( self.validate(
eliminate_distinct_on, eliminate_distinct_on,
"SELECT DISTINCT ON (a) a, b FROM x", "SELECT DISTINCT ON (a) a, b FROM x",
'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a) AS "_row_number" FROM x) WHERE "_row_number" = 1', 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a) AS _row_number FROM x) WHERE "_row_number" = 1',
) )
self.validate( self.validate(
eliminate_distinct_on, eliminate_distinct_on,
"SELECT DISTINCT ON (a, b) a, b FROM x ORDER BY c DESC", "SELECT DISTINCT ON (a, b) a, b FROM x ORDER BY c DESC",
'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1', 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1',
) )
self.validate( self.validate(
eliminate_distinct_on, eliminate_distinct_on,
@ -60,5 +60,5 @@ class TestTime(unittest.TestCase):
self.validate( self.validate(
eliminate_distinct_on, eliminate_distinct_on,
"SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC", "SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC",
'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS "_row_number_2" FROM x) WHERE "_row_number_2" = 1', 'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) WHERE "_row_number_2" = 1',
) )

View file

@ -28,7 +28,7 @@ class TestTranspile(unittest.TestCase):
self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime") self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime")
self.assertEqual(transpile("SELECT 1 row")[0], "SELECT 1 AS row") self.assertEqual(transpile("SELECT 1 row")[0], "SELECT 1 AS row")
for key in ("union", "filter", "over", "from", "join"): for key in ("union", "over", "from", "join"):
with self.subTest(f"alias {key}"): with self.subTest(f"alias {key}"):
self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}") self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}")
self.validate(f'SELECT x "{key}"', f'SELECT x AS "{key}"') self.validate(f'SELECT x "{key}"', f'SELECT x AS "{key}"')
@ -263,6 +263,25 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""",
"WITH a AS (SELECT 1), WITH b AS (SELECT 2) SELECT *", "WITH a AS (SELECT 1), WITH b AS (SELECT 2) SELECT *",
"WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *", "WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *",
) )
self.validate(
"WITH A(filter) AS (VALUES 1, 2, 3) SELECT * FROM A WHERE filter >= 2",
"WITH A(filter) AS (VALUES (1), (2), (3)) SELECT * FROM A WHERE filter >= 2",
)
def test_alter(self):
self.validate(
"ALTER TABLE integers ADD k INTEGER",
"ALTER TABLE integers ADD COLUMN k INT",
)
self.validate("ALTER TABLE integers DROP k", "ALTER TABLE integers DROP COLUMN k")
self.validate(
"ALTER TABLE integers ALTER i SET DATA TYPE VARCHAR",
"ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR",
)
self.validate(
"ALTER TABLE integers ALTER i TYPE VARCHAR COLLATE foo USING bar",
"ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR COLLATE foo USING bar",
)
def test_time(self): def test_time(self):
self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)") self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)")
@ -403,6 +422,14 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""",
with self.subTest(sql): with self.subTest(sql):
self.assertEqual(transpile(sql)[0], sql.strip()) self.assertEqual(transpile(sql)[0], sql.strip())
def test_normalize_name(self):
self.assertEqual(
transpile("cardinality(x)", read="presto", write="presto", normalize_functions="lower")[
0
],
"cardinality(x)",
)
def test_partial(self): def test_partial(self):
for sql in load_sql_fixtures("partial.sql"): for sql in load_sql_fixtures("partial.sql"):
with self.subTest(sql): with self.subTest(sql):