Merging upstream version 10.5.2.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
77197f1e44
commit
e0f3bbb5f3
58 changed files with 1480 additions and 383 deletions
29
CHANGELOG.md
29
CHANGELOG.md
|
@ -1,6 +1,35 @@
|
||||||
Changelog
|
Changelog
|
||||||
=========
|
=========
|
||||||
|
|
||||||
|
v10.5.0
|
||||||
|
------
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
- Breaking: Added python type hints in the parser module, which may result in some mypy errors.
|
||||||
|
|
||||||
|
- New: SQLGlot expressions can [now be serialized / deserialized into JSON](https://github.com/tobymao/sqlglot/commit/bac38151a8d72687247922e6898696be43ff4992).
|
||||||
|
|
||||||
|
- New: Added support for T-SQL [hints](https://github.com/tobymao/sqlglot/commit/3220ec1adb1e1130b109677d03c9be947b03f9ca) and [EOMONTH](https://github.com/tobymao/sqlglot/commit/1ac05d9265667c883b9f6db5d825a6d864c95c73).
|
||||||
|
|
||||||
|
- New: Added support for Clickhouse's parametric function syntax.
|
||||||
|
|
||||||
|
- New: Added [wider support](https://github.com/tobymao/sqlglot/commit/beb660f943b73c730f1b06fce4986e26642ee8dc) for timestr and datestr.
|
||||||
|
|
||||||
|
- New: CLI now accepts a flag [for parsing SQL from the standard input stream](https://github.com/tobymao/sqlglot/commit/f89b38ebf3e24ba951ee8b249d73bbf48685928a).
|
||||||
|
|
||||||
|
- Improvement: Fixed BigQuery transpilation for [parameterized types and unnest](https://github.com/tobymao/sqlglot/pull/924).
|
||||||
|
|
||||||
|
- Improvement: Hive / Spark identifiers can now begin with a digit.
|
||||||
|
|
||||||
|
- Improvement: Bug fixes in [date/datetime simplification](https://github.com/tobymao/sqlglot/commit/b26b8d88af14f72d90c0019ec332d268a23b078f).
|
||||||
|
|
||||||
|
- Improvement: Bug fixes in [merge_subquery](https://github.com/tobymao/sqlglot/commit/e30e21b6c572d0931bfb5873cc6ac3949c6ef5aa).
|
||||||
|
|
||||||
|
- Improvement: Schema identifiers are now [converted to lowercase](https://github.com/tobymao/sqlglot/commit/8212032968a519c199b461eba1a2618e89bf0326) unless they're quoted.
|
||||||
|
|
||||||
|
- Improvement: Identifiers with a leading underscore are now regarded as [safe](https://github.com/tobymao/sqlglot/commit/de3b0804bb7606673d0bbb989997c13744957f7c#diff-7857fedd1d1451b1b9a5b8efaa1cc292c02e7ee4f0d04d7e2f9d5bfb9565802c) and hence are not quoted.
|
||||||
|
|
||||||
v10.4.0
|
v10.4.0
|
||||||
------
|
------
|
||||||
|
|
||||||
|
|
21
README.md
21
README.md
|
@ -1,12 +1,12 @@
|
||||||
# SQLGlot
|
# SQLGlot
|
||||||
|
|
||||||
SQLGlot is a no dependency Python SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between different dialects like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects.
|
SQLGlot is a no dependency Python SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between [18 different dialects](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects.
|
||||||
|
|
||||||
It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks) while being written purely in Python.
|
It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks) while being written purely in Python.
|
||||||
|
|
||||||
You can easily [customize](#custom-dialects) the parser, [analyze](#metadata) queries, traverse expression trees, and programmatically [build](#build-and-modify-sql) SQL.
|
You can easily [customize](#custom-dialects) the parser, [analyze](#metadata) queries, traverse expression trees, and programmatically [build](#build-and-modify-sql) SQL.
|
||||||
|
|
||||||
Syntax [errors](#parser-errors) are highlighted and dialect incompatibilities can warn or raise depending on configurations.
|
Syntax [errors](#parser-errors) are highlighted and dialect incompatibilities can warn or raise depending on configurations. However, it should be noted that the parser is very lenient when it comes to detecting errors, because it aims to consume as much SQL as possible. On one hand, this makes its implementation simpler, and thus more comprehensible, but on the other hand it means that syntax errors may sometimes go unnoticed.
|
||||||
|
|
||||||
Contributions are very welcome in SQLGlot; read the [contribution guide](https://github.com/tobymao/sqlglot/blob/main/CONTRIBUTING.md) to get started!
|
Contributions are very welcome in SQLGlot; read the [contribution guide](https://github.com/tobymao/sqlglot/blob/main/CONTRIBUTING.md) to get started!
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/
|
||||||
* [AST Diff](#ast-diff)
|
* [AST Diff](#ast-diff)
|
||||||
* [Custom Dialects](#custom-dialects)
|
* [Custom Dialects](#custom-dialects)
|
||||||
* [SQL Execution](#sql-execution)
|
* [SQL Execution](#sql-execution)
|
||||||
|
* [Used By](#used-by)
|
||||||
* [Documentation](#documentation)
|
* [Documentation](#documentation)
|
||||||
* [Run Tests and Lint](#run-tests-and-lint)
|
* [Run Tests and Lint](#run-tests-and-lint)
|
||||||
* [Benchmarks](#benchmarks)
|
* [Benchmarks](#benchmarks)
|
||||||
|
@ -165,7 +166,7 @@ for table in parse_one("SELECT * FROM x JOIN y JOIN z").find_all(exp.Table):
|
||||||
|
|
||||||
### Parser Errors
|
### Parser Errors
|
||||||
|
|
||||||
A syntax error will result in a parser error:
|
When the parser detects an error in the syntax, it raises a ParserError:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import sqlglot
|
import sqlglot
|
||||||
|
@ -283,13 +284,13 @@ print(
|
||||||
```sql
|
```sql
|
||||||
SELECT
|
SELECT
|
||||||
(
|
(
|
||||||
"x"."A" OR "x"."B" OR "x"."C"
|
"x"."a" OR "x"."b" OR "x"."c"
|
||||||
) AND (
|
) AND (
|
||||||
"x"."A" OR "x"."B" OR "x"."D"
|
"x"."a" OR "x"."b" OR "x"."d"
|
||||||
) AS "_col_0"
|
) AS "_col_0"
|
||||||
FROM "x" AS "x"
|
FROM "x" AS "x"
|
||||||
WHERE
|
WHERE
|
||||||
"x"."Z" = CAST('2021-02-01' AS DATE)
|
CAST("x"."z" AS DATE) = CAST('2021-02-01' AS DATE)
|
||||||
```
|
```
|
||||||
|
|
||||||
### AST Introspection
|
### AST Introspection
|
||||||
|
@ -432,6 +433,14 @@ user_id price
|
||||||
2 3.0
|
2 3.0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Used By
|
||||||
|
* [Fugue](https://github.com/fugue-project/fugue)
|
||||||
|
* [ibis](https://github.com/ibis-project/ibis)
|
||||||
|
* [mysql-mimic](https://github.com/kelsin/mysql-mimic)
|
||||||
|
* [Querybook](https://github.com/pinterest/querybook)
|
||||||
|
* [Quokka](https://github.com/marsupialtail/quokka)
|
||||||
|
* [Splink](https://github.com/moj-analytical-services/splink)
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
SQLGlot uses [pdocs](https://pdoc.dev/) to serve its API documentation:
|
SQLGlot uses [pdocs](https://pdoc.dev/) to serve its API documentation:
|
||||||
|
|
|
@ -23,7 +23,7 @@ SELECT
|
||||||
"e"."phone_number" AS "Phone",
|
"e"."phone_number" AS "Phone",
|
||||||
TO_CHAR("e"."hire_date", 'MM/DD/YYYY') AS "Hire Date",
|
TO_CHAR("e"."hire_date", 'MM/DD/YYYY') AS "Hire Date",
|
||||||
TO_CHAR("e"."salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') AS "Salary",
|
TO_CHAR("e"."salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') AS "Salary",
|
||||||
"e"."commission_pct" AS "Comission %",
|
"e"."commission_pct" AS "Commission %",
|
||||||
'works as ' || "j"."job_title" || ' in ' || "d"."department_name" || ' department (manager: ' || "dm"."first_name" || ' ' || "dm"."last_name" || ') and immediate supervisor: ' || "m"."first_name" || ' ' || "m"."last_name" AS "Current Job",
|
'works as ' || "j"."job_title" || ' in ' || "d"."department_name" || ' department (manager: ' || "dm"."first_name" || ' ' || "dm"."last_name" || ') and immediate supervisor: ' || "m"."first_name" || ' ' || "m"."last_name" AS "Current Job",
|
||||||
TO_CHAR("j"."min_salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') || ' - ' || TO_CHAR("j"."max_salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') AS "Current Salary",
|
TO_CHAR("j"."min_salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') || ' - ' || TO_CHAR("j"."max_salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') AS "Current Salary",
|
||||||
"l"."street_address" || ', ' || "l"."postal_code" || ', ' || "l"."city" || ', ' || "l"."state_province" || ', ' || "c"."country_name" || ' (' || "r"."region_name" || ')' AS "Location",
|
"l"."street_address" || ', ' || "l"."postal_code" || ', ' || "l"."city" || ', ' || "l"."state_province" || ', ' || "c"."country_name" || ' (' || "r"."region_name" || ')' AS "Location",
|
||||||
|
|
|
@ -14,13 +14,13 @@ This post will cover [why](#why) I went through the effort of creating a Python
|
||||||
* [Executing](#executing)
|
* [Executing](#executing)
|
||||||
|
|
||||||
## Why?
|
## Why?
|
||||||
I started working on SQLGlot because of my work on the [experimentation and metrics platform](https://netflixtechblog.com/reimagining-experimentation-analysis-at-netflix-71356393af21) at Netflix, where I built tools that allowed data scientists to define and compute SQL-based metrics. Netflix relied on multiple engines to query data (Spark, Presto, and Druid), so my team built the metrics platform around [PyPika](https://github.com/kayak/pypika), a Python SQL query builder. This way, definitions could be reused across multiple engines. However, it became quickly apparent that writing python code to programatically generate SQL was challenging for data scientists, especially those with academic backgrounds, since they were mostly familiar with R and SQL. At the time, the only Python SQL parser was [sqlparse]([https://github.com/andialbrecht/sqlparse), which is not actually a parser but a tokenizer, so having users write raw SQL into the platform wasn't really an option. Some time later, I randomly stumbled across [Crafting Interpreters](https://craftinginterpreters.com/) and realized that I could use it as a guide towards creating my own SQL parser/transpiler.
|
I started working on SQLGlot because of my work on the [experimentation and metrics platform](https://netflixtechblog.com/reimagining-experimentation-analysis-at-netflix-71356393af21) at Netflix, where I built tools that allowed data scientists to define and compute SQL-based metrics. Netflix relied on multiple engines to query data (Spark, Presto, and Druid), so my team built the metrics platform around [PyPika](https://github.com/kayak/pypika), a Python SQL query builder. This way, definitions could be reused across multiple engines. However, it became quickly apparent that writing python code to programmatically generate SQL was challenging for data scientists, especially those with academic backgrounds, since they were mostly familiar with R and SQL. At the time, the only Python SQL parser was [sqlparse]([https://github.com/andialbrecht/sqlparse), which is not actually a parser but a tokenizer, so having users write raw SQL into the platform wasn't really an option. Some time later, I randomly stumbled across [Crafting Interpreters](https://craftinginterpreters.com/) and realized that I could use it as a guide towards creating my own SQL parser/transpiler.
|
||||||
|
|
||||||
Why did I do this? Isn't a Python SQL engine going to be extremely slow?
|
Why did I do this? Isn't a Python SQL engine going to be extremely slow?
|
||||||
|
|
||||||
The main reason why I ended up building a SQL engine was...just for **entertainment**. It's been fun learning about all the things required to actually run a SQL query, and seeing it actually work is extremely rewarding. Before SQLGlot, I had zero experience with lexers, parsers, or compilers.
|
The main reason why I ended up building a SQL engine was...just for **entertainment**. It's been fun learning about all the things required to actually run a SQL query, and seeing it actually work is extremely rewarding. Before SQLGlot, I had zero experience with lexers, parsers, or compilers.
|
||||||
|
|
||||||
In terms of practical use cases, I planned to use the Python SQL engine for unit testing SQL pipelines. Big data pipelines are tough to test because many of the engines are not open source and cannot be run locally. With SQLGlot, you can take a SQL query targeting a warehouse such as [Snowflake](https://www.snowflake.com/en/) and seamlessly run it in CI on mock Python data. It's easy to mock data and create arbitrary [UDFs](https://en.wikipedia.org/wiki/User-defined_function) because everything is just Python. Although the implementation is slow and unsuitable for large amounts of data (> 1 millon rows), there's very little overhead/startup and you can run queries on test data in a couple of milliseconds.
|
In terms of practical use cases, I planned to use the Python SQL engine for unit testing SQL pipelines. Big data pipelines are tough to test because many of the engines are not open source and cannot be run locally. With SQLGlot, you can take a SQL query targeting a warehouse such as [Snowflake](https://www.snowflake.com/en/) and seamlessly run it in CI on mock Python data. It's easy to mock data and create arbitrary [UDFs](https://en.wikipedia.org/wiki/User-defined_function) because everything is just Python. Although the implementation is slow and unsuitable for large amounts of data (> 1 million rows), there's very little overhead/startup and you can run queries on test data in a couple of milliseconds.
|
||||||
|
|
||||||
Finally, the components that have been built to support execution can be used as a **foundation** for a faster engine. I'm inspired by what [Apache Calcite](https://github.com/apache/calcite) has done for the JVM world. Even though Python is commonly used for data, there hasn't been a Calcite for Python. So, you could say that SQLGlot aims to be that framework. For example, it wouldn't take much work to replace the Python execution engine with numpy/pandas/arrow to become a respectably-performing query engine. The implementation would be able to leverage the parser, optimizer, and logical planner, only needing to implement physical execution. There is a lot of work in the Python ecosystem around high performance vectorized computation, which I think could benefit from a pure Python-based [AST](https://en.wikipedia.org/wiki/Abstract_syntax_tree)/[plan](https://en.wikipedia.org/wiki/Query_plan). Parsing and planning doesn't have to be fast when the bottleneck of running queries is processing terabytes of data. So, having a Python-based ecosystem around SQL is beneficial given the ease of development in Python, despite not having bare metal performance.
|
Finally, the components that have been built to support execution can be used as a **foundation** for a faster engine. I'm inspired by what [Apache Calcite](https://github.com/apache/calcite) has done for the JVM world. Even though Python is commonly used for data, there hasn't been a Calcite for Python. So, you could say that SQLGlot aims to be that framework. For example, it wouldn't take much work to replace the Python execution engine with numpy/pandas/arrow to become a respectably-performing query engine. The implementation would be able to leverage the parser, optimizer, and logical planner, only needing to implement physical execution. There is a lot of work in the Python ecosystem around high performance vectorized computation, which I think could benefit from a pure Python-based [AST](https://en.wikipedia.org/wiki/Abstract_syntax_tree)/[plan](https://en.wikipedia.org/wiki/Query_plan). Parsing and planning doesn't have to be fast when the bottleneck of running queries is processing terabytes of data. So, having a Python-based ecosystem around SQL is beneficial given the ease of development in Python, despite not having bare metal performance.
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ Once we have our AST, we can transform it into an equivalent query that produces
|
||||||
|
|
||||||
1. It's easier to debug and [validate](https://github.com/tobymao/sqlglot/blob/main/tests/fixtures/optimizer) the optimizations when the input and output are both SQL.
|
1. It's easier to debug and [validate](https://github.com/tobymao/sqlglot/blob/main/tests/fixtures/optimizer) the optimizations when the input and output are both SQL.
|
||||||
|
|
||||||
2. Rules can be applied a la carte to transform SQL into a more desireable form.
|
2. Rules can be applied a la carte to transform SQL into a more desirable form.
|
||||||
|
|
||||||
3. I wanted a way to generate 'canonical sql'. Having a canonical representation of SQL is useful for understanding if two queries are semantically equivalent (e.g. `SELECT 1 + 1` and `SELECT 2`).
|
3. I wanted a way to generate 'canonical sql'. Having a canonical representation of SQL is useful for understanding if two queries are semantically equivalent (e.g. `SELECT 1 + 1` and `SELECT 2`).
|
||||||
|
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -28,7 +28,7 @@ setup(
|
||||||
"black",
|
"black",
|
||||||
"duckdb",
|
"duckdb",
|
||||||
"isort",
|
"isort",
|
||||||
"mypy",
|
"mypy>=0.990",
|
||||||
"pandas",
|
"pandas",
|
||||||
"pyspark",
|
"pyspark",
|
||||||
"python-dateutil",
|
"python-dateutil",
|
||||||
|
|
|
@ -32,7 +32,7 @@ from sqlglot.parser import Parser
|
||||||
from sqlglot.schema import MappingSchema
|
from sqlglot.schema import MappingSchema
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
from sqlglot.tokens import Tokenizer, TokenType
|
||||||
|
|
||||||
__version__ = "10.4.2"
|
__version__ = "10.5.2"
|
||||||
|
|
||||||
pretty = False
|
pretty = False
|
||||||
|
|
||||||
|
@ -60,9 +60,9 @@ def parse(
|
||||||
def parse_one(
|
def parse_one(
|
||||||
sql: str,
|
sql: str,
|
||||||
read: t.Optional[str | Dialect] = None,
|
read: t.Optional[str | Dialect] = None,
|
||||||
into: t.Optional[Expression | str] = None,
|
into: t.Optional[t.Type[Expression] | str] = None,
|
||||||
**opts,
|
**opts,
|
||||||
) -> t.Optional[Expression]:
|
) -> Expression:
|
||||||
"""
|
"""
|
||||||
Parses the given SQL string and returns a syntax tree for the first parsed SQL statement.
|
Parses the given SQL string and returns a syntax tree for the first parsed SQL statement.
|
||||||
|
|
||||||
|
@ -83,7 +83,12 @@ def parse_one(
|
||||||
else:
|
else:
|
||||||
result = dialect.parse(sql, **opts)
|
result = dialect.parse(sql, **opts)
|
||||||
|
|
||||||
return result[0] if result else None
|
for expression in result:
|
||||||
|
if not expression:
|
||||||
|
raise ParseError(f"No expression was parsed from '{sql}'")
|
||||||
|
return expression
|
||||||
|
else:
|
||||||
|
raise ParseError(f"No expression was parsed from '{sql}'")
|
||||||
|
|
||||||
|
|
||||||
def transpile(
|
def transpile(
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from sqlglot import exp, generator, parser, tokens
|
from sqlglot import exp, generator, parser, tokens, transforms
|
||||||
from sqlglot.dialects.dialect import (
|
from sqlglot.dialects.dialect import (
|
||||||
Dialect,
|
Dialect,
|
||||||
datestrtodate_sql,
|
datestrtodate_sql,
|
||||||
|
@ -46,8 +46,9 @@ def _date_add_sql(data_type, kind):
|
||||||
|
|
||||||
def _derived_table_values_to_unnest(self, expression):
|
def _derived_table_values_to_unnest(self, expression):
|
||||||
if not isinstance(expression.unnest().parent, exp.From):
|
if not isinstance(expression.unnest().parent, exp.From):
|
||||||
|
expression = transforms.remove_precision_parameterized_types(expression)
|
||||||
return self.values_sql(expression)
|
return self.values_sql(expression)
|
||||||
rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.find_all(exp.Tuple)]
|
rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)]
|
||||||
structs = []
|
structs = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
aliases = [
|
aliases = [
|
||||||
|
@ -118,6 +119,7 @@ class BigQuery(Dialect):
|
||||||
"BEGIN TRANSACTION": TokenType.BEGIN,
|
"BEGIN TRANSACTION": TokenType.BEGIN,
|
||||||
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
|
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
|
||||||
"CURRENT_TIME": TokenType.CURRENT_TIME,
|
"CURRENT_TIME": TokenType.CURRENT_TIME,
|
||||||
|
"DECLARE": TokenType.COMMAND,
|
||||||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||||
"FLOAT64": TokenType.DOUBLE,
|
"FLOAT64": TokenType.DOUBLE,
|
||||||
"INT64": TokenType.BIGINT,
|
"INT64": TokenType.BIGINT,
|
||||||
|
@ -166,6 +168,7 @@ class BigQuery(Dialect):
|
||||||
class Generator(generator.Generator):
|
class Generator(generator.Generator):
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**generator.Generator.TRANSFORMS, # type: ignore
|
**generator.Generator.TRANSFORMS, # type: ignore
|
||||||
|
**transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore
|
||||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||||
exp.DateAdd: _date_add_sql("DATE", "ADD"),
|
exp.DateAdd: _date_add_sql("DATE", "ADD"),
|
||||||
exp.DateSub: _date_add_sql("DATE", "SUB"),
|
exp.DateSub: _date_add_sql("DATE", "SUB"),
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing as t
|
||||||
|
|
||||||
from sqlglot import exp, generator, parser, tokens
|
from sqlglot import exp, generator, parser, tokens
|
||||||
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
|
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
|
||||||
from sqlglot.parser import parse_var_map
|
from sqlglot.parser import parse_var_map
|
||||||
|
@ -22,6 +24,7 @@ class ClickHouse(Dialect):
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**tokens.Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"ASOF": TokenType.ASOF,
|
"ASOF": TokenType.ASOF,
|
||||||
|
"GLOBAL": TokenType.GLOBAL,
|
||||||
"DATETIME64": TokenType.DATETIME,
|
"DATETIME64": TokenType.DATETIME,
|
||||||
"FINAL": TokenType.FINAL,
|
"FINAL": TokenType.FINAL,
|
||||||
"FLOAT32": TokenType.FLOAT,
|
"FLOAT32": TokenType.FLOAT,
|
||||||
|
@ -37,14 +40,32 @@ class ClickHouse(Dialect):
|
||||||
FUNCTIONS = {
|
FUNCTIONS = {
|
||||||
**parser.Parser.FUNCTIONS, # type: ignore
|
**parser.Parser.FUNCTIONS, # type: ignore
|
||||||
"MAP": parse_var_map,
|
"MAP": parse_var_map,
|
||||||
|
"QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params),
|
||||||
|
"QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args),
|
||||||
|
"QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args),
|
||||||
|
}
|
||||||
|
|
||||||
|
RANGE_PARSERS = {
|
||||||
|
**parser.Parser.RANGE_PARSERS,
|
||||||
|
TokenType.GLOBAL: lambda self, this: self._match(TokenType.IN)
|
||||||
|
and self._parse_in(this, is_global=True),
|
||||||
}
|
}
|
||||||
|
|
||||||
JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore
|
JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore
|
||||||
|
|
||||||
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
|
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
|
||||||
|
|
||||||
def _parse_table(self, schema=False):
|
def _parse_in(
|
||||||
this = super()._parse_table(schema)
|
self, this: t.Optional[exp.Expression], is_global: bool = False
|
||||||
|
) -> exp.Expression:
|
||||||
|
this = super()._parse_in(this)
|
||||||
|
this.set("is_global", is_global)
|
||||||
|
return this
|
||||||
|
|
||||||
|
def _parse_table(
|
||||||
|
self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None
|
||||||
|
) -> t.Optional[exp.Expression]:
|
||||||
|
this = super()._parse_table(schema=schema, alias_tokens=alias_tokens)
|
||||||
|
|
||||||
if self._match(TokenType.FINAL):
|
if self._match(TokenType.FINAL):
|
||||||
this = self.expression(exp.Final, this=this)
|
this = self.expression(exp.Final, this=this)
|
||||||
|
@ -76,6 +97,16 @@ class ClickHouse(Dialect):
|
||||||
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
|
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
|
||||||
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
|
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||||
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
|
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||||
|
exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}",
|
||||||
|
exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}",
|
||||||
|
exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}",
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPLICIT_UNION = True
|
EXPLICIT_UNION = True
|
||||||
|
|
||||||
|
def _param_args_sql(
|
||||||
|
self, expression: exp.Expression, params_name: str, args_name: str
|
||||||
|
) -> str:
|
||||||
|
params = self.format_args(self.expressions(expression, params_name))
|
||||||
|
args = self.format_args(self.expressions(expression, args_name))
|
||||||
|
return f"({params})({args})"
|
||||||
|
|
|
@ -381,3 +381,20 @@ def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str:
|
||||||
|
|
||||||
def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str:
|
def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str:
|
||||||
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
||||||
|
|
||||||
|
|
||||||
|
def trim_sql(self, expression):
|
||||||
|
target = self.sql(expression, "this")
|
||||||
|
trim_type = self.sql(expression, "position")
|
||||||
|
remove_chars = self.sql(expression, "expression")
|
||||||
|
collation = self.sql(expression, "collation")
|
||||||
|
|
||||||
|
# Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
|
||||||
|
if not remove_chars and not collation:
|
||||||
|
return self.trim_sql(expression)
|
||||||
|
|
||||||
|
trim_type = f"{trim_type} " if trim_type else ""
|
||||||
|
remove_chars = f"{remove_chars} " if remove_chars else ""
|
||||||
|
from_part = "FROM " if trim_type or remove_chars else ""
|
||||||
|
collation = f" COLLATE {collation}" if collation else ""
|
||||||
|
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
|
||||||
|
|
|
@ -175,14 +175,6 @@ class Hive(Dialect):
|
||||||
ESCAPES = ["\\"]
|
ESCAPES = ["\\"]
|
||||||
ENCODE = "utf-8"
|
ENCODE = "utf-8"
|
||||||
|
|
||||||
NUMERIC_LITERALS = {
|
|
||||||
"L": "BIGINT",
|
|
||||||
"S": "SMALLINT",
|
|
||||||
"Y": "TINYINT",
|
|
||||||
"D": "DOUBLE",
|
|
||||||
"F": "FLOAT",
|
|
||||||
"BD": "DECIMAL",
|
|
||||||
}
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**tokens.Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"ADD ARCHIVE": TokenType.COMMAND,
|
"ADD ARCHIVE": TokenType.COMMAND,
|
||||||
|
@ -191,9 +183,21 @@ class Hive(Dialect):
|
||||||
"ADD FILES": TokenType.COMMAND,
|
"ADD FILES": TokenType.COMMAND,
|
||||||
"ADD JAR": TokenType.COMMAND,
|
"ADD JAR": TokenType.COMMAND,
|
||||||
"ADD JARS": TokenType.COMMAND,
|
"ADD JARS": TokenType.COMMAND,
|
||||||
|
"MSCK REPAIR": TokenType.COMMAND,
|
||||||
"WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
|
"WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NUMERIC_LITERALS = {
|
||||||
|
"L": "BIGINT",
|
||||||
|
"S": "SMALLINT",
|
||||||
|
"Y": "TINYINT",
|
||||||
|
"D": "DOUBLE",
|
||||||
|
"F": "FLOAT",
|
||||||
|
"BD": "DECIMAL",
|
||||||
|
}
|
||||||
|
|
||||||
|
IDENTIFIER_CAN_START_WITH_DIGIT = True
|
||||||
|
|
||||||
class Parser(parser.Parser):
|
class Parser(parser.Parser):
|
||||||
STRICT_CAST = False
|
STRICT_CAST = False
|
||||||
|
|
||||||
|
@ -315,6 +319,7 @@ class Hive(Dialect):
|
||||||
exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}",
|
exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}",
|
||||||
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
|
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
|
||||||
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
|
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
|
||||||
|
exp.LastDateOfMonth: rename_func("LAST_DAY"),
|
||||||
}
|
}
|
||||||
|
|
||||||
WITH_PROPERTIES = {exp.Property}
|
WITH_PROPERTIES = {exp.Property}
|
||||||
|
@ -342,4 +347,6 @@ class Hive(Dialect):
|
||||||
and not expression.expressions
|
and not expression.expressions
|
||||||
):
|
):
|
||||||
expression = exp.DataType.build("text")
|
expression = exp.DataType.build("text")
|
||||||
|
elif expression.this in exp.DataType.TEMPORAL_TYPES:
|
||||||
|
expression = exp.DataType.build(expression.this)
|
||||||
return super().datatype_sql(expression)
|
return super().datatype_sql(expression)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from sqlglot import exp, generator, parser, tokens, transforms
|
from sqlglot import exp, generator, parser, tokens, transforms
|
||||||
from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func
|
from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql
|
||||||
from sqlglot.helper import csv
|
from sqlglot.helper import csv
|
||||||
from sqlglot.tokens import TokenType
|
from sqlglot.tokens import TokenType
|
||||||
|
|
||||||
|
@ -64,6 +64,7 @@ class Oracle(Dialect):
|
||||||
**transforms.UNALIAS_GROUP, # type: ignore
|
**transforms.UNALIAS_GROUP, # type: ignore
|
||||||
exp.ILike: no_ilike_sql,
|
exp.ILike: no_ilike_sql,
|
||||||
exp.Limit: _limit_sql,
|
exp.Limit: _limit_sql,
|
||||||
|
exp.Trim: trim_sql,
|
||||||
exp.Matches: rename_func("DECODE"),
|
exp.Matches: rename_func("DECODE"),
|
||||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||||
|
|
|
@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
|
||||||
no_tablesample_sql,
|
no_tablesample_sql,
|
||||||
no_trycast_sql,
|
no_trycast_sql,
|
||||||
str_position_sql,
|
str_position_sql,
|
||||||
|
trim_sql,
|
||||||
)
|
)
|
||||||
from sqlglot.helper import seq_get
|
from sqlglot.helper import seq_get
|
||||||
from sqlglot.tokens import TokenType
|
from sqlglot.tokens import TokenType
|
||||||
|
@ -81,23 +82,6 @@ def _substring_sql(self, expression):
|
||||||
return f"SUBSTRING({this}{from_part}{for_part})"
|
return f"SUBSTRING({this}{from_part}{for_part})"
|
||||||
|
|
||||||
|
|
||||||
def _trim_sql(self, expression):
|
|
||||||
target = self.sql(expression, "this")
|
|
||||||
trim_type = self.sql(expression, "position")
|
|
||||||
remove_chars = self.sql(expression, "expression")
|
|
||||||
collation = self.sql(expression, "collation")
|
|
||||||
|
|
||||||
# Use TRIM/LTRIM/RTRIM syntax if the expression isn't postgres-specific
|
|
||||||
if not remove_chars and not collation:
|
|
||||||
return self.trim_sql(expression)
|
|
||||||
|
|
||||||
trim_type = f"{trim_type} " if trim_type else ""
|
|
||||||
remove_chars = f"{remove_chars} " if remove_chars else ""
|
|
||||||
from_part = "FROM " if trim_type or remove_chars else ""
|
|
||||||
collation = f" COLLATE {collation}" if collation else ""
|
|
||||||
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
|
|
||||||
|
|
||||||
|
|
||||||
def _string_agg_sql(self, expression):
|
def _string_agg_sql(self, expression):
|
||||||
expression = expression.copy()
|
expression = expression.copy()
|
||||||
separator = expression.args.get("separator") or exp.Literal.string(",")
|
separator = expression.args.get("separator") or exp.Literal.string(",")
|
||||||
|
@ -248,7 +232,6 @@ class Postgres(Dialect):
|
||||||
"COMMENT ON": TokenType.COMMAND,
|
"COMMENT ON": TokenType.COMMAND,
|
||||||
"DECLARE": TokenType.COMMAND,
|
"DECLARE": TokenType.COMMAND,
|
||||||
"DO": TokenType.COMMAND,
|
"DO": TokenType.COMMAND,
|
||||||
"DOUBLE PRECISION": TokenType.DOUBLE,
|
|
||||||
"GENERATED": TokenType.GENERATED,
|
"GENERATED": TokenType.GENERATED,
|
||||||
"GRANT": TokenType.COMMAND,
|
"GRANT": TokenType.COMMAND,
|
||||||
"HSTORE": TokenType.HSTORE,
|
"HSTORE": TokenType.HSTORE,
|
||||||
|
@ -318,7 +301,7 @@ class Postgres(Dialect):
|
||||||
exp.Substring: _substring_sql,
|
exp.Substring: _substring_sql,
|
||||||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||||
exp.TableSample: no_tablesample_sql,
|
exp.TableSample: no_tablesample_sql,
|
||||||
exp.Trim: _trim_sql,
|
exp.Trim: trim_sql,
|
||||||
exp.TryCast: no_trycast_sql,
|
exp.TryCast: no_trycast_sql,
|
||||||
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
|
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
|
||||||
exp.DataType: _datatype_sql,
|
exp.DataType: _datatype_sql,
|
||||||
|
|
|
@ -195,7 +195,6 @@ class Snowflake(Dialect):
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**tokens.Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"QUALIFY": TokenType.QUALIFY,
|
"QUALIFY": TokenType.QUALIFY,
|
||||||
"DOUBLE PRECISION": TokenType.DOUBLE,
|
|
||||||
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
|
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
|
||||||
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
|
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
|
||||||
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
|
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
|
||||||
|
@ -294,3 +293,10 @@ class Snowflake(Dialect):
|
||||||
)
|
)
|
||||||
return self.no_identify(lambda: super(self.__class__, self).select_sql(expression))
|
return self.no_identify(lambda: super(self.__class__, self).select_sql(expression))
|
||||||
return super().select_sql(expression)
|
return super().select_sql(expression)
|
||||||
|
|
||||||
|
def describe_sql(self, expression: exp.Describe) -> str:
|
||||||
|
# Default to table if kind is unknown
|
||||||
|
kind_value = expression.args.get("kind") or "TABLE"
|
||||||
|
kind = f" {kind_value}" if kind_value else ""
|
||||||
|
this = f" {self.sql(expression, 'this')}"
|
||||||
|
return f"DESCRIBE{kind}{this}"
|
||||||
|
|
|
@ -75,6 +75,20 @@ def _parse_format(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_eomonth(args):
|
||||||
|
date = seq_get(args, 0)
|
||||||
|
month_lag = seq_get(args, 1)
|
||||||
|
unit = DATE_DELTA_INTERVAL.get("month")
|
||||||
|
|
||||||
|
if month_lag is None:
|
||||||
|
return exp.LastDateOfMonth(this=date)
|
||||||
|
|
||||||
|
# Remove month lag argument in parser as its compared with the number of arguments of the resulting class
|
||||||
|
args.remove(month_lag)
|
||||||
|
|
||||||
|
return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))
|
||||||
|
|
||||||
|
|
||||||
def generate_date_delta_with_unit_sql(self, e):
|
def generate_date_delta_with_unit_sql(self, e):
|
||||||
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
|
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
|
||||||
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
|
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
|
||||||
|
@ -256,12 +270,14 @@ class TSQL(Dialect):
|
||||||
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
|
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
|
||||||
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
|
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
|
||||||
"DATEPART": _format_time_lambda(exp.TimeToStr),
|
"DATEPART": _format_time_lambda(exp.TimeToStr),
|
||||||
"GETDATE": exp.CurrentDate.from_arg_list,
|
"GETDATE": exp.CurrentTimestamp.from_arg_list,
|
||||||
|
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
|
||||||
"IIF": exp.If.from_arg_list,
|
"IIF": exp.If.from_arg_list,
|
||||||
"LEN": exp.Length.from_arg_list,
|
"LEN": exp.Length.from_arg_list,
|
||||||
"REPLICATE": exp.Repeat.from_arg_list,
|
"REPLICATE": exp.Repeat.from_arg_list,
|
||||||
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
|
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
|
||||||
"FORMAT": _parse_format,
|
"FORMAT": _parse_format,
|
||||||
|
"EOMONTH": _parse_eomonth,
|
||||||
}
|
}
|
||||||
|
|
||||||
VAR_LENGTH_DATATYPES = {
|
VAR_LENGTH_DATATYPES = {
|
||||||
|
@ -271,6 +287,9 @@ class TSQL(Dialect):
|
||||||
DataType.Type.NCHAR,
|
DataType.Type.NCHAR,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table
|
||||||
|
TABLE_PREFIX_TOKENS = {TokenType.HASH}
|
||||||
|
|
||||||
def _parse_convert(self, strict):
|
def _parse_convert(self, strict):
|
||||||
to = self._parse_types()
|
to = self._parse_types()
|
||||||
self._match(TokenType.COMMA)
|
self._match(TokenType.COMMA)
|
||||||
|
@ -323,6 +342,7 @@ class TSQL(Dialect):
|
||||||
exp.DateAdd: generate_date_delta_with_unit_sql,
|
exp.DateAdd: generate_date_delta_with_unit_sql,
|
||||||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||||
exp.CurrentDate: rename_func("GETDATE"),
|
exp.CurrentDate: rename_func("GETDATE"),
|
||||||
|
exp.CurrentTimestamp: rename_func("GETDATE"),
|
||||||
exp.If: rename_func("IIF"),
|
exp.If: rename_func("IIF"),
|
||||||
exp.NumberToStr: _format_sql,
|
exp.NumberToStr: _format_sql,
|
||||||
exp.TimeToStr: _format_sql,
|
exp.TimeToStr: _format_sql,
|
||||||
|
|
|
@ -22,6 +22,7 @@ from sqlglot.helper import (
|
||||||
split_num_words,
|
split_num_words,
|
||||||
subclasses,
|
subclasses,
|
||||||
)
|
)
|
||||||
|
from sqlglot.tokens import Token
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from sqlglot.dialects.dialect import Dialect
|
from sqlglot.dialects.dialect import Dialect
|
||||||
|
@ -457,6 +458,23 @@ class Expression(metaclass=_Expression):
|
||||||
assert isinstance(self, type_)
|
assert isinstance(self, type_)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def dump(self):
|
||||||
|
"""
|
||||||
|
Dump this Expression to a JSON-serializable dict.
|
||||||
|
"""
|
||||||
|
from sqlglot.serde import dump
|
||||||
|
|
||||||
|
return dump(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, obj):
|
||||||
|
"""
|
||||||
|
Load a dict (as returned by `Expression.dump`) into an Expression instance.
|
||||||
|
"""
|
||||||
|
from sqlglot.serde import load
|
||||||
|
|
||||||
|
return load(obj)
|
||||||
|
|
||||||
|
|
||||||
class Condition(Expression):
|
class Condition(Expression):
|
||||||
def and_(self, *expressions, dialect=None, **opts):
|
def and_(self, *expressions, dialect=None, **opts):
|
||||||
|
@ -631,11 +649,15 @@ class Create(Expression):
|
||||||
"replace": False,
|
"replace": False,
|
||||||
"unique": False,
|
"unique": False,
|
||||||
"materialized": False,
|
"materialized": False,
|
||||||
|
"data": False,
|
||||||
|
"statistics": False,
|
||||||
|
"no_primary_index": False,
|
||||||
|
"indexes": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Describe(Expression):
|
class Describe(Expression):
|
||||||
pass
|
arg_types = {"this": True, "kind": False}
|
||||||
|
|
||||||
|
|
||||||
class Set(Expression):
|
class Set(Expression):
|
||||||
|
@ -731,7 +753,7 @@ class Column(Condition):
|
||||||
class ColumnDef(Expression):
|
class ColumnDef(Expression):
|
||||||
arg_types = {
|
arg_types = {
|
||||||
"this": True,
|
"this": True,
|
||||||
"kind": True,
|
"kind": False,
|
||||||
"constraints": False,
|
"constraints": False,
|
||||||
"exists": False,
|
"exists": False,
|
||||||
}
|
}
|
||||||
|
@ -879,7 +901,15 @@ class Identifier(Expression):
|
||||||
|
|
||||||
|
|
||||||
class Index(Expression):
|
class Index(Expression):
|
||||||
arg_types = {"this": False, "table": False, "where": False, "columns": False}
|
arg_types = {
|
||||||
|
"this": False,
|
||||||
|
"table": False,
|
||||||
|
"where": False,
|
||||||
|
"columns": False,
|
||||||
|
"unique": False,
|
||||||
|
"primary": False,
|
||||||
|
"amp": False, # teradata
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Insert(Expression):
|
class Insert(Expression):
|
||||||
|
@ -1361,6 +1391,7 @@ class Table(Expression):
|
||||||
"laterals": False,
|
"laterals": False,
|
||||||
"joins": False,
|
"joins": False,
|
||||||
"pivots": False,
|
"pivots": False,
|
||||||
|
"hints": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1818,7 +1849,12 @@ class Select(Subqueryable):
|
||||||
join.this.replace(join.this.subquery())
|
join.this.replace(join.this.subquery())
|
||||||
|
|
||||||
if join_type:
|
if join_type:
|
||||||
|
natural: t.Optional[Token]
|
||||||
|
side: t.Optional[Token]
|
||||||
|
kind: t.Optional[Token]
|
||||||
|
|
||||||
natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore
|
natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore
|
||||||
|
|
||||||
if natural:
|
if natural:
|
||||||
join.set("natural", True)
|
join.set("natural", True)
|
||||||
if side:
|
if side:
|
||||||
|
@ -2111,6 +2147,7 @@ class DataType(Expression):
|
||||||
JSON = auto()
|
JSON = auto()
|
||||||
JSONB = auto()
|
JSONB = auto()
|
||||||
INTERVAL = auto()
|
INTERVAL = auto()
|
||||||
|
TIME = auto()
|
||||||
TIMESTAMP = auto()
|
TIMESTAMP = auto()
|
||||||
TIMESTAMPTZ = auto()
|
TIMESTAMPTZ = auto()
|
||||||
TIMESTAMPLTZ = auto()
|
TIMESTAMPLTZ = auto()
|
||||||
|
@ -2171,11 +2208,24 @@ class DataType(Expression):
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(cls, dtype, **kwargs) -> DataType:
|
def build(
|
||||||
return DataType(
|
cls, dtype: str | DataType.Type, dialect: t.Optional[str | Dialect] = None, **kwargs
|
||||||
this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()],
|
) -> DataType:
|
||||||
**kwargs,
|
from sqlglot import parse_one
|
||||||
)
|
|
||||||
|
if isinstance(dtype, str):
|
||||||
|
data_type_exp: t.Optional[Expression]
|
||||||
|
if dtype.upper() in cls.Type.__members__:
|
||||||
|
data_type_exp = DataType(this=DataType.Type[dtype.upper()])
|
||||||
|
else:
|
||||||
|
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
|
||||||
|
if data_type_exp is None:
|
||||||
|
raise ValueError(f"Unparsable data type value: {dtype}")
|
||||||
|
elif isinstance(dtype, DataType.Type):
|
||||||
|
data_type_exp = DataType(this=dtype)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
|
||||||
|
return DataType(**{**data_type_exp.args, **kwargs})
|
||||||
|
|
||||||
|
|
||||||
# https://www.postgresql.org/docs/15/datatype-pseudo.html
|
# https://www.postgresql.org/docs/15/datatype-pseudo.html
|
||||||
|
@ -2429,6 +2479,7 @@ class In(Predicate):
|
||||||
"query": False,
|
"query": False,
|
||||||
"unnest": False,
|
"unnest": False,
|
||||||
"field": False,
|
"field": False,
|
||||||
|
"is_global": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2678,6 +2729,10 @@ class DatetimeTrunc(Func, TimeUnit):
|
||||||
arg_types = {"this": True, "unit": True, "zone": False}
|
arg_types = {"this": True, "unit": True, "zone": False}
|
||||||
|
|
||||||
|
|
||||||
|
class LastDateOfMonth(Func):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Extract(Func):
|
class Extract(Func):
|
||||||
arg_types = {"this": True, "expression": True}
|
arg_types = {"this": True, "expression": True}
|
||||||
|
|
||||||
|
@ -2815,7 +2870,13 @@ class Length(Func):
|
||||||
|
|
||||||
|
|
||||||
class Levenshtein(Func):
|
class Levenshtein(Func):
|
||||||
arg_types = {"this": True, "expression": False}
|
arg_types = {
|
||||||
|
"this": True,
|
||||||
|
"expression": False,
|
||||||
|
"ins_cost": False,
|
||||||
|
"del_cost": False,
|
||||||
|
"sub_cost": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Ln(Func):
|
class Ln(Func):
|
||||||
|
@ -2890,6 +2951,16 @@ class Quantile(AggFunc):
|
||||||
arg_types = {"this": True, "quantile": True}
|
arg_types = {"this": True, "quantile": True}
|
||||||
|
|
||||||
|
|
||||||
|
# Clickhouse-specific:
|
||||||
|
# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/quantiles/#quantiles
|
||||||
|
class Quantiles(AggFunc):
|
||||||
|
arg_types = {"parameters": True, "expressions": True}
|
||||||
|
|
||||||
|
|
||||||
|
class QuantileIf(AggFunc):
|
||||||
|
arg_types = {"parameters": True, "expressions": True}
|
||||||
|
|
||||||
|
|
||||||
class ApproxQuantile(Quantile):
|
class ApproxQuantile(Quantile):
|
||||||
arg_types = {"this": True, "quantile": True, "accuracy": False}
|
arg_types = {"this": True, "quantile": True, "accuracy": False}
|
||||||
|
|
||||||
|
@ -2962,8 +3033,10 @@ class StrToTime(Func):
|
||||||
arg_types = {"this": True, "format": True}
|
arg_types = {"this": True, "format": True}
|
||||||
|
|
||||||
|
|
||||||
|
# Spark allows unix_timestamp()
|
||||||
|
# https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.unix_timestamp.html
|
||||||
class StrToUnix(Func):
|
class StrToUnix(Func):
|
||||||
arg_types = {"this": True, "format": True}
|
arg_types = {"this": False, "format": False}
|
||||||
|
|
||||||
|
|
||||||
class NumberToStr(Func):
|
class NumberToStr(Func):
|
||||||
|
@ -3131,7 +3204,7 @@ def maybe_parse(
|
||||||
dialect=None,
|
dialect=None,
|
||||||
prefix=None,
|
prefix=None,
|
||||||
**opts,
|
**opts,
|
||||||
) -> t.Optional[Expression]:
|
) -> Expression:
|
||||||
"""Gracefully handle a possible string or expression.
|
"""Gracefully handle a possible string or expression.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
@ -3627,11 +3700,11 @@ def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
|
||||||
if not isinstance(sql_path, str):
|
if not isinstance(sql_path, str):
|
||||||
raise ValueError(f"Invalid type provided for a table: {type(sql_path)}")
|
raise ValueError(f"Invalid type provided for a table: {type(sql_path)}")
|
||||||
|
|
||||||
catalog, db, table_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 3)]
|
catalog, db, table_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 3))
|
||||||
return Table(this=table_name, db=db, catalog=catalog, **kwargs)
|
return Table(this=table_name, db=db, catalog=catalog, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def to_column(sql_path: str, **kwargs) -> Column:
|
def to_column(sql_path: str | Column, **kwargs) -> Column:
|
||||||
"""
|
"""
|
||||||
Create a column from a `[table].[column]` sql path. Schema is optional.
|
Create a column from a `[table].[column]` sql path. Schema is optional.
|
||||||
|
|
||||||
|
@ -3646,7 +3719,7 @@ def to_column(sql_path: str, **kwargs) -> Column:
|
||||||
return sql_path
|
return sql_path
|
||||||
if not isinstance(sql_path, str):
|
if not isinstance(sql_path, str):
|
||||||
raise ValueError(f"Invalid type provided for column: {type(sql_path)}")
|
raise ValueError(f"Invalid type provided for column: {type(sql_path)}")
|
||||||
table_name, column_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 2)]
|
table_name, column_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 2))
|
||||||
return Column(this=column_name, table=table_name, **kwargs)
|
return Column(this=column_name, table=table_name, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -3748,7 +3821,7 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
|
||||||
def values(
|
def values(
|
||||||
values: t.Iterable[t.Tuple[t.Any, ...]],
|
values: t.Iterable[t.Tuple[t.Any, ...]],
|
||||||
alias: t.Optional[str] = None,
|
alias: t.Optional[str] = None,
|
||||||
columns: t.Optional[t.Iterable[str]] = None,
|
columns: t.Optional[t.Iterable[str] | t.Dict[str, DataType]] = None,
|
||||||
) -> Values:
|
) -> Values:
|
||||||
"""Build VALUES statement.
|
"""Build VALUES statement.
|
||||||
|
|
||||||
|
@ -3759,7 +3832,10 @@ def values(
|
||||||
Args:
|
Args:
|
||||||
values: values statements that will be converted to SQL
|
values: values statements that will be converted to SQL
|
||||||
alias: optional alias
|
alias: optional alias
|
||||||
columns: Optional list of ordered column names. An alias is required when providing column names.
|
columns: Optional list of ordered column names or ordered dictionary of column names to types.
|
||||||
|
If either are provided then an alias is also required.
|
||||||
|
If a dictionary is provided then the first column of the values will be casted to the expected type
|
||||||
|
in order to help with type inference.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Values: the Values expression object
|
Values: the Values expression object
|
||||||
|
@ -3771,8 +3847,15 @@ def values(
|
||||||
if columns
|
if columns
|
||||||
else TableAlias(this=to_identifier(alias) if alias else None)
|
else TableAlias(this=to_identifier(alias) if alias else None)
|
||||||
)
|
)
|
||||||
|
expressions = [convert(tup) for tup in values]
|
||||||
|
if columns and isinstance(columns, dict):
|
||||||
|
types = list(columns.values())
|
||||||
|
expressions[0].set(
|
||||||
|
"expressions",
|
||||||
|
[Cast(this=x, to=types[i]) for i, x in enumerate(expressions[0].expressions)],
|
||||||
|
)
|
||||||
return Values(
|
return Values(
|
||||||
expressions=[convert(tup) for tup in values],
|
expressions=expressions,
|
||||||
alias=table_alias,
|
alias=table_alias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,7 @@ class Generator:
|
||||||
The default is on the smaller end because the length only represents a segment and not the true
|
The default is on the smaller end because the length only represents a segment and not the true
|
||||||
line length.
|
line length.
|
||||||
Default: 80
|
Default: 80
|
||||||
comments: Whether or not to preserve comments in the ouput SQL code.
|
comments: Whether or not to preserve comments in the output SQL code.
|
||||||
Default: True
|
Default: True
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -236,7 +236,10 @@ class Generator:
|
||||||
return sql
|
return sql
|
||||||
|
|
||||||
sep = "\n" if self.pretty else " "
|
sep = "\n" if self.pretty else " "
|
||||||
comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments)
|
comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments if comment)
|
||||||
|
|
||||||
|
if not comments:
|
||||||
|
return sql
|
||||||
|
|
||||||
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
|
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
|
||||||
return f"{comments}{self.sep()}{sql}"
|
return f"{comments}{self.sep()}{sql}"
|
||||||
|
@ -362,10 +365,10 @@ class Generator:
|
||||||
kind = self.sql(expression, "kind")
|
kind = self.sql(expression, "kind")
|
||||||
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
|
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
|
||||||
exists = "IF NOT EXISTS " if expression.args.get("exists") else ""
|
exists = "IF NOT EXISTS " if expression.args.get("exists") else ""
|
||||||
|
kind = f" {kind}" if kind else ""
|
||||||
|
constraints = f" {constraints}" if constraints else ""
|
||||||
|
|
||||||
if not constraints:
|
return f"{exists}{column}{kind}{constraints}"
|
||||||
return f"{exists}{column} {kind}"
|
|
||||||
return f"{exists}{column} {kind} {constraints}"
|
|
||||||
|
|
||||||
def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
|
def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
|
@ -416,7 +419,7 @@ class Generator:
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
kind = self.sql(expression, "kind").upper()
|
kind = self.sql(expression, "kind").upper()
|
||||||
expression_sql = self.sql(expression, "expression")
|
expression_sql = self.sql(expression, "expression")
|
||||||
expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else ""
|
expression_sql = f" AS{self.sep()}{expression_sql}" if expression_sql else ""
|
||||||
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
|
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
|
||||||
transient = (
|
transient = (
|
||||||
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
|
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
|
||||||
|
@ -427,6 +430,40 @@ class Generator:
|
||||||
unique = " UNIQUE" if expression.args.get("unique") else ""
|
unique = " UNIQUE" if expression.args.get("unique") else ""
|
||||||
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
|
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
|
||||||
properties = self.sql(expression, "properties")
|
properties = self.sql(expression, "properties")
|
||||||
|
data = expression.args.get("data")
|
||||||
|
if data is None:
|
||||||
|
data = ""
|
||||||
|
elif data:
|
||||||
|
data = " WITH DATA"
|
||||||
|
else:
|
||||||
|
data = " WITH NO DATA"
|
||||||
|
statistics = expression.args.get("statistics")
|
||||||
|
if statistics is None:
|
||||||
|
statistics = ""
|
||||||
|
elif statistics:
|
||||||
|
statistics = " AND STATISTICS"
|
||||||
|
else:
|
||||||
|
statistics = " AND NO STATISTICS"
|
||||||
|
no_primary_index = " NO PRIMARY INDEX" if expression.args.get("no_primary_index") else ""
|
||||||
|
|
||||||
|
indexes = expression.args.get("indexes")
|
||||||
|
index_sql = ""
|
||||||
|
if indexes is not None:
|
||||||
|
indexes_sql = []
|
||||||
|
for index in indexes:
|
||||||
|
ind_unique = " UNIQUE" if index.args.get("unique") else ""
|
||||||
|
ind_primary = " PRIMARY" if index.args.get("primary") else ""
|
||||||
|
ind_amp = " AMP" if index.args.get("amp") else ""
|
||||||
|
ind_name = f" {index.name}" if index.name else ""
|
||||||
|
ind_columns = (
|
||||||
|
f' ({self.expressions(index, key="columns", flat=True)})'
|
||||||
|
if index.args.get("columns")
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
indexes_sql.append(
|
||||||
|
f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}"
|
||||||
|
)
|
||||||
|
index_sql = "".join(indexes_sql)
|
||||||
|
|
||||||
modifiers = "".join(
|
modifiers = "".join(
|
||||||
(
|
(
|
||||||
|
@ -438,7 +475,10 @@ class Generator:
|
||||||
materialized,
|
materialized,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties} {expression_sql}"
|
|
||||||
|
post_expression_modifiers = "".join((data, statistics, no_primary_index))
|
||||||
|
|
||||||
|
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}"
|
||||||
return self.prepend_ctes(expression, expression_sql)
|
return self.prepend_ctes(expression, expression_sql)
|
||||||
|
|
||||||
def describe_sql(self, expression: exp.Describe) -> str:
|
def describe_sql(self, expression: exp.Describe) -> str:
|
||||||
|
@ -668,6 +708,8 @@ class Generator:
|
||||||
|
|
||||||
alias = self.sql(expression, "alias")
|
alias = self.sql(expression, "alias")
|
||||||
alias = f"{sep}{alias}" if alias else ""
|
alias = f"{sep}{alias}" if alias else ""
|
||||||
|
hints = self.expressions(expression, key="hints", sep=", ", flat=True)
|
||||||
|
hints = f" WITH ({hints})" if hints else ""
|
||||||
laterals = self.expressions(expression, key="laterals", sep="")
|
laterals = self.expressions(expression, key="laterals", sep="")
|
||||||
joins = self.expressions(expression, key="joins", sep="")
|
joins = self.expressions(expression, key="joins", sep="")
|
||||||
pivots = self.expressions(expression, key="pivots", sep="")
|
pivots = self.expressions(expression, key="pivots", sep="")
|
||||||
|
@ -676,7 +718,7 @@ class Generator:
|
||||||
pivots = f"{pivots}{alias}"
|
pivots = f"{pivots}{alias}"
|
||||||
alias = ""
|
alias = ""
|
||||||
|
|
||||||
return f"{table}{alias}{laterals}{joins}{pivots}"
|
return f"{table}{alias}{hints}{laterals}{joins}{pivots}"
|
||||||
|
|
||||||
def tablesample_sql(self, expression: exp.TableSample) -> str:
|
def tablesample_sql(self, expression: exp.TableSample) -> str:
|
||||||
if self.alias_post_tablesample and expression.this.alias:
|
if self.alias_post_tablesample and expression.this.alias:
|
||||||
|
@ -1020,7 +1062,9 @@ class Generator:
|
||||||
if not partition and not order and not spec and alias:
|
if not partition and not order and not spec and alias:
|
||||||
return f"{this} {alias}"
|
return f"{this} {alias}"
|
||||||
|
|
||||||
return f"{this} ({alias}{partition_sql}{order_sql}{spec_sql})"
|
window_args = alias + partition_sql + order_sql + spec_sql
|
||||||
|
|
||||||
|
return f"{this} ({window_args.strip()})"
|
||||||
|
|
||||||
def window_spec_sql(self, expression: exp.WindowSpec) -> str:
|
def window_spec_sql(self, expression: exp.WindowSpec) -> str:
|
||||||
kind = self.sql(expression, "kind")
|
kind = self.sql(expression, "kind")
|
||||||
|
@ -1130,6 +1174,8 @@ class Generator:
|
||||||
query = expression.args.get("query")
|
query = expression.args.get("query")
|
||||||
unnest = expression.args.get("unnest")
|
unnest = expression.args.get("unnest")
|
||||||
field = expression.args.get("field")
|
field = expression.args.get("field")
|
||||||
|
is_global = " GLOBAL" if expression.args.get("is_global") else ""
|
||||||
|
|
||||||
if query:
|
if query:
|
||||||
in_sql = self.wrap(query)
|
in_sql = self.wrap(query)
|
||||||
elif unnest:
|
elif unnest:
|
||||||
|
@ -1138,7 +1184,8 @@ class Generator:
|
||||||
in_sql = self.sql(field)
|
in_sql = self.sql(field)
|
||||||
else:
|
else:
|
||||||
in_sql = f"({self.expressions(expression, flat=True)})"
|
in_sql = f"({self.expressions(expression, flat=True)})"
|
||||||
return f"{self.sql(expression, 'this')} IN {in_sql}"
|
|
||||||
|
return f"{self.sql(expression, 'this')}{is_global} IN {in_sql}"
|
||||||
|
|
||||||
def in_unnest_op(self, unnest: exp.Unnest) -> str:
|
def in_unnest_op(self, unnest: exp.Unnest) -> str:
|
||||||
return f"(SELECT {self.sql(unnest)})"
|
return f"(SELECT {self.sql(unnest)})"
|
||||||
|
@ -1433,7 +1480,7 @@ class Generator:
|
||||||
result_sqls = []
|
result_sqls = []
|
||||||
for i, e in enumerate(expressions):
|
for i, e in enumerate(expressions):
|
||||||
sql = self.sql(e, comment=False)
|
sql = self.sql(e, comment=False)
|
||||||
comments = self.maybe_comment("", e)
|
comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else ""
|
||||||
|
|
||||||
if self.pretty:
|
if self.pretty:
|
||||||
if self._leading_comma:
|
if self._leading_comma:
|
||||||
|
|
|
@ -131,7 +131,7 @@ def subclasses(
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
|
def apply_index_offset(expressions: t.List[t.Optional[E]], offset: int) -> t.List[t.Optional[E]]:
|
||||||
"""
|
"""
|
||||||
Applies an offset to a given integer literal expression.
|
Applies an offset to a given integer literal expression.
|
||||||
|
|
||||||
|
@ -148,10 +148,10 @@ def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
|
||||||
|
|
||||||
expression = expressions[0]
|
expression = expressions[0]
|
||||||
|
|
||||||
if expression.is_int:
|
if expression and expression.is_int:
|
||||||
expression = expression.copy()
|
expression = expression.copy()
|
||||||
logger.warning("Applying array index offset (%s)", offset)
|
logger.warning("Applying array index offset (%s)", offset)
|
||||||
expression.args["this"] = str(int(expression.this) + offset)
|
expression.args["this"] = str(int(expression.this) + offset) # type: ignore
|
||||||
return [expression]
|
return [expression]
|
||||||
|
|
||||||
return expressions
|
return expressions
|
||||||
|
@ -225,7 +225,7 @@ def open_file(file_name: str) -> t.TextIO:
|
||||||
|
|
||||||
return gzip.open(file_name, "rt", newline="")
|
return gzip.open(file_name, "rt", newline="")
|
||||||
|
|
||||||
return open(file_name, "rt", encoding="utf-8", newline="")
|
return open(file_name, encoding="utf-8", newline="")
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -256,7 +256,7 @@ def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
|
||||||
file.close()
|
file.close()
|
||||||
|
|
||||||
|
|
||||||
def find_new_name(taken: t.Sequence[str], base: str) -> str:
|
def find_new_name(taken: t.Collection[str], base: str) -> str:
|
||||||
"""
|
"""
|
||||||
Searches for a new name.
|
Searches for a new name.
|
||||||
|
|
||||||
|
@ -356,6 +356,15 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any,
|
||||||
yield value
|
yield value
|
||||||
|
|
||||||
|
|
||||||
|
def count_params(function: t.Callable) -> int:
|
||||||
|
"""
|
||||||
|
Returns the number of formal parameters expected by a function, without counting "self"
|
||||||
|
and "cls", in case of instance and class methods, respectively.
|
||||||
|
"""
|
||||||
|
count = function.__code__.co_argcount
|
||||||
|
return count - 1 if inspect.ismethod(function) else count
|
||||||
|
|
||||||
|
|
||||||
def dict_depth(d: t.Dict) -> int:
|
def dict_depth(d: t.Dict) -> int:
|
||||||
"""
|
"""
|
||||||
Get the nesting depth of a dictionary.
|
Get the nesting depth of a dictionary.
|
||||||
|
@ -374,6 +383,7 @@ def dict_depth(d: t.Dict) -> int:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
d (dict): dictionary
|
d (dict): dictionary
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: depth
|
int: depth
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -43,7 +43,7 @@ class TypeAnnotator:
|
||||||
},
|
},
|
||||||
exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
|
exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
|
||||||
exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
|
exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
|
||||||
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr),
|
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
|
||||||
exp.Alias: lambda self, expr: self._annotate_unary(expr),
|
exp.Alias: lambda self, expr: self._annotate_unary(expr),
|
||||||
exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||||
exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||||
|
|
|
@ -57,7 +57,7 @@ def _join_is_used(scope, join, alias):
|
||||||
# But columns in the ON clause shouldn't count.
|
# But columns in the ON clause shouldn't count.
|
||||||
on = join.args.get("on")
|
on = join.args.get("on")
|
||||||
if on:
|
if on:
|
||||||
on_clause_columns = set(id(column) for column in on.find_all(exp.Column))
|
on_clause_columns = {id(column) for column in on.find_all(exp.Column)}
|
||||||
else:
|
else:
|
||||||
on_clause_columns = set()
|
on_clause_columns = set()
|
||||||
return any(
|
return any(
|
||||||
|
@ -71,7 +71,7 @@ def _is_joined_on_all_unique_outputs(scope, join):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
_, join_keys, _ = join_condition(join)
|
_, join_keys, _ = join_condition(join)
|
||||||
remaining_unique_outputs = unique_outputs - set(c.name for c in join_keys)
|
remaining_unique_outputs = unique_outputs - {c.name for c in join_keys}
|
||||||
return not remaining_unique_outputs
|
return not remaining_unique_outputs
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -67,11 +67,9 @@ def merge_ctes(expression, leave_tables_isolated=False):
|
||||||
|
|
||||||
singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
|
singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
|
||||||
for outer_scope, inner_scope, table in singular_cte_selections:
|
for outer_scope, inner_scope, table in singular_cte_selections:
|
||||||
inner_select = inner_scope.expression.unnest()
|
|
||||||
from_or_join = table.find_ancestor(exp.From, exp.Join)
|
from_or_join = table.find_ancestor(exp.From, exp.Join)
|
||||||
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
||||||
alias = table.alias_or_name
|
alias = table.alias_or_name
|
||||||
|
|
||||||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||||
_merge_from(outer_scope, inner_scope, table, alias)
|
_merge_from(outer_scope, inner_scope, table, alias)
|
||||||
_merge_expressions(outer_scope, inner_scope, alias)
|
_merge_expressions(outer_scope, inner_scope, alias)
|
||||||
|
@ -80,18 +78,17 @@ def merge_ctes(expression, leave_tables_isolated=False):
|
||||||
_merge_order(outer_scope, inner_scope)
|
_merge_order(outer_scope, inner_scope)
|
||||||
_merge_hints(outer_scope, inner_scope)
|
_merge_hints(outer_scope, inner_scope)
|
||||||
_pop_cte(inner_scope)
|
_pop_cte(inner_scope)
|
||||||
|
outer_scope.clear_cache()
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
|
|
||||||
def merge_derived_tables(expression, leave_tables_isolated=False):
|
def merge_derived_tables(expression, leave_tables_isolated=False):
|
||||||
for outer_scope in traverse_scope(expression):
|
for outer_scope in traverse_scope(expression):
|
||||||
for subquery in outer_scope.derived_tables:
|
for subquery in outer_scope.derived_tables:
|
||||||
inner_select = subquery.unnest()
|
|
||||||
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
|
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
|
||||||
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
|
||||||
alias = subquery.alias_or_name
|
alias = subquery.alias_or_name
|
||||||
inner_scope = outer_scope.sources[alias]
|
inner_scope = outer_scope.sources[alias]
|
||||||
|
if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
||||||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||||
_merge_from(outer_scope, inner_scope, subquery, alias)
|
_merge_from(outer_scope, inner_scope, subquery, alias)
|
||||||
_merge_expressions(outer_scope, inner_scope, alias)
|
_merge_expressions(outer_scope, inner_scope, alias)
|
||||||
|
@ -99,21 +96,23 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
|
||||||
_merge_where(outer_scope, inner_scope, from_or_join)
|
_merge_where(outer_scope, inner_scope, from_or_join)
|
||||||
_merge_order(outer_scope, inner_scope)
|
_merge_order(outer_scope, inner_scope)
|
||||||
_merge_hints(outer_scope, inner_scope)
|
_merge_hints(outer_scope, inner_scope)
|
||||||
|
outer_scope.clear_cache()
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
|
|
||||||
def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
||||||
"""
|
"""
|
||||||
Return True if `inner_select` can be merged into outer query.
|
Return True if `inner_select` can be merged into outer query.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
outer_scope (Scope)
|
outer_scope (Scope)
|
||||||
inner_select (exp.Select)
|
inner_scope (Scope)
|
||||||
leave_tables_isolated (bool)
|
leave_tables_isolated (bool)
|
||||||
from_or_join (exp.From|exp.Join)
|
from_or_join (exp.From|exp.Join)
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if can be merged
|
bool: True if can be merged
|
||||||
"""
|
"""
|
||||||
|
inner_select = inner_scope.expression.unnest()
|
||||||
|
|
||||||
def _is_a_window_expression_in_unmergable_operation():
|
def _is_a_window_expression_in_unmergable_operation():
|
||||||
window_expressions = inner_select.find_all(exp.Window)
|
window_expressions = inner_select.find_all(exp.Window)
|
||||||
|
@ -133,10 +132,40 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
||||||
]
|
]
|
||||||
return any(window_expressions_in_unmergable)
|
return any(window_expressions_in_unmergable)
|
||||||
|
|
||||||
|
def _outer_select_joins_on_inner_select_join():
|
||||||
|
"""
|
||||||
|
All columns from the inner select in the ON clause must be from the first FROM table.
|
||||||
|
|
||||||
|
That is, this can be merged:
|
||||||
|
SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a
|
||||||
|
^^^ ^
|
||||||
|
But this can't:
|
||||||
|
SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a
|
||||||
|
^^^ ^
|
||||||
|
"""
|
||||||
|
if not isinstance(from_or_join, exp.Join):
|
||||||
|
return False
|
||||||
|
|
||||||
|
alias = from_or_join.this.alias_or_name
|
||||||
|
|
||||||
|
on = from_or_join.args.get("on")
|
||||||
|
if not on:
|
||||||
|
return False
|
||||||
|
selections = [c.name for c in on.find_all(exp.Column) if c.table == alias]
|
||||||
|
inner_from = inner_scope.expression.args.get("from")
|
||||||
|
if not inner_from:
|
||||||
|
return False
|
||||||
|
inner_from_table = inner_from.expressions[0].alias_or_name
|
||||||
|
inner_projections = {s.alias_or_name: s for s in inner_scope.selects}
|
||||||
|
return any(
|
||||||
|
col.table != inner_from_table
|
||||||
|
for selection in selections
|
||||||
|
for col in inner_projections[selection].find_all(exp.Column)
|
||||||
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
isinstance(outer_scope.expression, exp.Select)
|
isinstance(outer_scope.expression, exp.Select)
|
||||||
and isinstance(inner_select, exp.Select)
|
and isinstance(inner_select, exp.Select)
|
||||||
and isinstance(inner_select, exp.Select)
|
|
||||||
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
|
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
|
||||||
and inner_select.args.get("from")
|
and inner_select.args.get("from")
|
||||||
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
|
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
|
||||||
|
@ -153,6 +182,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
||||||
j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
|
j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
and not _outer_select_joins_on_inner_select_join()
|
||||||
and not _is_a_window_expression_in_unmergable_operation()
|
and not _is_a_window_expression_in_unmergable_operation()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -168,7 +198,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
|
||||||
"""
|
"""
|
||||||
taken = set(outer_scope.selected_sources)
|
taken = set(outer_scope.selected_sources)
|
||||||
conflicts = taken.intersection(set(inner_scope.selected_sources))
|
conflicts = taken.intersection(set(inner_scope.selected_sources))
|
||||||
conflicts = conflicts - {alias}
|
conflicts -= {alias}
|
||||||
|
|
||||||
for conflict in conflicts:
|
for conflict in conflicts:
|
||||||
new_name = find_new_name(taken, conflict)
|
new_name = find_new_name(taken, conflict)
|
||||||
|
|
|
@ -15,6 +15,7 @@ from sqlglot.optimizer.pushdown_projections import pushdown_projections
|
||||||
from sqlglot.optimizer.qualify_columns import qualify_columns
|
from sqlglot.optimizer.qualify_columns import qualify_columns
|
||||||
from sqlglot.optimizer.qualify_tables import qualify_tables
|
from sqlglot.optimizer.qualify_tables import qualify_tables
|
||||||
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
|
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
|
||||||
|
from sqlglot.schema import ensure_schema
|
||||||
|
|
||||||
RULES = (
|
RULES = (
|
||||||
lower_identities,
|
lower_identities,
|
||||||
|
@ -51,12 +52,13 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
|
||||||
If no schema is provided then the default schema defined at `sqlgot.schema` will be used
|
If no schema is provided then the default schema defined at `sqlgot.schema` will be used
|
||||||
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
|
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
|
||||||
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
|
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
|
||||||
rules (list): sequence of optimizer rules to use
|
rules (sequence): sequence of optimizer rules to use
|
||||||
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
|
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
|
||||||
Returns:
|
Returns:
|
||||||
sqlglot.Expression: optimized expression
|
sqlglot.Expression: optimized expression
|
||||||
"""
|
"""
|
||||||
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema or sqlglot.schema, **kwargs}
|
schema = ensure_schema(schema or sqlglot.schema)
|
||||||
|
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
|
||||||
expression = expression.copy()
|
expression = expression.copy()
|
||||||
for rule in rules:
|
for rule in rules:
|
||||||
|
|
||||||
|
|
|
@ -79,6 +79,7 @@ def _remove_unused_selections(scope, parent_selections):
|
||||||
order_refs = set()
|
order_refs = set()
|
||||||
|
|
||||||
new_selections = []
|
new_selections = []
|
||||||
|
removed = False
|
||||||
for i, selection in enumerate(scope.selects):
|
for i, selection in enumerate(scope.selects):
|
||||||
if (
|
if (
|
||||||
SELECT_ALL in parent_selections
|
SELECT_ALL in parent_selections
|
||||||
|
@ -88,12 +89,15 @@ def _remove_unused_selections(scope, parent_selections):
|
||||||
new_selections.append(selection)
|
new_selections.append(selection)
|
||||||
else:
|
else:
|
||||||
removed_indexes.append(i)
|
removed_indexes.append(i)
|
||||||
|
removed = True
|
||||||
|
|
||||||
# If there are no remaining selections, just select a single constant
|
# If there are no remaining selections, just select a single constant
|
||||||
if not new_selections:
|
if not new_selections:
|
||||||
new_selections.append(DEFAULT_SELECTION.copy())
|
new_selections.append(DEFAULT_SELECTION.copy())
|
||||||
|
|
||||||
scope.expression.set("expressions", new_selections)
|
scope.expression.set("expressions", new_selections)
|
||||||
|
if removed:
|
||||||
|
scope.clear_cache()
|
||||||
return removed_indexes
|
return removed_indexes
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -365,9 +365,9 @@ class _Resolver:
|
||||||
def all_columns(self):
|
def all_columns(self):
|
||||||
"""All available columns of all sources in this scope"""
|
"""All available columns of all sources in this scope"""
|
||||||
if self._all_columns is None:
|
if self._all_columns is None:
|
||||||
self._all_columns = set(
|
self._all_columns = {
|
||||||
column for columns in self._get_all_source_columns().values() for column in columns
|
column for columns in self._get_all_source_columns().values() for column in columns
|
||||||
)
|
}
|
||||||
return self._all_columns
|
return self._all_columns
|
||||||
|
|
||||||
def get_source_columns(self, name, only_visible=False):
|
def get_source_columns(self, name, only_visible=False):
|
||||||
|
|
|
@ -361,7 +361,7 @@ def _simplify_binary(expression, a, b):
|
||||||
return boolean
|
return boolean
|
||||||
elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
|
elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
|
||||||
a, b = extract_date(a), extract_interval(b)
|
a, b = extract_date(a), extract_interval(b)
|
||||||
if b:
|
if a and b:
|
||||||
if isinstance(expression, exp.Add):
|
if isinstance(expression, exp.Add):
|
||||||
return date_literal(a + b)
|
return date_literal(a + b)
|
||||||
if isinstance(expression, exp.Sub):
|
if isinstance(expression, exp.Sub):
|
||||||
|
@ -369,7 +369,7 @@ def _simplify_binary(expression, a, b):
|
||||||
elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
|
elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
|
||||||
a, b = extract_interval(a), extract_date(b)
|
a, b = extract_interval(a), extract_date(b)
|
||||||
# you cannot subtract a date from an interval
|
# you cannot subtract a date from an interval
|
||||||
if a and isinstance(expression, exp.Add):
|
if a and b and isinstance(expression, exp.Add):
|
||||||
return date_literal(a + b)
|
return date_literal(a + b)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
@ -424,8 +424,14 @@ def eval_boolean(expression, a, b):
|
||||||
|
|
||||||
|
|
||||||
def extract_date(cast):
|
def extract_date(cast):
|
||||||
|
# The "fromisoformat" conversion could fail if the cast is used on an identifier,
|
||||||
|
# so in that case we can't extract the date.
|
||||||
|
try:
|
||||||
if cast.args["to"].this == exp.DataType.Type.DATE:
|
if cast.args["to"].this == exp.DataType.Type.DATE:
|
||||||
return datetime.date.fromisoformat(cast.name)
|
return datetime.date.fromisoformat(cast.name)
|
||||||
|
if cast.args["to"].this == exp.DataType.Type.DATETIME:
|
||||||
|
return datetime.datetime.fromisoformat(cast.name)
|
||||||
|
except ValueError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -450,7 +456,8 @@ def extract_interval(interval):
|
||||||
|
|
||||||
|
|
||||||
def date_literal(date):
|
def date_literal(date):
|
||||||
return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE"))
|
expr_type = exp.DataType.build("DATETIME" if isinstance(date, datetime.datetime) else "DATE")
|
||||||
|
return exp.Cast(this=exp.Literal.string(date), to=expr_type)
|
||||||
|
|
||||||
|
|
||||||
def boolean_literal(condition):
|
def boolean_literal(condition):
|
||||||
|
|
|
@ -15,8 +15,7 @@ def unnest_subqueries(expression):
|
||||||
>>> import sqlglot
|
>>> import sqlglot
|
||||||
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
|
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
|
||||||
>>> unnest_subqueries(expression).sql()
|
>>> unnest_subqueries(expression).sql()
|
||||||
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
|
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
|
||||||
AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)'
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
expression (sqlglot.Expression): expression to unnest
|
expression (sqlglot.Expression): expression to unnest
|
||||||
|
@ -173,10 +172,8 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
||||||
other = _other_operand(parent_predicate)
|
other = _other_operand(parent_predicate)
|
||||||
|
|
||||||
if isinstance(parent_predicate, exp.Exists):
|
if isinstance(parent_predicate, exp.Exists):
|
||||||
if value.this in group_by:
|
alias = exp.column(list(key_aliases.values())[0], table_alias)
|
||||||
parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
|
parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
|
||||||
else:
|
|
||||||
parent_predicate = _replace(parent_predicate, "TRUE")
|
|
||||||
elif isinstance(parent_predicate, exp.All):
|
elif isinstance(parent_predicate, exp.All):
|
||||||
parent_predicate = _replace(
|
parent_predicate = _replace(
|
||||||
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
|
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
|
||||||
|
@ -197,6 +194,23 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
||||||
else:
|
else:
|
||||||
if is_subquery_projection:
|
if is_subquery_projection:
|
||||||
alias = exp.alias_(alias, select.parent.alias)
|
alias = exp.alias_(alias, select.parent.alias)
|
||||||
|
|
||||||
|
# COUNT always returns 0 on empty datasets, so we need take that into consideration here
|
||||||
|
# by transforming all counts into 0 and using that as the coalesced value
|
||||||
|
if value.find(exp.Count):
|
||||||
|
|
||||||
|
def remove_aggs(node):
|
||||||
|
if isinstance(node, exp.Count):
|
||||||
|
return exp.Literal.number(0)
|
||||||
|
elif isinstance(node, exp.AggFunc):
|
||||||
|
return exp.null()
|
||||||
|
return node
|
||||||
|
|
||||||
|
alias = exp.Coalesce(
|
||||||
|
this=alias,
|
||||||
|
expressions=[value.this.transform(remove_aggs)],
|
||||||
|
)
|
||||||
|
|
||||||
select.parent.replace(alias)
|
select.parent.replace(alias)
|
||||||
|
|
||||||
for key, column, predicate in keys:
|
for key, column, predicate in keys:
|
||||||
|
@ -209,9 +223,6 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
||||||
|
|
||||||
if key in group_by:
|
if key in group_by:
|
||||||
key.replace(nested)
|
key.replace(nested)
|
||||||
parent_predicate = _replace(
|
|
||||||
parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
|
|
||||||
)
|
|
||||||
elif isinstance(predicate, exp.EQ):
|
elif isinstance(predicate, exp.EQ):
|
||||||
parent_predicate = _replace(
|
parent_predicate = _replace(
|
||||||
parent_predicate,
|
parent_predicate,
|
||||||
|
@ -245,7 +256,14 @@ def _other_operand(expression):
|
||||||
if isinstance(expression, exp.In):
|
if isinstance(expression, exp.In):
|
||||||
return expression.this
|
return expression.this
|
||||||
|
|
||||||
|
if isinstance(expression, (exp.Any, exp.All)):
|
||||||
|
return _other_operand(expression.parent)
|
||||||
|
|
||||||
if isinstance(expression, exp.Binary):
|
if isinstance(expression, exp.Binary):
|
||||||
return expression.right if expression.arg_key == "this" else expression.left
|
return (
|
||||||
|
expression.right
|
||||||
|
if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All))
|
||||||
|
else expression.left
|
||||||
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||||
import abc
|
import abc
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
|
import sqlglot
|
||||||
from sqlglot import expressions as exp
|
from sqlglot import expressions as exp
|
||||||
from sqlglot.errors import SchemaError
|
from sqlglot.errors import SchemaError
|
||||||
from sqlglot.helper import dict_depth
|
from sqlglot.helper import dict_depth
|
||||||
|
@ -157,10 +158,10 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
||||||
visible: t.Optional[t.Dict] = None,
|
visible: t.Optional[t.Dict] = None,
|
||||||
dialect: t.Optional[str] = None,
|
dialect: t.Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(schema)
|
|
||||||
self.visible = visible or {}
|
|
||||||
self.dialect = dialect
|
self.dialect = dialect
|
||||||
|
self.visible = visible or {}
|
||||||
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
|
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
|
||||||
|
super().__init__(self._normalize(schema or {}))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
|
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
|
||||||
|
@ -180,6 +181,33 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _normalize(self, schema: t.Dict) -> t.Dict:
|
||||||
|
"""
|
||||||
|
Converts all identifiers in the schema into lowercase, unless they're quoted.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema: the schema to normalize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The normalized schema mapping.
|
||||||
|
"""
|
||||||
|
flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
|
||||||
|
|
||||||
|
normalized_mapping: t.Dict = {}
|
||||||
|
for keys in flattened_schema:
|
||||||
|
columns = _nested_get(schema, *zip(keys, keys))
|
||||||
|
assert columns is not None
|
||||||
|
|
||||||
|
normalized_keys = [self._normalize_name(key) for key in keys]
|
||||||
|
for column_name, column_type in columns.items():
|
||||||
|
_nested_set(
|
||||||
|
normalized_mapping,
|
||||||
|
normalized_keys + [self._normalize_name(column_name)],
|
||||||
|
column_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
return normalized_mapping
|
||||||
|
|
||||||
def add_table(
|
def add_table(
|
||||||
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
|
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -204,6 +232,19 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
||||||
)
|
)
|
||||||
self.mapping_trie = self._build_trie(self.mapping)
|
self.mapping_trie = self._build_trie(self.mapping)
|
||||||
|
|
||||||
|
def _normalize_name(self, name: str) -> str:
|
||||||
|
try:
|
||||||
|
identifier: t.Optional[exp.Expression] = sqlglot.parse_one(
|
||||||
|
name, read=self.dialect, into=exp.Identifier
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
identifier = exp.to_identifier(name)
|
||||||
|
assert isinstance(identifier, exp.Identifier)
|
||||||
|
|
||||||
|
if identifier.quoted:
|
||||||
|
return identifier.name
|
||||||
|
return identifier.name.lower()
|
||||||
|
|
||||||
def _depth(self) -> int:
|
def _depth(self) -> int:
|
||||||
# The columns themselves are a mapping, but we don't want to include those
|
# The columns themselves are a mapping, but we don't want to include those
|
||||||
return super()._depth() - 1
|
return super()._depth() - 1
|
||||||
|
|
67
sqlglot/serde.py
Normal file
67
sqlglot/serde.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
from sqlglot import expressions as exp
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
JSON = t.Union[dict, list, str, float, int, bool]
|
||||||
|
Node = t.Union[t.List["Node"], exp.DataType.Type, exp.Expression, JSON]
|
||||||
|
|
||||||
|
|
||||||
|
def dump(node: Node) -> JSON:
|
||||||
|
"""
|
||||||
|
Recursively dump an AST into a JSON-serializable dict.
|
||||||
|
"""
|
||||||
|
if isinstance(node, list):
|
||||||
|
return [dump(i) for i in node]
|
||||||
|
if isinstance(node, exp.DataType.Type):
|
||||||
|
return {
|
||||||
|
"class": "DataType.Type",
|
||||||
|
"value": node.value,
|
||||||
|
}
|
||||||
|
if isinstance(node, exp.Expression):
|
||||||
|
klass = node.__class__.__qualname__
|
||||||
|
if node.__class__.__module__ != exp.__name__:
|
||||||
|
klass = f"{node.__module__}.{klass}"
|
||||||
|
obj = {
|
||||||
|
"class": klass,
|
||||||
|
"args": {k: dump(v) for k, v in node.args.items() if v is not None and v != []},
|
||||||
|
}
|
||||||
|
if node.type:
|
||||||
|
obj["type"] = node.type.sql()
|
||||||
|
if node.comments:
|
||||||
|
obj["comments"] = node.comments
|
||||||
|
return obj
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
def load(obj: JSON) -> Node:
|
||||||
|
"""
|
||||||
|
Recursively load a dict (as returned by `dump`) into an AST.
|
||||||
|
"""
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [load(i) for i in obj]
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
class_name = obj["class"]
|
||||||
|
|
||||||
|
if class_name == "DataType.Type":
|
||||||
|
return exp.DataType.Type(obj["value"])
|
||||||
|
|
||||||
|
if "." in class_name:
|
||||||
|
module_path, class_name = class_name.rsplit(".", maxsplit=1)
|
||||||
|
module = __import__(module_path, fromlist=[class_name])
|
||||||
|
else:
|
||||||
|
module = exp
|
||||||
|
|
||||||
|
klass = getattr(module, class_name)
|
||||||
|
|
||||||
|
expression = klass(**{k: load(v) for k, v in obj["args"].items()})
|
||||||
|
type_ = obj.get("type")
|
||||||
|
if type_:
|
||||||
|
expression.type = exp.DataType.build(type_)
|
||||||
|
comments = obj.get("comments")
|
||||||
|
if comments:
|
||||||
|
expression.comments = load(comments)
|
||||||
|
return expression
|
||||||
|
return obj
|
|
@ -86,6 +86,7 @@ class TokenType(AutoName):
|
||||||
VARBINARY = auto()
|
VARBINARY = auto()
|
||||||
JSON = auto()
|
JSON = auto()
|
||||||
JSONB = auto()
|
JSONB = auto()
|
||||||
|
TIME = auto()
|
||||||
TIMESTAMP = auto()
|
TIMESTAMP = auto()
|
||||||
TIMESTAMPTZ = auto()
|
TIMESTAMPTZ = auto()
|
||||||
TIMESTAMPLTZ = auto()
|
TIMESTAMPLTZ = auto()
|
||||||
|
@ -181,6 +182,7 @@ class TokenType(AutoName):
|
||||||
FUNCTION = auto()
|
FUNCTION = auto()
|
||||||
FROM = auto()
|
FROM = auto()
|
||||||
GENERATED = auto()
|
GENERATED = auto()
|
||||||
|
GLOBAL = auto()
|
||||||
GROUP_BY = auto()
|
GROUP_BY = auto()
|
||||||
GROUPING_SETS = auto()
|
GROUPING_SETS = auto()
|
||||||
HAVING = auto()
|
HAVING = auto()
|
||||||
|
@ -656,6 +658,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"FLOAT4": TokenType.FLOAT,
|
"FLOAT4": TokenType.FLOAT,
|
||||||
"FLOAT8": TokenType.DOUBLE,
|
"FLOAT8": TokenType.DOUBLE,
|
||||||
"DOUBLE": TokenType.DOUBLE,
|
"DOUBLE": TokenType.DOUBLE,
|
||||||
|
"DOUBLE PRECISION": TokenType.DOUBLE,
|
||||||
"JSON": TokenType.JSON,
|
"JSON": TokenType.JSON,
|
||||||
"CHAR": TokenType.CHAR,
|
"CHAR": TokenType.CHAR,
|
||||||
"NCHAR": TokenType.NCHAR,
|
"NCHAR": TokenType.NCHAR,
|
||||||
|
@ -671,6 +674,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"BLOB": TokenType.VARBINARY,
|
"BLOB": TokenType.VARBINARY,
|
||||||
"BYTEA": TokenType.VARBINARY,
|
"BYTEA": TokenType.VARBINARY,
|
||||||
"VARBINARY": TokenType.VARBINARY,
|
"VARBINARY": TokenType.VARBINARY,
|
||||||
|
"TIME": TokenType.TIME,
|
||||||
"TIMESTAMP": TokenType.TIMESTAMP,
|
"TIMESTAMP": TokenType.TIMESTAMP,
|
||||||
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
|
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
|
||||||
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
|
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
|
||||||
|
@ -721,6 +725,8 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
COMMENTS = ["--", ("/*", "*/")]
|
COMMENTS = ["--", ("/*", "*/")]
|
||||||
KEYWORD_TRIE = None # autofilled
|
KEYWORD_TRIE = None # autofilled
|
||||||
|
|
||||||
|
IDENTIFIER_CAN_START_WITH_DIGIT = False
|
||||||
|
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
"sql",
|
"sql",
|
||||||
"size",
|
"size",
|
||||||
|
@ -938,17 +944,24 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
elif self._peek.upper() == "E" and not scientific: # type: ignore
|
elif self._peek.upper() == "E" and not scientific: # type: ignore
|
||||||
scientific += 1
|
scientific += 1
|
||||||
self._advance()
|
self._advance()
|
||||||
elif self._peek.isalpha(): # type: ignore
|
elif self._peek.isidentifier(): # type: ignore
|
||||||
self._add(TokenType.NUMBER)
|
number_text = self._text
|
||||||
literal = []
|
literal = []
|
||||||
while self._peek.isalpha(): # type: ignore
|
while self._peek.isidentifier(): # type: ignore
|
||||||
literal.append(self._peek.upper()) # type: ignore
|
literal.append(self._peek.upper()) # type: ignore
|
||||||
self._advance()
|
self._advance()
|
||||||
|
|
||||||
literal = "".join(literal) # type: ignore
|
literal = "".join(literal) # type: ignore
|
||||||
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore
|
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore
|
||||||
|
|
||||||
if token_type:
|
if token_type:
|
||||||
|
self._add(TokenType.NUMBER, number_text)
|
||||||
self._add(TokenType.DCOLON, "::")
|
self._add(TokenType.DCOLON, "::")
|
||||||
return self._add(token_type, literal) # type: ignore
|
return self._add(token_type, literal) # type: ignore
|
||||||
|
elif self.IDENTIFIER_CAN_START_WITH_DIGIT:
|
||||||
|
return self._add(TokenType.VAR)
|
||||||
|
|
||||||
|
self._add(TokenType.NUMBER, number_text)
|
||||||
return self._advance(-len(literal))
|
return self._advance(-len(literal))
|
||||||
else:
|
else:
|
||||||
return self._add(TokenType.NUMBER)
|
return self._add(TokenType.NUMBER)
|
||||||
|
|
|
@ -82,6 +82,27 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
|
|
||||||
|
def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
|
||||||
|
"""
|
||||||
|
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions.
|
||||||
|
This transforms removes the precision from parameterized types in expressions.
|
||||||
|
"""
|
||||||
|
return expression.transform(
|
||||||
|
lambda node: exp.DataType(
|
||||||
|
**{
|
||||||
|
**node.args,
|
||||||
|
"expressions": [
|
||||||
|
node_expression
|
||||||
|
for node_expression in node.expressions
|
||||||
|
if isinstance(node_expression, exp.DataType)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if isinstance(node, exp.DataType)
|
||||||
|
else node,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
|
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
|
||||||
to_sql: t.Callable[[Generator, exp.Expression], str],
|
to_sql: t.Callable[[Generator, exp.Expression], str],
|
||||||
|
@ -121,3 +142,6 @@ def delegate(attr: str) -> t.Callable:
|
||||||
|
|
||||||
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
|
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
|
||||||
ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))}
|
ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))}
|
||||||
|
REMOVE_PRECISION_PARAMETERIZED_TYPES = {
|
||||||
|
exp.Cast: preprocess([remove_precision_parameterized_types], delegate("cast_sql"))
|
||||||
|
}
|
||||||
|
|
|
@ -52,7 +52,7 @@ def in_trie(trie: t.Dict, key: key) -> t.Tuple[int, t.Dict]:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point where the search stops, and `value`
|
A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point where the search stops, and `value`
|
||||||
is either 0 (search was unsuccessfull), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`).
|
is either 0 (search was unsuccessful), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`).
|
||||||
"""
|
"""
|
||||||
if not key:
|
if not key:
|
||||||
return (0, trie)
|
return (0, trie)
|
||||||
|
|
|
@ -1152,17 +1152,17 @@ class TestFunctions(unittest.TestCase):
|
||||||
|
|
||||||
def test_regexp_extract(self):
|
def test_regexp_extract(self):
|
||||||
col_str = SF.regexp_extract("cola", r"(\d+)-(\d+)", 1)
|
col_str = SF.regexp_extract("cola", r"(\d+)-(\d+)", 1)
|
||||||
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)', 1)", col_str.sql())
|
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)', 1)", col_str.sql())
|
||||||
col = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)", 1)
|
col = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)", 1)
|
||||||
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)', 1)", col.sql())
|
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)', 1)", col.sql())
|
||||||
col_no_idx = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)")
|
col_no_idx = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)")
|
||||||
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)')", col_no_idx.sql())
|
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)')", col_no_idx.sql())
|
||||||
|
|
||||||
def test_regexp_replace(self):
|
def test_regexp_replace(self):
|
||||||
col_str = SF.regexp_replace("cola", r"(\d+)", "--")
|
col_str = SF.regexp_replace("cola", r"(\d+)", "--")
|
||||||
self.assertEqual("REGEXP_REPLACE(cola, '(\\\d+)', '--')", col_str.sql())
|
self.assertEqual("REGEXP_REPLACE(cola, '(\\\\d+)', '--')", col_str.sql())
|
||||||
col = SF.regexp_replace(SF.col("cola"), r"(\d+)", "--")
|
col = SF.regexp_replace(SF.col("cola"), r"(\d+)", "--")
|
||||||
self.assertEqual("REGEXP_REPLACE(cola, '(\\\d+)', '--')", col.sql())
|
self.assertEqual("REGEXP_REPLACE(cola, '(\\\\d+)', '--')", col.sql())
|
||||||
|
|
||||||
def test_initcap(self):
|
def test_initcap(self):
|
||||||
col_str = SF.initcap("cola")
|
col_str = SF.initcap("cola")
|
||||||
|
|
|
@ -15,11 +15,11 @@ class TestDataframeWindow(unittest.TestCase):
|
||||||
|
|
||||||
def test_window_spec_rows_between(self):
|
def test_window_spec_rows_between(self):
|
||||||
rows_between = WindowSpec().rowsBetween(3, 5)
|
rows_between = WindowSpec().rowsBetween(3, 5)
|
||||||
self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
|
self.assertEqual("OVER (ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
|
||||||
|
|
||||||
def test_window_spec_range_between(self):
|
def test_window_spec_range_between(self):
|
||||||
range_between = WindowSpec().rangeBetween(3, 5)
|
range_between = WindowSpec().rangeBetween(3, 5)
|
||||||
self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
|
self.assertEqual("OVER (RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
|
||||||
|
|
||||||
def test_window_partition_by(self):
|
def test_window_partition_by(self):
|
||||||
partition_by = Window.partitionBy(F.col("cola"), F.col("colb"))
|
partition_by = Window.partitionBy(F.col("cola"), F.col("colb"))
|
||||||
|
@ -31,46 +31,46 @@ class TestDataframeWindow(unittest.TestCase):
|
||||||
|
|
||||||
def test_window_rows_between(self):
|
def test_window_rows_between(self):
|
||||||
rows_between = Window.rowsBetween(3, 5)
|
rows_between = Window.rowsBetween(3, 5)
|
||||||
self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
|
self.assertEqual("OVER (ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
|
||||||
|
|
||||||
def test_window_range_between(self):
|
def test_window_range_between(self):
|
||||||
range_between = Window.rangeBetween(3, 5)
|
range_between = Window.rangeBetween(3, 5)
|
||||||
self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
|
self.assertEqual("OVER (RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
|
||||||
|
|
||||||
def test_window_rows_unbounded(self):
|
def test_window_rows_unbounded(self):
|
||||||
rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
|
rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
|
"OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
|
||||||
rows_between_unbounded_start.sql(),
|
rows_between_unbounded_start.sql(),
|
||||||
)
|
)
|
||||||
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
|
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
|
"OVER (ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||||
rows_between_unbounded_end.sql(),
|
rows_between_unbounded_end.sql(),
|
||||||
)
|
)
|
||||||
rows_between_unbounded_both = Window.rowsBetween(
|
rows_between_unbounded_both = Window.rowsBetween(
|
||||||
Window.unboundedPreceding, Window.unboundedFollowing
|
Window.unboundedPreceding, Window.unboundedFollowing
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
|
"OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||||
rows_between_unbounded_both.sql(),
|
rows_between_unbounded_both.sql(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_window_range_unbounded(self):
|
def test_window_range_unbounded(self):
|
||||||
range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
|
range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
|
"OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
|
||||||
range_between_unbounded_start.sql(),
|
range_between_unbounded_start.sql(),
|
||||||
)
|
)
|
||||||
range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
|
range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
|
"OVER (RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||||
range_between_unbounded_end.sql(),
|
range_between_unbounded_end.sql(),
|
||||||
)
|
)
|
||||||
range_between_unbounded_both = Window.rangeBetween(
|
range_between_unbounded_both = Window.rangeBetween(
|
||||||
Window.unboundedPreceding, Window.unboundedFollowing
|
Window.unboundedPreceding, Window.unboundedFollowing
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
|
"OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||||
range_between_unbounded_both.sql(),
|
range_between_unbounded_both.sql(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -125,7 +125,7 @@ class TestBigQuery(Validator):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CURRENT_DATE",
|
"CURRENT_TIMESTAMP()",
|
||||||
read={
|
read={
|
||||||
"tsql": "GETDATE()",
|
"tsql": "GETDATE()",
|
||||||
},
|
},
|
||||||
|
@ -299,6 +299,14 @@ class TestBigQuery(Validator):
|
||||||
"snowflake": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
|
"snowflake": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT cola, colb, colc FROM (VALUES (1, 'test', NULL)) AS tab(cola, colb, colc)",
|
||||||
|
write={
|
||||||
|
"spark": "SELECT cola, colb, colc FROM VALUES (1, 'test', NULL) AS tab(cola, colb, colc)",
|
||||||
|
"bigquery": "SELECT cola, colb, colc FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb, NULL AS colc)])",
|
||||||
|
"snowflake": "SELECT cola, colb, colc FROM (VALUES (1, 'test', NULL)) AS tab(cola, colb, colc)",
|
||||||
|
},
|
||||||
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) d, COUNT(*) e FOR c IN ('x', 'y'))",
|
"SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) d, COUNT(*) e FOR c IN ('x', 'y'))",
|
||||||
write={
|
write={
|
||||||
|
@ -324,3 +332,35 @@ class TestBigQuery(Validator):
|
||||||
"SELECT a, GROUP_CONCAT(b) FROM table GROUP BY a",
|
"SELECT a, GROUP_CONCAT(b) FROM table GROUP BY a",
|
||||||
write={"bigquery": "SELECT a, STRING_AGG(b) FROM table GROUP BY a"},
|
write={"bigquery": "SELECT a, STRING_AGG(b) FROM table GROUP BY a"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_remove_precision_parameterized_types(self):
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT CAST(1 AS NUMERIC(10, 2))",
|
||||||
|
write={
|
||||||
|
"bigquery": "SELECT CAST(1 AS NUMERIC)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"CREATE TABLE test (a NUMERIC(10, 2))",
|
||||||
|
write={
|
||||||
|
"bigquery": "CREATE TABLE test (a NUMERIC(10, 2))",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT CAST('1' AS STRING(10)) UNION ALL SELECT CAST('2' AS STRING(10))",
|
||||||
|
write={
|
||||||
|
"bigquery": "SELECT CAST('1' AS STRING) UNION ALL SELECT CAST('2' AS STRING)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT cola FROM (SELECT CAST('1' AS STRING(10)) AS cola UNION ALL SELECT CAST('2' AS STRING(10)) AS cola)",
|
||||||
|
write={
|
||||||
|
"bigquery": "SELECT cola FROM (SELECT CAST('1' AS STRING) AS cola UNION ALL SELECT CAST('2' AS STRING) AS cola)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"INSERT INTO test (cola, colb) VALUES (CAST(7 AS STRING(10)), CAST(14 AS STRING(10)))",
|
||||||
|
write={
|
||||||
|
"bigquery": "INSERT INTO test (cola, colb) VALUES (CAST(7 AS STRING), CAST(14 AS STRING))",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
|
@ -14,6 +14,9 @@ class TestClickhouse(Validator):
|
||||||
self.validate_identity("SELECT * FROM foo LEFT ASOF JOIN bla")
|
self.validate_identity("SELECT * FROM foo LEFT ASOF JOIN bla")
|
||||||
self.validate_identity("SELECT * FROM foo ASOF JOIN bla")
|
self.validate_identity("SELECT * FROM foo ASOF JOIN bla")
|
||||||
self.validate_identity("SELECT * FROM foo ANY JOIN bla")
|
self.validate_identity("SELECT * FROM foo ANY JOIN bla")
|
||||||
|
self.validate_identity("SELECT quantile(0.5)(a)")
|
||||||
|
self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t")
|
||||||
|
self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)")
|
||||||
|
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
|
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
|
||||||
|
@ -38,3 +41,9 @@ class TestClickhouse(Validator):
|
||||||
"SELECT x #! comment",
|
"SELECT x #! comment",
|
||||||
write={"": "SELECT x /* comment */"},
|
write={"": "SELECT x /* comment */"},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT quantileIf(0.5)(a, true)",
|
||||||
|
write={
|
||||||
|
"clickhouse": "SELECT quantileIf(0.5)(a, TRUE)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
|
@ -85,7 +85,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CAST(a AS BINARY(4))",
|
"CAST(a AS BINARY(4))",
|
||||||
write={
|
write={
|
||||||
"bigquery": "CAST(a AS BINARY(4))",
|
"bigquery": "CAST(a AS BINARY)",
|
||||||
"clickhouse": "CAST(a AS BINARY(4))",
|
"clickhouse": "CAST(a AS BINARY(4))",
|
||||||
"drill": "CAST(a AS VARBINARY(4))",
|
"drill": "CAST(a AS VARBINARY(4))",
|
||||||
"duckdb": "CAST(a AS BINARY(4))",
|
"duckdb": "CAST(a AS BINARY(4))",
|
||||||
|
@ -104,7 +104,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CAST(a AS VARBINARY(4))",
|
"CAST(a AS VARBINARY(4))",
|
||||||
write={
|
write={
|
||||||
"bigquery": "CAST(a AS VARBINARY(4))",
|
"bigquery": "CAST(a AS VARBINARY)",
|
||||||
"clickhouse": "CAST(a AS VARBINARY(4))",
|
"clickhouse": "CAST(a AS VARBINARY(4))",
|
||||||
"duckdb": "CAST(a AS VARBINARY(4))",
|
"duckdb": "CAST(a AS VARBINARY(4))",
|
||||||
"mysql": "CAST(a AS VARBINARY(4))",
|
"mysql": "CAST(a AS VARBINARY(4))",
|
||||||
|
@ -181,7 +181,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CAST(a AS VARCHAR(3))",
|
"CAST(a AS VARCHAR(3))",
|
||||||
write={
|
write={
|
||||||
"bigquery": "CAST(a AS STRING(3))",
|
"bigquery": "CAST(a AS STRING)",
|
||||||
"drill": "CAST(a AS VARCHAR(3))",
|
"drill": "CAST(a AS VARCHAR(3))",
|
||||||
"duckdb": "CAST(a AS TEXT(3))",
|
"duckdb": "CAST(a AS TEXT(3))",
|
||||||
"mysql": "CAST(a AS VARCHAR(3))",
|
"mysql": "CAST(a AS VARCHAR(3))",
|
||||||
|
|
|
@ -338,6 +338,24 @@ class TestHive(Validator):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_hive(self):
|
def test_hive(self):
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT A.1a AS b FROM test_a AS A",
|
||||||
|
write={
|
||||||
|
"spark": "SELECT A.1a AS b FROM test_a AS A",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT 1_a AS a FROM test_table",
|
||||||
|
write={
|
||||||
|
"spark": "SELECT 1_a AS a FROM test_table",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT a_b AS 1_a FROM test_table",
|
||||||
|
write={
|
||||||
|
"spark": "SELECT a_b AS 1_a FROM test_table",
|
||||||
|
},
|
||||||
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"PERCENTILE(x, 0.5)",
|
"PERCENTILE(x, 0.5)",
|
||||||
write={
|
write={
|
||||||
|
@ -411,7 +429,7 @@ class TestHive(Validator):
|
||||||
"INITCAP('new york')",
|
"INITCAP('new york')",
|
||||||
write={
|
write={
|
||||||
"duckdb": "INITCAP('new york')",
|
"duckdb": "INITCAP('new york')",
|
||||||
"presto": "REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2]))",
|
"presto": r"REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2]))",
|
||||||
"hive": "INITCAP('new york')",
|
"hive": "INITCAP('new york')",
|
||||||
"spark": "INITCAP('new york')",
|
"spark": "INITCAP('new york')",
|
||||||
},
|
},
|
||||||
|
|
|
@ -122,6 +122,10 @@ class TestPostgres(Validator):
|
||||||
"TO_TIMESTAMP(123::DOUBLE PRECISION)",
|
"TO_TIMESTAMP(123::DOUBLE PRECISION)",
|
||||||
write={"postgres": "TO_TIMESTAMP(CAST(123 AS DOUBLE PRECISION))"},
|
write={"postgres": "TO_TIMESTAMP(CAST(123 AS DOUBLE PRECISION))"},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT to_timestamp(123)::time without time zone",
|
||||||
|
write={"postgres": "SELECT CAST(TO_TIMESTAMP(123) AS TIME)"},
|
||||||
|
)
|
||||||
|
|
||||||
self.validate_identity(
|
self.validate_identity(
|
||||||
"CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)"
|
"CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)"
|
||||||
|
|
|
@ -60,11 +60,11 @@ class TestPresto(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CAST(x AS TIMESTAMP(9) WITH TIME ZONE)",
|
"CAST(x AS TIMESTAMP(9) WITH TIME ZONE)",
|
||||||
write={
|
write={
|
||||||
"bigquery": "CAST(x AS TIMESTAMPTZ(9))",
|
"bigquery": "CAST(x AS TIMESTAMPTZ)",
|
||||||
"duckdb": "CAST(x AS TIMESTAMPTZ(9))",
|
"duckdb": "CAST(x AS TIMESTAMPTZ(9))",
|
||||||
"presto": "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)",
|
"presto": "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)",
|
||||||
"hive": "CAST(x AS TIMESTAMPTZ(9))",
|
"hive": "CAST(x AS TIMESTAMPTZ)",
|
||||||
"spark": "CAST(x AS TIMESTAMPTZ(9))",
|
"spark": "CAST(x AS TIMESTAMPTZ)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -523,3 +523,33 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f, LATERA
|
||||||
"spark": "SELECT `c0`, `c1` FROM (VALUES (1, 2), (3, 4)) AS `t0`(`c0`, `c1`)",
|
"spark": "SELECT `c0`, `c1` FROM (VALUES (1, 2), (3, 4)) AS `t0`(`c0`, `c1`)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_describe_table(self):
|
||||||
|
self.validate_all(
|
||||||
|
"DESCRIBE TABLE db.table",
|
||||||
|
write={
|
||||||
|
"snowflake": "DESCRIBE TABLE db.table",
|
||||||
|
"spark": "DESCRIBE db.table",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"DESCRIBE db.table",
|
||||||
|
write={
|
||||||
|
"snowflake": "DESCRIBE TABLE db.table",
|
||||||
|
"spark": "DESCRIBE db.table",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"DESC TABLE db.table",
|
||||||
|
write={
|
||||||
|
"snowflake": "DESCRIBE TABLE db.table",
|
||||||
|
"spark": "DESCRIBE db.table",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"DESC VIEW db.table",
|
||||||
|
write={
|
||||||
|
"snowflake": "DESCRIBE VIEW db.table",
|
||||||
|
"spark": "DESCRIBE db.table",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
|
@ -207,6 +207,7 @@ TBLPROPERTIES (
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_spark(self):
|
def test_spark(self):
|
||||||
|
self.validate_identity("SELECT UNIX_TIMESTAMP()")
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"ARRAY_SORT(x, (left, right) -> -1)",
|
"ARRAY_SORT(x, (left, right) -> -1)",
|
||||||
write={
|
write={
|
||||||
|
|
|
@ -6,6 +6,8 @@ class TestTSQL(Validator):
|
||||||
|
|
||||||
def test_tsql(self):
|
def test_tsql(self):
|
||||||
self.validate_identity('SELECT "x"."y" FROM foo')
|
self.validate_identity('SELECT "x"."y" FROM foo')
|
||||||
|
self.validate_identity("SELECT * FROM #foo")
|
||||||
|
self.validate_identity("SELECT * FROM ##foo")
|
||||||
self.validate_identity(
|
self.validate_identity(
|
||||||
"SELECT DISTINCT DepartmentName, PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY BaseRate) OVER (PARTITION BY DepartmentName) AS MedianCont FROM dbo.DimEmployee"
|
"SELECT DISTINCT DepartmentName, PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY BaseRate) OVER (PARTITION BY DepartmentName) AS MedianCont FROM dbo.DimEmployee"
|
||||||
)
|
)
|
||||||
|
@ -71,6 +73,12 @@ class TestTSQL(Validator):
|
||||||
"tsql": "CAST(x AS DATETIME2)",
|
"tsql": "CAST(x AS DATETIME2)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"CAST(x AS DATETIME2(6))",
|
||||||
|
write={
|
||||||
|
"hive": "CAST(x AS TIMESTAMP)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_charindex(self):
|
def test_charindex(self):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -300,6 +308,12 @@ class TestTSQL(Validator):
|
||||||
"spark": "SELECT CAST(y.x AS VARCHAR(10)) AS z FROM testdb.dbo.test AS y",
|
"spark": "SELECT CAST(y.x AS VARCHAR(10)) AS z FROM testdb.dbo.test AS y",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT CAST((SELECT x FROM y) AS VARCHAR) AS test",
|
||||||
|
write={
|
||||||
|
"spark": "SELECT CAST((SELECT x FROM y) AS STRING) AS test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_add_date(self):
|
def test_add_date(self):
|
||||||
self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')")
|
self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')")
|
||||||
|
@ -441,3 +455,13 @@ class TestTSQL(Validator):
|
||||||
"SELECT '''test'''",
|
"SELECT '''test'''",
|
||||||
write={"spark": r"SELECT '\'test\''"},
|
write={"spark": r"SELECT '\'test\''"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_eomonth(self):
|
||||||
|
self.validate_all(
|
||||||
|
"EOMONTH(GETDATE())",
|
||||||
|
write={"spark": "LAST_DAY(CURRENT_TIMESTAMP())"},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"EOMONTH(GETDATE(), -1)",
|
||||||
|
write={"spark": "LAST_DAY(ADD_MONTHS(CURRENT_TIMESTAMP(), -1))"},
|
||||||
|
)
|
||||||
|
|
16
tests/fixtures/identity.sql
vendored
16
tests/fixtures/identity.sql
vendored
|
@ -89,6 +89,7 @@ POSEXPLODE("x") AS ("a", "b")
|
||||||
POSEXPLODE("x") AS ("a", "b", "c")
|
POSEXPLODE("x") AS ("a", "b", "c")
|
||||||
STR_POSITION(x, 'a')
|
STR_POSITION(x, 'a')
|
||||||
STR_POSITION(x, 'a', 3)
|
STR_POSITION(x, 'a', 3)
|
||||||
|
LEVENSHTEIN('gumbo', 'gambol', 2, 1, 1)
|
||||||
SPLIT(SPLIT(referrer, 'utm_source=')[OFFSET(1)], "&")[OFFSET(0)]
|
SPLIT(SPLIT(referrer, 'utm_source=')[OFFSET(1)], "&")[OFFSET(0)]
|
||||||
x[ORDINAL(1)][SAFE_OFFSET(2)]
|
x[ORDINAL(1)][SAFE_OFFSET(2)]
|
||||||
x LIKE SUBSTR('abc', 1, 1)
|
x LIKE SUBSTR('abc', 1, 1)
|
||||||
|
@ -425,6 +426,7 @@ SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT
|
||||||
SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 AND 3)
|
SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 AND 3)
|
||||||
SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND 3)
|
SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND 3)
|
||||||
SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLOWING)
|
SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLOWING)
|
||||||
|
SELECT AVG(x) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM t
|
||||||
SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x) AS y
|
SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x) AS y
|
||||||
SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x DESC)
|
SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x DESC)
|
||||||
SELECT SUM(x) FILTER(WHERE x > 1)
|
SELECT SUM(x) FILTER(WHERE x > 1)
|
||||||
|
@ -450,14 +452,24 @@ SELECT student, score FROM tests CROSS JOIN UNNEST(scores) AS t(a, b)
|
||||||
SELECT student, score FROM tests CROSS JOIN UNNEST(scores) WITH ORDINALITY AS t(a, b)
|
SELECT student, score FROM tests CROSS JOIN UNNEST(scores) WITH ORDINALITY AS t(a, b)
|
||||||
SELECT student, score FROM tests CROSS JOIN UNNEST(x.scores) AS t(score)
|
SELECT student, score FROM tests CROSS JOIN UNNEST(x.scores) AS t(score)
|
||||||
SELECT student, score FROM tests CROSS JOIN UNNEST(ARRAY(x.scores)) AS t(score)
|
SELECT student, score FROM tests CROSS JOIN UNNEST(ARRAY(x.scores)) AS t(score)
|
||||||
|
SELECT * FROM t WITH (TABLOCK, INDEX(myindex))
|
||||||
|
SELECT * FROM t WITH (NOWAIT)
|
||||||
|
CREATE TABLE foo AS (SELECT 1) UNION ALL (SELECT 2)
|
||||||
CREATE TABLE foo (id INT PRIMARY KEY ASC)
|
CREATE TABLE foo (id INT PRIMARY KEY ASC)
|
||||||
CREATE TABLE a.b AS SELECT 1
|
CREATE TABLE a.b AS SELECT 1
|
||||||
|
CREATE TABLE a.b AS SELECT 1 WITH DATA AND STATISTICS
|
||||||
|
CREATE TABLE a.b AS SELECT 1 WITH NO DATA AND NO STATISTICS
|
||||||
|
CREATE TABLE a.b AS (SELECT 1) NO PRIMARY INDEX
|
||||||
|
CREATE TABLE a.b AS (SELECT 1) UNIQUE PRIMARY INDEX index1 (a) UNIQUE INDEX index2 (b)
|
||||||
|
CREATE TABLE a.b AS (SELECT 1) PRIMARY AMP INDEX index1 (a) UNIQUE INDEX index2 (b)
|
||||||
CREATE TABLE a.b AS SELECT a FROM a.c
|
CREATE TABLE a.b AS SELECT a FROM a.c
|
||||||
CREATE TABLE IF NOT EXISTS x AS SELECT a FROM d
|
CREATE TABLE IF NOT EXISTS x AS SELECT a FROM d
|
||||||
CREATE TEMPORARY TABLE x AS SELECT a FROM d
|
CREATE TEMPORARY TABLE x AS SELECT a FROM d
|
||||||
CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d
|
CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d
|
||||||
CREATE VIEW x AS SELECT a FROM b
|
CREATE VIEW x AS SELECT a FROM b
|
||||||
CREATE VIEW IF NOT EXISTS x AS SELECT a FROM b
|
CREATE VIEW IF NOT EXISTS x AS SELECT a FROM b
|
||||||
|
CREATE VIEW z (a, b COMMENT 'b', c COMMENT 'c') AS SELECT a, b, c FROM d
|
||||||
|
CREATE VIEW IF NOT EXISTS z (a, b COMMENT 'b', c COMMENT 'c') AS SELECT a, b, c FROM d
|
||||||
CREATE OR REPLACE VIEW x AS SELECT *
|
CREATE OR REPLACE VIEW x AS SELECT *
|
||||||
CREATE OR REPLACE TEMPORARY VIEW x AS SELECT *
|
CREATE OR REPLACE TEMPORARY VIEW x AS SELECT *
|
||||||
CREATE TEMPORARY VIEW x AS SELECT a FROM d
|
CREATE TEMPORARY VIEW x AS SELECT a FROM d
|
||||||
|
@ -490,6 +502,8 @@ CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT)
|
||||||
CREATE TABLE z (a INT REFERENCES parent(b, c))
|
CREATE TABLE z (a INT REFERENCES parent(b, c))
|
||||||
CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id))
|
CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id))
|
||||||
CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c))
|
CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c))
|
||||||
|
CREATE VIEW z (a, b)
|
||||||
|
CREATE VIEW z (a, b COMMENT 'b', c COMMENT 'c')
|
||||||
CREATE TEMPORARY FUNCTION f
|
CREATE TEMPORARY FUNCTION f
|
||||||
CREATE TEMPORARY FUNCTION f AS 'g'
|
CREATE TEMPORARY FUNCTION f AS 'g'
|
||||||
CREATE FUNCTION f
|
CREATE FUNCTION f
|
||||||
|
@ -559,6 +573,7 @@ INSERT INTO x.z IF EXISTS SELECT * FROM y
|
||||||
INSERT INTO x VALUES (1, 'a', 2.0)
|
INSERT INTO x VALUES (1, 'a', 2.0)
|
||||||
INSERT INTO x VALUES (1, 'a', 2.0), (1, 'a', 3.0), (X(), y[1], z.x)
|
INSERT INTO x VALUES (1, 'a', 2.0), (1, 'a', 3.0), (X(), y[1], z.x)
|
||||||
INSERT INTO y (a, b, c) SELECT a, b, c FROM x
|
INSERT INTO y (a, b, c) SELECT a, b, c FROM x
|
||||||
|
INSERT INTO y (SELECT 1) UNION (SELECT 2)
|
||||||
INSERT OVERWRITE TABLE x IF EXISTS SELECT * FROM y
|
INSERT OVERWRITE TABLE x IF EXISTS SELECT * FROM y
|
||||||
INSERT OVERWRITE TABLE a.b IF EXISTS SELECT * FROM y
|
INSERT OVERWRITE TABLE a.b IF EXISTS SELECT * FROM y
|
||||||
INSERT OVERWRITE DIRECTORY 'x' SELECT 1
|
INSERT OVERWRITE DIRECTORY 'x' SELECT 1
|
||||||
|
@ -627,3 +642,4 @@ ALTER TABLE integers ALTER COLUMN i SET DEFAULT 10
|
||||||
ALTER TABLE integers ALTER COLUMN i DROP DEFAULT
|
ALTER TABLE integers ALTER COLUMN i DROP DEFAULT
|
||||||
ALTER TABLE mydataset.mytable DROP COLUMN A, DROP COLUMN IF EXISTS B
|
ALTER TABLE mydataset.mytable DROP COLUMN A, DROP COLUMN IF EXISTS B
|
||||||
ALTER TABLE mydataset.mytable ADD COLUMN A TEXT, ADD COLUMN IF NOT EXISTS B INT
|
ALTER TABLE mydataset.mytable ADD COLUMN A TEXT, ADD COLUMN IF NOT EXISTS B INT
|
||||||
|
SELECT div.a FROM test_table AS div
|
||||||
|
|
39
tests/fixtures/optimizer/merge_subqueries.sql
vendored
39
tests/fixtures/optimizer/merge_subqueries.sql
vendored
|
@ -311,3 +311,42 @@ FROM
|
||||||
ON
|
ON
|
||||||
t1.cola = t2.cola;
|
t1.cola = t2.cola;
|
||||||
SELECT /*+ BROADCAST(a2) */ a1.cola AS cola, a2.cola AS cola FROM VALUES (1) AS a1(cola) JOIN VALUES (1) AS a2(cola) ON a1.cola = a2.cola;
|
SELECT /*+ BROADCAST(a2) */ a1.cola AS cola, a2.cola AS cola FROM VALUES (1) AS a1(cola) JOIN VALUES (1) AS a2(cola) ON a1.cola = a2.cola;
|
||||||
|
|
||||||
|
# title: Nested subquery selects from same table as another subquery
|
||||||
|
WITH i AS (
|
||||||
|
SELECT
|
||||||
|
x.a AS a
|
||||||
|
FROM x AS x
|
||||||
|
), j AS (
|
||||||
|
SELECT
|
||||||
|
x.a,
|
||||||
|
x.b
|
||||||
|
FROM x AS x
|
||||||
|
), k AS (
|
||||||
|
SELECT
|
||||||
|
j.a,
|
||||||
|
j.b
|
||||||
|
FROM j AS j
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
i.a,
|
||||||
|
k.b
|
||||||
|
FROM i AS i
|
||||||
|
LEFT JOIN k AS k
|
||||||
|
ON i.a = k.a;
|
||||||
|
SELECT x.a AS a, x_2.b AS b FROM x AS x LEFT JOIN x AS x_2 ON x.a = x_2.a;
|
||||||
|
|
||||||
|
# title: Outer select joins on inner select join
|
||||||
|
WITH i AS (
|
||||||
|
SELECT
|
||||||
|
x.a AS a
|
||||||
|
FROM y AS y
|
||||||
|
JOIN x AS x
|
||||||
|
ON y.b = x.b
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
x.a AS a
|
||||||
|
FROM x AS x
|
||||||
|
LEFT JOIN i AS i
|
||||||
|
ON x.a = i.a;
|
||||||
|
WITH i AS (SELECT x.a AS a FROM y AS y JOIN x AS x ON y.b = x.b) SELECT x.a AS a FROM x AS x LEFT JOIN i AS i ON x.a = i.a;
|
||||||
|
|
2
tests/fixtures/optimizer/optimizer.sql
vendored
2
tests/fixtures/optimizer/optimizer.sql
vendored
|
@ -105,7 +105,7 @@ LEFT JOIN "_u_0" AS "_u_0"
|
||||||
JOIN "y" AS "y"
|
JOIN "y" AS "y"
|
||||||
ON "x"."b" = "y"."b"
|
ON "x"."b" = "y"."b"
|
||||||
WHERE
|
WHERE
|
||||||
"_u_0"."_col_0" >= 0 AND "x"."a" > 1 AND NOT "_u_0"."_u_1" IS NULL
|
"_u_0"."_col_0" >= 0 AND "x"."a" > 1
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"x"."a";
|
"x"."a";
|
||||||
|
|
||||||
|
|
|
@ -54,3 +54,6 @@ WITH t1 AS (SELECT q.cola AS cola FROM UNNEST(ARRAY(STRUCT(1 AS cola, 'test' AS
|
||||||
|
|
||||||
SELECT x FROM VALUES(1, 2) AS q(x, y);
|
SELECT x FROM VALUES(1, 2) AS q(x, y);
|
||||||
SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y);
|
SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y);
|
||||||
|
|
||||||
|
SELECT i.a FROM x AS i LEFT JOIN (SELECT a, b FROM (SELECT a, b FROM x)) AS j ON i.a = j.a;
|
||||||
|
SELECT i.a AS a FROM x AS i LEFT JOIN (SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0) AS j ON i.a = j.a;
|
||||||
|
|
12
tests/fixtures/optimizer/simplify.sql
vendored
12
tests/fixtures/optimizer/simplify.sql
vendored
|
@ -375,6 +375,18 @@ CAST('1998-12-01' AS DATE) - INTERVAL '90' foo;
|
||||||
date '1998-12-01' + interval '90' foo;
|
date '1998-12-01' + interval '90' foo;
|
||||||
CAST('1998-12-01' AS DATE) + INTERVAL '90' foo;
|
CAST('1998-12-01' AS DATE) + INTERVAL '90' foo;
|
||||||
|
|
||||||
|
CAST(x AS DATE) + interval '1' week;
|
||||||
|
CAST(x AS DATE) + INTERVAL '1' week;
|
||||||
|
|
||||||
|
CAST('2008-11-11' AS DATETIME) + INTERVAL '5' MONTH;
|
||||||
|
CAST('2009-04-11 00:00:00' AS DATETIME);
|
||||||
|
|
||||||
|
datetime '1998-12-01' - interval '90' day;
|
||||||
|
CAST('1998-09-02 00:00:00' AS DATETIME);
|
||||||
|
|
||||||
|
CAST(x AS DATETIME) + interval '1' week;
|
||||||
|
CAST(x AS DATETIME) + INTERVAL '1' week;
|
||||||
|
|
||||||
--------------------------------------
|
--------------------------------------
|
||||||
-- Comparisons
|
-- Comparisons
|
||||||
--------------------------------------
|
--------------------------------------
|
||||||
|
|
8
tests/fixtures/optimizer/tpc-h/tpc-h.sql
vendored
8
tests/fixtures/optimizer/tpc-h/tpc-h.sql
vendored
|
@ -150,7 +150,6 @@ WHERE
|
||||||
"part"."p_size" = 15
|
"part"."p_size" = 15
|
||||||
AND "part"."p_type" LIKE '%BRASS'
|
AND "part"."p_type" LIKE '%BRASS'
|
||||||
AND "partsupp"."ps_supplycost" = "_u_0"."_col_0"
|
AND "partsupp"."ps_supplycost" = "_u_0"."_col_0"
|
||||||
AND NOT "_u_0"."_u_1" IS NULL
|
|
||||||
ORDER BY
|
ORDER BY
|
||||||
"s_acctbal" DESC,
|
"s_acctbal" DESC,
|
||||||
"n_name",
|
"n_name",
|
||||||
|
@ -1008,7 +1007,7 @@ JOIN "part" AS "part"
|
||||||
LEFT JOIN "_u_0" AS "_u_0"
|
LEFT JOIN "_u_0" AS "_u_0"
|
||||||
ON "_u_0"."_u_1" = "part"."p_partkey"
|
ON "_u_0"."_u_1" = "part"."p_partkey"
|
||||||
WHERE
|
WHERE
|
||||||
"lineitem"."l_quantity" < "_u_0"."_col_0" AND NOT "_u_0"."_u_1" IS NULL;
|
"lineitem"."l_quantity" < "_u_0"."_col_0";
|
||||||
|
|
||||||
--------------------------------------
|
--------------------------------------
|
||||||
-- TPC-H 18
|
-- TPC-H 18
|
||||||
|
@ -1253,10 +1252,7 @@ WITH "_u_0" AS (
|
||||||
LEFT JOIN "_u_3" AS "_u_3"
|
LEFT JOIN "_u_3" AS "_u_3"
|
||||||
ON "partsupp"."ps_partkey" = "_u_3"."p_partkey"
|
ON "partsupp"."ps_partkey" = "_u_3"."p_partkey"
|
||||||
WHERE
|
WHERE
|
||||||
"partsupp"."ps_availqty" > "_u_0"."_col_0"
|
"partsupp"."ps_availqty" > "_u_0"."_col_0" AND NOT "_u_3"."p_partkey" IS NULL
|
||||||
AND NOT "_u_0"."_u_1" IS NULL
|
|
||||||
AND NOT "_u_0"."_u_2" IS NULL
|
|
||||||
AND NOT "_u_3"."p_partkey" IS NULL
|
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"partsupp"."ps_suppkey"
|
"partsupp"."ps_suppkey"
|
||||||
)
|
)
|
||||||
|
|
59
tests/fixtures/optimizer/unnest_subqueries.sql
vendored
59
tests/fixtures/optimizer/unnest_subqueries.sql
vendored
|
@ -22,6 +22,8 @@ WHERE
|
||||||
AND x.a > ANY (SELECT y.a FROM y)
|
AND x.a > ANY (SELECT y.a FROM y)
|
||||||
AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a LIMIT 10)
|
AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a LIMIT 10)
|
||||||
AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a OFFSET 10)
|
AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a OFFSET 10)
|
||||||
|
AND x.a > ALL (SELECT y.c FROM y WHERE y.a = x.a)
|
||||||
|
AND x.a > (SELECT COUNT(*) as d FROM y WHERE y.a = x.a)
|
||||||
;
|
;
|
||||||
SELECT
|
SELECT
|
||||||
*
|
*
|
||||||
|
@ -130,37 +132,42 @@ LEFT JOIN (
|
||||||
y.a
|
y.a
|
||||||
) AS _u_15
|
) AS _u_15
|
||||||
ON x.a = _u_15.a
|
ON x.a = _u_15.a
|
||||||
|
LEFT JOIN (
|
||||||
|
SELECT
|
||||||
|
ARRAY_AGG(c),
|
||||||
|
y.a AS _u_20
|
||||||
|
FROM y
|
||||||
|
WHERE
|
||||||
|
TRUE
|
||||||
|
GROUP BY
|
||||||
|
y.a
|
||||||
|
) AS _u_19
|
||||||
|
ON _u_19._u_20 = x.a
|
||||||
|
LEFT JOIN (
|
||||||
|
SELECT
|
||||||
|
COUNT(*) AS d,
|
||||||
|
y.a AS _u_22
|
||||||
|
FROM y
|
||||||
|
WHERE
|
||||||
|
TRUE
|
||||||
|
GROUP BY
|
||||||
|
y.a
|
||||||
|
) AS _u_21
|
||||||
|
ON _u_21._u_22 = x.a
|
||||||
WHERE
|
WHERE
|
||||||
x.a = _u_0.a
|
x.a = _u_0.a
|
||||||
AND NOT "_u_1"."a" IS NULL
|
AND NOT "_u_1"."a" IS NULL
|
||||||
AND NOT "_u_2"."b" IS NULL
|
AND NOT "_u_2"."b" IS NULL
|
||||||
AND NOT "_u_3"."a" IS NULL
|
AND NOT "_u_3"."a" IS NULL
|
||||||
|
AND x.a = _u_4.b
|
||||||
|
AND x.a > _u_6.b
|
||||||
|
AND x.a = _u_8.a
|
||||||
|
AND NOT x.a = _u_9.a
|
||||||
|
AND ARRAY_ANY(_u_10.a, _x -> _x = x.a)
|
||||||
AND (
|
AND (
|
||||||
x.a = _u_4.b AND NOT _u_4._u_5 IS NULL
|
x.a < _u_12.a AND ARRAY_ANY(_u_12._u_14, "_x" -> _x <> x.d)
|
||||||
)
|
|
||||||
AND (
|
|
||||||
x.a > _u_6.b AND NOT _u_6._u_7 IS NULL
|
|
||||||
)
|
|
||||||
AND (
|
|
||||||
None = _u_8.a AND NOT _u_8.a IS NULL
|
|
||||||
)
|
|
||||||
AND NOT (
|
|
||||||
x.a = _u_9.a AND NOT _u_9.a IS NULL
|
|
||||||
)
|
|
||||||
AND (
|
|
||||||
ARRAY_ANY(_u_10.a, _x -> _x = x.a) AND NOT _u_10._u_11 IS NULL
|
|
||||||
)
|
|
||||||
AND (
|
|
||||||
(
|
|
||||||
(
|
|
||||||
x.a < _u_12.a AND NOT _u_12._u_13 IS NULL
|
|
||||||
) AND NOT _u_12._u_13 IS NULL
|
|
||||||
)
|
|
||||||
AND ARRAY_ANY(_u_12._u_14, "_x" -> _x <> x.d)
|
|
||||||
)
|
|
||||||
AND (
|
|
||||||
NOT _u_15.a IS NULL AND NOT _u_15.a IS NULL
|
|
||||||
)
|
)
|
||||||
|
AND NOT _u_15.a IS NULL
|
||||||
AND x.a IN (
|
AND x.a IN (
|
||||||
SELECT
|
SELECT
|
||||||
y.a AS a
|
y.a AS a
|
||||||
|
@ -199,4 +206,6 @@ WHERE
|
||||||
WHERE
|
WHERE
|
||||||
y.a = x.a
|
y.a = x.a
|
||||||
OFFSET 10
|
OFFSET 10
|
||||||
);
|
)
|
||||||
|
AND ARRAY_ALL(_u_19."", _x -> _x = x.a)
|
||||||
|
AND x.a > COALESCE(_u_21.d, 0);
|
||||||
|
|
|
@ -27,8 +27,7 @@ def assert_logger_contains(message, logger, level="error"):
|
||||||
|
|
||||||
def load_sql_fixtures(filename):
|
def load_sql_fixtures(filename):
|
||||||
with open(os.path.join(FIXTURES_DIR, filename), encoding="utf-8") as f:
|
with open(os.path.join(FIXTURES_DIR, filename), encoding="utf-8") as f:
|
||||||
for sql in _filter_comments(f.read()).splitlines():
|
yield from _filter_comments(f.read()).splitlines()
|
||||||
yield sql
|
|
||||||
|
|
||||||
|
|
||||||
def load_sql_fixture_pairs(filename):
|
def load_sql_fixture_pairs(filename):
|
||||||
|
|
|
@ -401,6 +401,36 @@ class TestExecutor(unittest.TestCase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_correlated_count(self):
|
||||||
|
tables = {
|
||||||
|
"parts": [{"pnum": 0, "qoh": 1}],
|
||||||
|
"supplies": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
schema = {
|
||||||
|
"parts": {"pnum": "int", "qoh": "int"},
|
||||||
|
"supplies": {"pnum": "int", "shipdate": "int"},
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
execute(
|
||||||
|
"""
|
||||||
|
select *
|
||||||
|
from parts
|
||||||
|
where parts.qoh >= (
|
||||||
|
select count(supplies.shipdate) + 1
|
||||||
|
from supplies
|
||||||
|
where supplies.pnum = parts.pnum and supplies.shipdate < 10
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
tables=tables,
|
||||||
|
schema=schema,
|
||||||
|
).rows,
|
||||||
|
[
|
||||||
|
(0, 1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
def test_table_depth_mismatch(self):
|
def test_table_depth_mismatch(self):
|
||||||
tables = {"table": []}
|
tables = {"table": []}
|
||||||
schema = {"db": {"table": {"col": "VARCHAR"}}}
|
schema = {"db": {"table": {"col": "VARCHAR"}}}
|
||||||
|
|
|
@ -646,3 +646,72 @@ FROM foo""",
|
||||||
exp.Column(this=exp.to_identifier("colb")),
|
exp.Column(this=exp.to_identifier("colb")),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_values(self):
|
||||||
|
self.assertEqual(
|
||||||
|
exp.values([(1, 2), (3, 4)], "t", ["a", "b"]).sql(),
|
||||||
|
"(VALUES (1, 2), (3, 4)) AS t(a, b)",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
exp.values(
|
||||||
|
[(1, 2), (3, 4)],
|
||||||
|
"t",
|
||||||
|
{"a": exp.DataType.build("TEXT"), "b": exp.DataType.build("TEXT")},
|
||||||
|
).sql(),
|
||||||
|
"(VALUES (CAST(1 AS TEXT), CAST(2 AS TEXT)), (3, 4)) AS t(a, b)",
|
||||||
|
)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
exp.values([(1, 2), (3, 4)], columns=["a"])
|
||||||
|
|
||||||
|
def test_data_type_builder(self):
|
||||||
|
self.assertEqual(exp.DataType.build("TEXT").sql(), "TEXT")
|
||||||
|
self.assertEqual(exp.DataType.build("DECIMAL(10, 2)").sql(), "DECIMAL(10, 2)")
|
||||||
|
self.assertEqual(exp.DataType.build("VARCHAR(255)").sql(), "VARCHAR(255)")
|
||||||
|
self.assertEqual(exp.DataType.build("ARRAY<INT>").sql(), "ARRAY<INT>")
|
||||||
|
self.assertEqual(exp.DataType.build("CHAR").sql(), "CHAR")
|
||||||
|
self.assertEqual(exp.DataType.build("NCHAR").sql(), "CHAR")
|
||||||
|
self.assertEqual(exp.DataType.build("VARCHAR").sql(), "VARCHAR")
|
||||||
|
self.assertEqual(exp.DataType.build("NVARCHAR").sql(), "VARCHAR")
|
||||||
|
self.assertEqual(exp.DataType.build("TEXT").sql(), "TEXT")
|
||||||
|
self.assertEqual(exp.DataType.build("BINARY").sql(), "BINARY")
|
||||||
|
self.assertEqual(exp.DataType.build("VARBINARY").sql(), "VARBINARY")
|
||||||
|
self.assertEqual(exp.DataType.build("INT").sql(), "INT")
|
||||||
|
self.assertEqual(exp.DataType.build("TINYINT").sql(), "TINYINT")
|
||||||
|
self.assertEqual(exp.DataType.build("SMALLINT").sql(), "SMALLINT")
|
||||||
|
self.assertEqual(exp.DataType.build("BIGINT").sql(), "BIGINT")
|
||||||
|
self.assertEqual(exp.DataType.build("FLOAT").sql(), "FLOAT")
|
||||||
|
self.assertEqual(exp.DataType.build("DOUBLE").sql(), "DOUBLE")
|
||||||
|
self.assertEqual(exp.DataType.build("DECIMAL").sql(), "DECIMAL")
|
||||||
|
self.assertEqual(exp.DataType.build("BOOLEAN").sql(), "BOOLEAN")
|
||||||
|
self.assertEqual(exp.DataType.build("JSON").sql(), "JSON")
|
||||||
|
self.assertEqual(exp.DataType.build("JSONB").sql(), "JSONB")
|
||||||
|
self.assertEqual(exp.DataType.build("INTERVAL").sql(), "INTERVAL")
|
||||||
|
self.assertEqual(exp.DataType.build("TIME").sql(), "TIME")
|
||||||
|
self.assertEqual(exp.DataType.build("TIMESTAMP").sql(), "TIMESTAMP")
|
||||||
|
self.assertEqual(exp.DataType.build("TIMESTAMPTZ").sql(), "TIMESTAMPTZ")
|
||||||
|
self.assertEqual(exp.DataType.build("TIMESTAMPLTZ").sql(), "TIMESTAMPLTZ")
|
||||||
|
self.assertEqual(exp.DataType.build("DATE").sql(), "DATE")
|
||||||
|
self.assertEqual(exp.DataType.build("DATETIME").sql(), "DATETIME")
|
||||||
|
self.assertEqual(exp.DataType.build("ARRAY").sql(), "ARRAY")
|
||||||
|
self.assertEqual(exp.DataType.build("MAP").sql(), "MAP")
|
||||||
|
self.assertEqual(exp.DataType.build("UUID").sql(), "UUID")
|
||||||
|
self.assertEqual(exp.DataType.build("GEOGRAPHY").sql(), "GEOGRAPHY")
|
||||||
|
self.assertEqual(exp.DataType.build("GEOMETRY").sql(), "GEOMETRY")
|
||||||
|
self.assertEqual(exp.DataType.build("STRUCT").sql(), "STRUCT")
|
||||||
|
self.assertEqual(exp.DataType.build("NULLABLE").sql(), "NULLABLE")
|
||||||
|
self.assertEqual(exp.DataType.build("HLLSKETCH").sql(), "HLLSKETCH")
|
||||||
|
self.assertEqual(exp.DataType.build("HSTORE").sql(), "HSTORE")
|
||||||
|
self.assertEqual(exp.DataType.build("SUPER").sql(), "SUPER")
|
||||||
|
self.assertEqual(exp.DataType.build("SERIAL").sql(), "SERIAL")
|
||||||
|
self.assertEqual(exp.DataType.build("SMALLSERIAL").sql(), "SMALLSERIAL")
|
||||||
|
self.assertEqual(exp.DataType.build("BIGSERIAL").sql(), "BIGSERIAL")
|
||||||
|
self.assertEqual(exp.DataType.build("XML").sql(), "XML")
|
||||||
|
self.assertEqual(exp.DataType.build("UNIQUEIDENTIFIER").sql(), "UNIQUEIDENTIFIER")
|
||||||
|
self.assertEqual(exp.DataType.build("MONEY").sql(), "MONEY")
|
||||||
|
self.assertEqual(exp.DataType.build("SMALLMONEY").sql(), "SMALLMONEY")
|
||||||
|
self.assertEqual(exp.DataType.build("ROWVERSION").sql(), "ROWVERSION")
|
||||||
|
self.assertEqual(exp.DataType.build("IMAGE").sql(), "IMAGE")
|
||||||
|
self.assertEqual(exp.DataType.build("VARIANT").sql(), "VARIANT")
|
||||||
|
self.assertEqual(exp.DataType.build("OBJECT").sql(), "OBJECT")
|
||||||
|
self.assertEqual(exp.DataType.build("NULL").sql(), "NULL")
|
||||||
|
self.assertEqual(exp.DataType.build("UNKNOWN").sql(), "UNKNOWN")
|
||||||
|
|
|
@ -299,10 +299,10 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
||||||
|
|
||||||
self.assertEqual(set(scopes[6].sources), {"q", "z", "r", "s"})
|
self.assertEqual(set(scopes[6].sources), {"q", "z", "r", "s"})
|
||||||
self.assertEqual(len(scopes[6].columns), 6)
|
self.assertEqual(len(scopes[6].columns), 6)
|
||||||
self.assertEqual(set(c.table for c in scopes[6].columns), {"r", "s"})
|
self.assertEqual({c.table for c in scopes[6].columns}, {"r", "s"})
|
||||||
self.assertEqual(scopes[6].source_columns("q"), [])
|
self.assertEqual(scopes[6].source_columns("q"), [])
|
||||||
self.assertEqual(len(scopes[6].source_columns("r")), 2)
|
self.assertEqual(len(scopes[6].source_columns("r")), 2)
|
||||||
self.assertEqual(set(c.table for c in scopes[6].source_columns("r")), {"r"})
|
self.assertEqual({c.table for c in scopes[6].source_columns("r")}, {"r"})
|
||||||
|
|
||||||
self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"})
|
self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"})
|
||||||
self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b")
|
self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b")
|
||||||
|
@ -578,3 +578,16 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
||||||
scope_t, scope_y = build_scope(query).cte_scopes
|
scope_t, scope_y = build_scope(query).cte_scopes
|
||||||
self.assertEqual(set(scope_t.cte_sources), {"t"})
|
self.assertEqual(set(scope_t.cte_sources), {"t"})
|
||||||
self.assertEqual(set(scope_y.cte_sources), {"t", "y"})
|
self.assertEqual(set(scope_y.cte_sources), {"t", "y"})
|
||||||
|
|
||||||
|
def test_schema_with_spaces(self):
|
||||||
|
schema = {
|
||||||
|
"a": {
|
||||||
|
"b c": "text",
|
||||||
|
'"d e"': "text",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
optimizer.optimize(parse_one("SELECT * FROM a"), schema=schema),
|
||||||
|
parse_one('SELECT "a"."b c" AS "b c", "a"."d e" AS "d e" FROM "a" AS "a"'),
|
||||||
|
)
|
||||||
|
|
|
@ -8,7 +8,8 @@ from tests.helpers import assert_logger_contains
|
||||||
|
|
||||||
class TestParser(unittest.TestCase):
|
class TestParser(unittest.TestCase):
|
||||||
def test_parse_empty(self):
|
def test_parse_empty(self):
|
||||||
self.assertIsNone(parse_one(""))
|
with self.assertRaises(ParseError) as ctx:
|
||||||
|
parse_one("")
|
||||||
|
|
||||||
def test_parse_into(self):
|
def test_parse_into(self):
|
||||||
self.assertIsInstance(parse_one("left join foo", into=exp.Join), exp.Join)
|
self.assertIsInstance(parse_one("left join foo", into=exp.Join), exp.Join)
|
||||||
|
@ -90,6 +91,9 @@ class TestParser(unittest.TestCase):
|
||||||
parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(),
|
parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(),
|
||||||
"""SELECT * FROM x, z CROSS JOIN y LATERAL VIEW EXPLODE(y)""",
|
"""SELECT * FROM x, z CROSS JOIN y LATERAL VIEW EXPLODE(y)""",
|
||||||
)
|
)
|
||||||
|
self.assertIsNone(
|
||||||
|
parse_one("create table a as (select b from c) index").find(exp.TableAlias)
|
||||||
|
)
|
||||||
|
|
||||||
def test_command(self):
|
def test_command(self):
|
||||||
expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1", read="hive")
|
expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1", read="hive")
|
||||||
|
@ -155,6 +159,11 @@ class TestParser(unittest.TestCase):
|
||||||
assert expressions[0].args["from"].expressions[0].this.name == "a"
|
assert expressions[0].args["from"].expressions[0].this.name == "a"
|
||||||
assert expressions[1].args["from"].expressions[0].this.name == "b"
|
assert expressions[1].args["from"].expressions[0].this.name == "b"
|
||||||
|
|
||||||
|
expressions = parse("SELECT 1; ; SELECT 2")
|
||||||
|
|
||||||
|
assert len(expressions) == 3
|
||||||
|
assert expressions[1] is None
|
||||||
|
|
||||||
def test_expression(self):
|
def test_expression(self):
|
||||||
ignore = Parser(error_level=ErrorLevel.IGNORE)
|
ignore = Parser(error_level=ErrorLevel.IGNORE)
|
||||||
self.assertIsInstance(ignore.expression(exp.Hint, expressions=[""]), exp.Hint)
|
self.assertIsInstance(ignore.expression(exp.Hint, expressions=[""]), exp.Hint)
|
||||||
|
|
|
@ -184,3 +184,19 @@ class TestSchema(unittest.TestCase):
|
||||||
|
|
||||||
schema = MappingSchema({"foo": {"bar": parse_one("INT", into=exp.DataType)}})
|
schema = MappingSchema({"foo": {"bar": parse_one("INT", into=exp.DataType)}})
|
||||||
self.assertEqual(schema.get_column_type("foo", "bar").this, exp.DataType.Type.INT)
|
self.assertEqual(schema.get_column_type("foo", "bar").this, exp.DataType.Type.INT)
|
||||||
|
|
||||||
|
def test_schema_normalization(self):
|
||||||
|
schema = MappingSchema(
|
||||||
|
schema={"x": {"`y`": {"Z": {"a": "INT", "`B`": "VARCHAR"}, "w": {"C": "INT"}}}},
|
||||||
|
dialect="spark",
|
||||||
|
)
|
||||||
|
|
||||||
|
table_z = exp.Table(this="z", db="y", catalog="x")
|
||||||
|
table_w = exp.Table(this="w", db="y", catalog="x")
|
||||||
|
|
||||||
|
self.assertEqual(schema.column_names(table_z), ["a", "B"])
|
||||||
|
self.assertEqual(schema.column_names(table_w), ["c"])
|
||||||
|
|
||||||
|
# Clickhouse supports both `` and "" for identifier quotes; sqlglot uses "" when generating sql
|
||||||
|
schema = MappingSchema(schema={"x": {"`y`": "INT"}}, dialect="clickhouse")
|
||||||
|
self.assertEqual(schema.column_names(exp.Table(this="x")), ["y"])
|
||||||
|
|
33
tests/test_serde.py
Normal file
33
tests/test_serde.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from sqlglot import exp, parse_one
|
||||||
|
from sqlglot.optimizer.annotate_types import annotate_types
|
||||||
|
from tests.helpers import load_sql_fixtures
|
||||||
|
|
||||||
|
|
||||||
|
class CustomExpression(exp.Expression):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class TestSerDe(unittest.TestCase):
|
||||||
|
def dump_load(self, expression):
|
||||||
|
return exp.Expression.load(json.loads(json.dumps(expression.dump())))
|
||||||
|
|
||||||
|
def test_serde(self):
|
||||||
|
for sql in load_sql_fixtures("identity.sql"):
|
||||||
|
with self.subTest(sql):
|
||||||
|
before = parse_one(sql)
|
||||||
|
after = self.dump_load(before)
|
||||||
|
self.assertEqual(before, after)
|
||||||
|
|
||||||
|
def test_custom_expression(self):
|
||||||
|
before = CustomExpression()
|
||||||
|
after = self.dump_load(before)
|
||||||
|
self.assertEqual(before, after)
|
||||||
|
|
||||||
|
def test_type_annotations(self):
|
||||||
|
before = annotate_types(parse_one("CAST('1' AS INT)"))
|
||||||
|
after = self.dump_load(before)
|
||||||
|
self.assertEqual(before.type, after.type)
|
||||||
|
self.assertEqual(before.this.type, after.this.type)
|
|
@ -1,7 +1,11 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from sqlglot import parse_one
|
from sqlglot import parse_one
|
||||||
from sqlglot.transforms import eliminate_distinct_on, unalias_group
|
from sqlglot.transforms import (
|
||||||
|
eliminate_distinct_on,
|
||||||
|
remove_precision_parameterized_types,
|
||||||
|
unalias_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestTime(unittest.TestCase):
|
class TestTime(unittest.TestCase):
|
||||||
|
@ -62,3 +66,10 @@ class TestTime(unittest.TestCase):
|
||||||
"SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC",
|
"SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC",
|
||||||
'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) WHERE "_row_number_2" = 1',
|
'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) WHERE "_row_number_2" = 1',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_remove_precision_parameterized_types(self):
|
||||||
|
self.validate(
|
||||||
|
remove_precision_parameterized_types,
|
||||||
|
"SELECT CAST(1 AS DECIMAL(10, 2)), CAST('13' AS VARCHAR(10))",
|
||||||
|
"SELECT CAST(1 AS DECIMAL), CAST('13' AS VARCHAR)",
|
||||||
|
)
|
||||||
|
|
|
@ -117,6 +117,11 @@ class TestTranspile(unittest.TestCase):
|
||||||
"select x from foo -- x",
|
"select x from foo -- x",
|
||||||
"SELECT x FROM foo /* x */",
|
"SELECT x FROM foo /* x */",
|
||||||
)
|
)
|
||||||
|
self.validate(
|
||||||
|
"""select x, --
|
||||||
|
from foo""",
|
||||||
|
"SELECT x FROM foo",
|
||||||
|
)
|
||||||
self.validate(
|
self.validate(
|
||||||
"""
|
"""
|
||||||
-- comment 1
|
-- comment 1
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue