1
0
Fork 0

Merging upstream version 10.5.2.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:03:38 +01:00
parent 77197f1e44
commit e0f3bbb5f3
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
58 changed files with 1480 additions and 383 deletions

View file

@ -1,6 +1,35 @@
Changelog Changelog
========= =========
v10.5.0
------
Changes:
- Breaking: Added python type hints in the parser module, which may result in some mypy errors.
- New: SQLGlot expressions can [now be serialized / deserialized into JSON](https://github.com/tobymao/sqlglot/commit/bac38151a8d72687247922e6898696be43ff4992).
- New: Added support for T-SQL [hints](https://github.com/tobymao/sqlglot/commit/3220ec1adb1e1130b109677d03c9be947b03f9ca) and [EOMONTH](https://github.com/tobymao/sqlglot/commit/1ac05d9265667c883b9f6db5d825a6d864c95c73).
- New: Added support for Clickhouse's parametric function syntax.
- New: Added [wider support](https://github.com/tobymao/sqlglot/commit/beb660f943b73c730f1b06fce4986e26642ee8dc) for timestr and datestr.
- New: CLI now accepts a flag [for parsing SQL from the standard input stream](https://github.com/tobymao/sqlglot/commit/f89b38ebf3e24ba951ee8b249d73bbf48685928a).
- Improvement: Fixed BigQuery transpilation for [parameterized types and unnest](https://github.com/tobymao/sqlglot/pull/924).
- Improvement: Hive / Spark identifiers can now begin with a digit.
- Improvement: Bug fixes in [date/datetime simplification](https://github.com/tobymao/sqlglot/commit/b26b8d88af14f72d90c0019ec332d268a23b078f).
- Improvement: Bug fixes in [merge_subquery](https://github.com/tobymao/sqlglot/commit/e30e21b6c572d0931bfb5873cc6ac3949c6ef5aa).
- Improvement: Schema identifiers are now [converted to lowercase](https://github.com/tobymao/sqlglot/commit/8212032968a519c199b461eba1a2618e89bf0326) unless they're quoted.
- Improvement: Identifiers with a leading underscore are now regarded as [safe](https://github.com/tobymao/sqlglot/commit/de3b0804bb7606673d0bbb989997c13744957f7c#diff-7857fedd1d1451b1b9a5b8efaa1cc292c02e7ee4f0d04d7e2f9d5bfb9565802c) and hence are not quoted.
v10.4.0 v10.4.0
------ ------

View file

@ -1,12 +1,12 @@
# SQLGlot # SQLGlot
SQLGlot is a no dependency Python SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between different dialects like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects. SQLGlot is a no dependency Python SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between [18 different dialects](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects.
It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks) while being written purely in Python. It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks) while being written purely in Python.
You can easily [customize](#custom-dialects) the parser, [analyze](#metadata) queries, traverse expression trees, and programmatically [build](#build-and-modify-sql) SQL. You can easily [customize](#custom-dialects) the parser, [analyze](#metadata) queries, traverse expression trees, and programmatically [build](#build-and-modify-sql) SQL.
Syntax [errors](#parser-errors) are highlighted and dialect incompatibilities can warn or raise depending on configurations. Syntax [errors](#parser-errors) are highlighted and dialect incompatibilities can warn or raise depending on configurations. However, it should be noted that the parser is very lenient when it comes to detecting errors, because it aims to consume as much SQL as possible. On one hand, this makes its implementation simpler, and thus more comprehensible, but on the other hand it means that syntax errors may sometimes go unnoticed.
Contributions are very welcome in SQLGlot; read the [contribution guide](https://github.com/tobymao/sqlglot/blob/main/CONTRIBUTING.md) to get started! Contributions are very welcome in SQLGlot; read the [contribution guide](https://github.com/tobymao/sqlglot/blob/main/CONTRIBUTING.md) to get started!
@ -25,6 +25,7 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/
* [AST Diff](#ast-diff) * [AST Diff](#ast-diff)
* [Custom Dialects](#custom-dialects) * [Custom Dialects](#custom-dialects)
* [SQL Execution](#sql-execution) * [SQL Execution](#sql-execution)
* [Used By](#used-by)
* [Documentation](#documentation) * [Documentation](#documentation)
* [Run Tests and Lint](#run-tests-and-lint) * [Run Tests and Lint](#run-tests-and-lint)
* [Benchmarks](#benchmarks) * [Benchmarks](#benchmarks)
@ -165,7 +166,7 @@ for table in parse_one("SELECT * FROM x JOIN y JOIN z").find_all(exp.Table):
### Parser Errors ### Parser Errors
A syntax error will result in a parser error: When the parser detects an error in the syntax, it raises a ParserError:
```python ```python
import sqlglot import sqlglot
@ -283,13 +284,13 @@ print(
```sql ```sql
SELECT SELECT
( (
"x"."A" OR "x"."B" OR "x"."C" "x"."a" OR "x"."b" OR "x"."c"
) AND ( ) AND (
"x"."A" OR "x"."B" OR "x"."D" "x"."a" OR "x"."b" OR "x"."d"
) AS "_col_0" ) AS "_col_0"
FROM "x" AS "x" FROM "x" AS "x"
WHERE WHERE
"x"."Z" = CAST('2021-02-01' AS DATE) CAST("x"."z" AS DATE) = CAST('2021-02-01' AS DATE)
``` ```
### AST Introspection ### AST Introspection
@ -432,6 +433,14 @@ user_id price
2 3.0 2 3.0
``` ```
## Used By
* [Fugue](https://github.com/fugue-project/fugue)
* [ibis](https://github.com/ibis-project/ibis)
* [mysql-mimic](https://github.com/kelsin/mysql-mimic)
* [Querybook](https://github.com/pinterest/querybook)
* [Quokka](https://github.com/marsupialtail/quokka)
* [Splink](https://github.com/moj-analytical-services/splink)
## Documentation ## Documentation
SQLGlot uses [pdocs](https://pdoc.dev/) to serve its API documentation: SQLGlot uses [pdocs](https://pdoc.dev/) to serve its API documentation:

View file

@ -23,7 +23,7 @@ SELECT
"e"."phone_number" AS "Phone", "e"."phone_number" AS "Phone",
TO_CHAR("e"."hire_date", 'MM/DD/YYYY') AS "Hire Date", TO_CHAR("e"."hire_date", 'MM/DD/YYYY') AS "Hire Date",
TO_CHAR("e"."salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') AS "Salary", TO_CHAR("e"."salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') AS "Salary",
"e"."commission_pct" AS "Comission %", "e"."commission_pct" AS "Commission %",
'works as ' || "j"."job_title" || ' in ' || "d"."department_name" || ' department (manager: ' || "dm"."first_name" || ' ' || "dm"."last_name" || ') and immediate supervisor: ' || "m"."first_name" || ' ' || "m"."last_name" AS "Current Job", 'works as ' || "j"."job_title" || ' in ' || "d"."department_name" || ' department (manager: ' || "dm"."first_name" || ' ' || "dm"."last_name" || ') and immediate supervisor: ' || "m"."first_name" || ' ' || "m"."last_name" AS "Current Job",
TO_CHAR("j"."min_salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') || ' - ' || TO_CHAR("j"."max_salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') AS "Current Salary", TO_CHAR("j"."min_salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') || ' - ' || TO_CHAR("j"."max_salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') AS "Current Salary",
"l"."street_address" || ', ' || "l"."postal_code" || ', ' || "l"."city" || ', ' || "l"."state_province" || ', ' || "c"."country_name" || ' (' || "r"."region_name" || ')' AS "Location", "l"."street_address" || ', ' || "l"."postal_code" || ', ' || "l"."city" || ', ' || "l"."state_province" || ', ' || "c"."country_name" || ' (' || "r"."region_name" || ')' AS "Location",

View file

@ -14,13 +14,13 @@ This post will cover [why](#why) I went through the effort of creating a Python
* [Executing](#executing) * [Executing](#executing)
## Why? ## Why?
I started working on SQLGlot because of my work on the [experimentation and metrics platform](https://netflixtechblog.com/reimagining-experimentation-analysis-at-netflix-71356393af21) at Netflix, where I built tools that allowed data scientists to define and compute SQL-based metrics. Netflix relied on multiple engines to query data (Spark, Presto, and Druid), so my team built the metrics platform around [PyPika](https://github.com/kayak/pypika), a Python SQL query builder. This way, definitions could be reused across multiple engines. However, it became quickly apparent that writing python code to programatically generate SQL was challenging for data scientists, especially those with academic backgrounds, since they were mostly familiar with R and SQL. At the time, the only Python SQL parser was [sqlparse]([https://github.com/andialbrecht/sqlparse), which is not actually a parser but a tokenizer, so having users write raw SQL into the platform wasn't really an option. Some time later, I randomly stumbled across [Crafting Interpreters](https://craftinginterpreters.com/) and realized that I could use it as a guide towards creating my own SQL parser/transpiler. I started working on SQLGlot because of my work on the [experimentation and metrics platform](https://netflixtechblog.com/reimagining-experimentation-analysis-at-netflix-71356393af21) at Netflix, where I built tools that allowed data scientists to define and compute SQL-based metrics. Netflix relied on multiple engines to query data (Spark, Presto, and Druid), so my team built the metrics platform around [PyPika](https://github.com/kayak/pypika), a Python SQL query builder. This way, definitions could be reused across multiple engines. However, it became quickly apparent that writing python code to programmatically generate SQL was challenging for data scientists, especially those with academic backgrounds, since they were mostly familiar with R and SQL. At the time, the only Python SQL parser was [sqlparse]([https://github.com/andialbrecht/sqlparse), which is not actually a parser but a tokenizer, so having users write raw SQL into the platform wasn't really an option. Some time later, I randomly stumbled across [Crafting Interpreters](https://craftinginterpreters.com/) and realized that I could use it as a guide towards creating my own SQL parser/transpiler.
Why did I do this? Isn't a Python SQL engine going to be extremely slow? Why did I do this? Isn't a Python SQL engine going to be extremely slow?
The main reason why I ended up building a SQL engine was...just for **entertainment**. It's been fun learning about all the things required to actually run a SQL query, and seeing it actually work is extremely rewarding. Before SQLGlot, I had zero experience with lexers, parsers, or compilers. The main reason why I ended up building a SQL engine was...just for **entertainment**. It's been fun learning about all the things required to actually run a SQL query, and seeing it actually work is extremely rewarding. Before SQLGlot, I had zero experience with lexers, parsers, or compilers.
In terms of practical use cases, I planned to use the Python SQL engine for unit testing SQL pipelines. Big data pipelines are tough to test because many of the engines are not open source and cannot be run locally. With SQLGlot, you can take a SQL query targeting a warehouse such as [Snowflake](https://www.snowflake.com/en/) and seamlessly run it in CI on mock Python data. It's easy to mock data and create arbitrary [UDFs](https://en.wikipedia.org/wiki/User-defined_function) because everything is just Python. Although the implementation is slow and unsuitable for large amounts of data (> 1 millon rows), there's very little overhead/startup and you can run queries on test data in a couple of milliseconds. In terms of practical use cases, I planned to use the Python SQL engine for unit testing SQL pipelines. Big data pipelines are tough to test because many of the engines are not open source and cannot be run locally. With SQLGlot, you can take a SQL query targeting a warehouse such as [Snowflake](https://www.snowflake.com/en/) and seamlessly run it in CI on mock Python data. It's easy to mock data and create arbitrary [UDFs](https://en.wikipedia.org/wiki/User-defined_function) because everything is just Python. Although the implementation is slow and unsuitable for large amounts of data (> 1 million rows), there's very little overhead/startup and you can run queries on test data in a couple of milliseconds.
Finally, the components that have been built to support execution can be used as a **foundation** for a faster engine. I'm inspired by what [Apache Calcite](https://github.com/apache/calcite) has done for the JVM world. Even though Python is commonly used for data, there hasn't been a Calcite for Python. So, you could say that SQLGlot aims to be that framework. For example, it wouldn't take much work to replace the Python execution engine with numpy/pandas/arrow to become a respectably-performing query engine. The implementation would be able to leverage the parser, optimizer, and logical planner, only needing to implement physical execution. There is a lot of work in the Python ecosystem around high performance vectorized computation, which I think could benefit from a pure Python-based [AST](https://en.wikipedia.org/wiki/Abstract_syntax_tree)/[plan](https://en.wikipedia.org/wiki/Query_plan). Parsing and planning doesn't have to be fast when the bottleneck of running queries is processing terabytes of data. So, having a Python-based ecosystem around SQL is beneficial given the ease of development in Python, despite not having bare metal performance. Finally, the components that have been built to support execution can be used as a **foundation** for a faster engine. I'm inspired by what [Apache Calcite](https://github.com/apache/calcite) has done for the JVM world. Even though Python is commonly used for data, there hasn't been a Calcite for Python. So, you could say that SQLGlot aims to be that framework. For example, it wouldn't take much work to replace the Python execution engine with numpy/pandas/arrow to become a respectably-performing query engine. The implementation would be able to leverage the parser, optimizer, and logical planner, only needing to implement physical execution. There is a lot of work in the Python ecosystem around high performance vectorized computation, which I think could benefit from a pure Python-based [AST](https://en.wikipedia.org/wiki/Abstract_syntax_tree)/[plan](https://en.wikipedia.org/wiki/Query_plan). Parsing and planning doesn't have to be fast when the bottleneck of running queries is processing terabytes of data. So, having a Python-based ecosystem around SQL is beneficial given the ease of development in Python, despite not having bare metal performance.
@ -77,7 +77,7 @@ Once we have our AST, we can transform it into an equivalent query that produces
1. It's easier to debug and [validate](https://github.com/tobymao/sqlglot/blob/main/tests/fixtures/optimizer) the optimizations when the input and output are both SQL. 1. It's easier to debug and [validate](https://github.com/tobymao/sqlglot/blob/main/tests/fixtures/optimizer) the optimizations when the input and output are both SQL.
2. Rules can be applied a la carte to transform SQL into a more desireable form. 2. Rules can be applied a la carte to transform SQL into a more desirable form.
3. I wanted a way to generate 'canonical sql'. Having a canonical representation of SQL is useful for understanding if two queries are semantically equivalent (e.g. `SELECT 1 + 1` and `SELECT 2`). 3. I wanted a way to generate 'canonical sql'. Having a canonical representation of SQL is useful for understanding if two queries are semantically equivalent (e.g. `SELECT 1 + 1` and `SELECT 2`).

View file

@ -28,7 +28,7 @@ setup(
"black", "black",
"duckdb", "duckdb",
"isort", "isort",
"mypy", "mypy>=0.990",
"pandas", "pandas",
"pyspark", "pyspark",
"python-dateutil", "python-dateutil",

View file

@ -32,7 +32,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType from sqlglot.tokens import Tokenizer, TokenType
__version__ = "10.4.2" __version__ = "10.5.2"
pretty = False pretty = False
@ -60,9 +60,9 @@ def parse(
def parse_one( def parse_one(
sql: str, sql: str,
read: t.Optional[str | Dialect] = None, read: t.Optional[str | Dialect] = None,
into: t.Optional[Expression | str] = None, into: t.Optional[t.Type[Expression] | str] = None,
**opts, **opts,
) -> t.Optional[Expression]: ) -> Expression:
""" """
Parses the given SQL string and returns a syntax tree for the first parsed SQL statement. Parses the given SQL string and returns a syntax tree for the first parsed SQL statement.
@ -83,7 +83,12 @@ def parse_one(
else: else:
result = dialect.parse(sql, **opts) result = dialect.parse(sql, **opts)
return result[0] if result else None for expression in result:
if not expression:
raise ParseError(f"No expression was parsed from '{sql}'")
return expression
else:
raise ParseError(f"No expression was parsed from '{sql}'")
def transpile( def transpile(

View file

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from sqlglot import exp, generator, parser, tokens from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import ( from sqlglot.dialects.dialect import (
Dialect, Dialect,
datestrtodate_sql, datestrtodate_sql,
@ -46,8 +46,9 @@ def _date_add_sql(data_type, kind):
def _derived_table_values_to_unnest(self, expression): def _derived_table_values_to_unnest(self, expression):
if not isinstance(expression.unnest().parent, exp.From): if not isinstance(expression.unnest().parent, exp.From):
expression = transforms.remove_precision_parameterized_types(expression)
return self.values_sql(expression) return self.values_sql(expression)
rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.find_all(exp.Tuple)] rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)]
structs = [] structs = []
for row in rows: for row in rows:
aliases = [ aliases = [
@ -118,6 +119,7 @@ class BigQuery(Dialect):
"BEGIN TRANSACTION": TokenType.BEGIN, "BEGIN TRANSACTION": TokenType.BEGIN,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"CURRENT_TIME": TokenType.CURRENT_TIME, "CURRENT_TIME": TokenType.CURRENT_TIME,
"DECLARE": TokenType.COMMAND,
"GEOGRAPHY": TokenType.GEOGRAPHY, "GEOGRAPHY": TokenType.GEOGRAPHY,
"FLOAT64": TokenType.DOUBLE, "FLOAT64": TokenType.DOUBLE,
"INT64": TokenType.BIGINT, "INT64": TokenType.BIGINT,
@ -166,6 +168,7 @@ class BigQuery(Dialect):
class Generator(generator.Generator): class Generator(generator.Generator):
TRANSFORMS = { TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore **generator.Generator.TRANSFORMS, # type: ignore
**transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore
exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.DateAdd: _date_add_sql("DATE", "ADD"), exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"), exp.DateSub: _date_add_sql("DATE", "SUB"),

View file

@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
from sqlglot.parser import parse_var_map from sqlglot.parser import parse_var_map
@ -22,6 +24,7 @@ class ClickHouse(Dialect):
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"ASOF": TokenType.ASOF, "ASOF": TokenType.ASOF,
"GLOBAL": TokenType.GLOBAL,
"DATETIME64": TokenType.DATETIME, "DATETIME64": TokenType.DATETIME,
"FINAL": TokenType.FINAL, "FINAL": TokenType.FINAL,
"FLOAT32": TokenType.FLOAT, "FLOAT32": TokenType.FLOAT,
@ -37,14 +40,32 @@ class ClickHouse(Dialect):
FUNCTIONS = { FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore **parser.Parser.FUNCTIONS, # type: ignore
"MAP": parse_var_map, "MAP": parse_var_map,
"QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params),
"QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args),
"QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args),
}
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
TokenType.GLOBAL: lambda self, this: self._match(TokenType.IN)
and self._parse_in(this, is_global=True),
} }
JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
def _parse_table(self, schema=False): def _parse_in(
this = super()._parse_table(schema) self, this: t.Optional[exp.Expression], is_global: bool = False
) -> exp.Expression:
this = super()._parse_in(this)
this.set("is_global", is_global)
return this
def _parse_table(
self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None
) -> t.Optional[exp.Expression]:
this = super()._parse_table(schema=schema, alias_tokens=alias_tokens)
if self._match(TokenType.FINAL): if self._match(TokenType.FINAL):
this = self.expression(exp.Final, this=this) this = self.expression(exp.Final, this=this)
@ -76,6 +97,16 @@ class ClickHouse(Dialect):
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)), exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}",
exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}",
exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}",
} }
EXPLICIT_UNION = True EXPLICIT_UNION = True
def _param_args_sql(
self, expression: exp.Expression, params_name: str, args_name: str
) -> str:
params = self.format_args(self.expressions(expression, params_name))
args = self.format_args(self.expressions(expression, args_name))
return f"({params})({args})"

View file

@ -381,3 +381,20 @@ def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str:
def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str: def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str:
return f"CAST({self.sql(expression, 'this')} AS DATE)" return f"CAST({self.sql(expression, 'this')} AS DATE)"
def trim_sql(self, expression):
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
remove_chars = self.sql(expression, "expression")
collation = self.sql(expression, "collation")
# Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
if not remove_chars and not collation:
return self.trim_sql(expression)
trim_type = f"{trim_type} " if trim_type else ""
remove_chars = f"{remove_chars} " if remove_chars else ""
from_part = "FROM " if trim_type or remove_chars else ""
collation = f" COLLATE {collation}" if collation else ""
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"

View file

@ -175,14 +175,6 @@ class Hive(Dialect):
ESCAPES = ["\\"] ESCAPES = ["\\"]
ENCODE = "utf-8" ENCODE = "utf-8"
NUMERIC_LITERALS = {
"L": "BIGINT",
"S": "SMALLINT",
"Y": "TINYINT",
"D": "DOUBLE",
"F": "FLOAT",
"BD": "DECIMAL",
}
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"ADD ARCHIVE": TokenType.COMMAND, "ADD ARCHIVE": TokenType.COMMAND,
@ -191,9 +183,21 @@ class Hive(Dialect):
"ADD FILES": TokenType.COMMAND, "ADD FILES": TokenType.COMMAND,
"ADD JAR": TokenType.COMMAND, "ADD JAR": TokenType.COMMAND,
"ADD JARS": TokenType.COMMAND, "ADD JARS": TokenType.COMMAND,
"MSCK REPAIR": TokenType.COMMAND,
"WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES, "WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
} }
NUMERIC_LITERALS = {
"L": "BIGINT",
"S": "SMALLINT",
"Y": "TINYINT",
"D": "DOUBLE",
"F": "FLOAT",
"BD": "DECIMAL",
}
IDENTIFIER_CAN_START_WITH_DIGIT = True
class Parser(parser.Parser): class Parser(parser.Parser):
STRICT_CAST = False STRICT_CAST = False
@ -315,6 +319,7 @@ class Hive(Dialect):
exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}", exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}",
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"), exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"), exp.NumberToStr: rename_func("FORMAT_NUMBER"),
exp.LastDateOfMonth: rename_func("LAST_DAY"),
} }
WITH_PROPERTIES = {exp.Property} WITH_PROPERTIES = {exp.Property}
@ -342,4 +347,6 @@ class Hive(Dialect):
and not expression.expressions and not expression.expressions
): ):
expression = exp.DataType.build("text") expression = exp.DataType.build("text")
elif expression.this in exp.DataType.TEMPORAL_TYPES:
expression = exp.DataType.build(expression.this)
return super().datatype_sql(expression) return super().datatype_sql(expression)

View file

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from sqlglot import exp, generator, parser, tokens, transforms from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql
from sqlglot.helper import csv from sqlglot.helper import csv
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
@ -64,6 +64,7 @@ class Oracle(Dialect):
**transforms.UNALIAS_GROUP, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore
exp.ILike: no_ilike_sql, exp.ILike: no_ilike_sql,
exp.Limit: _limit_sql, exp.Limit: _limit_sql,
exp.Trim: trim_sql,
exp.Matches: rename_func("DECODE"), exp.Matches: rename_func("DECODE"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",

View file

@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
no_tablesample_sql, no_tablesample_sql,
no_trycast_sql, no_trycast_sql,
str_position_sql, str_position_sql,
trim_sql,
) )
from sqlglot.helper import seq_get from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
@ -81,23 +82,6 @@ def _substring_sql(self, expression):
return f"SUBSTRING({this}{from_part}{for_part})" return f"SUBSTRING({this}{from_part}{for_part})"
def _trim_sql(self, expression):
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
remove_chars = self.sql(expression, "expression")
collation = self.sql(expression, "collation")
# Use TRIM/LTRIM/RTRIM syntax if the expression isn't postgres-specific
if not remove_chars and not collation:
return self.trim_sql(expression)
trim_type = f"{trim_type} " if trim_type else ""
remove_chars = f"{remove_chars} " if remove_chars else ""
from_part = "FROM " if trim_type or remove_chars else ""
collation = f" COLLATE {collation}" if collation else ""
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def _string_agg_sql(self, expression): def _string_agg_sql(self, expression):
expression = expression.copy() expression = expression.copy()
separator = expression.args.get("separator") or exp.Literal.string(",") separator = expression.args.get("separator") or exp.Literal.string(",")
@ -248,7 +232,6 @@ class Postgres(Dialect):
"COMMENT ON": TokenType.COMMAND, "COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND, "DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND, "DO": TokenType.COMMAND,
"DOUBLE PRECISION": TokenType.DOUBLE,
"GENERATED": TokenType.GENERATED, "GENERATED": TokenType.GENERATED,
"GRANT": TokenType.COMMAND, "GRANT": TokenType.COMMAND,
"HSTORE": TokenType.HSTORE, "HSTORE": TokenType.HSTORE,
@ -318,7 +301,7 @@ class Postgres(Dialect):
exp.Substring: _substring_sql, exp.Substring: _substring_sql,
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TableSample: no_tablesample_sql, exp.TableSample: no_tablesample_sql,
exp.Trim: _trim_sql, exp.Trim: trim_sql,
exp.TryCast: no_trycast_sql, exp.TryCast: no_trycast_sql,
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
exp.DataType: _datatype_sql, exp.DataType: _datatype_sql,

View file

@ -195,7 +195,6 @@ class Snowflake(Dialect):
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"QUALIFY": TokenType.QUALIFY, "QUALIFY": TokenType.QUALIFY,
"DOUBLE PRECISION": TokenType.DOUBLE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
@ -294,3 +293,10 @@ class Snowflake(Dialect):
) )
return self.no_identify(lambda: super(self.__class__, self).select_sql(expression)) return self.no_identify(lambda: super(self.__class__, self).select_sql(expression))
return super().select_sql(expression) return super().select_sql(expression)
def describe_sql(self, expression: exp.Describe) -> str:
# Default to table if kind is unknown
kind_value = expression.args.get("kind") or "TABLE"
kind = f" {kind_value}" if kind_value else ""
this = f" {self.sql(expression, 'this')}"
return f"DESCRIBE{kind}{this}"

View file

@ -75,6 +75,20 @@ def _parse_format(args):
) )
def _parse_eomonth(args):
date = seq_get(args, 0)
month_lag = seq_get(args, 1)
unit = DATE_DELTA_INTERVAL.get("month")
if month_lag is None:
return exp.LastDateOfMonth(this=date)
# Remove month lag argument in parser as its compared with the number of arguments of the resulting class
args.remove(month_lag)
return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))
def generate_date_delta_with_unit_sql(self, e): def generate_date_delta_with_unit_sql(self, e):
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF" func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})" return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
@ -256,12 +270,14 @@ class TSQL(Dialect):
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True), "DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
"DATEPART": _format_time_lambda(exp.TimeToStr), "DATEPART": _format_time_lambda(exp.TimeToStr),
"GETDATE": exp.CurrentDate.from_arg_list, "GETDATE": exp.CurrentTimestamp.from_arg_list,
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
"IIF": exp.If.from_arg_list, "IIF": exp.If.from_arg_list,
"LEN": exp.Length.from_arg_list, "LEN": exp.Length.from_arg_list,
"REPLICATE": exp.Repeat.from_arg_list, "REPLICATE": exp.Repeat.from_arg_list,
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list, "JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
"FORMAT": _parse_format, "FORMAT": _parse_format,
"EOMONTH": _parse_eomonth,
} }
VAR_LENGTH_DATATYPES = { VAR_LENGTH_DATATYPES = {
@ -271,6 +287,9 @@ class TSQL(Dialect):
DataType.Type.NCHAR, DataType.Type.NCHAR,
} }
# https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table
TABLE_PREFIX_TOKENS = {TokenType.HASH}
def _parse_convert(self, strict): def _parse_convert(self, strict):
to = self._parse_types() to = self._parse_types()
self._match(TokenType.COMMA) self._match(TokenType.COMMA)
@ -323,6 +342,7 @@ class TSQL(Dialect):
exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"), exp.CurrentDate: rename_func("GETDATE"),
exp.CurrentTimestamp: rename_func("GETDATE"),
exp.If: rename_func("IIF"), exp.If: rename_func("IIF"),
exp.NumberToStr: _format_sql, exp.NumberToStr: _format_sql,
exp.TimeToStr: _format_sql, exp.TimeToStr: _format_sql,

View file

@ -22,6 +22,7 @@ from sqlglot.helper import (
split_num_words, split_num_words,
subclasses, subclasses,
) )
from sqlglot.tokens import Token
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import Dialect from sqlglot.dialects.dialect import Dialect
@ -457,6 +458,23 @@ class Expression(metaclass=_Expression):
assert isinstance(self, type_) assert isinstance(self, type_)
return self return self
def dump(self):
"""
Dump this Expression to a JSON-serializable dict.
"""
from sqlglot.serde import dump
return dump(self)
@classmethod
def load(cls, obj):
"""
Load a dict (as returned by `Expression.dump`) into an Expression instance.
"""
from sqlglot.serde import load
return load(obj)
class Condition(Expression): class Condition(Expression):
def and_(self, *expressions, dialect=None, **opts): def and_(self, *expressions, dialect=None, **opts):
@ -631,11 +649,15 @@ class Create(Expression):
"replace": False, "replace": False,
"unique": False, "unique": False,
"materialized": False, "materialized": False,
"data": False,
"statistics": False,
"no_primary_index": False,
"indexes": False,
} }
class Describe(Expression): class Describe(Expression):
pass arg_types = {"this": True, "kind": False}
class Set(Expression): class Set(Expression):
@ -731,7 +753,7 @@ class Column(Condition):
class ColumnDef(Expression): class ColumnDef(Expression):
arg_types = { arg_types = {
"this": True, "this": True,
"kind": True, "kind": False,
"constraints": False, "constraints": False,
"exists": False, "exists": False,
} }
@ -879,7 +901,15 @@ class Identifier(Expression):
class Index(Expression): class Index(Expression):
arg_types = {"this": False, "table": False, "where": False, "columns": False} arg_types = {
"this": False,
"table": False,
"where": False,
"columns": False,
"unique": False,
"primary": False,
"amp": False, # teradata
}
class Insert(Expression): class Insert(Expression):
@ -1361,6 +1391,7 @@ class Table(Expression):
"laterals": False, "laterals": False,
"joins": False, "joins": False,
"pivots": False, "pivots": False,
"hints": False,
} }
@ -1818,7 +1849,12 @@ class Select(Subqueryable):
join.this.replace(join.this.subquery()) join.this.replace(join.this.subquery())
if join_type: if join_type:
natural: t.Optional[Token]
side: t.Optional[Token]
kind: t.Optional[Token]
natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore
if natural: if natural:
join.set("natural", True) join.set("natural", True)
if side: if side:
@ -2111,6 +2147,7 @@ class DataType(Expression):
JSON = auto() JSON = auto()
JSONB = auto() JSONB = auto()
INTERVAL = auto() INTERVAL = auto()
TIME = auto()
TIMESTAMP = auto() TIMESTAMP = auto()
TIMESTAMPTZ = auto() TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto() TIMESTAMPLTZ = auto()
@ -2171,11 +2208,24 @@ class DataType(Expression):
} }
@classmethod @classmethod
def build(cls, dtype, **kwargs) -> DataType: def build(
return DataType( cls, dtype: str | DataType.Type, dialect: t.Optional[str | Dialect] = None, **kwargs
this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()], ) -> DataType:
**kwargs, from sqlglot import parse_one
)
if isinstance(dtype, str):
data_type_exp: t.Optional[Expression]
if dtype.upper() in cls.Type.__members__:
data_type_exp = DataType(this=DataType.Type[dtype.upper()])
else:
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
if data_type_exp is None:
raise ValueError(f"Unparsable data type value: {dtype}")
elif isinstance(dtype, DataType.Type):
data_type_exp = DataType(this=dtype)
else:
raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
return DataType(**{**data_type_exp.args, **kwargs})
# https://www.postgresql.org/docs/15/datatype-pseudo.html # https://www.postgresql.org/docs/15/datatype-pseudo.html
@ -2429,6 +2479,7 @@ class In(Predicate):
"query": False, "query": False,
"unnest": False, "unnest": False,
"field": False, "field": False,
"is_global": False,
} }
@ -2678,6 +2729,10 @@ class DatetimeTrunc(Func, TimeUnit):
arg_types = {"this": True, "unit": True, "zone": False} arg_types = {"this": True, "unit": True, "zone": False}
class LastDateOfMonth(Func):
pass
class Extract(Func): class Extract(Func):
arg_types = {"this": True, "expression": True} arg_types = {"this": True, "expression": True}
@ -2815,7 +2870,13 @@ class Length(Func):
class Levenshtein(Func): class Levenshtein(Func):
arg_types = {"this": True, "expression": False} arg_types = {
"this": True,
"expression": False,
"ins_cost": False,
"del_cost": False,
"sub_cost": False,
}
class Ln(Func): class Ln(Func):
@ -2890,6 +2951,16 @@ class Quantile(AggFunc):
arg_types = {"this": True, "quantile": True} arg_types = {"this": True, "quantile": True}
# Clickhouse-specific:
# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/quantiles/#quantiles
class Quantiles(AggFunc):
arg_types = {"parameters": True, "expressions": True}
class QuantileIf(AggFunc):
arg_types = {"parameters": True, "expressions": True}
class ApproxQuantile(Quantile): class ApproxQuantile(Quantile):
arg_types = {"this": True, "quantile": True, "accuracy": False} arg_types = {"this": True, "quantile": True, "accuracy": False}
@ -2962,8 +3033,10 @@ class StrToTime(Func):
arg_types = {"this": True, "format": True} arg_types = {"this": True, "format": True}
# Spark allows unix_timestamp()
# https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.unix_timestamp.html
class StrToUnix(Func): class StrToUnix(Func):
arg_types = {"this": True, "format": True} arg_types = {"this": False, "format": False}
class NumberToStr(Func): class NumberToStr(Func):
@ -3131,7 +3204,7 @@ def maybe_parse(
dialect=None, dialect=None,
prefix=None, prefix=None,
**opts, **opts,
) -> t.Optional[Expression]: ) -> Expression:
"""Gracefully handle a possible string or expression. """Gracefully handle a possible string or expression.
Example: Example:
@ -3627,11 +3700,11 @@ def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
if not isinstance(sql_path, str): if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for a table: {type(sql_path)}") raise ValueError(f"Invalid type provided for a table: {type(sql_path)}")
catalog, db, table_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 3)] catalog, db, table_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 3))
return Table(this=table_name, db=db, catalog=catalog, **kwargs) return Table(this=table_name, db=db, catalog=catalog, **kwargs)
def to_column(sql_path: str, **kwargs) -> Column: def to_column(sql_path: str | Column, **kwargs) -> Column:
""" """
Create a column from a `[table].[column]` sql path. Schema is optional. Create a column from a `[table].[column]` sql path. Schema is optional.
@ -3646,7 +3719,7 @@ def to_column(sql_path: str, **kwargs) -> Column:
return sql_path return sql_path
if not isinstance(sql_path, str): if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for column: {type(sql_path)}") raise ValueError(f"Invalid type provided for column: {type(sql_path)}")
table_name, column_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 2)] table_name, column_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 2))
return Column(this=column_name, table=table_name, **kwargs) return Column(this=column_name, table=table_name, **kwargs)
@ -3748,7 +3821,7 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
def values( def values(
values: t.Iterable[t.Tuple[t.Any, ...]], values: t.Iterable[t.Tuple[t.Any, ...]],
alias: t.Optional[str] = None, alias: t.Optional[str] = None,
columns: t.Optional[t.Iterable[str]] = None, columns: t.Optional[t.Iterable[str] | t.Dict[str, DataType]] = None,
) -> Values: ) -> Values:
"""Build VALUES statement. """Build VALUES statement.
@ -3759,7 +3832,10 @@ def values(
Args: Args:
values: values statements that will be converted to SQL values: values statements that will be converted to SQL
alias: optional alias alias: optional alias
columns: Optional list of ordered column names. An alias is required when providing column names. columns: Optional list of ordered column names or ordered dictionary of column names to types.
If either are provided then an alias is also required.
If a dictionary is provided then the first column of the values will be casted to the expected type
in order to help with type inference.
Returns: Returns:
Values: the Values expression object Values: the Values expression object
@ -3771,8 +3847,15 @@ def values(
if columns if columns
else TableAlias(this=to_identifier(alias) if alias else None) else TableAlias(this=to_identifier(alias) if alias else None)
) )
expressions = [convert(tup) for tup in values]
if columns and isinstance(columns, dict):
types = list(columns.values())
expressions[0].set(
"expressions",
[Cast(this=x, to=types[i]) for i, x in enumerate(expressions[0].expressions)],
)
return Values( return Values(
expressions=[convert(tup) for tup in values], expressions=expressions,
alias=table_alias, alias=table_alias,
) )

View file

@ -50,7 +50,7 @@ class Generator:
The default is on the smaller end because the length only represents a segment and not the true The default is on the smaller end because the length only represents a segment and not the true
line length. line length.
Default: 80 Default: 80
comments: Whether or not to preserve comments in the ouput SQL code. comments: Whether or not to preserve comments in the output SQL code.
Default: True Default: True
""" """
@ -236,7 +236,10 @@ class Generator:
return sql return sql
sep = "\n" if self.pretty else " " sep = "\n" if self.pretty else " "
comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments) comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments if comment)
if not comments:
return sql
if isinstance(expression, self.WITH_SEPARATED_COMMENTS): if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
return f"{comments}{self.sep()}{sql}" return f"{comments}{self.sep()}{sql}"
@ -362,10 +365,10 @@ class Generator:
kind = self.sql(expression, "kind") kind = self.sql(expression, "kind")
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True) constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
exists = "IF NOT EXISTS " if expression.args.get("exists") else "" exists = "IF NOT EXISTS " if expression.args.get("exists") else ""
kind = f" {kind}" if kind else ""
constraints = f" {constraints}" if constraints else ""
if not constraints: return f"{exists}{column}{kind}{constraints}"
return f"{exists}{column} {kind}"
return f"{exists}{column} {kind} {constraints}"
def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str: def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
this = self.sql(expression, "this") this = self.sql(expression, "this")
@ -416,7 +419,7 @@ class Generator:
this = self.sql(expression, "this") this = self.sql(expression, "this")
kind = self.sql(expression, "kind").upper() kind = self.sql(expression, "kind").upper()
expression_sql = self.sql(expression, "expression") expression_sql = self.sql(expression, "expression")
expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else "" expression_sql = f" AS{self.sep()}{expression_sql}" if expression_sql else ""
temporary = " TEMPORARY" if expression.args.get("temporary") else "" temporary = " TEMPORARY" if expression.args.get("temporary") else ""
transient = ( transient = (
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
@ -427,6 +430,40 @@ class Generator:
unique = " UNIQUE" if expression.args.get("unique") else "" unique = " UNIQUE" if expression.args.get("unique") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else "" materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
properties = self.sql(expression, "properties") properties = self.sql(expression, "properties")
data = expression.args.get("data")
if data is None:
data = ""
elif data:
data = " WITH DATA"
else:
data = " WITH NO DATA"
statistics = expression.args.get("statistics")
if statistics is None:
statistics = ""
elif statistics:
statistics = " AND STATISTICS"
else:
statistics = " AND NO STATISTICS"
no_primary_index = " NO PRIMARY INDEX" if expression.args.get("no_primary_index") else ""
indexes = expression.args.get("indexes")
index_sql = ""
if indexes is not None:
indexes_sql = []
for index in indexes:
ind_unique = " UNIQUE" if index.args.get("unique") else ""
ind_primary = " PRIMARY" if index.args.get("primary") else ""
ind_amp = " AMP" if index.args.get("amp") else ""
ind_name = f" {index.name}" if index.name else ""
ind_columns = (
f' ({self.expressions(index, key="columns", flat=True)})'
if index.args.get("columns")
else ""
)
indexes_sql.append(
f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}"
)
index_sql = "".join(indexes_sql)
modifiers = "".join( modifiers = "".join(
( (
@ -438,7 +475,10 @@ class Generator:
materialized, materialized,
) )
) )
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties} {expression_sql}"
post_expression_modifiers = "".join((data, statistics, no_primary_index))
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}"
return self.prepend_ctes(expression, expression_sql) return self.prepend_ctes(expression, expression_sql)
def describe_sql(self, expression: exp.Describe) -> str: def describe_sql(self, expression: exp.Describe) -> str:
@ -668,6 +708,8 @@ class Generator:
alias = self.sql(expression, "alias") alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else "" alias = f"{sep}{alias}" if alias else ""
hints = self.expressions(expression, key="hints", sep=", ", flat=True)
hints = f" WITH ({hints})" if hints else ""
laterals = self.expressions(expression, key="laterals", sep="") laterals = self.expressions(expression, key="laterals", sep="")
joins = self.expressions(expression, key="joins", sep="") joins = self.expressions(expression, key="joins", sep="")
pivots = self.expressions(expression, key="pivots", sep="") pivots = self.expressions(expression, key="pivots", sep="")
@ -676,7 +718,7 @@ class Generator:
pivots = f"{pivots}{alias}" pivots = f"{pivots}{alias}"
alias = "" alias = ""
return f"{table}{alias}{laterals}{joins}{pivots}" return f"{table}{alias}{hints}{laterals}{joins}{pivots}"
def tablesample_sql(self, expression: exp.TableSample) -> str: def tablesample_sql(self, expression: exp.TableSample) -> str:
if self.alias_post_tablesample and expression.this.alias: if self.alias_post_tablesample and expression.this.alias:
@ -1020,7 +1062,9 @@ class Generator:
if not partition and not order and not spec and alias: if not partition and not order and not spec and alias:
return f"{this} {alias}" return f"{this} {alias}"
return f"{this} ({alias}{partition_sql}{order_sql}{spec_sql})" window_args = alias + partition_sql + order_sql + spec_sql
return f"{this} ({window_args.strip()})"
def window_spec_sql(self, expression: exp.WindowSpec) -> str: def window_spec_sql(self, expression: exp.WindowSpec) -> str:
kind = self.sql(expression, "kind") kind = self.sql(expression, "kind")
@ -1130,6 +1174,8 @@ class Generator:
query = expression.args.get("query") query = expression.args.get("query")
unnest = expression.args.get("unnest") unnest = expression.args.get("unnest")
field = expression.args.get("field") field = expression.args.get("field")
is_global = " GLOBAL" if expression.args.get("is_global") else ""
if query: if query:
in_sql = self.wrap(query) in_sql = self.wrap(query)
elif unnest: elif unnest:
@ -1138,7 +1184,8 @@ class Generator:
in_sql = self.sql(field) in_sql = self.sql(field)
else: else:
in_sql = f"({self.expressions(expression, flat=True)})" in_sql = f"({self.expressions(expression, flat=True)})"
return f"{self.sql(expression, 'this')} IN {in_sql}"
return f"{self.sql(expression, 'this')}{is_global} IN {in_sql}"
def in_unnest_op(self, unnest: exp.Unnest) -> str: def in_unnest_op(self, unnest: exp.Unnest) -> str:
return f"(SELECT {self.sql(unnest)})" return f"(SELECT {self.sql(unnest)})"
@ -1433,7 +1480,7 @@ class Generator:
result_sqls = [] result_sqls = []
for i, e in enumerate(expressions): for i, e in enumerate(expressions):
sql = self.sql(e, comment=False) sql = self.sql(e, comment=False)
comments = self.maybe_comment("", e) comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else ""
if self.pretty: if self.pretty:
if self._leading_comma: if self._leading_comma:

View file

@ -131,7 +131,7 @@ def subclasses(
] ]
def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]: def apply_index_offset(expressions: t.List[t.Optional[E]], offset: int) -> t.List[t.Optional[E]]:
""" """
Applies an offset to a given integer literal expression. Applies an offset to a given integer literal expression.
@ -148,10 +148,10 @@ def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
expression = expressions[0] expression = expressions[0]
if expression.is_int: if expression and expression.is_int:
expression = expression.copy() expression = expression.copy()
logger.warning("Applying array index offset (%s)", offset) logger.warning("Applying array index offset (%s)", offset)
expression.args["this"] = str(int(expression.this) + offset) expression.args["this"] = str(int(expression.this) + offset) # type: ignore
return [expression] return [expression]
return expressions return expressions
@ -225,7 +225,7 @@ def open_file(file_name: str) -> t.TextIO:
return gzip.open(file_name, "rt", newline="") return gzip.open(file_name, "rt", newline="")
return open(file_name, "rt", encoding="utf-8", newline="") return open(file_name, encoding="utf-8", newline="")
@contextmanager @contextmanager
@ -256,7 +256,7 @@ def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
file.close() file.close()
def find_new_name(taken: t.Sequence[str], base: str) -> str: def find_new_name(taken: t.Collection[str], base: str) -> str:
""" """
Searches for a new name. Searches for a new name.
@ -356,6 +356,15 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any,
yield value yield value
def count_params(function: t.Callable) -> int:
"""
Returns the number of formal parameters expected by a function, without counting "self"
and "cls", in case of instance and class methods, respectively.
"""
count = function.__code__.co_argcount
return count - 1 if inspect.ismethod(function) else count
def dict_depth(d: t.Dict) -> int: def dict_depth(d: t.Dict) -> int:
""" """
Get the nesting depth of a dictionary. Get the nesting depth of a dictionary.
@ -374,6 +383,7 @@ def dict_depth(d: t.Dict) -> int:
Args: Args:
d (dict): dictionary d (dict): dictionary
Returns: Returns:
int: depth int: depth
""" """

View file

@ -43,7 +43,7 @@ class TypeAnnotator:
}, },
exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr), exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
exp.Alias: lambda self, expr: self._annotate_unary(expr), exp.Alias: lambda self, expr: self._annotate_unary(expr),
exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),

View file

@ -57,7 +57,7 @@ def _join_is_used(scope, join, alias):
# But columns in the ON clause shouldn't count. # But columns in the ON clause shouldn't count.
on = join.args.get("on") on = join.args.get("on")
if on: if on:
on_clause_columns = set(id(column) for column in on.find_all(exp.Column)) on_clause_columns = {id(column) for column in on.find_all(exp.Column)}
else: else:
on_clause_columns = set() on_clause_columns = set()
return any( return any(
@ -71,7 +71,7 @@ def _is_joined_on_all_unique_outputs(scope, join):
return False return False
_, join_keys, _ = join_condition(join) _, join_keys, _ = join_condition(join)
remaining_unique_outputs = unique_outputs - set(c.name for c in join_keys) remaining_unique_outputs = unique_outputs - {c.name for c in join_keys}
return not remaining_unique_outputs return not remaining_unique_outputs

View file

@ -67,11 +67,9 @@ def merge_ctes(expression, leave_tables_isolated=False):
singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
for outer_scope, inner_scope, table in singular_cte_selections: for outer_scope, inner_scope, table in singular_cte_selections:
inner_select = inner_scope.expression.unnest()
from_or_join = table.find_ancestor(exp.From, exp.Join) from_or_join = table.find_ancestor(exp.From, exp.Join)
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
alias = table.alias_or_name alias = table.alias_or_name
_rename_inner_sources(outer_scope, inner_scope, alias) _rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, table, alias) _merge_from(outer_scope, inner_scope, table, alias)
_merge_expressions(outer_scope, inner_scope, alias) _merge_expressions(outer_scope, inner_scope, alias)
@ -80,18 +78,17 @@ def merge_ctes(expression, leave_tables_isolated=False):
_merge_order(outer_scope, inner_scope) _merge_order(outer_scope, inner_scope)
_merge_hints(outer_scope, inner_scope) _merge_hints(outer_scope, inner_scope)
_pop_cte(inner_scope) _pop_cte(inner_scope)
outer_scope.clear_cache()
return expression return expression
def merge_derived_tables(expression, leave_tables_isolated=False): def merge_derived_tables(expression, leave_tables_isolated=False):
for outer_scope in traverse_scope(expression): for outer_scope in traverse_scope(expression):
for subquery in outer_scope.derived_tables: for subquery in outer_scope.derived_tables:
inner_select = subquery.unnest()
from_or_join = subquery.find_ancestor(exp.From, exp.Join) from_or_join = subquery.find_ancestor(exp.From, exp.Join)
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
alias = subquery.alias_or_name alias = subquery.alias_or_name
inner_scope = outer_scope.sources[alias] inner_scope = outer_scope.sources[alias]
if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
_rename_inner_sources(outer_scope, inner_scope, alias) _rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, subquery, alias) _merge_from(outer_scope, inner_scope, subquery, alias)
_merge_expressions(outer_scope, inner_scope, alias) _merge_expressions(outer_scope, inner_scope, alias)
@ -99,21 +96,23 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
_merge_where(outer_scope, inner_scope, from_or_join) _merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope) _merge_order(outer_scope, inner_scope)
_merge_hints(outer_scope, inner_scope) _merge_hints(outer_scope, inner_scope)
outer_scope.clear_cache()
return expression return expression
def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
""" """
Return True if `inner_select` can be merged into outer query. Return True if `inner_select` can be merged into outer query.
Args: Args:
outer_scope (Scope) outer_scope (Scope)
inner_select (exp.Select) inner_scope (Scope)
leave_tables_isolated (bool) leave_tables_isolated (bool)
from_or_join (exp.From|exp.Join) from_or_join (exp.From|exp.Join)
Returns: Returns:
bool: True if can be merged bool: True if can be merged
""" """
inner_select = inner_scope.expression.unnest()
def _is_a_window_expression_in_unmergable_operation(): def _is_a_window_expression_in_unmergable_operation():
window_expressions = inner_select.find_all(exp.Window) window_expressions = inner_select.find_all(exp.Window)
@ -133,10 +132,40 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
] ]
return any(window_expressions_in_unmergable) return any(window_expressions_in_unmergable)
def _outer_select_joins_on_inner_select_join():
"""
All columns from the inner select in the ON clause must be from the first FROM table.
That is, this can be merged:
SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a
^^^ ^
But this can't:
SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a
^^^ ^
"""
if not isinstance(from_or_join, exp.Join):
return False
alias = from_or_join.this.alias_or_name
on = from_or_join.args.get("on")
if not on:
return False
selections = [c.name for c in on.find_all(exp.Column) if c.table == alias]
inner_from = inner_scope.expression.args.get("from")
if not inner_from:
return False
inner_from_table = inner_from.expressions[0].alias_or_name
inner_projections = {s.alias_or_name: s for s in inner_scope.selects}
return any(
col.table != inner_from_table
for selection in selections
for col in inner_projections[selection].find_all(exp.Column)
)
return ( return (
isinstance(outer_scope.expression, exp.Select) isinstance(outer_scope.expression, exp.Select)
and isinstance(inner_select, exp.Select) and isinstance(inner_select, exp.Select)
and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from") and inner_select.args.get("from")
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
@ -153,6 +182,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []) j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
) )
) )
and not _outer_select_joins_on_inner_select_join()
and not _is_a_window_expression_in_unmergable_operation() and not _is_a_window_expression_in_unmergable_operation()
) )
@ -168,7 +198,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
""" """
taken = set(outer_scope.selected_sources) taken = set(outer_scope.selected_sources)
conflicts = taken.intersection(set(inner_scope.selected_sources)) conflicts = taken.intersection(set(inner_scope.selected_sources))
conflicts = conflicts - {alias} conflicts -= {alias}
for conflict in conflicts: for conflict in conflicts:
new_name = find_new_name(taken, conflict) new_name = find_new_name(taken, conflict)

View file

@ -15,6 +15,7 @@ from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema
RULES = ( RULES = (
lower_identities, lower_identities,
@ -51,12 +52,13 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
If no schema is provided then the default schema defined at `sqlgot.schema` will be used If no schema is provided then the default schema defined at `sqlgot.schema` will be used
db (str): specify the default database, as might be set by a `USE DATABASE db` statement db (str): specify the default database, as might be set by a `USE DATABASE db` statement
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
rules (list): sequence of optimizer rules to use rules (sequence): sequence of optimizer rules to use
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in. **kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
Returns: Returns:
sqlglot.Expression: optimized expression sqlglot.Expression: optimized expression
""" """
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema or sqlglot.schema, **kwargs} schema = ensure_schema(schema or sqlglot.schema)
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
expression = expression.copy() expression = expression.copy()
for rule in rules: for rule in rules:

View file

@ -79,6 +79,7 @@ def _remove_unused_selections(scope, parent_selections):
order_refs = set() order_refs = set()
new_selections = [] new_selections = []
removed = False
for i, selection in enumerate(scope.selects): for i, selection in enumerate(scope.selects):
if ( if (
SELECT_ALL in parent_selections SELECT_ALL in parent_selections
@ -88,12 +89,15 @@ def _remove_unused_selections(scope, parent_selections):
new_selections.append(selection) new_selections.append(selection)
else: else:
removed_indexes.append(i) removed_indexes.append(i)
removed = True
# If there are no remaining selections, just select a single constant # If there are no remaining selections, just select a single constant
if not new_selections: if not new_selections:
new_selections.append(DEFAULT_SELECTION.copy()) new_selections.append(DEFAULT_SELECTION.copy())
scope.expression.set("expressions", new_selections) scope.expression.set("expressions", new_selections)
if removed:
scope.clear_cache()
return removed_indexes return removed_indexes

View file

@ -365,9 +365,9 @@ class _Resolver:
def all_columns(self): def all_columns(self):
"""All available columns of all sources in this scope""" """All available columns of all sources in this scope"""
if self._all_columns is None: if self._all_columns is None:
self._all_columns = set( self._all_columns = {
column for columns in self._get_all_source_columns().values() for column in columns column for columns in self._get_all_source_columns().values() for column in columns
) }
return self._all_columns return self._all_columns
def get_source_columns(self, name, only_visible=False): def get_source_columns(self, name, only_visible=False):

View file

@ -361,7 +361,7 @@ def _simplify_binary(expression, a, b):
return boolean return boolean
elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval): elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
a, b = extract_date(a), extract_interval(b) a, b = extract_date(a), extract_interval(b)
if b: if a and b:
if isinstance(expression, exp.Add): if isinstance(expression, exp.Add):
return date_literal(a + b) return date_literal(a + b)
if isinstance(expression, exp.Sub): if isinstance(expression, exp.Sub):
@ -369,7 +369,7 @@ def _simplify_binary(expression, a, b):
elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast): elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
a, b = extract_interval(a), extract_date(b) a, b = extract_interval(a), extract_date(b)
# you cannot subtract a date from an interval # you cannot subtract a date from an interval
if a and isinstance(expression, exp.Add): if a and b and isinstance(expression, exp.Add):
return date_literal(a + b) return date_literal(a + b)
return None return None
@ -424,8 +424,14 @@ def eval_boolean(expression, a, b):
def extract_date(cast): def extract_date(cast):
# The "fromisoformat" conversion could fail if the cast is used on an identifier,
# so in that case we can't extract the date.
try:
if cast.args["to"].this == exp.DataType.Type.DATE: if cast.args["to"].this == exp.DataType.Type.DATE:
return datetime.date.fromisoformat(cast.name) return datetime.date.fromisoformat(cast.name)
if cast.args["to"].this == exp.DataType.Type.DATETIME:
return datetime.datetime.fromisoformat(cast.name)
except ValueError:
return None return None
@ -450,7 +456,8 @@ def extract_interval(interval):
def date_literal(date): def date_literal(date):
return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE")) expr_type = exp.DataType.build("DATETIME" if isinstance(date, datetime.datetime) else "DATE")
return exp.Cast(this=exp.Literal.string(date), to=expr_type)
def boolean_literal(condition): def boolean_literal(condition):

View file

@ -15,8 +15,7 @@ def unnest_subqueries(expression):
>>> import sqlglot >>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql() >>> unnest_subqueries(expression).sql()
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\ 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)'
Args: Args:
expression (sqlglot.Expression): expression to unnest expression (sqlglot.Expression): expression to unnest
@ -173,10 +172,8 @@ def decorrelate(select, parent_select, external_columns, sequence):
other = _other_operand(parent_predicate) other = _other_operand(parent_predicate)
if isinstance(parent_predicate, exp.Exists): if isinstance(parent_predicate, exp.Exists):
if value.this in group_by: alias = exp.column(list(key_aliases.values())[0], table_alias)
parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
else:
parent_predicate = _replace(parent_predicate, "TRUE")
elif isinstance(parent_predicate, exp.All): elif isinstance(parent_predicate, exp.All):
parent_predicate = _replace( parent_predicate = _replace(
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
@ -197,6 +194,23 @@ def decorrelate(select, parent_select, external_columns, sequence):
else: else:
if is_subquery_projection: if is_subquery_projection:
alias = exp.alias_(alias, select.parent.alias) alias = exp.alias_(alias, select.parent.alias)
# COUNT always returns 0 on empty datasets, so we need take that into consideration here
# by transforming all counts into 0 and using that as the coalesced value
if value.find(exp.Count):
def remove_aggs(node):
if isinstance(node, exp.Count):
return exp.Literal.number(0)
elif isinstance(node, exp.AggFunc):
return exp.null()
return node
alias = exp.Coalesce(
this=alias,
expressions=[value.this.transform(remove_aggs)],
)
select.parent.replace(alias) select.parent.replace(alias)
for key, column, predicate in keys: for key, column, predicate in keys:
@ -209,9 +223,6 @@ def decorrelate(select, parent_select, external_columns, sequence):
if key in group_by: if key in group_by:
key.replace(nested) key.replace(nested)
parent_predicate = _replace(
parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
)
elif isinstance(predicate, exp.EQ): elif isinstance(predicate, exp.EQ):
parent_predicate = _replace( parent_predicate = _replace(
parent_predicate, parent_predicate,
@ -245,7 +256,14 @@ def _other_operand(expression):
if isinstance(expression, exp.In): if isinstance(expression, exp.In):
return expression.this return expression.this
if isinstance(expression, (exp.Any, exp.All)):
return _other_operand(expression.parent)
if isinstance(expression, exp.Binary): if isinstance(expression, exp.Binary):
return expression.right if expression.arg_key == "this" else expression.left return (
expression.right
if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All))
else expression.left
)
return None return None

File diff suppressed because it is too large Load diff

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import abc import abc
import typing as t import typing as t
import sqlglot
from sqlglot import expressions as exp from sqlglot import expressions as exp
from sqlglot.errors import SchemaError from sqlglot.errors import SchemaError
from sqlglot.helper import dict_depth from sqlglot.helper import dict_depth
@ -157,10 +158,10 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
visible: t.Optional[t.Dict] = None, visible: t.Optional[t.Dict] = None,
dialect: t.Optional[str] = None, dialect: t.Optional[str] = None,
) -> None: ) -> None:
super().__init__(schema)
self.visible = visible or {}
self.dialect = dialect self.dialect = dialect
self.visible = visible or {}
self._type_mapping_cache: t.Dict[str, exp.DataType] = {} self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
super().__init__(self._normalize(schema or {}))
@classmethod @classmethod
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
@ -180,6 +181,33 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
} }
) )
def _normalize(self, schema: t.Dict) -> t.Dict:
"""
Converts all identifiers in the schema into lowercase, unless they're quoted.
Args:
schema: the schema to normalize.
Returns:
The normalized schema mapping.
"""
flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
normalized_mapping: t.Dict = {}
for keys in flattened_schema:
columns = _nested_get(schema, *zip(keys, keys))
assert columns is not None
normalized_keys = [self._normalize_name(key) for key in keys]
for column_name, column_type in columns.items():
_nested_set(
normalized_mapping,
normalized_keys + [self._normalize_name(column_name)],
column_type,
)
return normalized_mapping
def add_table( def add_table(
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
) -> None: ) -> None:
@ -204,6 +232,19 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
) )
self.mapping_trie = self._build_trie(self.mapping) self.mapping_trie = self._build_trie(self.mapping)
def _normalize_name(self, name: str) -> str:
try:
identifier: t.Optional[exp.Expression] = sqlglot.parse_one(
name, read=self.dialect, into=exp.Identifier
)
except:
identifier = exp.to_identifier(name)
assert isinstance(identifier, exp.Identifier)
if identifier.quoted:
return identifier.name
return identifier.name.lower()
def _depth(self) -> int: def _depth(self) -> int:
# The columns themselves are a mapping, but we don't want to include those # The columns themselves are a mapping, but we don't want to include those
return super()._depth() - 1 return super()._depth() - 1

