1
0
Fork 0

Merging upstream version 10.4.2.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:01:55 +01:00
parent de4e42d4d3
commit 0c79f8b507
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
88 changed files with 1637 additions and 436 deletions

View file

@ -20,7 +20,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -r dev-requirements.txt
make install-dev
- name: Run checks (linter, code style, tests)
run: |
./run_checks.sh
make check

5
.gitignore vendored
View file

@ -130,3 +130,8 @@ dmypy.json
# PyCharm
.idea/
# Visual Studio Code
.vscode
.DS_STORE

31
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,31 @@
repos:
- repo: local
hooks:
- id: autoflake
name: autoflake
entry: autoflake -i -r
language: system
types: [ python ]
require_serial: true
files: ^(sqlglot/|tests/|setup.py)
- id: isort
name: isort
entry: isort
language: system
types: [ python ]
files: ^(sqlglot/|tests/|setup.py)
require_serial: true
- id: black
name: black
entry: black --line-length 100
language: system
types: [ python ]
require_serial: true
files: ^(sqlglot/|tests/|setup.py)
- id: mypy
name: mypy
entry: mypy
language: system
types: [ python ]
files: ^(sqlglot/|tests/)
require_serial: true

View file

@ -1,3 +0,0 @@
{
"python.linting.pylintEnabled": true
}

View file