67
sqlglot/serde.py Normal file
View file

@ -0,0 +1,67 @@
from __future__ import annotations
import typing as t
from sqlglot import expressions as exp
if t.TYPE_CHECKING:
JSON = t.Union[dict, list, str, float, int, bool]
Node = t.Union[t.List["Node"], exp.DataType.Type, exp.Expression, JSON]
def dump(node: Node) -> JSON:
"""
Recursively dump an AST into a JSON-serializable dict.
"""
if isinstance(node, list):
return [dump(i) for i in node]
if isinstance(node, exp.DataType.Type):
return {
"class": "DataType.Type",
"value": node.value,
}
if isinstance(node, exp.Expression):
klass = node.__class__.__qualname__
if node.__class__.__module__ != exp.__name__:
klass = f"{node.__module__}.{klass}"
obj = {
"class": klass,
"args": {k: dump(v) for k, v in node.args.items() if v is not None and v != []},
}
if node.type:
obj["type"] = node.type.sql()
if node.comments:
obj["comments"] = node.comments
return obj
return node
def load(obj: JSON) -> Node:
"""
Recursively load a dict (as returned by `dump`) into an AST.
"""
if isinstance(obj, list):
return [load(i) for i in obj]
if isinstance(obj, dict):
class_name = obj["class"]
if class_name == "DataType.Type":
return exp.DataType.Type(obj["value"])
if "." in class_name:
module_path, class_name = class_name.rsplit(".", maxsplit=1)
module = __import__(module_path, fromlist=[class_name])
else:
module = exp
klass = getattr(module, class_name)
expression = klass(**{k: load(v) for k, v in obj["args"].items()})
type_ = obj.get("type")
if type_:
expression.type = exp.DataType.build(type_)
comments = obj.get("comments")
if comments:
expression.comments = load(comments)
return expression
return obj

View file

@ -86,6 +86,7 @@ class TokenType(AutoName):
VARBINARY = auto() VARBINARY = auto()
JSON = auto() JSON = auto()
JSONB = auto() JSONB = auto()
TIME = auto()
TIMESTAMP = auto() TIMESTAMP = auto()
TIMESTAMPTZ = auto() TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto() TIMESTAMPLTZ = auto()
@ -181,6 +182,7 @@ class TokenType(AutoName):
FUNCTION = auto() FUNCTION = auto()
FROM = auto() FROM = auto()
GENERATED = auto() GENERATED = auto()
GLOBAL = auto()
GROUP_BY = auto() GROUP_BY = auto()
GROUPING_SETS = auto() GROUPING_SETS = auto()
HAVING = auto() HAVING = auto()
@ -656,6 +658,7 @@ class Tokenizer(metaclass=_Tokenizer):
"FLOAT4": TokenType.FLOAT, "FLOAT4": TokenType.FLOAT,
"FLOAT8": TokenType.DOUBLE, "FLOAT8": TokenType.DOUBLE,
"DOUBLE": TokenType.DOUBLE, "DOUBLE": TokenType.DOUBLE,
"DOUBLE PRECISION": TokenType.DOUBLE,
"JSON": TokenType.JSON, "JSON": TokenType.JSON,
"CHAR": TokenType.CHAR, "CHAR": TokenType.CHAR,
"NCHAR": TokenType.NCHAR, "NCHAR": TokenType.NCHAR,
@ -671,6 +674,7 @@ class Tokenizer(metaclass=_Tokenizer):
"BLOB": TokenType.VARBINARY, "BLOB": TokenType.VARBINARY,
"BYTEA": TokenType.VARBINARY, "BYTEA": TokenType.VARBINARY,
"VARBINARY": TokenType.VARBINARY, "VARBINARY": TokenType.VARBINARY,
"TIME": TokenType.TIME,
"TIMESTAMP": TokenType.TIMESTAMP, "TIMESTAMP": TokenType.TIMESTAMP,
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ, "TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
@ -721,6 +725,8 @@ class Tokenizer(metaclass=_Tokenizer):
COMMENTS = ["--", ("/*", "*/")] COMMENTS = ["--", ("/*", "*/")]
KEYWORD_TRIE = None # autofilled KEYWORD_TRIE = None # autofilled
IDENTIFIER_CAN_START_WITH_DIGIT = False
__slots__ = ( __slots__ = (
"sql", "sql",
"size", "size",
@ -938,17 +944,24 @@ class Tokenizer(metaclass=_Tokenizer):
elif self._peek.upper() == "E" and not scientific: # type: ignore elif self._peek.upper() == "E" and not scientific: # type: ignore
scientific += 1 scientific += 1
self._advance() self._advance()
elif self._peek.isalpha(): # type: ignore elif self._peek.isidentifier(): # type: ignore
self._add(TokenType.NUMBER) number_text = self._text
literal = [] literal = []
while self._peek.isalpha(): # type: ignore while self._peek.isidentifier(): # type: ignore
literal.append(self._peek.upper()) # type: ignore literal.append(self._peek.upper()) # type: ignore
self._advance() self._advance()
literal = "".join(literal) # type: ignore literal = "".join(literal) # type: ignore
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore
if token_type: if token_type:
self._add(TokenType.NUMBER, number_text)
self._add(TokenType.DCOLON, "::") self._add(TokenType.DCOLON, "::")
return self._add(token_type, literal) # type: ignore return self._add(token_type, literal) # type: ignore
elif self.IDENTIFIER_CAN_START_WITH_DIGIT:
return self._add(TokenType.VAR)
self._add(TokenType.NUMBER, number_text)
return self._advance(-len(literal)) return self._advance(-len(literal))
else: else:
return self._add(TokenType.NUMBER) return self._add(TokenType.NUMBER)

View file

@ -82,6 +82,27 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
return expression return expression
def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
"""
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions.
This transforms removes the precision from parameterized types in expressions.
"""
return expression.transform(
lambda node: exp.DataType(
**{
**node.args,
"expressions": [
node_expression
for node_expression in node.expressions
if isinstance(node_expression, exp.DataType)
],
}
)
if isinstance(node, exp.DataType)
else node,
)
def preprocess( def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
to_sql: t.Callable[[Generator, exp.Expression], str], to_sql: t.Callable[[Generator, exp.Expression], str],
@ -121,3 +142,6 @@ def delegate(attr: str) -> t.Callable:
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))} UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))} ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))}
REMOVE_PRECISION_PARAMETERIZED_TYPES = {
exp.Cast: preprocess([remove_precision_parameterized_types], delegate("cast_sql"))
}

View file

@ -52,7 +52,7 @@ def in_trie(trie: t.Dict, key: key) -> t.Tuple[int, t.Dict]:
Returns: Returns:
A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point where the search stops, and `value` A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point where the search stops, and `value`
is either 0 (search was unsuccessfull), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`). is either 0 (search was unsuccessful), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`).
""" """
if not key: if not key:
return (0, trie) return (0, trie)

View file

@ -1152,17 +1152,17 @@ class TestFunctions(unittest.TestCase):
def test_regexp_extract(self): def test_regexp_extract(self):
col_str = SF.regexp_extract("cola", r"(\d+)-(\d+)", 1) col_str = SF.regexp_extract("cola", r"(\d+)-(\d+)", 1)
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)', 1)", col_str.sql()) self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)', 1)", col_str.sql())
col = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)", 1) col = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)", 1)
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)', 1)", col.sql()) self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)', 1)", col.sql())
col_no_idx = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)") col_no_idx = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)")
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)')", col_no_idx.sql()) self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)')", col_no_idx.sql())
def test_regexp_replace(self): def test_regexp_replace(self):
col_str = SF.regexp_replace("cola", r"(\d+)", "--") col_str = SF.regexp_replace("cola", r"(\d+)", "--")
self.assertEqual("REGEXP_REPLACE(cola, '(\\\d+)', '--')", col_str.sql()) self.assertEqual("REGEXP_REPLACE(cola, '(\\\\d+)', '--')", col_str.sql())
col = SF.regexp_replace(SF.col("cola"), r"(\d+)", "--") col = SF.regexp_replace(SF.col("cola"), r"(\d+)", "--")
self.assertEqual("REGEXP_REPLACE(cola, '(\\\d+)', '--')", col.sql()) self.assertEqual("REGEXP_REPLACE(cola, '(\\\\d+)', '--')", col.sql())
def test_initcap(self): def test_initcap(self):
col_str = SF.initcap("cola") col_str = SF.initcap("cola")

View file

@ -15,11 +15,11 @@ class TestDataframeWindow(unittest.TestCase):
def test_window_spec_rows_between(self): def test_window_spec_rows_between(self):
rows_between = WindowSpec().rowsBetween(3, 5) rows_between = WindowSpec().rowsBetween(3, 5)
self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql()) self.assertEqual("OVER (ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
def test_window_spec_range_between(self): def test_window_spec_range_between(self):
range_between = WindowSpec().rangeBetween(3, 5) range_between = WindowSpec().rangeBetween(3, 5)
self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql()) self.assertEqual("OVER (RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
def test_window_partition_by(self): def test_window_partition_by(self):
partition_by = Window.partitionBy(F.col("cola"), F.col("colb")) partition_by = Window.partitionBy(F.col("cola"), F.col("colb"))
@ -31,46 +31,46 @@ class TestDataframeWindow(unittest.TestCase):
def test_window_rows_between(self): def test_window_rows_between(self):
rows_between = Window.rowsBetween(3, 5) rows_between = Window.rowsBetween(3, 5)
self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql()) self.assertEqual("OVER (ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
def test_window_range_between(self): def test_window_range_between(self):
range_between = Window.rangeBetween(3, 5) range_between = Window.rangeBetween(3, 5)
self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql()) self.assertEqual("OVER (RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
def test_window_rows_unbounded(self): def test_window_rows_unbounded(self):
rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2) rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
self.assertEqual( self.assertEqual(
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", "OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
rows_between_unbounded_start.sql(), rows_between_unbounded_start.sql(),
) )
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing) rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
self.assertEqual( self.assertEqual(
"OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", "OVER (ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
rows_between_unbounded_end.sql(), rows_between_unbounded_end.sql(),
) )
rows_between_unbounded_both = Window.rowsBetween( rows_between_unbounded_both = Window.rowsBetween(
Window.unboundedPreceding, Window.unboundedFollowing Window.unboundedPreceding, Window.unboundedFollowing
) )
self.assertEqual( self.assertEqual(
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", "OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
rows_between_unbounded_both.sql(), rows_between_unbounded_both.sql(),
) )
def test_window_range_unbounded(self): def test_window_range_unbounded(self):
range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2) range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
self.assertEqual( self.assertEqual(
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", "OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
range_between_unbounded_start.sql(), range_between_unbounded_start.sql(),
) )
range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing) range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
self.assertEqual( self.assertEqual(
"OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", "OVER (RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
range_between_unbounded_end.sql(), range_between_unbounded_end.sql(),
) )
range_between_unbounded_both = Window.rangeBetween( range_between_unbounded_both = Window.rangeBetween(
Window.unboundedPreceding, Window.unboundedFollowing Window.unboundedPreceding, Window.unboundedFollowing
) )
self.assertEqual( self.assertEqual(
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", "OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
range_between_unbounded_both.sql(), range_between_unbounded_both.sql(),
) )

View file

@ -125,7 +125,7 @@ class TestBigQuery(Validator):
}, },
) )
self.validate_all( self.validate_all(
"CURRENT_DATE", "CURRENT_TIMESTAMP()",
read={ read={
"tsql": "GETDATE()", "tsql": "GETDATE()",
}, },
@ -299,6 +299,14 @@ class TestBigQuery(Validator):
"snowflake": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)", "snowflake": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
}, },
) )
self.validate_all(
"SELECT cola, colb, colc FROM (VALUES (1, 'test', NULL)) AS tab(cola, colb, colc)",
write={
"spark": "SELECT cola, colb, colc FROM VALUES (1, 'test', NULL) AS tab(cola, colb, colc)",
"bigquery": "SELECT cola, colb, colc FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb, NULL AS colc)])",
"snowflake": "SELECT cola, colb, colc FROM (VALUES (1, 'test', NULL)) AS tab(cola, colb, colc)",
},
)
self.validate_all( self.validate_all(
"SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) d, COUNT(*) e FOR c IN ('x', 'y'))", "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) d, COUNT(*) e FOR c IN ('x', 'y'))",
write={ write={
@ -324,3 +332,35 @@ class TestBigQuery(Validator):
"SELECT a, GROUP_CONCAT(b) FROM table GROUP BY a", "SELECT a, GROUP_CONCAT(b) FROM table GROUP BY a",
write={"bigquery": "SELECT a, STRING_AGG(b) FROM table GROUP BY a"}, write={"bigquery": "SELECT a, STRING_AGG(b) FROM table GROUP BY a"},
) )
def test_remove_precision_parameterized_types(self):
self.validate_all(
"SELECT CAST(1 AS NUMERIC(10, 2))",
write={
"bigquery": "SELECT CAST(1 AS NUMERIC)",
},
)
self.validate_all(
"CREATE TABLE test (a NUMERIC(10, 2))",
write={
"bigquery": "CREATE TABLE test (a NUMERIC(10, 2))",
},
)
self.validate_all(
"SELECT CAST('1' AS STRING(10)) UNION ALL SELECT CAST('2' AS STRING(10))",
write={
"bigquery": "SELECT CAST('1' AS STRING) UNION ALL SELECT CAST('2' AS STRING)",
},
)
self.validate_all(
"SELECT cola FROM (SELECT CAST('1' AS STRING(10)) AS cola UNION ALL SELECT CAST('2' AS STRING(10)) AS cola)",
write={
"bigquery": "SELECT cola FROM (SELECT CAST('1' AS STRING) AS cola UNION ALL SELECT CAST('2' AS STRING) AS cola)",
},
)
self.validate_all(
"INSERT INTO test (cola, colb) VALUES (CAST(7 AS STRING(10)), CAST(14 AS STRING(10)))",
write={
"bigquery": "INSERT INTO test (cola, colb) VALUES (CAST(7 AS STRING), CAST(14 AS STRING))",
},
)

View file

@ -14,6 +14,9 @@ class TestClickhouse(Validator):
self.validate_identity("SELECT * FROM foo LEFT ASOF JOIN bla") self.validate_identity("SELECT * FROM foo LEFT ASOF JOIN bla")
self.validate_identity("SELECT * FROM foo ASOF JOIN bla") self.validate_identity("SELECT * FROM foo ASOF JOIN bla")
self.validate_identity("SELECT * FROM foo ANY JOIN bla") self.validate_identity("SELECT * FROM foo ANY JOIN bla")
self.validate_identity("SELECT quantile(0.5)(a)")
self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t")
self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)")
self.validate_all( self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
@ -38,3 +41,9 @@ class TestClickhouse(Validator):
"SELECT x #! comment", "SELECT x #! comment",
write={"": "SELECT x /* comment */"}, write={"": "SELECT x /* comment */"},
) )
self.validate_all(
"SELECT quantileIf(0.5)(a, true)",
write={
"clickhouse": "SELECT quantileIf(0.5)(a, TRUE)",
},
)

View file

@ -85,7 +85,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"CAST(a AS BINARY(4))", "CAST(a AS BINARY(4))",
write={ write={
"bigquery": "CAST(a AS BINARY(4))", "bigquery": "CAST(a AS BINARY)",
"clickhouse": "CAST(a AS BINARY(4))", "clickhouse": "CAST(a AS BINARY(4))",
"drill": "CAST(a AS VARBINARY(4))", "drill": "CAST(a AS VARBINARY(4))",
"duckdb": "CAST(a AS BINARY(4))", "duckdb": "CAST(a AS BINARY(4))",
@ -104,7 +104,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"CAST(a AS VARBINARY(4))", "CAST(a AS VARBINARY(4))",
write={ write={
"bigquery": "CAST(a AS VARBINARY(4))", "bigquery": "CAST(a AS VARBINARY)",
"clickhouse": "CAST(a AS VARBINARY(4))", "clickhouse": "CAST(a AS VARBINARY(4))",
"duckdb": "CAST(a AS VARBINARY(4))", "duckdb": "CAST(a AS VARBINARY(4))",
"mysql": "CAST(a AS VARBINARY(4))", "mysql": "CAST(a AS VARBINARY(4))",
@ -181,7 +181,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"CAST(a AS VARCHAR(3))", "CAST(a AS VARCHAR(3))",
write={ write={
"bigquery": "CAST(a AS STRING(3))", "bigquery": "CAST(a AS STRING)",
"drill": "CAST(a AS VARCHAR(3))", "drill": "CAST(a AS VARCHAR(3))",
"duckdb": "CAST(a AS TEXT(3))", "duckdb": "CAST(a AS TEXT(3))",
"mysql": "CAST(a AS VARCHAR(3))", "mysql": "CAST(a AS VARCHAR(3))",

View file

@ -338,6 +338,24 @@ class TestHive(Validator):
) )
def test_hive(self): def test_hive(self):
self.validate_all(
"SELECT A.1a AS b FROM test_a AS A",
write={
"spark": "SELECT A.1a AS b FROM test_a AS A",
},
)
self.validate_all(
"SELECT 1_a AS a FROM test_table",
write={
"spark": "SELECT 1_a AS a FROM test_table",
},
)
self.validate_all(
"SELECT a_b AS 1_a FROM test_table",
write={
"spark": "SELECT a_b AS 1_a FROM test_table",
},
)
self.validate_all( self.validate_all(
"PERCENTILE(x, 0.5)", "PERCENTILE(x, 0.5)",
write={ write={
@ -411,7 +429,7 @@ class TestHive(Validator):
"INITCAP('new york')", "INITCAP('new york')",
write={ write={
"duckdb": "INITCAP('new york')", "duckdb": "INITCAP('new york')",
"presto": "REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2]))", "presto": r"REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2]))",
"hive": "INITCAP('new york')", "hive": "INITCAP('new york')",
"spark": "INITCAP('new york')", "spark": "INITCAP('new york')",
}, },

View file

@ -122,6 +122,10 @@ class TestPostgres(Validator):
"TO_TIMESTAMP(123::DOUBLE PRECISION)", "TO_TIMESTAMP(123::DOUBLE PRECISION)",
write={"postgres": "TO_TIMESTAMP(CAST(123 AS DOUBLE PRECISION))"}, write={"postgres": "TO_TIMESTAMP(CAST(123 AS DOUBLE PRECISION))"},
) )
self.validate_all(
"SELECT to_timestamp(123)::time without time zone",
write={"postgres": "SELECT CAST(TO_TIMESTAMP(123) AS TIME)"},
)
self.validate_identity( self.validate_identity(
"CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)" "CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)"

View file

@ -60,11 +60,11 @@ class TestPresto(Validator):
self.validate_all( self.validate_all(
"CAST(x AS TIMESTAMP(9) WITH TIME ZONE)", "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)",
write={ write={
"bigquery": "CAST(x AS TIMESTAMPTZ(9))", "bigquery": "CAST(x AS TIMESTAMPTZ)",
"duckdb": "CAST(x AS TIMESTAMPTZ(9))", "duckdb": "CAST(x AS TIMESTAMPTZ(9))",
"presto": "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)", "presto": "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)",
"hive": "CAST(x AS TIMESTAMPTZ(9))", "hive": "CAST(x AS TIMESTAMPTZ)",
"spark": "CAST(x AS TIMESTAMPTZ(9))", "spark": "CAST(x AS TIMESTAMPTZ)",
}, },
) )

View file

@ -523,3 +523,33 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f, LATERA
"spark": "SELECT `c0`, `c1` FROM (VALUES (1, 2), (3, 4)) AS `t0`(`c0`, `c1`)", "spark": "SELECT `c0`, `c1` FROM (VALUES (1, 2), (3, 4)) AS `t0`(`c0`, `c1`)",
}, },
) )
def test_describe_table(self):
self.validate_all(
"DESCRIBE TABLE db.table",
write={
"snowflake": "DESCRIBE TABLE db.table",
"spark": "DESCRIBE db.table",
},
)
self.validate_all(
"DESCRIBE db.table",
write={
"snowflake": "DESCRIBE TABLE db.table",
"spark": "DESCRIBE db.table",
},
)
self.validate_all(
"DESC TABLE db.table",
write={
"snowflake": "DESCRIBE TABLE db.table",
"spark": "DESCRIBE db.table",
},
)
self.validate_all(
"DESC VIEW db.table",
write={
"snowflake": "DESCRIBE VIEW db.table",
"spark": "DESCRIBE db.table",
},
)

View file

@ -207,6 +207,7 @@ TBLPROPERTIES (
) )
def test_spark(self): def test_spark(self):
self.validate_identity("SELECT UNIX_TIMESTAMP()")
self.validate_all( self.validate_all(
"ARRAY_SORT(x, (left, right) -> -1)", "ARRAY_SORT(x, (left, right) -> -1)",
write={ write={

View file

@ -6,6 +6,8 @@ class TestTSQL(Validator):
def test_tsql(self): def test_tsql(self):
self.validate_identity('SELECT "x"."y" FROM foo') self.validate_identity('SELECT "x"."y" FROM foo')
self.validate_identity("SELECT * FROM #foo")
self.validate_identity("SELECT * FROM ##foo")
self.validate_identity( self.validate_identity(
"SELECT DISTINCT DepartmentName, PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY BaseRate) OVER (PARTITION BY DepartmentName) AS MedianCont FROM dbo.DimEmployee" "SELECT DISTINCT DepartmentName, PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY BaseRate) OVER (PARTITION BY DepartmentName) AS MedianCont FROM dbo.DimEmployee"
) )
@ -71,6 +73,12 @@ class TestTSQL(Validator):
"tsql": "CAST(x AS DATETIME2)", "tsql": "CAST(x AS DATETIME2)",
}, },
) )
self.validate_all(
"CAST(x AS DATETIME2(6))",
write={
"hive": "CAST(x AS TIMESTAMP)",
},
)
def test_charindex(self): def test_charindex(self):
self.validate_all( self.validate_all(
@ -300,6 +308,12 @@ class TestTSQL(Validator):
"spark": "SELECT CAST(y.x AS VARCHAR(10)) AS z FROM testdb.dbo.test AS y", "spark": "SELECT CAST(y.x AS VARCHAR(10)) AS z FROM testdb.dbo.test AS y",
}, },
) )
self.validate_all(
"SELECT CAST((SELECT x FROM y) AS VARCHAR) AS test",
write={
"spark": "SELECT CAST((SELECT x FROM y) AS STRING) AS test",
},
)
def test_add_date(self): def test_add_date(self):
self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')") self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')")
@ -441,3 +455,13 @@ class TestTSQL(Validator):
"SELECT '''test'''", "SELECT '''test'''",
write={"spark": r"SELECT '\'test\''"}, write={"spark": r"SELECT '\'test\''"},
) )
def test_eomonth(self):
self.validate_all(
"EOMONTH(GETDATE())",
write={"spark": "LAST_DAY(CURRENT_TIMESTAMP())"},
)
self.validate_all(
"EOMONTH(GETDATE(), -1)",
write={"spark": "LAST_DAY(ADD_MONTHS(CURRENT_TIMESTAMP(), -1))"},
)

View file