@ -1,6 +1,37 @@
Changelog
=========
v10.4.0
------
Changes:
- Breaking: Removed the quote_identities optimizer rule.
- New: ARRAYAGG, SUM, ARRAYANY support in the engine. SQLGlot is now able to execute all TPC-H queries.
- Improvement: Transpile DATEDIFF to postgres.
- Improvement: Right join pushdown fixes.
- Improvement: Have Snowflake generate VALUES columns without quotes.
- Improvement: Support NaN values in convert.
- Improvement: Recursive CTE scope [fixes](https://github.com/tobymao/sqlglot/commit/bec36391d85152fa478222403d06beffa8d6ddfb).
v10.3.0
------
Changes:
- Breaking: Json ops changed to binary expressions.
- New: Jinja tokenization.
- Improvement: More robust type inference.
v10.2.0
------

24
Makefile Normal file
View file

@ -0,0 +1,24 @@
.PHONY: install install-dev install-pre-commit test style check docs docs-serve
install:
pip install -e .
install-dev:
pip install -e ".[dev]"
install-pre-commit:
pre-commit install
test:
python -m unittest
style:
pre-commit run --all-files
check: style test
docs:
pdoc/cli.py -o pdoc/docs
docs-serve:
pdoc/cli.py

View file

@ -1,8 +1,8 @@
# SQLGlot
SQLGlot is a no dependency Python SQL parser, transpiler, and optimizer. It can be used to format SQL or translate between different dialects like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects.
SQLGlot is a no dependency Python SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between different dialects like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects.
It is a very comprehensive generic SQL parser with a robust [test suite](tests). It is also quite [performant](#benchmarks) while being written purely in Python.
It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks) while being written purely in Python.
You can easily [customize](#custom-dialects) the parser, [analyze](#metadata) queries, traverse expression trees, and programmatically [build](#build-and-modify-sql) SQL.
@ -13,8 +13,7 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/
## Table of Contents
* [Install](#install)
* [Documentation](#documentation)
* [Run Tests and Lint](#run-tests-and-lint)
* [Get in Touch](#get-in-touch)
* [Examples](#examples)
* [Formatting and Transpiling](#formatting-and-transpiling)
* [Metadata](#metadata)
@ -26,6 +25,8 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/
* [AST Diff](#ast-diff)
* [Custom Dialects](#custom-dialects)
* [SQL Execution](#sql-execution)
* [Documentation](#documentation)
* [Run Tests and Lint](#run-tests-and-lint)
* [Benchmarks](#benchmarks)
* [Optional Dependencies](#optional-dependencies)
@ -40,30 +41,17 @@ pip3 install sqlglot
Or with a local checkout:
```
pip3 install -e .
make install
```
Requirements for development (optional):
```
pip3 install -r dev-requirements.txt
```
## Documentation
SQLGlot uses [pdocs](https://pdoc.dev/) to serve its API documentation:
```
pdoc sqlglot --docformat google
```
## Run Tests and Lint
```
# set `SKIP_INTEGRATION=1` to skip integration tests
./run_checks.sh
make install-dev
```
## Get in Touch
We'd love to hear from you. Join our community [Slack channel](https://join.slack.com/t/tobiko-data/shared_invite/zt-1ma66d79v-a4dbf4DUpLAQJ8ptQrJygg)!
## Examples
@ -274,7 +262,7 @@ transformed_tree.sql()
### SQL Optimizer
SQLGlot can rewrite queries into an "optimized" form. It performs a variety of [techniques](sqlglot/optimizer/optimizer.py) to create a new canonical AST. This AST can be used to standardize queries or provide the foundations for implementing an actual engine. For example:
SQLGlot can rewrite queries into an "optimized" form. It performs a variety of [techniques](https://github.com/tobymao/sqlglot/blob/main/sqlglot/optimizer/optimizer.py) to create a new canonical AST. This AST can be used to standardize queries or provide the foundations for implementing an actual engine. For example:
```python
import sqlglot
@ -292,7 +280,7 @@ print(
)
```
```
```sql
SELECT
(
"x"."A" OR "x"."B" OR "x"."C"
@ -351,9 +339,11 @@ diff(parse_one("SELECT a + b, c, d"), parse_one("SELECT c, a - b, d"))
]
```
See also: [Semantic Diff for SQL](https://github.com/tobymao/sqlglot/blob/main/posts/sql_diff.md).
### Custom Dialects
[Dialects](sqlglot/dialects) can be added by subclassing `Dialect`:
[Dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) can be added by subclassing `Dialect`:
```python
from sqlglot import exp
@ -391,7 +381,7 @@ class Custom(Dialect):
print(Dialect["custom"])
```
```python
```
<class '__main__.Custom'>
```
@ -442,9 +432,23 @@ user_id price
2 3.0
```
## Documentation
SQLGlot uses [pdocs](https://pdoc.dev/) to serve its API documentation:
```
make docs-serve
```
## Run Tests and Lint
```
make check # Set SKIP_INTEGRATION=1 to skip integration tests
```
## Benchmarks
[Benchmarks](benchmarks) run on Python 3.10.5 in seconds.
[Benchmarks](https://github.com/tobymao/sqlglot/blob/main/benchmarks/bench.py) run on Python 3.10.5 in seconds.
| Query | sqlglot | sqlfluff | sqltree | sqlparse | moz_sql_parser | sqloxide |
| --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- |

View file

@ -1,9 +0,0 @@
autoflake
black
duckdb
isort
mypy
pandas
pyspark
python-dateutil
pdoc

34
pdoc/cli.py Executable file
View file

@ -0,0 +1,34 @@
#!/usr/bin/env python3
from importlib import import_module
from pathlib import Path
from unittest import mock
from pdoc.__main__ import cli, parser
# Need this import or else import_module doesn't work
import sqlglot
def mocked_import(*args, **kwargs):
"""Return a MagicMock if import fails for any reason"""
try:
return import_module(*args, **kwargs)
except Exception:
mocked_module = mock.MagicMock()
mocked_module.__name__ = args[0]
return mocked_module
if __name__ == "__main__":
# Mock uninstalled dependencies so pdoc can still work
with mock.patch("importlib.import_module", side_effect=mocked_import):
opts = parser.parse_args()
opts.docformat = "google"
opts.modules = ["sqlglot"]
opts.footer_text = "Copyright (c) 2022 Toby Mao"
opts.template_directory = Path(__file__).parent.joinpath("templates").absolute()
opts.edit_url = ["sqlglot=https://github.com/tobymao/sqlglot/"]
with mock.patch("pdoc.__main__.parser", **{"parse_args.return_value": opts}):
cli()

41
pdoc/docs/expressions.md Normal file
View file

@ -0,0 +1,41 @@
# Expressions
Every AST node in SQLGlot is represented by a subclass of `Expression`. Each such expression encapsulates any necessary context, such as its child expressions, their names, or arg keys, and whether each child expression is optional or not.
Furthermore, the following attributes are common across all expressions:
#### key
A unique key for each class in the `Expression` hierarchy. This is useful for hashing and representing expressions as strings.
#### args
A dictionary used for mapping child arg keys, to the corresponding expressions. A value in this mapping is usually either a single or a list of `Expression` instances, but SQLGlot doesn't impose any constraints on the actual type of the value.
#### arg_types
A dictionary used for mapping arg keys to booleans that determine whether the corresponding expressions are optional or not. Consider the following example:
```python
class Limit(Expression):
arg_types = {"this": False, "expression": True}
```
Here, `Limit` declares that it expects to have one optional and one required child expression, which can be referenced through `this` and `expression`, respectively. The arg keys are generally arbitrary, but there are helper methods for keys like `this`, `expression` and `expressions` that abstract away dictionary lookups and related checks. For this reason, these keys are common throughout SQLGlot's codebase.
#### parent
A reference to the parent expression (may be `None`).
#### arg_key
The arg key an expression is associated with, i.e. the name its parent expression uses to refer to it.
#### comments
A list of comments that are associated with a given expression. This is used in order to preserve comments when transpiling SQL code.
#### type
The data type of an expression, as inferred by SQLGlot's optimizer.

View file

@ -0,0 +1,6 @@
{% extends "default/module.html.jinja2" %}
{% if module.docstring %}
{% macro module_name() %}
{% endmacro %}
{% endif %}

208
posts/python_sql_engine.md Normal file
View file

@ -0,0 +1,208 @@
# Writing a Python SQL engine from scratch
[Toby Mao](https://www.linkedin.com/in/toby-mao/)
## Introduction
When I first started writing SQLGlot in early 2021, my goal was just to translate SQL queries from SparkSQL to Presto and vice versa. However, over the last year and a half, I've ended up with a full-fledged SQL engine. SQLGlot can now parse and transpile between [18 SQL dialects](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) and can execute all 24 [TPC-H](https://www.tpc.org/tpch/) SQL queries. The parser and engine are all written from scratch using Python.
This post will cover [why](#why) I went through the effort of creating a Python SQL engine and [how](#how) a simple query goes from a string to actually transforming data. The following steps are briefly summarized:
* [Tokenizing](#tokenizing)
* [Parsing](#parsing)
* [Optimizing](#optimizing)
* [Planning](#planning)
* [Executing](#executing)
## Why?
I started working on SQLGlot because of my work on the [experimentation and metrics platform](https://netflixtechblog.com/reimagining-experimentation-analysis-at-netflix-71356393af21) at Netflix, where I built tools that allowed data scientists to define and compute SQL-based metrics. Netflix relied on multiple engines to query data (Spark, Presto, and Druid), so my team built the metrics platform around [PyPika](https://github.com/kayak/pypika), a Python SQL query builder. This way, definitions could be reused across multiple engines. However, it became quickly apparent that writing python code to programatically generate SQL was challenging for data scientists, especially those with academic backgrounds, since they were mostly familiar with R and SQL. At the time, the only Python SQL parser was [sqlparse]([https://github.com/andialbrecht/sqlparse), which is not actually a parser but a tokenizer, so having users write raw SQL into the platform wasn't really an option. Some time later, I randomly stumbled across [Crafting Interpreters](https://craftinginterpreters.com/) and realized that I could use it as a guide towards creating my own SQL parser/transpiler.
Why did I do this? Isn't a Python SQL engine going to be extremely slow?
The main reason why I ended up building a SQL engine was...just for **entertainment**. It's been fun learning about all the things required to actually run a SQL query, and seeing it actually work is extremely rewarding. Before SQLGlot, I had zero experience with lexers, parsers, or compilers.
In terms of practical use cases, I planned to use the Python SQL engine for unit testing SQL pipelines. Big data pipelines are tough to test because many of the engines are not open source and cannot be run locally. With SQLGlot, you can take a SQL query targeting a warehouse such as [Snowflake](https://www.snowflake.com/en/) and seamlessly run it in CI on mock Python data. It's easy to mock data and create arbitrary [UDFs](https://en.wikipedia.org/wiki/User-defined_function) because everything is just Python. Although the implementation is slow and unsuitable for large amounts of data (> 1 millon rows), there's very little overhead/startup and you can run queries on test data in a couple of milliseconds.
Finally, the components that have been built to support execution can be used as a **foundation** for a faster engine. I'm inspired by what [Apache Calcite](https://github.com/apache/calcite) has done for the JVM world. Even though Python is commonly used for data, there hasn't been a Calcite for Python. So, you could say that SQLGlot aims to be that framework. For example, it wouldn't take much work to replace the Python execution engine with numpy/pandas/arrow to become a respectably-performing query engine. The implementation would be able to leverage the parser, optimizer, and logical planner, only needing to implement physical execution. There is a lot of work in the Python ecosystem around high performance vectorized computation, which I think could benefit from a pure Python-based [AST](https://en.wikipedia.org/wiki/Abstract_syntax_tree)/[plan](https://en.wikipedia.org/wiki/Query_plan). Parsing and planning doesn't have to be fast when the bottleneck of running queries is processing terabytes of data. So, having a Python-based ecosystem around SQL is beneficial given the ease of development in Python, despite not having bare metal performance.
Parts of SQLGlot's toolkit are being used today by the following:
* [Ibis](https://github.com/ibis-project/ibis): A Python library that provides a lightweight, universal interface for data wrangling.
- Uses the Python SQL expression builder and leverages the optimizer/planner to convert SQL into dataframe operations.
* [mysql-mimic](https://github.com/kelsin/mysql-mimic): Pure-Python implementation of the MySQL server wire protocol
- Parses / transforms SQL and executes INFORMATION_SCHEMA queries.
* [Quokka](https://github.com/marsupialtail/quokka): Push-based vectorized query engine
- Parse and optimizes SQL.
* [Splink](https://github.com/moj-analytical-services/splink): Fast, accurate and scalable probabilistic data linkage using your choice of SQL backend.
- Transpiles queries.
## How?
There are many steps involved with actually running a simple query like:
```sql
SELECT
bar.a,
b + 1 AS b
FROM bar
JOIN baz
ON bar.a = baz.a
WHERE bar.a > 1
```
In this post, I'll walk through all the steps SQLGlot takes to run this query over Python objects.
## Tokenizing
The first step is to convert the sql string into a list of tokens. SQLGlot's tokenizer is quite simple and can be found [here](https://github.com/tobymao/sqlglot/blob/main/sqlglot/tokens.py). In a while loop, it checks each character and either appends the character to the current token, or makes a new token.
Running the SQLGlot tokenizer shows the output.
![Tokenizer Output](python_sql_engine_images/tokenizer.png)
Each keyword has been converted to a SQLGlot Token object. Each token has some metadata associated with it, like line/column information for error messages. Comments are also a part of the token, so that comments can be preserved.
## Parsing
Once a SQL statement is tokenized, we don't need to worry about white space and other formatting, so it's easier to work with. We can now convert the list of tokens into an AST. The SQLGlot [parser](https://github.com/tobymao/sqlglot/blob/main/sqlglot/parser.py) is a handwritten [recursive descent](https://en.wikipedia.org/wiki/Recursive_descent_parser) parser.
Similar to the tokenizer, it consumes the tokens sequentially, but it instead uses a recursive algorithm. The tokens are converted into a single AST node that presents the SQL query. The SQLGlot parser was designed to support various dialects, so it contains many options for overriding parsing functionality.
![Parser Output](python_sql_engine_images/parser.png)
The AST is a generic representation of a given SQL query. Each dialect can override or implement its own generator, which can convert an AST object into syntatically-correct SQL.
## Optimizing
Once we have our AST, we can transform it into an equivalent query that produces the same results more efficiently. When optimizing queries, most engines first convert the AST into a logical plan and then optimize the plan. However, I chose to **optimize the AST directly** for the following reasons:
1. It's easier to debug and [validate](https://github.com/tobymao/sqlglot/blob/main/tests/fixtures/optimizer) the optimizations when the input and output are both SQL.
2. Rules can be applied a la carte to transform SQL into a more desireable form.
3. I wanted a way to generate 'canonical sql'. Having a canonical representation of SQL is useful for understanding if two queries are semantically equivalent (e.g. `SELECT 1 + 1` and `SELECT 2`).
I've yet to find another engine that takes this approach, but I'm quite happy with this decision. The optimizer currently does not perform any "physical optimizations" such as join reordering. Those are left to the execution layer, as additional statistics and information could become relevant.
![Optimizer Output](python_sql_engine_images/optimizer.png)
The optimizer currently has [17 rules](https://github.com/tobymao/sqlglot/tree/main/sqlglot/optimizer). Each of these rules is applied, transforming the AST in place. The combination of these rules creates "canonical" sql that can then be more easily converted into a logical plan and executed.
Some example rules are:
### qualify\_tables and qualify_columns
- Adds all db/catalog qualifiers to tables and forces an alias.
- Ensure each column is unambiguous and expand stars.
```sql
SELECT * FROM x;
SELECT "db"."x" AS "x";
```
### simplify
Boolean and math simplification. Check out all the [test cases](https://github.com/tobymao/sqlglot/blob/main/tests/fixtures/optimizer/simplify.sql).
```sql
((NOT FALSE) AND (x = x)) AND (TRUE OR 1 <> 3);
x = x;
1 + 1;
2;
```
### normalize
Attempts to convert all predicates into [conjunctive normal form](https://en.wikipedia.org/wiki/Conjunctive_normal_form).
```sql
-- DNF
(A AND B) OR (B AND C AND D);
-- CNF
(A OR C) AND (A OR D) AND B;
```
### unnest\_subqueries
Converts subqueries in predicates into joins.
```sql
-- The subquery can be converted into a left join
SELECT *
FROM x AS x
WHERE (
SELECT y.a AS a
FROM y AS y
WHERE x.a = y.a
) = 1;
SELECT *
FROM x AS x
LEFT JOIN (
SELECT y.a AS a
FROM y AS y
WHERE TRUE
GROUP BY y.a
) AS "_u_0"
ON x.a = "_u_0".a
WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)
```
### pushdown_predicates
Push down filters into the innermost query.
```sql
SELECT *
FROM (
SELECT *
FROM x AS x
) AS y
WHERE y.a = 1;
SELECT *
FROM (
SELECT *
FROM x AS x
WHERE y.a = 1
) AS y WHERE TRUE
```
### annotate_types
Infer all types throughout the AST given schema information and function type definitions.
## Planning
After the SQL AST has been "optimized", it's much easier to [convert into a logical plan](https://github.com/tobymao/sqlglot/blob/main/sqlglot/planner.py). The AST is traversed and converted into a [DAG](https://en.wikipedia.org/wiki/Directed_acyclic_graph) consisting of one of five steps. The different steps are:
### Scan
Selects columns from a table, applies projections, and finally filters the table.
### Sort
Sorts a table for order by expressions.
### Set
Applies the operators union/union all/except/intersect.
### Aggregate
Applies an aggregation/group by.
### Join
Joins multiple tables together.
![Planner Output](python_sql_engine_images/planner.png)
The logical plan is quite simple and contains the information required to convert it into a physical plan (execution).
## Executing
Finally, we can actually execute the SQL query. The [Python engine](https://github.com/tobymao/sqlglot/blob/main/sqlglot/executor/python.py) is not fast, but it's very small (~400 LOC)! It iterates the DAG with a queue and runs each step, passing each intermediary table to the next step.
In order to keep things simple, it evaluates expressions with `eval`. Because SQLGlot was built primarily to be a transpiler, it was simple to create a "Python SQL" dialect. So a SQL expression `x + 1` can just be converted into `scope['x'] + 1`.
![Executor Output](python_sql_engine_images/executor.png)
## What's next
SQLGlot's main focus will always be on parsing/transpiling, but I plan to continue development on the execution engine. I'd like to pass [TPC-DS](https://www.tpc.org/tpcds/). If someone doesn't beat me to it, I may even take a stab at writing a Pandas/Arrow execution engine.
I'm hoping that over time, SQLGlot will spark the Python SQL ecosystem just like Calcite has for Java.
## Special thanks
SQLGlot would not be what it is without it's core contributors. In particular, the execution engine would not exist without [Barak Alon](https://github.com/barakalon) and [George Sittas](https://github.com/GeorgeSittas).
## Get in touch
If you'd like to chat more about SQLGlot, please join my [Slack Channel](https://join.slack.com/t/tobiko-data/shared_invite/zt-1ma66d79v-a4dbf4DUpLAQJ8ptQrJygg)!

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 154 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 121 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 333 KiB

View file

@ -1,8 +0,0 @@
#!/bin/bash -e
[[ -z "${GITHUB_ACTIONS}" ]] && RETURN_ERROR_CODE='' || RETURN_ERROR_CODE='--check'
TARGETS="sqlglot/ tests/"
python -m mypy $TARGETS
python -m autoflake -i -r ${RETURN_ERROR_CODE} $TARGETS
python -m isort $TARGETS
python -m black --line-length 100 ${RETURN_ERROR_CODE} $TARGETS
python -m unittest

View file

@ -22,6 +22,20 @@ setup(
license="MIT",
packages=find_packages(include=["sqlglot", "sqlglot.*"]),
package_data={"sqlglot": ["py.typed"]},
extras_require={
"dev": [
"autoflake",
"black",
"duckdb",
"isort",
"mypy",
"pandas",
"pyspark",
"python-dateutil",
"pdoc",
"pre-commit",
],
},
classifiers=[
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",

View file

@ -1,4 +1,6 @@
"""## Python SQL parser, transpiler and optimizer."""
"""
.. include:: ../README.md
"""
from __future__ import annotations
@ -30,7 +32,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "10.2.9"
__version__ = "10.4.2"
pretty = False

View file

@ -1,9 +1,15 @@
import argparse
import sys
import sqlglot
parser = argparse.ArgumentParser(description="Transpile SQL")
parser.add_argument("sql", metavar="sql", type=str, help="SQL string to transpile")
parser.add_argument(
"sql",
metavar="sql",
type=str,
help="SQL statement(s) to transpile, or - to parse stdin.",
)
parser.add_argument(
"--read",
dest="read",
@ -48,14 +54,20 @@ parser.add_argument(
args = parser.parse_args()
error_level = sqlglot.ErrorLevel[args.error_level.upper()]
sql = sys.stdin.read() if args.sql == "-" else args.sql
if args.parse:
sqls = [
repr(expression)
for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)
for expression in sqlglot.parse(
sql,
read=args.read,
error_level=error_level,
)
]
else:
sqls = sqlglot.transpile(
args.sql,
sql,
read=args.read,
write=args.write,
identify=args.identify,

View file

@ -0,0 +1,3 @@
"""
.. include:: ./README.md
"""

View file

@ -9,18 +9,8 @@ if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.types import StructType
ColumnLiterals = t.TypeVar(
"ColumnLiterals",
bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
)
ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str])
ColumnOrLiteral = t.TypeVar(
"ColumnOrLiteral",
bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
)
SchemaInput = t.TypeVar(
"SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]]
)
OutputExpressionContainer = t.TypeVar(
"OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert]
)
ColumnLiterals = t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
ColumnOrName = t.Union[Column, str]
ColumnOrLiteral = t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
SchemaInput = t.Union[str, t.List[str], StructType, t.Dict[str, t.Optional[str]]]
OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert]

View file

@ -634,7 +634,7 @@ class DataFrame:
all_columns = self._get_outer_select_columns(new_df.expression)
all_column_mapping = {column.alias_or_name: column for column in all_columns}
if isinstance(value, dict):
values = value.values()
values = list(value.values())
columns = self._ensure_and_normalize_cols(list(value))
if not columns:
columns = self._ensure_and_normalize_cols(subset) if subset else all_columns

View file

@ -1,11 +1,15 @@
"""Supports BigQuery Standard SQL."""
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
datestrtodate_sql,
inline_array_sql,
no_ilike_sql,
rename_func,
timestrtotime_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@ -120,13 +124,12 @@ class BigQuery(Dialect):
"NOT DETERMINISTIC": TokenType.VOLATILE,
"QUALIFY": TokenType.QUALIFY,
"UNKNOWN": TokenType.NULL,
"WINDOW": TokenType.WINDOW,
}
KEYWORDS.pop("DIV")
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"DATE_TRUNC": _date_trunc,
"DATE_ADD": _date_add(exp.DateAdd),
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
@ -144,31 +147,33 @@ class BigQuery(Dialect):
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
**parser.Parser.FUNCTION_PARSERS, # type: ignore
"ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]),
}
FUNCTION_PARSERS.pop("TRIM")
NO_PAREN_FUNCTIONS = {
**parser.Parser.NO_PAREN_FUNCTIONS,
**parser.Parser.NO_PAREN_FUNCTIONS, # type: ignore
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
TokenType.CURRENT_TIME: exp.CurrentTime,
}
NESTED_TYPE_TOKENS = {
*parser.Parser.NESTED_TYPE_TOKENS,
*parser.Parser.NESTED_TYPE_TOKENS, # type: ignore
TokenType.TABLE,
}
class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.DateStrToDate: datestrtodate_sql,
exp.GroupConcat: rename_func("STRING_AGG"),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
@ -176,6 +181,7 @@ class BigQuery(Dialect):
exp.TimeSub: _date_add_sql("TIME", "SUB"),
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
exp.VariancePop: rename_func("VAR_POP"),
exp.Values: _derived_table_values_to_unnest,
exp.ReturnsProperty: _returnsproperty_sql,
@ -188,7 +194,7 @@ class BigQuery(Dialect):
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.SMALLINT: "INT64",
exp.DataType.Type.INT: "INT64",

View file

@ -35,13 +35,13 @@ class ClickHouse(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"MAP": parse_var_map,
}
JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF}
JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY}
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
def _parse_table(self, schema=False):
this = super()._parse_table(schema)
@ -55,7 +55,7 @@ class ClickHouse(Dialect):
STRUCT_DELIMITER = ("(", ")")
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.NULLABLE: "Nullable",
exp.DataType.Type.DATETIME: "DateTime64",
exp.DataType.Type.MAP: "Map",
@ -70,7 +70,7 @@ class ClickHouse(Dialect):
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
exp.Array: inline_array_sql,
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",

View file

@ -198,7 +198,7 @@ class Dialect(metaclass=_Dialect):
def rename_func(name):
def _rename(self, expression):
args = flatten(expression.args.values())
return f"{name}({self.format_args(*args)})"
return f"{self.normalize_func(name)}({self.format_args(*args)})"
return _rename
@ -217,11 +217,11 @@ def if_sql(self, expression):
def arrow_json_extract_sql(self, expression):
return f"{self.sql(expression, 'this')}->{self.sql(expression, 'path')}"
return self.binary(expression, "->")
def arrow_json_extract_scalar_sql(self, expression):
return f"{self.sql(expression, 'this')}->>{self.sql(expression, 'path')}"
return self.binary(expression, "->>")
def inline_array_sql(self, expression):
@ -373,3 +373,11 @@ def strposition_to_local_sql(self, expression):
expression.args.get("substr"), expression.this, expression.args.get("position")
)
return f"LOCATE({args})"
def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str:
return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str:
return f"CAST({self.sql(expression, 'this')} AS DATE)"

View file

@ -6,13 +6,14 @@ from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
create_with_partitions_sql,
datestrtodate_sql,
format_time_lambda,
no_pivot_sql,
no_trycast_sql,
rename_func,
str_position_sql,
timestrtotime_sql,
)
from sqlglot.dialects.postgres import _lateral_sql
def _to_timestamp(args):
@ -117,14 +118,14 @@ class Drill(Dialect):
STRICT_CAST = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
}
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER",
@ -139,14 +140,13 @@ class Drill(Dialect):
ROOT_PROPERTIES = {exp.PartitionedByProperty}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.Lateral: _lateral_sql,
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
exp.ArraySize: rename_func("REPEATED_COUNT"),
exp.Create: create_with_partitions_sql,
exp.DateAdd: _date_add_sql("ADD"),
exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("SUB"),
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})",
@ -160,7 +160,7 @@ class Drill(Dialect):
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),

View file

@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
approx_count_distinct_sql,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
datestrtodate_sql,
format_time_lambda,
no_pivot_sql,
no_properties_sql,
@ -13,6 +14,7 @@ from sqlglot.dialects.dialect import (
no_tablesample_sql,
rename_func,
str_position_sql,
timestrtotime_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@ -83,11 +85,12 @@ class DuckDB(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
":=": TokenType.EQ,
"CHARACTER VARYING": TokenType.VARCHAR,
}
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list,
@ -119,16 +122,18 @@ class DuckDB(Dialect):
STRUCT_DELIMITER = ("(", ")")
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: rename_func("LIST_VALUE"),
exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
if isinstance(seq_get(e.expressions, 0), exp.Select)
else rename_func("LIST_VALUE")(self, e),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),
exp.DataType: _datatype_sql,
exp.DateAdd: _date_add,
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.format_args(e.args.get("unit") or "'day'", e.expression, e.this)})""",
exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)",
exp.Explode: rename_func("UNNEST"),
@ -136,6 +141,7 @@ class DuckDB(Dialect):
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Pivot: no_pivot_sql,
exp.Properties: no_properties_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
@ -150,7 +156,7 @@ class DuckDB(Dialect):
exp.Struct: _struct_pack_sql,
exp.TableSample: no_tablesample_sql,
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))",
exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("EPOCH"),
@ -163,7 +169,7 @@ class DuckDB(Dialect):
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.NVARCHAR: "TEXT",
}

View file

@ -15,6 +15,7 @@ from sqlglot.dialects.dialect import (
rename_func,
strposition_to_local_sql,
struct_extract_sql,
timestrtotime_sql,
var_map_sql,
)
from sqlglot.helper import seq_get
@ -197,7 +198,7 @@ class Hive(Dialect):
STRICT_CAST = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd(
@ -217,7 +218,12 @@ class Hive(Dialect):
),
unit=exp.Literal.string("DAY"),
),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"),
"DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")(
[
exp.TimeStrToTime(this=seq_get(args, 0)),
seq_get(args, 1),
]
),
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
@ -240,7 +246,7 @@ class Hive(Dialect):
}
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS,
**parser.Parser.PROPERTY_PARSERS, # type: ignore
TokenType.SERDE_PROPERTIES: lambda self: exp.SerdeProperties(
expressions=self._parse_wrapped_csv(self._parse_property)
),
@ -248,14 +254,14 @@ class Hive(Dialect):
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.DATETIME: "TIMESTAMP",
exp.DataType.Type.VARBINARY: "BINARY",
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
@ -294,7 +300,7 @@ class Hive(Dialect):
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: _time_to_str,
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),

View file

@ -161,8 +161,6 @@ class MySQL(Dialect):
"_UCS2": TokenType.INTRODUCER,
"_UJIS": TokenType.INTRODUCER,
# https://dev.mysql.com/doc/refman/8.0/en/string-literals.html
"N": TokenType.INTRODUCER,
"n": TokenType.INTRODUCER,
"_UTF8": TokenType.INTRODUCER,
"_UTF16": TokenType.INTRODUCER,
"_UTF16LE": TokenType.INTRODUCER,
@ -175,10 +173,10 @@ class MySQL(Dialect):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
class Parser(parser.Parser):
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA}
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA} # type: ignore
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"DATE_ADD": _date_add(exp.DateAdd),
"DATE_SUB": _date_add(exp.DateSub),
"STR_TO_DATE": _str_to_date,
@ -190,7 +188,7 @@ class MySQL(Dialect):
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
**parser.Parser.FUNCTION_PARSERS, # type: ignore
"GROUP_CONCAT": lambda self: self.expression(
exp.GroupConcat,
this=self._parse_lambda(),
@ -199,12 +197,12 @@ class MySQL(Dialect):
}
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS,
**parser.Parser.PROPERTY_PARSERS, # type: ignore
TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty),
}
STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS,
**parser.Parser.STATEMENT_PARSERS, # type: ignore
TokenType.SHOW: lambda self: self._parse_show(),
TokenType.SET: lambda self: self._parse_set(),
}
@ -429,7 +427,7 @@ class MySQL(Dialect):
NULL_ORDERING_SUPPORTED = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ILike: no_ilike_sql,

View file

@ -39,13 +39,13 @@ class Oracle(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"DECODE": exp.Matches.from_arg_list,
}
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "NUMBER",
exp.DataType.Type.SMALLINT: "NUMBER",
exp.DataType.Type.INT: "NUMBER",
@ -60,7 +60,7 @@ class Oracle(Dialect):
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
exp.ILike: no_ilike_sql,
exp.Limit: _limit_sql,

View file

@ -11,9 +11,19 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
str_position_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
from sqlglot.transforms import delegate, preprocess
DATE_DIFF_FACTOR = {
"MICROSECOND": " * 1000000",
"MILLISECOND": " * 1000",
"SECOND": "",
"MINUTE": " / 60",
"HOUR": " / 3600",
"DAY": " / 86400",
}
def _date_add_sql(kind):
def func(self, expression):
@ -34,16 +44,30 @@ def _date_add_sql(kind):
return func
def _lateral_sql(self, expression):
this = self.sql(expression, "this")
if isinstance(expression.this, exp.Subquery):
return f"LATERAL{self.sep()}{this}"
alias = expression.args["alias"]
table = alias.name
table = f" {table}" if table else table
columns = self.expressions(alias, key="columns", flat=True)
columns = f" AS {columns}" if columns else ""
return f"LATERAL{self.sep()}{this}{table}{columns}"
def _date_diff_sql(self, expression):
unit = expression.text("unit").upper()
factor = DATE_DIFF_FACTOR.get(unit)
end = f"CAST({expression.this} AS TIMESTAMP)"
start = f"CAST({expression.expression} AS TIMESTAMP)"
if factor is not None:
return f"CAST(EXTRACT(epoch FROM {end} - {start}){factor} AS BIGINT)"
age = f"AGE({end}, {start})"
if unit == "WEEK":
extract = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7"
elif unit == "MONTH":
extract = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})"
elif unit == "QUARTER":
extract = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3"
elif unit == "YEAR":
extract = f"EXTRACT(year FROM {age})"
else:
self.unsupported(f"Unsupported DATEDIFF unit {unit}")
return f"CAST({extract} AS BIGINT)"
def _substring_sql(self, expression):
@ -141,7 +165,7 @@ def _serial_to_generated(expression):
def _to_timestamp(args):
# TO_TIMESTAMP accepts either a single double argument or (text, text)
if len(args) == 1 and args[0].is_number:
if len(args) == 1:
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE
return exp.UnixToTime.from_arg_list(args)
# https://www.postgresql.org/docs/current/functions-formatting.html
@ -211,11 +235,16 @@ class Postgres(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"~~": TokenType.LIKE,
"~~*": TokenType.ILIKE,
"~*": TokenType.IRLIKE,
"~": TokenType.RLIKE,
"ALWAYS": TokenType.ALWAYS,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"BIGSERIAL": TokenType.BIGSERIAL,
"BY DEFAULT": TokenType.BY_DEFAULT,
"CHARACTER VARYING": TokenType.VARCHAR,
"COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
@ -233,6 +262,7 @@ class Postgres(Dialect):
"SMALLSERIAL": TokenType.SMALLSERIAL,
"TEMP": TokenType.TEMPORARY,
"UUID": TokenType.UUID,
"CSTRING": TokenType.PSEUDO_TYPE,
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
}
@ -244,17 +274,16 @@ class Postgres(Dialect):
class Parser(parser.Parser):
STRICT_CAST = False
LATERAL_FUNCTION_AS_VIEW = True
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"TO_TIMESTAMP": _to_timestamp,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
}
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "SMALLINT",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
@ -264,7 +293,7 @@ class Postgres(Dialect):
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
exp.ColumnDef: preprocess(
[
_auto_increment_to_serial,
@ -274,13 +303,16 @@ class Postgres(Dialect):
),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}",
exp.JSONBExtractScalar: lambda self, e: f"{self.sql(e, 'this')}#>>{self.sql(e, 'path')}",
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"),
exp.JSONBContains: lambda self, e: self.binary(e, "?"),
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: _date_add_sql("+"),
exp.DateSub: _date_add_sql("-"),
exp.Lateral: _lateral_sql,
exp.DateDiff: _date_diff_sql,
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
exp.StrPosition: str_position_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Substring: _substring_sql,
@ -291,5 +323,7 @@ class Postgres(Dialect):
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
exp.DataType: _datatype_sql,
exp.GroupConcat: _string_agg_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
if isinstance(seq_get(e.expressions, 0), exp.Select)
else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]",
}

View file

@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
rename_func,
str_position_sql,
struct_extract_sql,
timestrtotime_sql,
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.errors import UnsupportedError
@ -38,10 +39,6 @@ def _datatype_sql(self, expression):
return sql
def _date_parse_sql(self, expression):
return f"DATE_PARSE({self.sql(expression, 'this')}, '%Y-%m-%d %H:%i:%s')"
def _explode_to_unnest_sql(self, expression):
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
return self.sql(
@ -137,7 +134,7 @@ class Presto(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"CARDINALITY": exp.ArraySize.from_arg_list,
"CONTAINS": exp.ArrayContains.from_arg_list,
@ -174,7 +171,7 @@ class Presto(Dialect):
ROOT_PROPERTIES = {exp.SchemaCommentProperty}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.BINARY: "VARBINARY",
@ -184,7 +181,7 @@ class Presto(Dialect):
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
@ -224,8 +221,8 @@ class Presto(Dialect):
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.TimeStrToDate: _date_parse_sql,
exp.TimeStrToTime: _date_parse_sql,
exp.TimeStrToDate: timestrtotime_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",
exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("TO_UNIXTIME"),

View file

@ -36,7 +36,6 @@ class Redshift(Postgres):
"TIMETZ": TokenType.TIMESTAMPTZ,
"UNLOAD": TokenType.COMMAND,
"VARBYTE": TokenType.VARBINARY,
"SIMILAR TO": TokenType.SIMILAR_TO,
}
class Generator(Postgres.Generator):

View file

@ -3,13 +3,15 @@ from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
datestrtodate_sql,
format_time_lambda,
inline_array_sql,
rename_func,
timestrtotime_sql,
var_map_sql,
)
from sqlglot.expressions import Literal
from sqlglot.helper import seq_get
from sqlglot.helper import flatten, seq_get
from sqlglot.tokens import TokenType
@ -183,7 +185,7 @@ class Snowflake(Dialect):
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
ESCAPES = ["\\"]
ESCAPES = ["\\", "'"]
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
@ -206,9 +208,10 @@ class Snowflake(Dialect):
CREATE_TRANSIENT = True
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
@ -218,13 +221,14 @@ class Snowflake(Dialect):
exp.Matches: rename_func("DECODE"),
exp.StrPosition: rename_func("POSITION"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
exp.UnixToTime: _unix_to_time_sql,
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}
@ -246,3 +250,47 @@ class Snowflake(Dialect):
if not expression.args.get("distinct", False):
self.unsupported("INTERSECT with All is not supported in Snowflake")
return super().intersect_op(expression)
def values_sql(self, expression: exp.Values) -> str:
"""Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted.
We also want to make sure that after we find matches where we need to unquote a column that we prevent users
from adding quotes to the column by using the `identify` argument when generating the SQL.
"""
alias = expression.args.get("alias")
if alias and alias.args.get("columns"):
expression = expression.transform(
lambda node: exp.Identifier(**{**node.args, "quoted": False})
if isinstance(node, exp.Identifier)
and isinstance(node.parent, exp.TableAlias)
and node.arg_key == "columns"
else node,
)
return self.no_identify(lambda: super(self.__class__, self).values_sql(expression))
return super().values_sql(expression)
def select_sql(self, expression: exp.Select) -> str:
"""Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also
that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need
to unquote a column that we prevent users from adding quotes to the column by using the `identify` argument when
generating the SQL.
Note: We make an assumption that any columns referenced in a VALUES expression should be unquoted throughout the
expression. This might not be true in a case where the same column name can be sourced from another table that can
properly quote but should be true in most cases.
"""
values_expressions = expression.find_all(exp.Values)
values_identifiers = set(
flatten(
v.args.get("alias", exp.Alias()).args.get("columns", [])
for v in values_expressions
)
)
if values_identifiers:
expression = expression.transform(
lambda node: exp.Identifier(**{**node.args, "quoted": False})
if isinstance(node, exp.Identifier) and node in values_identifiers
else node,
)
return self.no_identify(lambda: super(self.__class__, self).select_sql(expression))
return super().select_sql(expression)

View file

@ -76,7 +76,7 @@ class Spark(Hive):
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
**parser.Parser.FUNCTION_PARSERS, # type: ignore
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
@ -87,6 +87,16 @@ class Spark(Hive):
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
}
def _parse_add_column(self):
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
def _parse_drop_column(self):
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
exp.Drop,
this=self._parse_schema(),
kind="COLUMNS",
)
class Generator(Hive.Generator):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING, # type: ignore

View file

@ -42,13 +42,13 @@ class SQLite(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"EDITDIST3": exp.Levenshtein.from_arg_list,
}
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BOOLEAN: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER",
@ -70,7 +70,7 @@ class SQLite(Dialect):
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
exp.ILike: no_ilike_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,

View file

@ -8,7 +8,7 @@ from sqlglot.dialects.mysql import MySQL
class StarRocks(MySQL):
class Generator(MySQL.Generator): # type: ignore
TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING,
**MySQL.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",

View file

@ -30,7 +30,7 @@ class Tableau(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"IFNULL": exp.Coalesce.from_arg_list,
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
}

View file

@ -224,11 +224,7 @@ class TSQL(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]")]
QUOTES = [
(prefix + quote, quote) if prefix else quote
for quote in ["'", '"']
for prefix in ["", "n", "N"]
]
QUOTES = ["'", '"']
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@ -253,7 +249,7 @@ class TSQL(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"CHARINDEX": exp.StrPosition.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
@ -314,7 +310,7 @@ class TSQL(Dialect):
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BOOLEAN: "BIT",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC",

View file

@ -1,3 +1,7 @@
"""
.. include:: ../posts/sql_diff.md
"""
from __future__ import annotations
import typing as t

View file

@ -29,10 +29,10 @@ class Context:
self._table: t.Optional[Table] = None
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
self.row_readers = {name: table.reader for name, table in tables.items()}
self.env = {**(env or {}), "scope": self.row_readers}
self.env = {**ENV, **(env or {}), "scope": self.row_readers}
def eval(self, code):
return eval(code, ENV, self.env)
return eval(code, self.env)
def eval_tuple(self, codes):
return tuple(self.eval(code) for code in codes)

View file

@ -127,14 +127,16 @@ def interval(this, unit):
ENV = {
"exp": exp,
# aggs
"SUM": filter_nulls(sum),
"ARRAYAGG": list,
"AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
"COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False),
"MAX": filter_nulls(max),
"MIN": filter_nulls(min),
"SUM": filter_nulls(sum),
# scalar functions
"ABS": null_if_any(lambda this: abs(this)),
"ADD": null_if_any(lambda e, this: e + this),
"ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)),
"BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
"BITWISEAND": null_if_any(lambda this, e: this & e),
"BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),

View file

@ -394,6 +394,18 @@ def _case_sql(self, expression):
return chain
def _lambda_sql(self, e: exp.Lambda) -> str:
names = {e.name.lower() for e in e.expressions}
e = e.transform(
lambda n: exp.Var(this=n.name)
if isinstance(n, exp.Identifier) and n.name.lower() in names
else n
)
return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}"
class Python(Dialect):
class Tokenizer(tokens.Tokenizer):
ESCAPES = ["\\"]
@ -414,6 +426,7 @@ class Python(Dialect):
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
exp.Is: lambda self, e: self.binary(e, "is"),
exp.Lambda: _lambda_sql,
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",
exp.Or: lambda self, e: self.binary(e, "or"),

View file

@ -1,6 +1,11 @@
"""
.. include:: ../pdoc/docs/expressions.md
"""
from __future__ import annotations
import datetime
import math
import numbers
import re
import typing as t
@ -682,6 +687,10 @@ class CharacterSet(Expression):
class With(Expression):
arg_types = {"expressions": True, "recursive": False}
@property
def recursive(self) -> bool:
return bool(self.args.get("recursive"))
class WithinGroup(Expression):
arg_types = {"this": True, "expression": False}
@ -724,6 +733,18 @@ class ColumnDef(Expression):
"this": True,
"kind": True,
"constraints": False,
"exists": False,
}
class AlterColumn(Expression):
arg_types = {
"this": True,
"dtype": False,
"collate": False,
"using": False,
"default": False,
"drop": False,
}
@ -877,6 +898,11 @@ class Introducer(Expression):
arg_types = {"this": True, "expression": True}
# national char, like n'utf8'
class National(Expression):
pass
class LoadData(Expression):
arg_types = {
"this": True,
@ -894,7 +920,7 @@ class Partition(Expression):
class Fetch(Expression):
arg_types = {"direction": False, "count": True}
arg_types = {"direction": False, "count": False}
class Group(Expression):
@ -1316,7 +1342,7 @@ QUERY_MODIFIERS = {
"group": False,
"having": False,
"qualify": False,
"window": False,
"windows": False,
"distribute": False,
"sort": False,
"cluster": False,
@ -1353,7 +1379,7 @@ class Union(Subqueryable):
Example:
>>> select("1").union(select("1")).limit(1).sql()
'SELECT * FROM (SELECT 1 UNION SELECT 1) AS "_l_0" LIMIT 1'
'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1'
Args:
expression (str | int | Expression): the SQL code string to parse.
@ -1889,6 +1915,18 @@ class Select(Subqueryable):
**opts,
)
def window(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
return _apply_list_builder(
*expressions,
instance=self,
arg="windows",
append=append,
into=Window,
dialect=dialect,
copy=copy,
**opts,
)
def distinct(self, distinct=True, copy=True) -> Select:
"""
Set the OFFSET expression.
@ -2140,6 +2178,11 @@ class DataType(Expression):
)
# https://www.postgresql.org/docs/15/datatype-pseudo.html
class PseudoType(Expression):
pass
class StructKwarg(Expression):
arg_types = {"this": True, "expression": True}
@ -2167,18 +2210,26 @@ class Command(Expression):
arg_types = {"this": True, "expression": False}
class Transaction(Command):
class Transaction(Expression):
arg_types = {"this": False, "modes": False}
class Commit(Command):
class Commit(Expression):
arg_types = {"chain": False}
class Rollback(Command):
class Rollback(Expression):
arg_types = {"savepoint": False}
class AlterTable(Expression):
arg_types = {
"this": True,
"actions": True,
"exists": False,
}
# Binary expressions like (ADD a b)
class Binary(Expression):
arg_types = {"this": True, "expression": True}
@ -2312,6 +2363,10 @@ class SimilarTo(Binary, Predicate):
pass
class Slice(Binary):
arg_types = {"this": False, "expression": False}
class Sub(Binary):
pass
@ -2392,7 +2447,7 @@ class TimeUnit(Expression):
class Interval(TimeUnit):
arg_types = {"this": True, "unit": False}
arg_types = {"this": False, "unit": False}
class IgnoreNulls(Expression):
@ -2730,8 +2785,11 @@ class Initcap(Func):
pass
class JSONExtract(Func):
arg_types = {"this": True, "path": True}
class JSONBContains(Binary):
_sql_names = ["JSONB_CONTAINS"]
class JSONExtract(Binary, Func):
_sql_names = ["JSON_EXTRACT"]
@ -2776,6 +2834,10 @@ class Log10(Func):
pass
class LogicalOr(AggFunc):
_sql_names = ["LOGICAL_OR", "BOOL_OR"]
class Lower(Func):
_sql_names = ["LOWER", "LCASE"]
@ -2846,6 +2908,10 @@ class RegexpLike(Func):
arg_types = {"this": True, "expression": True, "flag": False}
class RegexpILike(Func):
arg_types = {"this": True, "expression": True, "flag": False}
class RegexpSplit(Func):
arg_types = {"this": True, "expression": True}
@ -3388,11 +3454,17 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts) -> U
],
)
if from_:
update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts))
update.set(
"from",
maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts),
)
if isinstance(where, Condition):
where = Where(this=where)
if where:
update.set("where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts))
update.set(
"where",
maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts),
)
return update
@ -3522,7 +3594,7 @@ def paren(expression) -> Paren:
return Paren(this=expression)
SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$")
SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
@ -3724,6 +3796,8 @@ def convert(value) -> Expression:
return Boolean(this=value)
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, float) and math.isnan(value):
return NULL
if isinstance(value, numbers.Number):
return Literal.number(value)
if isinstance(value, tuple):
@ -3732,11 +3806,13 @@ def convert(value) -> Expression:
return Array(expressions=[convert(v) for v in value])
if isinstance(value, dict):
return Map(
keys=[convert(k) for k in value.keys()],
keys=[convert(k) for k in value],
values=[convert(v) for v in value.values()],
)
if isinstance(value, datetime.datetime):
datetime_literal = Literal.string(value.strftime("%Y-%m-%d %H:%M:%S.%f%z"))
datetime_literal = Literal.string(
(value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
)
return TimeStrToTime(this=datetime_literal)
if isinstance(value, datetime.date):
date_literal = Literal.string(value.strftime("%Y-%m-%d"))

View file

@ -361,10 +361,11 @@ class Generator:
column = self.sql(expression, "this")
kind = self.sql(expression, "kind")
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
exists = "IF NOT EXISTS " if expression.args.get("exists") else ""
if not constraints:
return f"{column} {kind}"
return f"{column} {kind} {constraints}"
return f"{exists}{column} {kind}"
return f"{exists}{column} {kind} {constraints}"
def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
this = self.sql(expression, "this")
@ -549,6 +550,9 @@ class Generator:
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
def national_sql(self, expression: exp.National) -> str:
return f"N{self.sql(expression, 'this')}"
def partition_sql(self, expression: exp.Partition) -> str:
keys = csv(
*[
@ -633,6 +637,9 @@ class Generator:
def introducer_sql(self, expression: exp.Introducer) -> str:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
def pseudotype_sql(self, expression: exp.PseudoType) -> str:
return expression.name.upper()
def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str:
fields = expression.args.get("fields")
fields = f" FIELDS TERMINATED BY {fields}" if fields else ""
@ -793,19 +800,17 @@ class Generator:
if isinstance(expression.this, exp.Subquery):
return f"LATERAL {this}"
alias = expression.args["alias"]
table = alias.name
columns = self.expressions(alias, key="columns", flat=True)
if expression.args.get("view"):
table = f" {table}" if table else table
alias = expression.args["alias"]
columns = self.expressions(alias, key="columns", flat=True)
table = f" {alias.name}" if alias.name else ""
columns = f" AS {columns}" if columns else ""
op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}")
return f"{op_sql}{self.sep()}{this}{table}{columns}"
table = f" AS {table}" if table else table
columns = f"({columns})" if columns else ""
return f"LATERAL {this}{table}{columns}"
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""
return f"LATERAL {this}{alias}"
def limit_sql(self, expression: exp.Limit) -> str:
this = self.sql(expression, "this")
@ -891,13 +896,15 @@ class Generator:
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
return csv(
*sqls,
*[self.sql(sql) for sql in expression.args.get("joins", [])],
*[self.sql(sql) for sql in expression.args.get("laterals", [])],
*[self.sql(sql) for sql in expression.args.get("joins") or []],
*[self.sql(sql) for sql in expression.args.get("laterals") or []],
self.sql(expression, "where"),
self.sql(expression, "group"),
self.sql(expression, "having"),
self.sql(expression, "qualify"),
self.sql(expression, "window"),
self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True)
if expression.args.get("windows")
else "",
self.sql(expression, "distribute"),
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
@ -1008,11 +1015,7 @@ class Generator:
spec_sql = " " + self.window_spec_sql(spec) if spec else ""
alias = self.sql(expression, "alias")
if expression.arg_key == "window":
this = this = f"{self.seg('WINDOW')} {this} AS"
else:
this = f"{this} OVER"
this = f"{this} {'AS' if expression.arg_key == 'windows' else 'OVER'}"
if not partition and not order and not spec and alias:
return f"{this} {alias}"
@ -1141,9 +1144,11 @@ class Generator:
return f"(SELECT {self.sql(unnest)})"
def interval_sql(self, expression: exp.Interval) -> str:
this = self.sql(expression, "this")
this = f" {this}" if this else ""
unit = self.sql(expression, "unit")
unit = f" {unit}" if unit else ""
return f"INTERVAL {self.sql(expression, 'this')}{unit}"
return f"INTERVAL{this}{unit}"
def reference_sql(self, expression: exp.Reference) -> str:
this = self.sql(expression, "this")
@ -1245,6 +1250,43 @@ class Generator:
savepoint = f" TO {savepoint}" if savepoint else ""
return f"ROLLBACK{savepoint}"
def altercolumn_sql(self, expression: exp.AlterColumn) -> str:
this = self.sql(expression, "this")
dtype = self.sql(expression, "dtype")
if dtype:
collate = self.sql(expression, "collate")
collate = f" COLLATE {collate}" if collate else ""
using = self.sql(expression, "using")
using = f" USING {using}" if using else ""
return f"ALTER COLUMN {this} TYPE {dtype}{collate}{using}"
default = self.sql(expression, "default")
if default:
return f"ALTER COLUMN {this} SET DEFAULT {default}"
if not expression.args.get("drop"):
self.unsupported("Unsupported ALTER COLUMN syntax")
return f"ALTER COLUMN {this} DROP DEFAULT"
def altertable_sql(self, expression: exp.AlterTable) -> str:
actions = expression.args["actions"]
if isinstance(actions[0], exp.ColumnDef):
actions = self.expressions(expression, "actions", prefix="ADD COLUMN ")
elif isinstance(actions[0], exp.Schema):
actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ")
elif isinstance(actions[0], exp.Drop):
actions = self.expressions(expression, "actions")
elif isinstance(actions[0], exp.AlterColumn):
actions = self.sql(actions[0])
else:
self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}")
exists = " IF EXISTS" if expression.args.get("exists") else ""
return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}"
def distinct_sql(self, expression: exp.Distinct) -> str:
this = self.expressions(expression, flat=True)
this = f" {this}" if this else ""
@ -1327,6 +1369,9 @@ class Generator:
def or_sql(self, expression: exp.Or) -> str:
return self.connector_sql(expression, "OR")
def slice_sql(self, expression: exp.Slice) -> str:
return self.binary(expression, ":")
def sub_sql(self, expression: exp.Sub) -> str:
return self.binary(expression, "-")
@ -1369,6 +1414,7 @@ class Generator:
flat: bool = False,
indent: bool = True,
sep: str = ", ",
prefix: str = "",
) -> str:
expressions = expression.args.get(key or "expressions")
@ -1391,11 +1437,13 @@ class Generator:
if self.pretty:
if self._leading_comma:
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comments}")
result_sqls.append(f"{sep if i > 0 else pad}{prefix}{sql}{comments}")
else:
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}")
result_sqls.append(
f"{prefix}{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}"
)
else:
result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}")
result_sqls.append(f"{prefix}{sql}{comments}{sep if i + 1 < num_sqls else ''}")
result_sql = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
return self.indent(result_sql, skip_first=False) if indent else result_sql

View file

@ -18,6 +18,9 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
if isinstance(expression, exp.Identifier):
expression.set("quoted", True)
return expression

View file

@ -129,20 +129,10 @@ def join_condition(join):
"""
name = join.this.alias_or_name
on = (join.args.get("on") or exp.true()).copy()
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
source_key = []
join_key = []
# find the join keys
# SELECT
# FROM x
# JOIN y
# ON x.a = y.b AND y.b > 1
#
# should pull y.b as the join key and x.a as the source key
if normalized(on):
for condition in on.flatten():
if isinstance(condition, exp.EQ):
def extract_condition(condition):
left, right = condition.unnest_operands()
left_tables = exp.column_table_names(left)
right_tables = exp.column_table_names(right)
@ -156,6 +146,39 @@ def join_condition(join):
source_key.append(left)
condition.replace(exp.true())
# find the join keys
# SELECT
# FROM x
# JOIN y
# ON x.a = y.b AND y.b > 1
#
# should pull y.b as the join key and x.a as the source key
if normalized(on):
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
for condition in on.flatten():
if isinstance(condition, exp.EQ):
extract_condition(condition)
elif normalized(on, dnf=True):
conditions = None
for condition in on.flatten():
parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)]
if conditions is None:
conditions = parts
else:
temp = []
for p in parts:
cs = [c for c in conditions if p == c]
if cs:
temp.append(p)
temp.extend(cs)
conditions = temp
for condition in conditions:
extract_condition(condition)
on = simplify(on)
remaining_condition = None if on == exp.true() else on
return source_key, join_key, remaining_condition

View file

@ -58,7 +58,9 @@ def eliminate_subqueries(expression):
existing_ctes = {}
with_ = root.expression.args.get("with")
recursive = False
if with_:
recursive = with_.args.get("recursive")
for cte in with_.expressions:
existing_ctes[cte.this] = cte.alias
new_ctes = []
@ -88,7 +90,7 @@ def eliminate_subqueries(expression):
new_ctes.append(new_cte)
if new_ctes:
expression.set("with", exp.With(expressions=new_ctes))
expression.set("with", exp.With(expressions=new_ctes, recursive=recursive))
return expression

View file

@ -69,8 +69,9 @@ def _predicate_lengths(expression, dnf):
left, right = expression.args.values()
if isinstance(expression, exp.And if dnf else exp.Or):
x = [a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)]
return x
return [
a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
]
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)

View file

@ -14,7 +14,6 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.quote_identities import quote_identities
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
RULES = (
@ -34,7 +33,6 @@ RULES = (
eliminate_ctes,
annotate_types,
canonicalize,
quote_identities,
)

View file

@ -27,7 +27,14 @@ def pushdown_predicates(expression):
select = scope.expression
where = select.args.get("where")
if where:
pushdown(where.this, scope.selected_sources, scope_ref_count)
selected_sources = scope.selected_sources
# a right join can only push down to itself and not the source FROM table
for k, (node, source) in selected_sources.items():
parent = node.find_ancestor(exp.Join, exp.From)
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
selected_sources = {k: (node, source)}
break
pushdown(where.this, selected_sources, scope_ref_count)
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
@ -148,10 +155,13 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
# a node can reference a CTE which should be pushed down
if isinstance(node, exp.From) and not isinstance(source, exp.Table):
with_ = source.parent.expression.args.get("with")
if with_ and with_.recursive:
return {}
node = source.expression
if isinstance(node, exp.Join):
if node.side:
if node.side and node.side != "RIGHT":
return {}
nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1:

View file

@ -6,7 +6,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()
# SELECTION TO USE IF SELECTION LIST IS EMPTY
# Selection to use if selection list is empty
DEFAULT_SELECTION = alias("1", "_")
@ -91,7 +91,7 @@ def _remove_unused_selections(scope, parent_selections):
# If there are no remaining selections, just select a single constant
if not new_selections:
new_selections.append(DEFAULT_SELECTION)
new_selections.append(DEFAULT_SELECTION.copy())
scope.expression.set("expressions", new_selections)
return removed_indexes
@ -102,5 +102,5 @@ def _remove_indexed_selections(scope, indexes_to_remove):
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
]
if not new_selections:
new_selections.append(DEFAULT_SELECTION)
new_selections.append(DEFAULT_SELECTION.copy())
scope.expression.set("expressions", new_selections)

View file

@ -311,6 +311,9 @@ def _qualify_outputs(scope):
alias_ = alias(exp.column(""), alias=selection.name)
alias_.set("this", selection)
selection = alias_
elif isinstance(selection, exp.Subquery):
if not selection.alias:
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
elif not isinstance(selection, exp.Alias):
alias_ = alias(exp.column(""), f"_col_{i}")
alias_.set("this", selection)

View file

@ -1,25 +0,0 @@
from sqlglot import exp
def quote_identities(expression):
"""
Rewrite sqlglot AST to ensure all identities are quoted.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT x.a AS a FROM db.x")
>>> quote_identities(expression).sql()
'SELECT "x"."a" AS "a" FROM "db"."x"'
Args:
expression (sqlglot.Expression): expression to quote
Returns:
sqlglot.Expression: quoted expression
"""
def qualify(node):
if isinstance(node, exp.Identifier):
node.set("quoted", True)
return node
return expression.transform(qualify, copy=False)

View file

@ -511,9 +511,20 @@ def _traverse_union(scope):
def _traverse_derived_tables(derived_tables, scope, scope_type):
sources = {}
is_cte = scope_type == ScopeType.CTE
for derived_table in derived_tables:
top = None
recursive_scope = None
# if the scope is a recursive cte, it must be in the form of
# base_case UNION recursive. thus the recursive scope is the first
# section of the union.
if is_cte and scope.expression.args["with"].recursive:
union = derived_table.this
if isinstance(union, exp.Union):
recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
for child_scope in _traverse_scope(
scope.branch(
derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
@ -523,16 +534,23 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
)
):
yield child_scope
top = child_scope
# Tables without aliases will be set as ""
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
sources[derived_table.alias] = child_scope
if scope_type == ScopeType.CTE:
scope.cte_scopes.append(top)
alias = derived_table.alias
sources[alias] = child_scope
if recursive_scope:
child_scope.add_source(alias, recursive_scope)
# append the final child_scope yielded
if is_cte:
scope.cte_scopes.append(child_scope)
else:
scope.derived_table_scopes.append(top)
scope.derived_table_scopes.append(child_scope)
scope.sources.update(sources)

View file

@ -16,7 +16,7 @@ def unnest_subqueries(expression):
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)'
AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)'
Args:
expression (sqlglot.Expression): expression to unnest
@ -97,8 +97,8 @@ def decorrelate(select, parent_select, external_columns, sequence):
table_alias = _alias(sequence)
keys = []
# for all external columns in the where statement,
# split out the relevant data to convert it into a join
# for all external columns in the where statement, find the relevant predicate
# keys to convert it into a join
for column in external_columns:
if column.find_ancestor(exp.Where) is not where:
return
@ -122,6 +122,10 @@ def decorrelate(select, parent_select, external_columns, sequence):
if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
return
is_subquery_projection = any(
node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
)
value = select.selects[0]
key_aliases = {}
group_by = []
@ -142,9 +146,14 @@ def decorrelate(select, parent_select, external_columns, sequence):
parent_predicate = select.find_ancestor(exp.Predicate)
# if the value of the subquery is not an agg or a key, we need to collect it into an array
# so that it can be grouped
# so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
if not value.find(exp.AggFunc) and value.this not in group_by:
select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False)
select.select(
exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
append=False,
copy=False,
)
# exists queries should not have any selects as it only checks if there are any rows
# all selects will be added by the optimizer and only used for join keys
@ -158,7 +167,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
if isinstance(parent_predicate, exp.Exists) or key != value.this:
select.select(f"{key} AS {alias}", copy=False)
else:
select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False)
select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
alias = exp.column(value.alias, table_alias)
other = _other_operand(parent_predicate)
@ -186,12 +195,18 @@ def decorrelate(select, parent_select, external_columns, sequence):
f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
)
else:
if is_subquery_projection:
alias = exp.alias_(alias, select.parent.alias)
select.parent.replace(alias)
for key, column, predicate in keys:
predicate.replace(exp.true())
nested = exp.column(key_aliases[key], table_alias)
if is_subquery_projection:
key.replace(nested)
continue
if key in group_by:
key.replace(nested)
parent_predicate = _replace(

View file

@ -5,7 +5,7 @@ import typing as t
from sqlglot import exp
from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
from sqlglot.helper import apply_index_offset, ensure_collection, seq_get
from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import in_trie, new_trie
@ -117,6 +117,7 @@ class Parser(metaclass=_Parser):
TokenType.GEOMETRY,
TokenType.HLLSKETCH,
TokenType.HSTORE,
TokenType.PSEUDO_TYPE,
TokenType.SUPER,
TokenType.SERIAL,
TokenType.SMALLSERIAL,
@ -153,6 +154,7 @@ class Parser(metaclass=_Parser):
TokenType.CACHE,
TokenType.CASCADE,
TokenType.COLLATE,
TokenType.COLUMN,
TokenType.COMMAND,
TokenType.COMMIT,
TokenType.COMPOUND,
@ -169,6 +171,7 @@ class Parser(metaclass=_Parser):
TokenType.ESCAPE,
TokenType.FALSE,
TokenType.FIRST,
TokenType.FILTER,
TokenType.FOLLOWING,
TokenType.FORMAT,
TokenType.FUNCTION,
@ -188,6 +191,7 @@ class Parser(metaclass=_Parser):
TokenType.MERGE,
TokenType.NATURAL,
TokenType.NEXT,
TokenType.OFFSET,
TokenType.ONLY,
TokenType.OPTIONS,
TokenType.ORDINALITY,
@ -222,12 +226,18 @@ class Parser(metaclass=_Parser):
TokenType.PROPERTIES,
TokenType.PROCEDURE,
TokenType.VOLATILE,
TokenType.WINDOW,
*SUBQUERY_PREDICATES,
*TYPE_TOKENS,
*NO_PAREN_FUNCTIONS,
}
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY}
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
TokenType.APPLY,
TokenType.NATURAL,
TokenType.OFFSET,
TokenType.WINDOW,
}
UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET}
@ -257,6 +267,7 @@ class Parser(metaclass=_Parser):
TokenType.TABLE,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
TokenType.WINDOW,
*TYPE_TOKENS,
*SUBQUERY_PREDICATES,
}
@ -351,22 +362,27 @@ class Parser(metaclass=_Parser):
TokenType.ARROW: lambda self, this, path: self.expression(
exp.JSONExtract,
this=this,
path=path,
expression=path,
),
TokenType.DARROW: lambda self, this, path: self.expression(
exp.JSONExtractScalar,
this=this,
path=path,
expression=path,
),
TokenType.HASH_ARROW: lambda self, this, path: self.expression(
exp.JSONBExtract,
this=this,
path=path,
expression=path,
),
TokenType.DHASH_ARROW: lambda self, this, path: self.expression(
exp.JSONBExtractScalar,
this=this,
path=path,
expression=path,
),
TokenType.PLACEHOLDER: lambda self, this, key: self.expression(
exp.JSONBContains,
this=this,
expression=key,
),
}
@ -392,25 +408,27 @@ class Parser(metaclass=_Parser):
exp.Ordered: lambda self: self._parse_ordered(),
exp.Having: lambda self: self._parse_having(),
exp.With: lambda self: self._parse_with(),
exp.Window: lambda self: self._parse_named_window(),
"JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
}
STATEMENT_PARSERS = {
TokenType.ALTER: lambda self: self._parse_alter(),
TokenType.BEGIN: lambda self: self._parse_transaction(),
TokenType.CACHE: lambda self: self._parse_cache(),
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.CREATE: lambda self: self._parse_create(),
TokenType.DELETE: lambda self: self._parse_delete(),
TokenType.DESCRIBE: lambda self: self._parse_describe(),
TokenType.DROP: lambda self: self._parse_drop(),
TokenType.END: lambda self: self._parse_commit_or_rollback(),
TokenType.INSERT: lambda self: self._parse_insert(),
TokenType.LOAD_DATA: lambda self: self._parse_load_data(),
TokenType.UPDATE: lambda self: self._parse_update(),
TokenType.DELETE: lambda self: self._parse_delete(),
TokenType.CACHE: lambda self: self._parse_cache(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
TokenType.BEGIN: lambda self: self._parse_transaction(),
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.END: lambda self: self._parse_commit_or_rollback(),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
TokenType.MERGE: lambda self: self._parse_merge(),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
TokenType.UPDATE: lambda self: self._parse_update(),
TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
}
UNARY_PARSERS = {
@ -441,6 +459,7 @@ class Parser(metaclass=_Parser):
TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
TokenType.NATIONAL: lambda self, token: self._parse_national(token),
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
}
@ -454,6 +473,9 @@ class Parser(metaclass=_Parser):
TokenType.ILIKE: lambda self, this: self._parse_escape(
self.expression(exp.ILike, this=this, expression=self._parse_bitwise())
),
TokenType.IRLIKE: lambda self, this: self.expression(
exp.RegexpILike, this=this, expression=self._parse_bitwise()
),
TokenType.RLIKE: lambda self, this: self.expression(
exp.RegexpLike, this=this, expression=self._parse_bitwise()
),
@ -535,8 +557,7 @@ class Parser(metaclass=_Parser):
"group": lambda self: self._parse_group(),
"having": lambda self: self._parse_having(),
"qualify": lambda self: self._parse_qualify(),
"window": lambda self: self._match(TokenType.WINDOW)
and self._parse_window(self._parse_id_var(), alias=True),
"windows": lambda self: self._parse_window_clause(),
"distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute),
"sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
"cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
@ -551,18 +572,18 @@ class Parser(metaclass=_Parser):
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
CREATABLES = {
TokenType.TABLE,
TokenType.VIEW,
TokenType.COLUMN,
TokenType.FUNCTION,
TokenType.INDEX,
TokenType.PROCEDURE,
TokenType.SCHEMA,
TokenType.TABLE,
TokenType.VIEW,
}
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
STRICT_CAST = True
LATERAL_FUNCTION_AS_VIEW = False
__slots__ = (
"error_level",
@ -782,11 +803,14 @@ class Parser(metaclass=_Parser):
self._parse_query_modifiers(expression)
return expression
def _parse_drop(self):
def _parse_drop(self, default_kind=None):
temporary = self._match(TokenType.TEMPORARY)
materialized = self._match(TokenType.MATERIALIZED)
kind = self._match_set(self.CREATABLES) and self._prev.text
if not kind:
if default_kind:
kind = default_kind
else:
self.raise_error(f"Expected {self.CREATABLES}")
return
@ -876,7 +900,7 @@ class Parser(metaclass=_Parser):
) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False)
if assignment:
key = self._parse_var() or self._parse_string()
key = self._parse_var_or_string()
self._match(TokenType.EQ)
return self.expression(exp.Property, this=key, value=self._parse_column())
@ -1152,18 +1176,32 @@ class Parser(metaclass=_Parser):
elif (table or nested) and self._match(TokenType.L_PAREN):
this = self._parse_table() if table else self._parse_select(nested=True)
self._parse_query_modifiers(this)
this = self._parse_set_operations(this)
self._match_r_paren()
this = self._parse_subquery(this)
# early return so that subquery unions aren't parsed again
# SELECT * FROM (SELECT 1) UNION ALL SELECT 1
# Union ALL should be a property of the top select node, not the subquery
return self._parse_subquery(this)
elif self._match(TokenType.VALUES):
if self._curr.token_type == TokenType.L_PAREN:
# We don't consume the left paren because it's consumed in _parse_value
expressions = self._parse_csv(self._parse_value)
else:
# In presto we can have VALUES 1, 2 which results in 1 column & 2 rows.
# Source: https://prestodb.io/docs/current/sql/values.html
expressions = self._parse_csv(
lambda: self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
)
this = self.expression(
exp.Values,
expressions=self._parse_csv(self._parse_value),
expressions=expressions,
alias=self._parse_table_alias(),
)
else:
this = None
return self._parse_set_operations(this) if this else None
return self._parse_set_operations(this)
def _parse_with(self, skip_with_token=False):
if not skip_with_token and not self._match(TokenType.WITH):
@ -1201,11 +1239,12 @@ class Parser(metaclass=_Parser):
alias = self._parse_id_var(
any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS
)
columns = None
if self._match(TokenType.L_PAREN):
columns = self._parse_csv(lambda: self._parse_id_var(any_token))
columns = self._parse_csv(lambda: self._parse_column_def(self._parse_id_var()))
self._match_r_paren()
else:
columns = None
if not alias and not columns:
return None
@ -1295,26 +1334,19 @@ class Parser(metaclass=_Parser):
expression=self._parse_function() or self._parse_id_var(any_token=False),
)
columns = None
table_alias = None
if view or self.LATERAL_FUNCTION_AS_VIEW:
table_alias = self._parse_id_var(any_token=False)
if self._match(TokenType.ALIAS):
columns = self._parse_csv(self._parse_id_var)
if view:
table = self._parse_id_var(any_token=False)
columns = self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else []
table_alias = self.expression(exp.TableAlias, this=table, columns=columns)
else:
self._match(TokenType.ALIAS)
table_alias = self._parse_id_var(any_token=False)
if self._match(TokenType.L_PAREN):
columns = self._parse_csv(self._parse_id_var)
self._match_r_paren()
table_alias = self._parse_table_alias()
expression = self.expression(
exp.Lateral,
this=this,
view=view,
outer=outer,
alias=self.expression(exp.TableAlias, this=table_alias, columns=columns),
alias=table_alias,
)
if outer_apply or cross_apply:
@ -1693,6 +1725,9 @@ class Parser(metaclass=_Parser):
if negate:
this = self.expression(exp.Not, this=this)
if self._match(TokenType.IS):
this = self._parse_is(this)
return this
def _parse_is(self, this):
@ -1796,6 +1831,10 @@ class Parser(metaclass=_Parser):
return None
type_token = self._prev.token_type
if type_token == TokenType.PSEUDO_TYPE:
return self.expression(exp.PseudoType, this=self._prev.text)
nested = type_token in self.NESTED_TYPE_TOKENS
is_struct = type_token == TokenType.STRUCT
expressions = None
@ -1851,6 +1890,8 @@ class Parser(metaclass=_Parser):
if value is None:
value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
elif type_token == TokenType.INTERVAL:
value = self.expression(exp.Interval, unit=self._parse_var())
if maybe_func and check_func:
index2 = self._index
@ -1924,7 +1965,16 @@ class Parser(metaclass=_Parser):
def _parse_primary(self):
if self._match_set(self.PRIMARY_PARSERS):
return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev)
token_type = self._prev.token_type
primary = self.PRIMARY_PARSERS[token_type](self, self._prev)
if token_type == TokenType.STRING:
expressions = [primary]
while self._match(TokenType.STRING):
expressions.append(exp.Literal.string(self._prev.text))
if len(expressions) > 1:
return self.expression(exp.Concat, expressions=expressions)
return primary
if self._match_pair(TokenType.DOT, TokenType.NUMBER):
return exp.Literal.number(f"0.{self._prev.text}")
@ -2027,6 +2077,9 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Identifier, this=token.text)
def _parse_national(self, token):
return self.expression(exp.National, this=exp.Literal.string(token.text))
def _parse_session_parameter(self):
kind = None
this = self._parse_id_var() or self._parse_primary()
@ -2051,7 +2104,9 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.L_PAREN):
expressions = self._parse_csv(self._parse_id_var)
self._match(TokenType.R_PAREN)
if not self._match(TokenType.R_PAREN):
self._retreat(index)
else:
expressions = [self._parse_id_var()]
@ -2065,14 +2120,14 @@ class Parser(metaclass=_Parser):
exp.Distinct, expressions=self._parse_csv(self._parse_conjunction)
)
else:
this = self._parse_conjunction()
this = self._parse_select_or_expression()
if self._match(TokenType.IGNORE_NULLS):
this = self.expression(exp.IgnoreNulls, this=this)
else:
self._match(TokenType.RESPECT_NULLS)
return self._parse_alias(self._parse_limit(self._parse_order(this)))
return self._parse_limit(self._parse_order(this))
def _parse_schema(self, this=None):
index = self._index
@ -2081,7 +2136,8 @@ class Parser(metaclass=_Parser):
return this
args = self._parse_csv(
lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True))
lambda: self._parse_constraint()
or self._parse_column_def(self._parse_field(any_token=True))
)
self._match_r_paren()
return self.expression(exp.Schema, this=this, expressions=args)
@ -2120,7 +2176,7 @@ class Parser(metaclass=_Parser):
elif self._match(TokenType.ENCODE):
kind = self.expression(exp.EncodeColumnConstraint, this=self._parse_var())
elif self._match(TokenType.DEFAULT):
kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction())
kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_bitwise())
elif self._match_pair(TokenType.NOT, TokenType.NULL):
kind = exp.NotNullColumnConstraint()
elif self._match(TokenType.NULL):
@ -2211,7 +2267,10 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.L_BRACKET):
return this
expressions = self._parse_csv(self._parse_conjunction)
if self._match(TokenType.COLON):
expressions = [self.expression(exp.Slice, expression=self._parse_conjunction())]
else:
expressions = self._parse_csv(lambda: self._parse_slice(self._parse_conjunction()))
if not this or this.name.upper() == "ARRAY":
this = self.expression(exp.Array, expressions=expressions)
@ -2225,6 +2284,11 @@ class Parser(metaclass=_Parser):
this.comments = self._prev_comments
return self._parse_bracket(this)
def _parse_slice(self, this):
if self._match(TokenType.COLON):
return self.expression(exp.Slice, this=this, expression=self._parse_conjunction())
return this
def _parse_case(self):
ifs = []
default = None
@ -2386,6 +2450,12 @@ class Parser(metaclass=_Parser):
collation=collation,
)
def _parse_window_clause(self):
return self._match(TokenType.WINDOW) and self._parse_csv(self._parse_named_window)
def _parse_named_window(self):
return self._parse_window(self._parse_id_var(), alias=True)
def _parse_window(self, this, alias=False):
if self._match(TokenType.FILTER):
where = self._parse_wrapped(self._parse_where)
@ -2501,11 +2571,9 @@ class Parser(metaclass=_Parser):
if identifier:
return identifier
if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
self._advance()
elif not self._match_set(tokens or self.ID_VAR_TOKENS):
return None
if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS):
return exp.Identifier(this=self._prev.text, quoted=False)
return None
def _parse_string(self):
if self._match(TokenType.STRING):
@ -2522,11 +2590,17 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Identifier, this=self._prev.text, quoted=True)
return self._parse_placeholder()
def _parse_var(self):
if self._match(TokenType.VAR):
def _parse_var(self, any_token=False):
if (any_token and self._advance_any()) or self._match(TokenType.VAR):
return self.expression(exp.Var, this=self._prev.text)
return self._parse_placeholder()
def _advance_any(self):
if self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
self._advance()
return self._prev
return None
def _parse_var_or_string(self):
return self._parse_var() or self._parse_string()
@ -2551,8 +2625,9 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.PLACEHOLDER):
return self.expression(exp.Placeholder)
elif self._match(TokenType.COLON):
self._advance()
if self._match_set((TokenType.NUMBER, TokenType.VAR)):
return self.expression(exp.Placeholder, this=self._prev.text)
self._advance(-1)
return None
def _parse_except(self):
@ -2647,6 +2722,54 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Rollback, savepoint=savepoint)
return self.expression(exp.Commit, chain=chain)
def _parse_add_column(self):
if not self._match_text_seq("ADD"):
return None
self._match(TokenType.COLUMN)
exists_column = self._parse_exists(not_=True)
expression = self._parse_column_def(self._parse_field(any_token=True))
expression.set("exists", exists_column)
return expression
def _parse_drop_column(self):
return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN")
def _parse_alter(self):
if not self._match(TokenType.TABLE):
return None
exists = self._parse_exists()
this = self._parse_table(schema=True)
actions = None
if self._match_text_seq("ADD", advance=False):
actions = self._parse_csv(self._parse_add_column)
elif self._match_text_seq("DROP", advance=False):
actions = self._parse_csv(self._parse_drop_column)
elif self._match_text_seq("ALTER"):
self._match(TokenType.COLUMN)
column = self._parse_field(any_token=True)
if self._match_pair(TokenType.DROP, TokenType.DEFAULT):
actions = self.expression(exp.AlterColumn, this=column, drop=True)
elif self._match_pair(TokenType.SET, TokenType.DEFAULT):
actions = self.expression(
exp.AlterColumn, this=column, default=self._parse_conjunction()
)
else:
self._match_text_seq("SET", "DATA")
actions = self.expression(
exp.AlterColumn,
this=column,
dtype=self._match_text_seq("TYPE") and self._parse_types(),
collate=self._match(TokenType.COLLATE) and self._parse_term(),
using=self._match(TokenType.USING) and self._parse_conjunction(),
)
actions = ensure_list(actions)
return self.expression(exp.AlterTable, this=this, exists=exists, actions=actions)
def _parse_show(self):
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
if parser:
@ -2782,7 +2905,7 @@ class Parser(metaclass=_Parser):
return True
return False
def _match_text_seq(self, *texts):
def _match_text_seq(self, *texts, advance=True):
index = self._index
for text in texts:
if self._curr and self._curr.text.upper() == text:
@ -2790,6 +2913,10 @@ class Parser(metaclass=_Parser):
else:
self._retreat(index)
return False
if not advance:
self._retreat(index)
return True
def _replace_columns_with_dots(self, this):

View file

@ -160,9 +160,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
super().__init__(schema)
self.visible = visible or {}
self.dialect = dialect
self._type_mapping_cache: t.Dict[str, exp.DataType] = {
"STR": exp.DataType.build("text"),
}
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
@classmethod
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:

View file

@ -48,6 +48,7 @@ class TokenType(AutoName):
DOLLAR = auto()
PARAMETER = auto()
SESSION_PARAMETER = auto()
NATIONAL = auto()
BLOCK_START = auto()
BLOCK_END = auto()
@ -111,6 +112,7 @@ class TokenType(AutoName):
# keywords
ALIAS = auto()
ALTER = auto()
ALWAYS = auto()
ALL = auto()
ANTI = auto()
@ -196,6 +198,7 @@ class TokenType(AutoName):
INTERVAL = auto()
INTO = auto()
INTRODUCER = auto()
IRLIKE = auto()
IS = auto()
ISNULL = auto()
JOIN = auto()
@ -241,6 +244,7 @@ class TokenType(AutoName):
PRIMARY_KEY = auto()
PROCEDURE = auto()
PROPERTIES = auto()
PSEUDO_TYPE = auto()
QUALIFY = auto()
QUOTE = auto()
RANGE = auto()
@ -346,7 +350,11 @@ class _Tokenizer(type):
def __new__(cls, clsname, bases, attrs): # type: ignore
klass = super().__new__(cls, clsname, bases, attrs)
klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES)
klass._QUOTES = {
f"{prefix}{s}": e
for s, e in cls._delimeter_list_to_dict(klass.QUOTES).items()
for prefix in (("",) if s[0].isalpha() else ("", "n", "N"))
}
klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS)
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
@ -470,6 +478,7 @@ class Tokenizer(metaclass=_Tokenizer):
"CHECK": TokenType.CHECK,
"CLUSTER BY": TokenType.CLUSTER_BY,
"COLLATE": TokenType.COLLATE,
"COLUMN": TokenType.COLUMN,
"COMMENT": TokenType.SCHEMA_COMMENT,
"COMMIT": TokenType.COMMIT,
"COMPOUND": TokenType.COMPOUND,
@ -587,6 +596,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SEMI": TokenType.SEMI,
"SET": TokenType.SET,
"SHOW": TokenType.SHOW,
"SIMILAR TO": TokenType.SIMILAR_TO,
"SOME": TokenType.SOME,
"SORTKEY": TokenType.SORTKEY,
"SORT BY": TokenType.SORT_BY,
@ -614,6 +624,7 @@ class Tokenizer(metaclass=_Tokenizer):
"VOLATILE": TokenType.VOLATILE,
"WHEN": TokenType.WHEN,
"WHERE": TokenType.WHERE,
"WINDOW": TokenType.WINDOW,
"WITH": TokenType.WITH,
"WITH TIME ZONE": TokenType.WITH_TIME_ZONE,
"WITH LOCAL TIME ZONE": TokenType.WITH_LOCAL_TIME_ZONE,
@ -652,6 +663,7 @@ class Tokenizer(metaclass=_Tokenizer):
"VARCHAR2": TokenType.VARCHAR,
"NVARCHAR": TokenType.NVARCHAR,
"NVARCHAR2": TokenType.NVARCHAR,
"STR": TokenType.TEXT,
"STRING": TokenType.TEXT,
"TEXT": TokenType.TEXT,
"CLOB": TokenType.TEXT,
@ -667,7 +679,16 @@ class Tokenizer(metaclass=_Tokenizer):
"UNIQUE": TokenType.UNIQUE,
"STRUCT": TokenType.STRUCT,
"VARIANT": TokenType.VARIANT,
"ALTER": TokenType.COMMAND,
"ALTER": TokenType.ALTER,
"ALTER AGGREGATE": TokenType.COMMAND,
"ALTER DEFAULT": TokenType.COMMAND,
"ALTER DOMAIN": TokenType.COMMAND,
"ALTER ROLE": TokenType.COMMAND,
"ALTER RULE": TokenType.COMMAND,
"ALTER SEQUENCE": TokenType.COMMAND,
"ALTER TYPE": TokenType.COMMAND,
"ALTER USER": TokenType.COMMAND,
"ALTER VIEW": TokenType.COMMAND,
"ANALYZE": TokenType.COMMAND,
"CALL": TokenType.COMMAND,
"EXPLAIN": TokenType.COMMAND,
@ -967,7 +988,7 @@ class Tokenizer(metaclass=_Tokenizer):
text = self._extract_string(quote_end)
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore
text = text.replace("\\\\", "\\") if self._replace_backslash else text
self._add(TokenType.STRING, text)
self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text)
return True
# X'1234, b'0110', E'\\\\\' etc.

View file

@ -150,8 +150,8 @@ class TestDataframeColumn(unittest.TestCase):
F.col("cola").between(datetime.date(2022, 1, 1), datetime.date(2022, 3, 1)).sql(),
)
self.assertEqual(
"cola BETWEEN CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP) "
"AND CAST('2022-03-01 01:01:01.000000' AS TIMESTAMP)",
"cola BETWEEN CAST('2022-01-01T01:01:01+00:00' AS TIMESTAMP) "
"AND CAST('2022-03-01T01:01:01+00:00' AS TIMESTAMP)",
F.col("cola")
.between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1))
.sql(),

View file

@ -30,7 +30,7 @@ class TestFunctions(unittest.TestCase):
test_date = SF.lit(datetime.date(2022, 1, 1))
self.assertEqual("TO_DATE('2022-01-01')", test_date.sql())
test_datetime = SF.lit(datetime.datetime(2022, 1, 1, 1, 1, 1))
self.assertEqual("CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP)", test_datetime.sql())
self.assertEqual("CAST('2022-01-01T01:01:01+00:00' AS TIMESTAMP)", test_datetime.sql())
test_dict = SF.lit({"cola": 1, "colb": "test"})
self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql())
@ -52,7 +52,7 @@ class TestFunctions(unittest.TestCase):
test_date = SF.col(datetime.date(2022, 1, 1))
self.assertEqual("TO_DATE('2022-01-01')", test_date.sql())
test_datetime = SF.col(datetime.datetime(2022, 1, 1, 1, 1, 1))
self.assertEqual("CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP)", test_datetime.sql())
self.assertEqual("CAST('2022-01-01T01:01:01+00:00' AS TIMESTAMP)", test_datetime.sql())
test_dict = SF.col({"cola": 1, "colb": "test"})
self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql())

View file

@ -318,3 +318,9 @@ class TestBigQuery(Validator):
self.validate_identity(
"CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t"
)
def test_group_concat(self):
self.validate_all(
"SELECT a, GROUP_CONCAT(b) FROM table GROUP BY a",
write={"bigquery": "SELECT a, STRING_AGG(b) FROM table GROUP BY a"},
)

View file

@ -12,6 +12,76 @@ class TestDatabricks(Validator):
"databricks": "SELECT DATEDIFF(year, 'start', 'end')",
},
)
self.validate_all(
"SELECT DATEDIFF(microsecond, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(microsecond, 'start', 'end')",
"postgres": "SELECT CAST(EXTRACT(epoch FROM CAST('end' AS TIMESTAMP) - CAST('start' AS TIMESTAMP)) * 1000000 AS BIGINT)",
},
)
self.validate_all(
"SELECT DATEDIFF(millisecond, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(millisecond, 'start', 'end')",
"postgres": "SELECT CAST(EXTRACT(epoch FROM CAST('end' AS TIMESTAMP) - CAST('start' AS TIMESTAMP)) * 1000 AS BIGINT)",
},
)
self.validate_all(
"SELECT DATEDIFF(second, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(second, 'start', 'end')",
"postgres": "SELECT CAST(EXTRACT(epoch FROM CAST('end' AS TIMESTAMP) - CAST('start' AS TIMESTAMP)) AS BIGINT)",
},
)
self.validate_all(
"SELECT DATEDIFF(minute, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(minute, 'start', 'end')",
"postgres": "SELECT CAST(EXTRACT(epoch FROM CAST('end' AS TIMESTAMP) - CAST('start' AS TIMESTAMP)) / 60 AS BIGINT)",
},
)
self.validate_all(
"SELECT DATEDIFF(hour, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(hour, 'start', 'end')",
"postgres": "SELECT CAST(EXTRACT(epoch FROM CAST('end' AS TIMESTAMP) - CAST('start' AS TIMESTAMP)) / 3600 AS BIGINT)",
},
)
self.validate_all(
"SELECT DATEDIFF(day, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(day, 'start', 'end')",
"postgres": "SELECT CAST(EXTRACT(epoch FROM CAST('end' AS TIMESTAMP) - CAST('start' AS TIMESTAMP)) / 86400 AS BIGINT)",
},
)
self.validate_all(
"SELECT DATEDIFF(week, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(week, 'start', 'end')",
"postgres": "SELECT CAST(EXTRACT(year FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) * 48 + EXTRACT(month FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) * 4 + EXTRACT(day FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) / 7 AS BIGINT)",
},
)
self.validate_all(
"SELECT DATEDIFF(month, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(month, 'start', 'end')",
"postgres": "SELECT CAST(EXTRACT(year FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) * 12 + EXTRACT(month FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) AS BIGINT)",
},
)
self.validate_all(
"SELECT DATEDIFF(quarter, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(quarter, 'start', 'end')",
"postgres": "SELECT CAST(EXTRACT(year FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) * 4 + EXTRACT(month FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) / 3 AS BIGINT)",
},
)
self.validate_all(
"SELECT DATEDIFF(year, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(year, 'start', 'end')",
"postgres": "SELECT CAST(EXTRACT(year FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) AS BIGINT)",
},
)
def test_add_date(self):
self.validate_all(

View file

@ -333,7 +333,7 @@ class TestDialect(Validator):
"drill": "CAST('2020-01-01' AS DATE)",
"duckdb": "CAST('2020-01-01' AS DATE)",
"hive": "TO_DATE('2020-01-01')",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')",
"presto": "CAST('2020-01-01' AS TIMESTAMP)",
"starrocks": "TO_DATE('2020-01-01')",
},
)
@ -343,7 +343,7 @@ class TestDialect(Validator):
"drill": "CAST('2020-01-01' AS TIMESTAMP)",
"duckdb": "CAST('2020-01-01' AS TIMESTAMP)",
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')",
"presto": "CAST('2020-01-01' AS TIMESTAMP)",
},
)
self.validate_all(
@ -723,23 +723,23 @@ class TestDialect(Validator):
read={
"postgres": "x->'y'",
"presto": "JSON_EXTRACT(x, 'y')",
"starrocks": "x->'y'",
"starrocks": "x -> 'y'",
},
write={
"oracle": "JSON_EXTRACT(x, 'y')",
"postgres": "x->'y'",
"postgres": "x -> 'y'",
"presto": "JSON_EXTRACT(x, 'y')",
"starrocks": "x->'y'",
"starrocks": "x -> 'y'",
},
)
self.validate_all(
"JSON_EXTRACT_SCALAR(x, 'y')",
read={
"postgres": "x->>'y'",
"postgres": "x ->> 'y'",
"presto": "JSON_EXTRACT_SCALAR(x, 'y')",
},
write={
"postgres": "x->>'y'",
"postgres": "x ->> 'y'",
"presto": "JSON_EXTRACT_SCALAR(x, 'y')",
},
)
@ -749,7 +749,7 @@ class TestDialect(Validator):
"postgres": "x#>'y'",
},
write={
"postgres": "x#>'y'",
"postgres": "x #> 'y'",
},
)
self.validate_all(
@ -758,7 +758,7 @@ class TestDialect(Validator):
"postgres": "x#>>'y'",
},
write={
"postgres": "x#>>'y'",
"postgres": "x #>> 'y'",
},
)

View file

@ -59,7 +59,7 @@ class TestDuckDB(Validator):
"TO_TIMESTAMP(x)",
write={
"duckdb": "CAST(x AS TIMESTAMP)",
"presto": "DATE_PARSE(x, '%Y-%m-%d %H:%i:%s')",
"presto": "CAST(x AS TIMESTAMP)",
"hive": "CAST(x AS TIMESTAMP)",
},
)
@ -302,3 +302,20 @@ class TestDuckDB(Validator):
read="duckdb",
unsupported_level=ErrorLevel.IMMEDIATE,
)
def test_array(self):
self.validate_identity("ARRAY(SELECT id FROM t)")
def test_cast(self):
self.validate_all(
"123::CHARACTER VARYING",
write={
"duckdb": "CAST(123 AS TEXT)",
},
)
def test_bool_or(self):
self.validate_all(
"SELECT a, LOGICAL_OR(b) FROM table GROUP BY a",
write={"duckdb": "SELECT a, BOOL_OR(b) FROM table GROUP BY a"},
)

View file

@ -268,10 +268,10 @@ class TestHive(Validator):
self.validate_all(
"DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')",
write={
"duckdb": "STRFTIME('2020-01-01', '%Y-%m-%d %H:%M:%S')",
"presto": "DATE_FORMAT('2020-01-01', '%Y-%m-%d %H:%i:%S')",
"hive": "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')",
"spark": "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')",
"duckdb": "STRFTIME(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %H:%M:%S')",
"presto": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %H:%i:%S')",
"hive": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss')",
"spark": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss')",
},
)
self.validate_all(

View file

@ -91,12 +91,12 @@ class TestMySQL(Validator):
},
)
self.validate_all(
"N 'some text'",
"N'some text'",
read={
"mysql": "N'some text'",
"mysql": "n'some text'",
},
write={
"mysql": "N 'some text'",
"mysql": "N'some text'",
},
)
self.validate_all(

View file

@ -3,6 +3,7 @@ from tests.dialects.test_dialect import Validator
class TestPostgres(Validator):
maxDiff = None
dialect = "postgres"
def test_ddl(self):
@ -94,6 +95,7 @@ class TestPostgres(Validator):
self.validate_identity("COMMENT ON TABLE mytable IS 'this'")
self.validate_identity("SELECT e'\\xDEADBEEF'")
self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
self.validate_identity("""SELECT * FROM JSON_TO_RECORDSET(z) AS y("rank" INT)""")
self.validate_all(
"END WORK AND NO CHAIN",
@ -112,6 +114,14 @@ class TestPostgres(Validator):
"spark": "CREATE TABLE x (a UUID, b BINARY)",
},
)
self.validate_all(
"123::CHARACTER VARYING",
write={"postgres": "CAST(123 AS VARCHAR)"},
)
self.validate_all(
"TO_TIMESTAMP(123::DOUBLE PRECISION)",
write={"postgres": "TO_TIMESTAMP(CAST(123 AS DOUBLE PRECISION))"},
)
self.validate_identity(
"CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)"
@ -193,15 +203,21 @@ class TestPostgres(Validator):
},
)
self.validate_all(
"SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL",
read={
"SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) pname ON TRUE WHERE pname IS NULL",
write={
"postgres": "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL",
},
)
self.validate_all(
"SELECT p1.id, p2.id, v1, v2 FROM polygons AS p1, polygons AS p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id <> p2.id",
read={
"postgres": "SELECT p1.id, p2.id, v1, v2 FROM polygons p1, polygons p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id != p2.id",
write={
"postgres": "SELECT p1.id, p2.id, v1, v2 FROM polygons AS p1, polygons AS p2, LATERAL VERTICES(p1.poly) AS v1, LATERAL VERTICES(p2.poly) AS v2 WHERE (v1 <-> v2) < 10 AND p1.id <> p2.id",
},
)
self.validate_all(
"SELECT * FROM r CROSS JOIN LATERAL unnest(array(1)) AS s(location)",
write={
"postgres": "SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)",
},
)
self.validate_all(
@ -218,35 +234,46 @@ class TestPostgres(Validator):
)
self.validate_all(
"'[1,2,3]'::json->2",
write={"postgres": "CAST('[1,2,3]' AS JSON)->'2'"},
write={"postgres": "CAST('[1,2,3]' AS JSON) -> '2'"},
)
self.validate_all(
"""'{"a":1,"b":2}'::json->'b'""",
write={"postgres": """CAST('{"a":1,"b":2}' AS JSON)->'b'"""},
write={"postgres": """CAST('{"a":1,"b":2}' AS JSON) -> 'b'"""},
)
self.validate_all(
"""'{"x": {"y": 1}}'::json->'x'->'y'""",
write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""},
write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON) -> 'x' -> 'y'"""},
)
self.validate_all(
"""'{"x": {"y": 1}}'::json->'x'::json->'y'""",
write={"postgres": """CAST(CAST('{"x": {"y": 1}}' AS JSON)->'x' AS JSON)->'y'"""},
write={"postgres": """CAST(CAST('{"x": {"y": 1}}' AS JSON) -> 'x' AS JSON) -> 'y'"""},
)
self.validate_all(
"""'[1,2,3]'::json->>2""",
write={"postgres": "CAST('[1,2,3]' AS JSON)->>'2'"},
write={"postgres": "CAST('[1,2,3]' AS JSON) ->> '2'"},
)
self.validate_all(
"""'{"a":1,"b":2}'::json->>'b'""",
write={"postgres": """CAST('{"a":1,"b":2}' AS JSON)->>'b'"""},
write={"postgres": """CAST('{"a":1,"b":2}' AS JSON) ->> 'b'"""},
)
self.validate_all(
"""'{"a":[1,2,3],"b":[4,5,6]}'::json#>'{a,2}'""",
write={"postgres": """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON)#>'{a,2}'"""},
write={"postgres": """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON) #> '{a,2}'"""},
)
self.validate_all(
"""'{"a":[1,2,3],"b":[4,5,6]}'::json#>>'{a,2}'""",
write={"postgres": """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON)#>>'{a,2}'"""},
write={"postgres": """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON) #>> '{a,2}'"""},
)
self.validate_all(
"""SELECT JSON_ARRAY_ELEMENTS((foo->'sections')::JSON) AS sections""",
write={
"postgres": """SELECT JSON_ARRAY_ELEMENTS(CAST((foo -> 'sections') AS JSON)) AS sections""",
"presto": """SELECT JSON_ARRAY_ELEMENTS(CAST((JSON_EXTRACT(foo, 'sections')) AS JSON)) AS sections""",
},
)
self.validate_all(
"""x ? 'x'""",
write={"postgres": "x ? 'x'"},
)
self.validate_all(
"SELECT $$a$$",
@ -260,3 +287,49 @@ class TestPostgres(Validator):
"UPDATE MYTABLE T1 SET T1.COL = 13",
write={"postgres": "UPDATE MYTABLE AS T1 SET T1.COL = 13"},
)
self.validate_identity("x ~ 'y'")
self.validate_identity("x ~* 'y'")
self.validate_all(
"x !~ 'y'",
write={"postgres": "NOT x ~ 'y'"},
)
self.validate_all(
"x !~* 'y'",
write={"postgres": "NOT x ~* 'y'"},
)
self.validate_all(
"x ~~ 'y'",
write={"postgres": "x LIKE 'y'"},
)
self.validate_all(
"x ~~* 'y'",
write={"postgres": "x ILIKE 'y'"},
)
self.validate_all(
"x !~~ 'y'",
write={"postgres": "NOT x LIKE 'y'"},
)
self.validate_all(
"x !~~* 'y'",
write={"postgres": "NOT x ILIKE 'y'"},
)
self.validate_all(
"'45 days'::interval day",
write={"postgres": "CAST('45 days' AS INTERVAL day)"},
)
self.validate_all(
"'x' 'y' 'z'",
write={"postgres": "CONCAT('x', 'y', 'z')"},
)
self.validate_identity("SELECT ARRAY(SELECT 1)")
self.validate_all(
"x::cstring",
write={"postgres": "CAST(x AS CSTRING)"},
)
self.validate_identity(
"SELECT SUM(x) OVER a, SUM(y) OVER b FROM c WINDOW a AS (PARTITION BY d), b AS (PARTITION BY e)"
)

View file

@ -53,7 +53,7 @@ class TestRedshift(Validator):
self.validate_all(
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
write={
"redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1',
"redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1',
},
)
self.validate_all(

View file

@ -6,6 +6,12 @@ class TestSnowflake(Validator):
dialect = "snowflake"
def test_snowflake(self):
self.validate_all(
"SELECT * FROM xxx WHERE col ilike '%Don''t%'",
write={
"snowflake": "SELECT * FROM xxx WHERE col ILIKE '%Don\\'t%'",
},
)
self.validate_all(
'x:a:"b c"',
write={
@ -509,3 +515,11 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f, LATERA
"snowflake": "SELECT 1 MINUS SELECT 1",
},
)
def test_values(self):
self.validate_all(
'SELECT c0, c1 FROM (VALUES (1, 2), (3, 4)) AS "t0"(c0, c1)',
read={
"spark": "SELECT `c0`, `c1` FROM (VALUES (1, 2), (3, 4)) AS `t0`(`c0`, `c1`)",
},
)

View file

@ -101,6 +101,18 @@ TBLPROPERTIES (
"spark": "CACHE TABLE testCache OPTIONS('storageLevel' = 'DISK_ONLY') AS SELECT * FROM testData"
},
)
self.validate_all(
"ALTER TABLE StudentInfo ADD COLUMNS (LastName STRING, DOB TIMESTAMP)",
write={
"spark": "ALTER TABLE StudentInfo ADD COLUMNS (LastName STRING, DOB TIMESTAMP)",
},
)
self.validate_all(
"ALTER TABLE StudentInfo DROP COLUMNS (LastName, DOB)",
write={
"spark": "ALTER TABLE StudentInfo DROP COLUMNS (LastName, DOB)",
},
)
def test_to_date(self):
self.validate_all(

View file

@ -431,11 +431,11 @@ class TestTSQL(Validator):
def test_string(self):
self.validate_all(
"SELECT N'test'",
write={"spark": "SELECT 'test'"},
write={"spark": "SELECT N'test'"},
)
self.validate_all(
"SELECT n'test'",
write={"spark": "SELECT 'test'"},
write={"spark": "SELECT N'test'"},
)
self.validate_all(
"SELECT '''test'''",

View file

@ -17,6 +17,7 @@ SUM(CASE WHEN x > 1 THEN 1 ELSE 0 END) / y
'\x'
"x"
""
N'abc'
x
x % 1
x < 1
@ -33,6 +34,10 @@ x << 1
x >> 1
x >> 1 | 1 & 1 ^ 1
x || y
x[ : ]
x[1 : ]
x[1 : 2]
x[-4 : -1]
1 - -1
- -5
dec.x + y
@ -62,6 +67,8 @@ x BETWEEN 'a' || b AND 'c' || d
NOT x IS NULL
x IS TRUE
x IS FALSE
x IS TRUE IS TRUE
x LIKE y IS TRUE
time
zone
ARRAY<TEXT>
@ -93,10 +100,11 @@ x LIKE '%y%' ESCAPE '\'
x ILIKE '%y%' ESCAPE '\'
1 AS escape
INTERVAL '1' day
INTERVAL '1' month
INTERVAL '1' MONTH
INTERVAL '1 day'
INTERVAL 2 months
INTERVAL 1 + 3 days
INTERVAL 1 + 3 DAYS
CAST('45' AS INTERVAL DAYS)
TIMESTAMP_DIFF(CURRENT_TIMESTAMP(), 1, DAY)
DATETIME_DIFF(CURRENT_DATE, 1, DAY)
QUANTILE(x, 0.5)
@ -144,6 +152,7 @@ SELECT 1 AS count FROM test
SELECT 1 AS comment FROM test
SELECT 1 AS numeric FROM test
SELECT 1 AS number FROM test
SELECT COALESCE(offset, 1)
SELECT t.count
SELECT DISTINCT x FROM test
SELECT DISTINCT x, y FROM test
@ -196,6 +205,7 @@ SELECT JSON_EXTRACT_SCALAR(x, '$.name')
SELECT x LIKE '%x%' FROM test
SELECT * FROM test LIMIT 100
SELECT * FROM test LIMIT 100 OFFSET 200
SELECT * FROM test FETCH FIRST ROWS ONLY
SELECT * FROM test FETCH FIRST 1 ROWS ONLY
SELECT * FROM test FETCH NEXT 1 ROWS ONLY
SELECT (1 > 2) AS x FROM test
@ -460,6 +470,7 @@ CREATE TABLE z (end INT)
CREATE TABLE z (a ARRAY<TEXT>, b MAP<TEXT, DOUBLE>, c DECIMAL(5, 3))
CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECIMAL(5, 3))
CREATE TABLE z (a INT(11) DEFAULT UUID())
CREATE TABLE z (n INT DEFAULT 0 NOT NULL)
CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id')
CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1)
CREATE TABLE z (a INT(11) NOT NULL DEFAULT -1)
@ -511,7 +522,13 @@ INSERT OVERWRITE TABLE a.b PARTITION(ds) SELECT x FROM y
INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD') SELECT x FROM y
INSERT OVERWRITE TABLE a.b PARTITION(ds, hour) SELECT x FROM y
INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD', hour='hh') SELECT x FROM y
ALTER AGGREGATE bla(foo) OWNER TO CURRENT_USER
ALTER RULE foo ON bla RENAME TO baz
ALTER ROLE CURRENT_USER WITH REPLICATION
ALTER SEQUENCE IF EXISTS baz RESTART WITH boo
ALTER TYPE electronic_mail RENAME TO email
ALTER VIEW foo ALTER COLUMN bla SET DEFAULT 'NOT SET'
ALTER DOMAIN foo VALIDATE CONSTRAINT bla
ANALYZE a.y
DELETE FROM x WHERE y > 1
DELETE FROM y
@ -596,3 +613,17 @@ SELECT x AS INTO FROM bla
SELECT * INTO newevent FROM event
SELECT * INTO TEMPORARY newevent FROM event
SELECT * INTO UNLOGGED newevent FROM event
ALTER TABLE integers ADD COLUMN k INT
ALTER TABLE integers ADD COLUMN IF NOT EXISTS k INT
ALTER TABLE IF EXISTS integers ADD COLUMN k INT
ALTER TABLE integers ADD COLUMN l INT DEFAULT 10
ALTER TABLE measurements ADD COLUMN mtime TIMESTAMPTZ DEFAULT NOW()
ALTER TABLE integers DROP COLUMN k
ALTER TABLE integers DROP COLUMN IF EXISTS k
ALTER TABLE integers DROP COLUMN k CASCADE
ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR
ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR USING CONCAT(i, '_', j)
ALTER TABLE integers ALTER COLUMN i SET DEFAULT 10
ALTER TABLE integers ALTER COLUMN i DROP DEFAULT
ALTER TABLE mydataset.mytable DROP COLUMN A, DROP COLUMN IF EXISTS B
ALTER TABLE mydataset.mytable ADD COLUMN A TEXT, ADD COLUMN IF NOT EXISTS B INT

View file

@ -1,11 +1,11 @@
SELECT w.d + w.e AS c FROM w AS w;
SELECT CONCAT(w.d, w.e) AS c FROM w AS w;
SELECT CONCAT("w"."d", "w"."e") AS "c" FROM "w" AS "w";
SELECT CAST(w.d AS DATE) > w.e AS a FROM w AS w;
SELECT CAST(w.d AS DATE) > CAST(w.e AS DATE) AS a FROM w AS w;
SELECT CAST("w"."d" AS DATE) > CAST("w"."e" AS DATE) AS "a" FROM "w" AS "w";
SELECT CAST(1 AS VARCHAR) AS a FROM w AS w;
SELECT CAST(1 AS VARCHAR) AS a FROM w AS w;
SELECT CAST(1 AS VARCHAR) AS "a" FROM "w" AS "w";
SELECT CAST(1 + 3.2 AS DOUBLE) AS a FROM w AS w;
SELECT 1 + 3.2 AS a FROM w AS w;
SELECT 1 + 3.2 AS "a" FROM "w" AS "w";

View file

@ -291,3 +291,81 @@ SELECT a1 FROM cte1;
SELECT
"x"."a" AS "a1"
FROM "x" AS "x";
# title: recursive cte
WITH RECURSIVE cte1 AS (
SELECT *
FROM (
SELECT 1 AS a, 2 AS b
) base
CROSS JOIN (SELECT 3 c) y
UNION ALL
SELECT *
FROM cte1
WHERE a < 1
)
SELECT *
FROM cte1;
WITH RECURSIVE "base" AS (
SELECT
1 AS "a",
2 AS "b"
), "y" AS (
SELECT
3 AS "c"
), "cte1" AS (
SELECT
"base"."a" AS "a",
"base"."b" AS "b",
"y"."c" AS "c"
FROM "base" AS "base"
CROSS JOIN "y" AS "y"
UNION ALL
SELECT
"cte1"."a" AS "a",
"cte1"."b" AS "b",
"cte1"."c" AS "c"
FROM "cte1"
WHERE
"cte1"."a" < 1
)
SELECT
"cte1"."a" AS "a",
"cte1"."b" AS "b",
"cte1"."c" AS "c"
FROM "cte1";
# title: right join should not push down to from
SELECT x.a, y.b
FROM x
RIGHT JOIN y
ON x.a = y.b
WHERE x.b = 1;
SELECT
"x"."a" AS "a",
"y"."b" AS "b"
FROM "x" AS "x"
RIGHT JOIN "y" AS "y"
ON "x"."a" = "y"."b"
WHERE
"x"."b" = 1;
# title: right join can push down to itself
SELECT x.a, y.b
FROM x
RIGHT JOIN y
ON x.a = y.b
WHERE y.b = 1;
WITH "y_2" AS (
SELECT
"y"."b" AS "b"
FROM "y" AS "y"
WHERE
"y"."b" = 1
)
SELECT
"x"."a" AS "a",
"y"."b" AS "b"
FROM "x" AS "x"
RIGHT JOIN "y_2" AS "y"
ON "x"."a" = "y"."b";

View file

@ -1,32 +1,32 @@
SELECT a FROM (SELECT * FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0";
SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0;
SELECT 1 FROM (SELECT * FROM x) WHERE b = 2;
SELECT 1 AS "_col_0" FROM (SELECT x.b AS b FROM x AS x) AS "_q_0" WHERE "_q_0".b = 2;
SELECT 1 AS _col_0 FROM (SELECT x.b AS b FROM x AS x) AS _q_0 WHERE _q_0.b = 2;
SELECT (SELECT c FROM y WHERE q.b = y.b) FROM (SELECT * FROM x) AS q;
SELECT (SELECT y.c AS c FROM y AS y WHERE q.b = y.b) AS "_col_0" FROM (SELECT x.b AS b FROM x AS x) AS q;
SELECT (SELECT y.c AS c FROM y AS y WHERE q.b = y.b) AS _col_0 FROM (SELECT x.b AS b FROM x AS x) AS q;
SELECT a FROM x JOIN (SELECT b, c FROM y) AS z ON x.b = z.b;
SELECT x.a AS a FROM x AS x JOIN (SELECT y.b AS b FROM y AS y) AS z ON x.b = z.b;
SELECT x1.a FROM (SELECT * FROM x) AS x1, (SELECT * FROM x) AS x2;
SELECT x1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS x1, (SELECT 1 AS "_" FROM x AS x) AS x2;
SELECT x1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS x1, (SELECT 1 AS _ FROM x AS x) AS x2;
SELECT x1.a FROM (SELECT * FROM x) AS x1, (SELECT * FROM x) AS x2;
SELECT x1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS x1, (SELECT 1 AS "_" FROM x AS x) AS x2;
SELECT x1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS x1, (SELECT 1 AS _ FROM x AS x) AS x2;
SELECT a FROM (SELECT DISTINCT a, b FROM x);
SELECT "_q_0".a AS a FROM (SELECT DISTINCT x.a AS a, x.b AS b FROM x AS x) AS "_q_0";
SELECT _q_0.a AS a FROM (SELECT DISTINCT x.a AS a, x.b AS b FROM x AS x) AS _q_0;
SELECT a FROM (SELECT a, b FROM x UNION ALL SELECT a, b FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION ALL SELECT x.a AS a FROM x AS x) AS "_q_0";
SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x UNION ALL SELECT x.a AS a FROM x AS x) AS _q_0;
WITH t1 AS (SELECT x.a AS a, x.b AS b FROM x UNION ALL SELECT z.b AS b, z.c AS c FROM z) SELECT a, b FROM t1;
WITH t1 AS (SELECT x.a AS a, x.b AS b FROM x AS x UNION ALL SELECT z.b AS b, z.c AS c FROM z AS z) SELECT t1.a AS a, t1.b AS b FROM t1;
SELECT a FROM (SELECT a, b FROM x UNION SELECT a, b FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x UNION SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0";
SELECT _q_0.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x UNION SELECT x.a AS a, x.b AS b FROM x AS x) AS _q_0;
WITH y AS (SELECT * FROM x) SELECT a FROM y;
WITH y AS (SELECT x.a AS a FROM x AS x) SELECT y.a AS a FROM y;
@ -38,10 +38,10 @@ WITH z AS (SELECT * FROM x) SELECT a FROM z UNION SELECT a FROM z;
WITH z AS (SELECT x.a AS a FROM x AS x) SELECT z.a AS a FROM z UNION SELECT z.a AS a FROM z;
SELECT b FROM (SELECT a, SUM(b) AS b FROM x GROUP BY a);
SELECT "_q_0".b AS b FROM (SELECT SUM(x.b) AS b FROM x AS x GROUP BY x.a) AS "_q_0";
SELECT _q_0.b AS b FROM (SELECT SUM(x.b) AS b FROM x AS x GROUP BY x.a) AS _q_0;
SELECT b FROM (SELECT a, SUM(b) AS b FROM x ORDER BY a);
SELECT "_q_0".b AS b FROM (SELECT x.a AS a, SUM(x.b) AS b FROM x AS x ORDER BY a) AS "_q_0";
SELECT _q_0.b AS b FROM (SELECT x.a AS a, SUM(x.b) AS b FROM x AS x ORDER BY a) AS _q_0;
SELECT x FROM (VALUES(1, 2)) AS q(x, y);
SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y);

View file

@ -21,15 +21,15 @@ SELECT x.a AS b FROM x AS x;
# execute: false
SELECT 1, 2 FROM x;
SELECT 1 AS "_col_0", 2 AS "_col_1" FROM x AS x;
SELECT 1 AS _col_0, 2 AS _col_1 FROM x AS x;
# execute: false
SELECT a + b FROM x;
SELECT x.a + x.b AS "_col_0" FROM x AS x;
SELECT x.a + x.b AS _col_0 FROM x AS x;
# execute: false
SELECT a, SUM(b) FROM x WHERE a > 1 AND b > 1 GROUP BY a;
SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 AND x.b > 1 GROUP BY x.a;
SELECT x.a AS a, SUM(x.b) AS _col_1 FROM x AS x WHERE x.a > 1 AND x.b > 1 GROUP BY x.a;
SELECT SUM(a) AS c FROM x HAVING SUM(a) > 3;
SELECT SUM(x.a) AS c FROM x AS x HAVING SUM(x.a) > 3;
@ -59,7 +59,7 @@ SELECT SUM(x.a) AS c, SUM(x.b) AS d FROM x AS x ORDER BY SUM(x.a), SUM(x.b);
# execute: false
SELECT SUM(a), SUM(b) AS c FROM x ORDER BY 1, 2;
SELECT SUM(x.a) AS "_col_0", SUM(x.b) AS c FROM x AS x ORDER BY SUM(x.a), SUM(x.b);
SELECT SUM(x.a) AS _col_0, SUM(x.b) AS c FROM x AS x ORDER BY SUM(x.a), SUM(x.b);
SELECT a AS j, b FROM x GROUP BY j, b;
SELECT x.a AS j, x.b AS b FROM x AS x GROUP BY x.a, x.b;
@ -72,7 +72,7 @@ SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a, x.b;
# execute: false
SELECT DATE(a), DATE(b) AS c FROM x GROUP BY 1, 2;
SELECT DATE(x.a) AS "_col_0", DATE(x.b) AS c FROM x AS x GROUP BY DATE(x.a), DATE(x.b);
SELECT DATE(x.a) AS _col_0, DATE(x.b) AS c FROM x AS x GROUP BY DATE(x.a), DATE(x.b);
SELECT SUM(x.a) AS c FROM x JOIN y ON x.b = y.b GROUP BY c;
SELECT SUM(x.a) AS c FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY y.c;
@ -130,10 +130,10 @@ SELECT a FROM (SELECT a FROM x AS x) y;
SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y;
SELECT a FROM (SELECT a AS a FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0";
SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0;
SELECT a FROM (SELECT a FROM (SELECT a FROM x));
SELECT "_q_1".a AS a FROM (SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0") AS "_q_1";
SELECT _q_1.a AS a FROM (SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0) AS _q_1;
SELECT x.a FROM x AS x JOIN (SELECT * FROM x) AS y ON x.a = y.a;
SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a, x.b AS b FROM x AS x) AS y ON x.a = y.a;
@ -157,7 +157,7 @@ SELECT a FROM x UNION SELECT a FROM x UNION SELECT a FROM x;
SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x;
SELECT a FROM (SELECT a FROM x UNION SELECT a FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x) AS "_q_0";
SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x) AS _q_0;
--------------------------------------
-- Subqueries
@ -167,10 +167,10 @@ SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT y.c AS c FROM y AS y);
# execute: false
SELECT (SELECT c FROM y) FROM x;
SELECT (SELECT y.c AS c FROM y AS y) AS "_col_0" FROM x AS x;
SELECT (SELECT y.c AS c FROM y AS y) AS _col_0 FROM x AS x;
SELECT a FROM (SELECT a FROM x) WHERE a IN (SELECT b FROM (SELECT b FROM y));
SELECT "_q_1".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_1" WHERE "_q_1".a IN (SELECT "_q_0".b AS b FROM (SELECT y.b AS b FROM y AS y) AS "_q_0");
SELECT _q_1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_1 WHERE _q_1.a IN (SELECT _q_0.b AS b FROM (SELECT y.b AS b FROM y AS y) AS _q_0);
--------------------------------------
-- Correlated subqueries
@ -215,10 +215,10 @@ SELECT x.*, y.* FROM x JOIN y ON x.b = y.b;
SELECT x.a AS a, x.b AS b, y.b AS b, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
SELECT a FROM (SELECT * FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0";
SELECT _q_0.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x) AS _q_0;
SELECT * FROM (SELECT a FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0";
SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0;
--------------------------------------
-- CTEs

View file

@ -11,10 +11,10 @@ SELECT x.b AS b FROM x AS x;
-- Derived tables
--------------------------------------
SELECT x.a FROM x AS x JOIN (SELECT * FROM x);
SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a FROM x AS x) AS "_q_0";
SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a FROM x AS x) AS _q_0;
SELECT x.b FROM x AS x JOIN (SELECT b FROM x);
SELECT x.b AS b FROM x AS x JOIN (SELECT x.b AS b FROM x AS x) AS "_q_0";
SELECT x.b AS b FROM x AS x JOIN (SELECT x.b AS b FROM x AS x) AS _q_0;
--------------------------------------
-- Expand *
@ -29,7 +29,7 @@ SELECT * FROM y JOIN z ON y.c = z.c;
SELECT y.b AS b, z.b AS b FROM y AS y JOIN z AS z ON y.c = z.c;
SELECT a FROM (SELECT * FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0";
SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0;
SELECT * FROM (SELECT a FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0";
SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0;

View file

@ -30,14 +30,14 @@ CROSS JOIN (
SELECT
SUM(y.a) AS a
FROM y
) AS "_u_0"
) AS _u_0
LEFT JOIN (
SELECT
y.a AS a
FROM y
GROUP BY
y.a
) AS "_u_1"
) AS _u_1
ON x.a = "_u_1"."a"
LEFT JOIN (
SELECT
@ -45,7 +45,7 @@ LEFT JOIN (
FROM y
GROUP BY
y.b
) AS "_u_2"
) AS _u_2
ON x.a = "_u_2"."b"
LEFT JOIN (
SELECT
@ -53,7 +53,7 @@ LEFT JOIN (
FROM y
GROUP BY
y.a
) AS "_u_3"
) AS _u_3
ON x.a = "_u_3"."a"
LEFT JOIN (
SELECT
@ -64,8 +64,8 @@ LEFT JOIN (
TRUE
GROUP BY
y.a
) AS "_u_4"
ON x.a = "_u_4"."_u_5"
) AS _u_4
ON x.a = _u_4._u_5
LEFT JOIN (
SELECT
SUM(y.b) AS b,
@ -75,8 +75,8 @@ LEFT JOIN (
TRUE
GROUP BY
y.a
) AS "_u_6"
ON x.a = "_u_6"."_u_7"
) AS _u_6
ON x.a = _u_6._u_7
LEFT JOIN (
SELECT
y.a AS a
@ -85,8 +85,8 @@ LEFT JOIN (
TRUE
GROUP BY
y.a
) AS "_u_8"
ON "_u_8".a = x.a
) AS _u_8
ON _u_8.a = x.a
LEFT JOIN (
SELECT
y.a AS a
@ -95,8 +95,8 @@ LEFT JOIN (
TRUE
GROUP BY
y.a
) AS "_u_9"
ON "_u_9".a = x.a
) AS _u_9
ON _u_9.a = x.a
LEFT JOIN (
SELECT
ARRAY_AGG(y.a) AS a,
@ -106,8 +106,8 @@ LEFT JOIN (
TRUE
GROUP BY
y.b
) AS "_u_10"
ON "_u_10"."_u_11" = x.a
) AS _u_10
ON _u_10._u_11 = x.a
LEFT JOIN (
SELECT
SUM(y.a) AS a,
@ -118,8 +118,8 @@ LEFT JOIN (
TRUE AND TRUE AND TRUE
GROUP BY
y.a
) AS "_u_12"
ON "_u_12"."_u_13" = x.a AND "_u_12"."_u_13" = x.b
) AS _u_12
ON _u_12._u_13 = x.a AND _u_12._u_13 = x.b
LEFT JOIN (
SELECT
y.a AS a
@ -128,38 +128,38 @@ LEFT JOIN (
TRUE
GROUP BY
y.a
) AS "_u_15"
ON x.a = "_u_15".a
) AS _u_15
ON x.a = _u_15.a
WHERE
x.a = "_u_0".a
x.a = _u_0.a
AND NOT "_u_1"."a" IS NULL
AND NOT "_u_2"."b" IS NULL
AND NOT "_u_3"."a" IS NULL
AND (
x.a = "_u_4".b AND NOT "_u_4"."_u_5" IS NULL
x.a = _u_4.b AND NOT _u_4._u_5 IS NULL
)
AND (
x.a > "_u_6".b AND NOT "_u_6"."_u_7" IS NULL
x.a > _u_6.b AND NOT _u_6._u_7 IS NULL
)
AND (
None = "_u_8".a AND NOT "_u_8".a IS NULL
None = _u_8.a AND NOT _u_8.a IS NULL
)
AND NOT (
x.a = "_u_9".a AND NOT "_u_9".a IS NULL
x.a = _u_9.a AND NOT _u_9.a IS NULL
)
AND (
ARRAY_ANY("_u_10".a, _x -> _x = x.a) AND NOT "_u_10"."_u_11" IS NULL
ARRAY_ANY(_u_10.a, _x -> _x = x.a) AND NOT _u_10._u_11 IS NULL
)
AND (
(
(
x.a < "_u_12".a AND NOT "_u_12"."_u_13" IS NULL
) AND NOT "_u_12"."_u_13" IS NULL
x.a < _u_12.a AND NOT _u_12._u_13 IS NULL
) AND NOT _u_12._u_13 IS NULL
)
AND ARRAY_ANY("_u_12"."_u_14", "_x" -> "_x" <> x.d)
AND ARRAY_ANY(_u_12._u_14, "_x" -> _x <> x.d)
)
AND (
NOT "_u_15".a IS NULL AND NOT "_u_15".a IS NULL
NOT _u_15.a IS NULL AND NOT _u_15.a IS NULL
)
AND x.a IN (
SELECT

View file

@ -481,6 +481,19 @@ class TestBuild(unittest.TestCase):
),
(lambda: exp.delete("y", where="x > 1"), "DELETE FROM y WHERE x > 1"),
(lambda: exp.delete("y", where=exp.and_("x > 1")), "DELETE FROM y WHERE x > 1"),
(
lambda: select("AVG(a) OVER b")
.from_("table")
.window("b AS (PARTITION BY c ORDER BY d)"),
"SELECT AVG(a) OVER b FROM table WINDOW b AS (PARTITION BY c ORDER BY d)",
),
(
lambda: select("AVG(a) OVER b", "MIN(c) OVER d")
.from_("table")
.window("b AS (PARTITION BY e ORDER BY f)")
.window("d AS (PARTITION BY g ORDER BY h)"),
"SELECT AVG(a) OVER b, MIN(c) OVER d FROM table WINDOW b AS (PARTITION BY e ORDER BY f), d AS (PARTITION BY g ORDER BY h)",
),
]:
with self.subTest(sql):
self.assertEqual(expression().sql(dialect[0] if dialect else None), sql)

View file

@ -74,7 +74,7 @@ class TestExecutor(unittest.TestCase):
)
return expression
for i, (sql, _) in enumerate(self.sqls[0:18]):
for i, (sql, _) in enumerate(self.sqls):
with self.subTest(f"tpch-h {i + 1}"):
a = self.cached_execute(sql)
sql = parse_one(sql).transform(to_csv).sql(pretty=True)

View file

@ -1,4 +1,5 @@
import datetime
import math
import unittest
from sqlglot import alias, exp, parse_one
@ -491,7 +492,7 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(alias("foo", "bar-1").sql(), 'foo AS "bar-1"')
self.assertEqual(alias("foo", "bar_1").sql(), "foo AS bar_1")
self.assertEqual(alias("foo * 2", "2bar").sql(), 'foo * 2 AS "2bar"')
self.assertEqual(alias('"foo"', "_bar").sql(), '"foo" AS "_bar"')
self.assertEqual(alias('"foo"', "_bar").sql(), '"foo" AS _bar')
self.assertEqual(alias("foo", "bar", quoted=True).sql(), 'foo AS "bar"')
def test_unit(self):
@ -503,6 +504,8 @@ class TestExpressions(unittest.TestCase):
def test_identifier(self):
self.assertTrue(exp.to_identifier('"x"').quoted)
self.assertFalse(exp.to_identifier("x").quoted)
self.assertTrue(exp.to_identifier("foo ").quoted)
self.assertFalse(exp.to_identifier("_x").quoted)
def test_function_normalizer(self):
self.assertEqual(parse_one("HELLO()").sql(normalize_functions="lower"), "hello()")
@ -549,14 +552,15 @@ class TestExpressions(unittest.TestCase):
([1, "2", None], "ARRAY(1, '2', NULL)"),
({"x": None}, "MAP('x', NULL)"),
(
datetime.datetime(2022, 10, 1, 1, 1, 1),
"TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')",
datetime.datetime(2022, 10, 1, 1, 1, 1, 1),
"TIME_STR_TO_TIME('2022-10-01T01:01:01.000001+00:00')",
),
(
datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc),
"TIME_STR_TO_TIME('2022-10-01 01:01:01.000000+0000')",
"TIME_STR_TO_TIME('2022-10-01T01:01:01+00:00')",
),
(datetime.date(2022, 10, 1), "DATE_STR_TO_DATE('2022-10-01')"),
(math.nan, "NULL"),
]:
with self.subTest(value):
self.assertEqual(exp.convert(value).sql(), expected)

View file

@ -164,9 +164,6 @@ class TestOptimizer(unittest.TestCase):
with self.assertRaises(OptimizeError):
optimizer.qualify_columns.qualify_columns(parse_one(sql), schema=self.schema)
def test_quote_identities(self):
self.check_file("quote_identities", optimizer.quote_identities.quote_identities)
def test_lower_identities(self):
self.check_file("lower_identities", optimizer.lower_identities.lower_identities)
@ -555,3 +552,29 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), schema=schema
)
self.assertEqual(expression.expressions[0].type.this, target_type)
def test_recursive_cte(self):
query = parse_one(
"""
with recursive t(n) AS
(
select 1
union all
select n + 1
FROM t
where n < 3
), y AS (
select n
FROM t
union all
select n + 1
FROM y
where n < 2
)
select * from y
"""
)
scope_t, scope_y = build_scope(query).cte_scopes
self.assertEqual(set(scope_t.cte_sources), {"t"})
self.assertEqual(set(scope_y.cte_sources), {"t", "y"})

View file

@ -76,6 +76,9 @@ class TestParser(unittest.TestCase):
tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)]
self.assertEqual(tables, ["a", "b.c", "d"])
def test_union_order(self):
self.assertIsInstance(parse_one("SELECT * FROM (SELECT 1) UNION SELECT 2"), exp.Union)
def test_select(self):
self.assertIsNotNone(parse_one("select 1 natural"))
self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"])

View file

@ -40,17 +40,17 @@ class TestTime(unittest.TestCase):
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1',
'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1',
)
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (a) a, b FROM x",
'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a) AS "_row_number" FROM x) WHERE "_row_number" = 1',
'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a) AS _row_number FROM x) WHERE "_row_number" = 1',
)
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (a, b) a, b FROM x ORDER BY c DESC",
'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1',
'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1',
)
self.validate(
eliminate_distinct_on,
@ -60,5 +60,5 @@ class TestTime(unittest.TestCase):
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC",
'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS "_row_number_2" FROM x) WHERE "_row_number_2" = 1',
'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) WHERE "_row_number_2" = 1',
)

View file

@ -28,7 +28,7 @@ class TestTranspile(unittest.TestCase):
self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime")
self.assertEqual(transpile("SELECT 1 row")[0], "SELECT 1 AS row")
for key in ("union", "filter", "over", "from", "join"):
for key in ("union", "over", "from", "join"):
with self.subTest(f"alias {key}"):
self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}")
self.validate(f'SELECT x "{key}"', f'SELECT x AS "{key}"')
@ -263,6 +263,25 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""",
"WITH a AS (SELECT 1), WITH b AS (SELECT 2) SELECT *",
"WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *",
)
self.validate(
"WITH A(filter) AS (VALUES 1, 2, 3) SELECT * FROM A WHERE filter >= 2",
"WITH A(filter) AS (VALUES (1), (2), (3)) SELECT * FROM A WHERE filter >= 2",
)
def test_alter(self):
self.validate(
"ALTER TABLE integers ADD k INTEGER",
"ALTER TABLE integers ADD COLUMN k INT",
)
self.validate("ALTER TABLE integers DROP k", "ALTER TABLE integers DROP COLUMN k")
self.validate(
"ALTER TABLE integers ALTER i SET DATA TYPE VARCHAR",
"ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR",
)
self.validate(
"ALTER TABLE integers ALTER i TYPE VARCHAR COLLATE foo USING bar",
"ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR COLLATE foo USING bar",
)
def test_time(self):
self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)")
@ -403,6 +422,14 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""",
with self.subTest(sql):
self.assertEqual(transpile(sql)[0], sql.strip())
def test_normalize_name(self):
self.assertEqual(
transpile("cardinality(x)", read="presto", write="presto", normalize_functions="lower")[
0
],
"cardinality(x)",
)
def test_partial(self):
for sql in load_sql_fixtures("partial.sql"):
with self.subTest(sql):