@ -89,6 +89,7 @@ POSEXPLODE("x") AS ("a", "b")
POSEXPLODE("x") AS ("a", "b", "c") POSEXPLODE("x") AS ("a", "b", "c")
STR_POSITION(x, 'a') STR_POSITION(x, 'a')
STR_POSITION(x, 'a', 3) STR_POSITION(x, 'a', 3)
LEVENSHTEIN('gumbo', 'gambol', 2, 1, 1)
SPLIT(SPLIT(referrer, 'utm_source=')[OFFSET(1)], "&")[OFFSET(0)] SPLIT(SPLIT(referrer, 'utm_source=')[OFFSET(1)], "&")[OFFSET(0)]
x[ORDINAL(1)][SAFE_OFFSET(2)] x[ORDINAL(1)][SAFE_OFFSET(2)]
x LIKE SUBSTR('abc', 1, 1) x LIKE SUBSTR('abc', 1, 1)
@ -425,6 +426,7 @@ SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT
SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 AND 3) SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 AND 3)
SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND 3) SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND 3)
SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLOWING) SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLOWING)
SELECT AVG(x) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM t
SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x) AS y SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x) AS y
SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x DESC) SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x DESC)
SELECT SUM(x) FILTER(WHERE x > 1) SELECT SUM(x) FILTER(WHERE x > 1)
@ -450,14 +452,24 @@ SELECT student, score FROM tests CROSS JOIN UNNEST(scores) AS t(a, b)
SELECT student, score FROM tests CROSS JOIN UNNEST(scores) WITH ORDINALITY AS t(a, b) SELECT student, score FROM tests CROSS JOIN UNNEST(scores) WITH ORDINALITY AS t(a, b)
SELECT student, score FROM tests CROSS JOIN UNNEST(x.scores) AS t(score) SELECT student, score FROM tests CROSS JOIN UNNEST(x.scores) AS t(score)
SELECT student, score FROM tests CROSS JOIN UNNEST(ARRAY(x.scores)) AS t(score) SELECT student, score FROM tests CROSS JOIN UNNEST(ARRAY(x.scores)) AS t(score)
SELECT * FROM t WITH (TABLOCK, INDEX(myindex))
SELECT * FROM t WITH (NOWAIT)
CREATE TABLE foo AS (SELECT 1) UNION ALL (SELECT 2)
CREATE TABLE foo (id INT PRIMARY KEY ASC) CREATE TABLE foo (id INT PRIMARY KEY ASC)
CREATE TABLE a.b AS SELECT 1 CREATE TABLE a.b AS SELECT 1
CREATE TABLE a.b AS SELECT 1 WITH DATA AND STATISTICS
CREATE TABLE a.b AS SELECT 1 WITH NO DATA AND NO STATISTICS
CREATE TABLE a.b AS (SELECT 1) NO PRIMARY INDEX
CREATE TABLE a.b AS (SELECT 1) UNIQUE PRIMARY INDEX index1 (a) UNIQUE INDEX index2 (b)
CREATE TABLE a.b AS (SELECT 1) PRIMARY AMP INDEX index1 (a) UNIQUE INDEX index2 (b)
CREATE TABLE a.b AS SELECT a FROM a.c CREATE TABLE a.b AS SELECT a FROM a.c
CREATE TABLE IF NOT EXISTS x AS SELECT a FROM d CREATE TABLE IF NOT EXISTS x AS SELECT a FROM d
CREATE TEMPORARY TABLE x AS SELECT a FROM d CREATE TEMPORARY TABLE x AS SELECT a FROM d
CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d
CREATE VIEW x AS SELECT a FROM b CREATE VIEW x AS SELECT a FROM b
CREATE VIEW IF NOT EXISTS x AS SELECT a FROM b CREATE VIEW IF NOT EXISTS x AS SELECT a FROM b
CREATE VIEW z (a, b COMMENT 'b', c COMMENT 'c') AS SELECT a, b, c FROM d
CREATE VIEW IF NOT EXISTS z (a, b COMMENT 'b', c COMMENT 'c') AS SELECT a, b, c FROM d
CREATE OR REPLACE VIEW x AS SELECT * CREATE OR REPLACE VIEW x AS SELECT *
CREATE OR REPLACE TEMPORARY VIEW x AS SELECT * CREATE OR REPLACE TEMPORARY VIEW x AS SELECT *
CREATE TEMPORARY VIEW x AS SELECT a FROM d CREATE TEMPORARY VIEW x AS SELECT a FROM d
@ -490,6 +502,8 @@ CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT)
CREATE TABLE z (a INT REFERENCES parent(b, c)) CREATE TABLE z (a INT REFERENCES parent(b, c))
CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id)) CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id))
CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c)) CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c))
CREATE VIEW z (a, b)
CREATE VIEW z (a, b COMMENT 'b', c COMMENT 'c')
CREATE TEMPORARY FUNCTION f CREATE TEMPORARY FUNCTION f
CREATE TEMPORARY FUNCTION f AS 'g' CREATE TEMPORARY FUNCTION f AS 'g'
CREATE FUNCTION f CREATE FUNCTION f
@ -559,6 +573,7 @@ INSERT INTO x.z IF EXISTS SELECT * FROM y
INSERT INTO x VALUES (1, 'a', 2.0) INSERT INTO x VALUES (1, 'a', 2.0)
INSERT INTO x VALUES (1, 'a', 2.0), (1, 'a', 3.0), (X(), y[1], z.x) INSERT INTO x VALUES (1, 'a', 2.0), (1, 'a', 3.0), (X(), y[1], z.x)
INSERT INTO y (a, b, c) SELECT a, b, c FROM x INSERT INTO y (a, b, c) SELECT a, b, c FROM x
INSERT INTO y (SELECT 1) UNION (SELECT 2)
INSERT OVERWRITE TABLE x IF EXISTS SELECT * FROM y INSERT OVERWRITE TABLE x IF EXISTS SELECT * FROM y
INSERT OVERWRITE TABLE a.b IF EXISTS SELECT * FROM y INSERT OVERWRITE TABLE a.b IF EXISTS SELECT * FROM y
INSERT OVERWRITE DIRECTORY 'x' SELECT 1 INSERT OVERWRITE DIRECTORY 'x' SELECT 1
@ -627,3 +642,4 @@ ALTER TABLE integers ALTER COLUMN i SET DEFAULT 10
ALTER TABLE integers ALTER COLUMN i DROP DEFAULT ALTER TABLE integers ALTER COLUMN i DROP DEFAULT
ALTER TABLE mydataset.mytable DROP COLUMN A, DROP COLUMN IF EXISTS B ALTER TABLE mydataset.mytable DROP COLUMN A, DROP COLUMN IF EXISTS B
ALTER TABLE mydataset.mytable ADD COLUMN A TEXT, ADD COLUMN IF NOT EXISTS B INT ALTER TABLE mydataset.mytable ADD COLUMN A TEXT, ADD COLUMN IF NOT EXISTS B INT
SELECT div.a FROM test_table AS div

View file

@ -311,3 +311,42 @@ FROM
ON ON
t1.cola = t2.cola; t1.cola = t2.cola;
SELECT /*+ BROADCAST(a2) */ a1.cola AS cola, a2.cola AS cola FROM VALUES (1) AS a1(cola) JOIN VALUES (1) AS a2(cola) ON a1.cola = a2.cola; SELECT /*+ BROADCAST(a2) */ a1.cola AS cola, a2.cola AS cola FROM VALUES (1) AS a1(cola) JOIN VALUES (1) AS a2(cola) ON a1.cola = a2.cola;
# title: Nested subquery selects from same table as another subquery
WITH i AS (
SELECT
x.a AS a
FROM x AS x
), j AS (
SELECT
x.a,
x.b
FROM x AS x
), k AS (
SELECT
j.a,
j.b
FROM j AS j
)
SELECT
i.a,
k.b
FROM i AS i
LEFT JOIN k AS k
ON i.a = k.a;
SELECT x.a AS a, x_2.b AS b FROM x AS x LEFT JOIN x AS x_2 ON x.a = x_2.a;
# title: Outer select joins on inner select join
WITH i AS (
SELECT
x.a AS a
FROM y AS y
JOIN x AS x
ON y.b = x.b
)
SELECT
x.a AS a
FROM x AS x
LEFT JOIN i AS i
ON x.a = i.a;
WITH i AS (SELECT x.a AS a FROM y AS y JOIN x AS x ON y.b = x.b) SELECT x.a AS a FROM x AS x LEFT JOIN i AS i ON x.a = i.a;

View file

@ -105,7 +105,7 @@ LEFT JOIN "_u_0" AS "_u_0"
JOIN "y" AS "y" JOIN "y" AS "y"
ON "x"."b" = "y"."b" ON "x"."b" = "y"."b"
WHERE WHERE
"_u_0"."_col_0" >= 0 AND "x"."a" > 1 AND NOT "_u_0"."_u_1" IS NULL "_u_0"."_col_0" >= 0 AND "x"."a" > 1
GROUP BY GROUP BY
"x"."a"; "x"."a";

View file

@ -54,3 +54,6 @@ WITH t1 AS (SELECT q.cola AS cola FROM UNNEST(ARRAY(STRUCT(1 AS cola, 'test' AS
SELECT x FROM VALUES(1, 2) AS q(x, y); SELECT x FROM VALUES(1, 2) AS q(x, y);
SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y); SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y);
SELECT i.a FROM x AS i LEFT JOIN (SELECT a, b FROM (SELECT a, b FROM x)) AS j ON i.a = j.a;
SELECT i.a AS a FROM x AS i LEFT JOIN (SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0) AS j ON i.a = j.a;

View file

@ -375,6 +375,18 @@ CAST('1998-12-01' AS DATE) - INTERVAL '90' foo;
date '1998-12-01' + interval '90' foo; date '1998-12-01' + interval '90' foo;
CAST('1998-12-01' AS DATE) + INTERVAL '90' foo; CAST('1998-12-01' AS DATE) + INTERVAL '90' foo;
CAST(x AS DATE) + interval '1' week;
CAST(x AS DATE) + INTERVAL '1' week;
CAST('2008-11-11' AS DATETIME) + INTERVAL '5' MONTH;
CAST('2009-04-11 00:00:00' AS DATETIME);
datetime '1998-12-01' - interval '90' day;
CAST('1998-09-02 00:00:00' AS DATETIME);
CAST(x AS DATETIME) + interval '1' week;
CAST(x AS DATETIME) + INTERVAL '1' week;
-------------------------------------- --------------------------------------
-- Comparisons -- Comparisons
-------------------------------------- --------------------------------------

View file

@ -150,7 +150,6 @@ WHERE
"part"."p_size" = 15 "part"."p_size" = 15
AND "part"."p_type" LIKE '%BRASS' AND "part"."p_type" LIKE '%BRASS'
AND "partsupp"."ps_supplycost" = "_u_0"."_col_0" AND "partsupp"."ps_supplycost" = "_u_0"."_col_0"
AND NOT "_u_0"."_u_1" IS NULL
ORDER BY ORDER BY
"s_acctbal" DESC, "s_acctbal" DESC,
"n_name", "n_name",
@ -1008,7 +1007,7 @@ JOIN "part" AS "part"
LEFT JOIN "_u_0" AS "_u_0" LEFT JOIN "_u_0" AS "_u_0"
ON "_u_0"."_u_1" = "part"."p_partkey" ON "_u_0"."_u_1" = "part"."p_partkey"
WHERE WHERE
"lineitem"."l_quantity" < "_u_0"."_col_0" AND NOT "_u_0"."_u_1" IS NULL; "lineitem"."l_quantity" < "_u_0"."_col_0";
-------------------------------------- --------------------------------------
-- TPC-H 18 -- TPC-H 18
@ -1253,10 +1252,7 @@ WITH "_u_0" AS (
LEFT JOIN "_u_3" AS "_u_3" LEFT JOIN "_u_3" AS "_u_3"
ON "partsupp"."ps_partkey" = "_u_3"."p_partkey" ON "partsupp"."ps_partkey" = "_u_3"."p_partkey"
WHERE WHERE
"partsupp"."ps_availqty" > "_u_0"."_col_0" "partsupp"."ps_availqty" > "_u_0"."_col_0" AND NOT "_u_3"."p_partkey" IS NULL
AND NOT "_u_0"."_u_1" IS NULL
AND NOT "_u_0"."_u_2" IS NULL
AND NOT "_u_3"."p_partkey" IS NULL
GROUP BY GROUP BY
"partsupp"."ps_suppkey" "partsupp"."ps_suppkey"
) )

View file

@ -22,6 +22,8 @@ WHERE
AND x.a > ANY (SELECT y.a FROM y) AND x.a > ANY (SELECT y.a FROM y)
AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a LIMIT 10) AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a LIMIT 10)
AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a OFFSET 10) AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a OFFSET 10)
AND x.a > ALL (SELECT y.c FROM y WHERE y.a = x.a)
AND x.a > (SELECT COUNT(*) as d FROM y WHERE y.a = x.a)
; ;
SELECT SELECT
* *
@ -130,37 +132,42 @@ LEFT JOIN (
y.a y.a
) AS _u_15 ) AS _u_15
ON x.a = _u_15.a ON x.a = _u_15.a
LEFT JOIN (
SELECT
ARRAY_AGG(c),
y.a AS _u_20
FROM y
WHERE
TRUE
GROUP BY
y.a
) AS _u_19
ON _u_19._u_20 = x.a
LEFT JOIN (
SELECT
COUNT(*) AS d,
y.a AS _u_22
FROM y
WHERE
TRUE
GROUP BY
y.a
) AS _u_21
ON _u_21._u_22 = x.a
WHERE WHERE
x.a = _u_0.a x.a = _u_0.a
AND NOT "_u_1"."a" IS NULL AND NOT "_u_1"."a" IS NULL
AND NOT "_u_2"."b" IS NULL AND NOT "_u_2"."b" IS NULL
AND NOT "_u_3"."a" IS NULL AND NOT "_u_3"."a" IS NULL
AND x.a = _u_4.b
AND x.a > _u_6.b
AND x.a = _u_8.a
AND NOT x.a = _u_9.a
AND ARRAY_ANY(_u_10.a, _x -> _x = x.a)
AND ( AND (
x.a = _u_4.b AND NOT _u_4._u_5 IS NULL x.a < _u_12.a AND ARRAY_ANY(_u_12._u_14, "_x" -> _x <> x.d)
)
AND (
x.a > _u_6.b AND NOT _u_6._u_7 IS NULL
)
AND (
None = _u_8.a AND NOT _u_8.a IS NULL
)
AND NOT (
x.a = _u_9.a AND NOT _u_9.a IS NULL
)
AND (
ARRAY_ANY(_u_10.a, _x -> _x = x.a) AND NOT _u_10._u_11 IS NULL
)
AND (
(
(
x.a < _u_12.a AND NOT _u_12._u_13 IS NULL
) AND NOT _u_12._u_13 IS NULL
)
AND ARRAY_ANY(_u_12._u_14, "_x" -> _x <> x.d)
)
AND (
NOT _u_15.a IS NULL AND NOT _u_15.a IS NULL
) )
AND NOT _u_15.a IS NULL
AND x.a IN ( AND x.a IN (
SELECT SELECT
y.a AS a y.a AS a
@ -199,4 +206,6 @@ WHERE
WHERE WHERE
y.a = x.a y.a = x.a
OFFSET 10 OFFSET 10
); )
AND ARRAY_ALL(_u_19."", _x -> _x = x.a)
AND x.a > COALESCE(_u_21.d, 0);

View file

@ -27,8 +27,7 @@ def assert_logger_contains(message, logger, level="error"):
def load_sql_fixtures(filename): def load_sql_fixtures(filename):
with open(os.path.join(FIXTURES_DIR, filename), encoding="utf-8") as f: with open(os.path.join(FIXTURES_DIR, filename), encoding="utf-8") as f:
for sql in _filter_comments(f.read()).splitlines(): yield from _filter_comments(f.read()).splitlines()
yield sql
def load_sql_fixture_pairs(filename): def load_sql_fixture_pairs(filename):

View file

@ -401,6 +401,36 @@ class TestExecutor(unittest.TestCase):
], ],
) )
def test_correlated_count(self):
tables = {
"parts": [{"pnum": 0, "qoh": 1}],
"supplies": [],
}
schema = {
"parts": {"pnum": "int", "qoh": "int"},
"supplies": {"pnum": "int", "shipdate": "int"},
}
self.assertEqual(
execute(
"""
select *
from parts
where parts.qoh >= (
select count(supplies.shipdate) + 1
from supplies
where supplies.pnum = parts.pnum and supplies.shipdate < 10
)
""",
tables=tables,
schema=schema,
).rows,
[
(0, 1),
],
)
def test_table_depth_mismatch(self): def test_table_depth_mismatch(self):
tables = {"table": []} tables = {"table": []}
schema = {"db": {"table": {"col": "VARCHAR"}}} schema = {"db": {"table": {"col": "VARCHAR"}}}

View file

@ -646,3 +646,72 @@ FROM foo""",
exp.Column(this=exp.to_identifier("colb")), exp.Column(this=exp.to_identifier("colb")),
], ],
) )
def test_values(self):
self.assertEqual(
exp.values([(1, 2), (3, 4)], "t", ["a", "b"]).sql(),
"(VALUES (1, 2), (3, 4)) AS t(a, b)",
)
self.assertEqual(
exp.values(
[(1, 2), (3, 4)],
"t",
{"a": exp.DataType.build("TEXT"), "b": exp.DataType.build("TEXT")},
).sql(),
"(VALUES (CAST(1 AS TEXT), CAST(2 AS TEXT)), (3, 4)) AS t(a, b)",
)
with self.assertRaises(ValueError):
exp.values([(1, 2), (3, 4)], columns=["a"])
def test_data_type_builder(self):
self.assertEqual(exp.DataType.build("TEXT").sql(), "TEXT")
self.assertEqual(exp.DataType.build("DECIMAL(10, 2)").sql(), "DECIMAL(10, 2)")
self.assertEqual(exp.DataType.build("VARCHAR(255)").sql(), "VARCHAR(255)")
self.assertEqual(exp.DataType.build("ARRAY<INT>").sql(), "ARRAY<INT>")
self.assertEqual(exp.DataType.build("CHAR").sql(), "CHAR")
self.assertEqual(exp.DataType.build("NCHAR").sql(), "CHAR")
self.assertEqual(exp.DataType.build("VARCHAR").sql(), "VARCHAR")
self.assertEqual(exp.DataType.build("NVARCHAR").sql(), "VARCHAR")
self.assertEqual(exp.DataType.build("TEXT").sql(), "TEXT")
self.assertEqual(exp.DataType.build("BINARY").sql(), "BINARY")
self.assertEqual(exp.DataType.build("VARBINARY").sql(), "VARBINARY")
self.assertEqual(exp.DataType.build("INT").sql(), "INT")
self.assertEqual(exp.DataType.build("TINYINT").sql(), "TINYINT")
self.assertEqual(exp.DataType.build("SMALLINT").sql(), "SMALLINT")
self.assertEqual(exp.DataType.build("BIGINT").sql(), "BIGINT")
self.assertEqual(exp.DataType.build("FLOAT").sql(), "FLOAT")
self.assertEqual(exp.DataType.build("DOUBLE").sql(), "DOUBLE")
self.assertEqual(exp.DataType.build("DECIMAL").sql(), "DECIMAL")
self.assertEqual(exp.DataType.build("BOOLEAN").sql(), "BOOLEAN")
self.assertEqual(exp.DataType.build("JSON").sql(), "JSON")
self.assertEqual(exp.DataType.build("JSONB").sql(), "JSONB")
self.assertEqual(exp.DataType.build("INTERVAL").sql(), "INTERVAL")
self.assertEqual(exp.DataType.build("TIME").sql(), "TIME")
self.assertEqual(exp.DataType.build("TIMESTAMP").sql(), "TIMESTAMP")
self.assertEqual(exp.DataType.build("TIMESTAMPTZ").sql(), "TIMESTAMPTZ")
self.assertEqual(exp.DataType.build("TIMESTAMPLTZ").sql(), "TIMESTAMPLTZ")
self.assertEqual(exp.DataType.build("DATE").sql(), "DATE")
self.assertEqual(exp.DataType.build("DATETIME").sql(), "DATETIME")
self.assertEqual(exp.DataType.build("ARRAY").sql(), "ARRAY")
self.assertEqual(exp.DataType.build("MAP").sql(), "MAP")
self.assertEqual(exp.DataType.build("UUID").sql(), "UUID")
self.assertEqual(exp.DataType.build("GEOGRAPHY").sql(), "GEOGRAPHY")
self.assertEqual(exp.DataType.build("GEOMETRY").sql(), "GEOMETRY")
self.assertEqual(exp.DataType.build("STRUCT").sql(), "STRUCT")
self.assertEqual(exp.DataType.build("NULLABLE").sql(), "NULLABLE")
self.assertEqual(exp.DataType.build("HLLSKETCH").sql(), "HLLSKETCH")
self.assertEqual(exp.DataType.build("HSTORE").sql(), "HSTORE")
self.assertEqual(exp.DataType.build("SUPER").sql(), "SUPER")
self.assertEqual(exp.DataType.build("SERIAL").sql(), "SERIAL")
self.assertEqual(exp.DataType.build("SMALLSERIAL").sql(), "SMALLSERIAL")
self.assertEqual(exp.DataType.build("BIGSERIAL").sql(), "BIGSERIAL")
self.assertEqual(exp.DataType.build("XML").sql(), "XML")
self.assertEqual(exp.DataType.build("UNIQUEIDENTIFIER").sql(), "UNIQUEIDENTIFIER")
self.assertEqual(exp.DataType.build("MONEY").sql(), "MONEY")
self.assertEqual(exp.DataType.build("SMALLMONEY").sql(), "SMALLMONEY")
self.assertEqual(exp.DataType.build("ROWVERSION").sql(), "ROWVERSION")
self.assertEqual(exp.DataType.build("IMAGE").sql(), "IMAGE")
self.assertEqual(exp.DataType.build("VARIANT").sql(), "VARIANT")
self.assertEqual(exp.DataType.build("OBJECT").sql(), "OBJECT")
self.assertEqual(exp.DataType.build("NULL").sql(), "NULL")
self.assertEqual(exp.DataType.build("UNKNOWN").sql(), "UNKNOWN")

View file

@ -299,10 +299,10 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(set(scopes[6].sources), {"q", "z", "r", "s"}) self.assertEqual(set(scopes[6].sources), {"q", "z", "r", "s"})
self.assertEqual(len(scopes[6].columns), 6) self.assertEqual(len(scopes[6].columns), 6)
self.assertEqual(set(c.table for c in scopes[6].columns), {"r", "s"}) self.assertEqual({c.table for c in scopes[6].columns}, {"r", "s"})
self.assertEqual(scopes[6].source_columns("q"), []) self.assertEqual(scopes[6].source_columns("q"), [])
self.assertEqual(len(scopes[6].source_columns("r")), 2) self.assertEqual(len(scopes[6].source_columns("r")), 2)
self.assertEqual(set(c.table for c in scopes[6].source_columns("r")), {"r"}) self.assertEqual({c.table for c in scopes[6].source_columns("r")}, {"r"})
self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"}) self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"})
self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b") self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b")
@ -578,3 +578,16 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
scope_t, scope_y = build_scope(query).cte_scopes scope_t, scope_y = build_scope(query).cte_scopes
self.assertEqual(set(scope_t.cte_sources), {"t"}) self.assertEqual(set(scope_t.cte_sources), {"t"})
self.assertEqual(set(scope_y.cte_sources), {"t", "y"}) self.assertEqual(set(scope_y.cte_sources), {"t", "y"})
def test_schema_with_spaces(self):
schema = {
"a": {
"b c": "text",
'"d e"': "text",
}
}
self.assertEqual(
optimizer.optimize(parse_one("SELECT * FROM a"), schema=schema),
parse_one('SELECT "a"."b c" AS "b c", "a"."d e" AS "d e" FROM "a" AS "a"'),
)

View file

@ -8,7 +8,8 @@ from tests.helpers import assert_logger_contains
class TestParser(unittest.TestCase): class TestParser(unittest.TestCase):
def test_parse_empty(self): def test_parse_empty(self):
self.assertIsNone(parse_one("")) with self.assertRaises(ParseError) as ctx:
parse_one("")
def test_parse_into(self): def test_parse_into(self):
self.assertIsInstance(parse_one("left join foo", into=exp.Join), exp.Join) self.assertIsInstance(parse_one("left join foo", into=exp.Join), exp.Join)
@ -90,6 +91,9 @@ class TestParser(unittest.TestCase):
parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(), parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(),
"""SELECT * FROM x, z CROSS JOIN y LATERAL VIEW EXPLODE(y)""", """SELECT * FROM x, z CROSS JOIN y LATERAL VIEW EXPLODE(y)""",
) )
self.assertIsNone(
parse_one("create table a as (select b from c) index").find(exp.TableAlias)
)
def test_command(self): def test_command(self):
expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1", read="hive") expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1", read="hive")
@ -155,6 +159,11 @@ class TestParser(unittest.TestCase):
assert expressions[0].args["from"].expressions[0].this.name == "a" assert expressions[0].args["from"].expressions[0].this.name == "a"
assert expressions[1].args["from"].expressions[0].this.name == "b" assert expressions[1].args["from"].expressions[0].this.name == "b"
expressions = parse("SELECT 1; ; SELECT 2")
assert len(expressions) == 3
assert expressions[1] is None
def test_expression(self): def test_expression(self):
ignore = Parser(error_level=ErrorLevel.IGNORE) ignore = Parser(error_level=ErrorLevel.IGNORE)
self.assertIsInstance(ignore.expression(exp.Hint, expressions=[""]), exp.Hint) self.assertIsInstance(ignore.expression(exp.Hint, expressions=[""]), exp.Hint)

View file

@ -184,3 +184,19 @@ class TestSchema(unittest.TestCase):
schema = MappingSchema({"foo": {"bar": parse_one("INT", into=exp.DataType)}}) schema = MappingSchema({"foo": {"bar": parse_one("INT", into=exp.DataType)}})
self.assertEqual(schema.get_column_type("foo", "bar").this, exp.DataType.Type.INT) self.assertEqual(schema.get_column_type("foo", "bar").this, exp.DataType.Type.INT)
def test_schema_normalization(self):
schema = MappingSchema(
schema={"x": {"`y`": {"Z": {"a": "INT", "`B`": "VARCHAR"}, "w": {"C": "INT"}}}},
dialect="spark",
)
table_z = exp.Table(this="z", db="y", catalog="x")
table_w = exp.Table(this="w", db="y", catalog="x")
self.assertEqual(schema.column_names(table_z), ["a", "B"])
self.assertEqual(schema.column_names(table_w), ["c"])
# Clickhouse supports both `` and "" for identifier quotes; sqlglot uses "" when generating sql
schema = MappingSchema(schema={"x": {"`y`": "INT"}}, dialect="clickhouse")
self.assertEqual(schema.column_names(exp.Table(this="x")), ["y"])

33
tests/test_serde.py Normal file
View file

@ -0,0 +1,33 @@
import json
import unittest
from sqlglot import exp, parse_one
from sqlglot.optimizer.annotate_types import annotate_types
from tests.helpers import load_sql_fixtures
class CustomExpression(exp.Expression):
...
class TestSerDe(unittest.TestCase):
def dump_load(self, expression):
return exp.Expression.load(json.loads(json.dumps(expression.dump())))
def test_serde(self):
for sql in load_sql_fixtures("identity.sql"):
with self.subTest(sql):
before = parse_one(sql)
after = self.dump_load(before)
self.assertEqual(before, after)
def test_custom_expression(self):
before = CustomExpression()
after = self.dump_load(before)
self.assertEqual(before, after)
def test_type_annotations(self):
before = annotate_types(parse_one("CAST('1' AS INT)"))
after = self.dump_load(before)
self.assertEqual(before.type, after.type)
self.assertEqual(before.this.type, after.this.type)

View file

@ -1,7 +1,11 @@
import unittest import unittest
from sqlglot import parse_one from sqlglot import parse_one
from sqlglot.transforms import eliminate_distinct_on, unalias_group from sqlglot.transforms import (
eliminate_distinct_on,
remove_precision_parameterized_types,
unalias_group,
)
class TestTime(unittest.TestCase): class TestTime(unittest.TestCase):
@ -62,3 +66,10 @@ class TestTime(unittest.TestCase):
"SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC", "SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC",
'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) WHERE "_row_number_2" = 1', 'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) WHERE "_row_number_2" = 1',
) )
def test_remove_precision_parameterized_types(self):
self.validate(
remove_precision_parameterized_types,
"SELECT CAST(1 AS DECIMAL(10, 2)), CAST('13' AS VARCHAR(10))",
"SELECT CAST(1 AS DECIMAL), CAST('13' AS VARCHAR)",
)

View file

@ -117,6 +117,11 @@ class TestTranspile(unittest.TestCase):
"select x from foo -- x", "select x from foo -- x",
"SELECT x FROM foo /* x */", "SELECT x FROM foo /* x */",
) )
self.validate(
"""select x, --
from foo""",
"SELECT x FROM foo",
)
self.validate( self.validate(
""" """
-- comment 1 -- comment 1