1
0
Fork 0

Merging upstream version 10.1.3.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:56:25 +01:00
parent 582b160275
commit a5128ea109
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
57 changed files with 1542 additions and 529 deletions

View file

@ -1,6 +1,47 @@
Changelog Changelog
========= =========
v10.1.0
------
Changes:
- Breaking: [refactored](https://github.com/tobymao/sqlglot/commit/6b0da1e1a2b5d6bdf7b5b918400456422d30a1d4) the way SQL comments are handled. Before at most one comment could be attached to an expression, now multiple comments may be stored in a list.
- Breaking: [refactored](https://github.com/tobymao/sqlglot/commit/be332d10404f36b43ea6ece956a73bf451348641) the way properties are represented and parsed. The argument `this` now stores a property's attributes instead of its name.
- New: added structured ParseError properties.
- New: the executor now handles set operations.
- New: sqlglot can [now execute SQL queries](https://github.com/tobymao/sqlglot/commit/62d3496e761a4f38dfa61af793062690923dce74) using python objects.
- New: added support for the [Drill dialect](https://github.com/tobymao/sqlglot/commit/543eca314546e0bd42f97c354807b4e398ab36ec).
- New: added a `canonicalize` method which leverages existing type information for an expression to apply various transformations to it.
- New: TRIM function support for Snowflake and Bigquery.
- New: added support for SQLite primary key ordering constraints (ASC, DESC).
- New: added support for Redshift DISTKEY / SORTKEY / DISTSTYLE properties.
- New: added support for SET TRANSACTION MySQL statements.
- New: added `null`, `true`, `false` helper methods to avoid using singleton expressions.
- Improvement: allow multiple aggregations in an expression.
- Improvement: execution of left / right joins.
- Improvement: execution of aggregations without the GROUP BY clause.
- Improvement: static query execution (e.g. SELECT 1, SELECT CONCAT('a', 'b') AS x, etc).
- Improvement: include a rule for type inference in the optimizer.
- Improvement: transaction, commit expressions parsed [at finer granularity](https://github.com/tobymao/sqlglot/commit/148282e710fd79512bb7d32e6e519d631df8115d).
v10.0.0 v10.0.0
------ ------

View file

@ -25,6 +25,7 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/
* [AST Introspection](#ast-introspection) * [AST Introspection](#ast-introspection)
* [AST Diff](#ast-diff) * [AST Diff](#ast-diff)
* [Custom Dialects](#custom-dialects) * [Custom Dialects](#custom-dialects)
* [SQL Execution](#sql-execution)
* [Benchmarks](#benchmarks) * [Benchmarks](#benchmarks)
* [Optional Dependencies](#optional-dependencies) * [Optional Dependencies](#optional-dependencies)
@ -147,9 +148,9 @@ print(sqlglot.transpile(sql, read='mysql', pretty=True)[0])
*/ */
SELECT SELECT
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */, tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
CAST(x AS INT), -- comment 3 CAST(x AS INT), /* comment 3 */
y -- comment 4 y /* comment 4 */
FROM bar /* comment 5 */, tbl /* comment 6*/ FROM bar /* comment 5 */, tbl /* comment 6 */
``` ```
@ -189,6 +190,28 @@ sqlglot.errors.ParseError: Expecting ). Line 1, Col: 13.
~~~~ ~~~~
``` ```
Structured syntax errors are accessible for programmatic use:
```python
import sqlglot
try:
sqlglot.transpile("SELECT foo( FROM bar")
except sqlglot.errors.ParseError as e:
print(e.errors)
```
Output:
```python
[{
'description': 'Expecting )',
'line': 1,
'col': 13,
'start_context': 'SELECT foo( ',
'highlight': 'FROM',
'end_context': ' bar'
}]
```
### Unsupported Errors ### Unsupported Errors
Presto `APPROX_DISTINCT` supports the accuracy argument which is not supported in Hive: Presto `APPROX_DISTINCT` supports the accuracy argument which is not supported in Hive:
@ -372,6 +395,53 @@ print(Dialect["custom"])
<class '__main__.Custom'> <class '__main__.Custom'>
``` ```
### SQL Execution
One can even interpret SQL queries using SQLGlot, where the tables are represented as Python dictionaries. Although the engine is not very fast (it's not supposed to be) and is in a relatively early stage of development, it can be useful for unit testing and running SQL natively across Python objects. Additionally, the foundation can be easily integrated with fast compute kernels (arrow, pandas). Below is an example showcasing the execution of a SELECT expression that involves aggregations and JOINs:
```python
from sqlglot.executor import execute
tables = {
"sushi": [
{"id": 1, "price": 1.0},
{"id": 2, "price": 2.0},
{"id": 3, "price": 3.0},
],
"order_items": [
{"sushi_id": 1, "order_id": 1},
{"sushi_id": 1, "order_id": 1},
{"sushi_id": 2, "order_id": 1},
{"sushi_id": 3, "order_id": 2},
],
"orders": [
{"id": 1, "user_id": 1},
{"id": 2, "user_id": 2},
],
}
execute(
"""
SELECT
o.user_id,
SUM(s.price) AS price
FROM orders o
JOIN order_items i
ON o.id = i.order_id
JOIN sushi s
ON i.sushi_id = s.id
GROUP BY o.user_id
""",
tables=tables
)
```
```python
user_id price
1 4.0
2 3.0
```
## Benchmarks ## Benchmarks
[Benchmarks](benchmarks) run on Python 3.10.5 in seconds. [Benchmarks](benchmarks) run on Python 3.10.5 in seconds.

View file

@ -30,7 +30,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType from sqlglot.tokens import Tokenizer, TokenType
__version__ = "10.0.8" __version__ = "10.1.3"
pretty = False pretty = False

View file

@ -56,12 +56,12 @@ def _derived_table_values_to_unnest(self, expression):
def _returnsproperty_sql(self, expression): def _returnsproperty_sql(self, expression):
value = expression.args.get("value") this = expression.this
if isinstance(value, exp.Schema): if isinstance(this, exp.Schema):
value = f"{value.this} <{self.expressions(value)}>" this = f"{this.this} <{self.expressions(this)}>"
else: else:
value = self.sql(value) this = self.sql(this)
return f"RETURNS {value}" return f"RETURNS {this}"
def _create_sql(self, expression): def _create_sql(self, expression):
@ -142,6 +142,11 @@ class BigQuery(Dialect):
), ),
} }
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
}
FUNCTION_PARSERS.pop("TRIM")
NO_PAREN_FUNCTIONS = { NO_PAREN_FUNCTIONS = {
**parser.Parser.NO_PAREN_FUNCTIONS, **parser.Parser.NO_PAREN_FUNCTIONS,
TokenType.CURRENT_DATETIME: exp.CurrentDatetime, TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
@ -174,6 +179,7 @@ class BigQuery(Dialect):
exp.Values: _derived_table_values_to_unnest, exp.Values: _derived_table_values_to_unnest,
exp.ReturnsProperty: _returnsproperty_sql, exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql, exp.Create: _create_sql,
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE" if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC", else "NOT DETERMINISTIC",
@ -200,9 +206,7 @@ class BigQuery(Dialect):
exp.VolatilityProperty, exp.VolatilityProperty,
} }
WITH_PROPERTIES = { WITH_PROPERTIES = {exp.Property}
exp.AnonymousProperty,
}
EXPLICIT_UNION = True EXPLICIT_UNION = True

View file

@ -21,14 +21,15 @@ class ClickHouse(Dialect):
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"FINAL": TokenType.FINAL, "ASOF": TokenType.ASOF,
"DATETIME64": TokenType.DATETIME, "DATETIME64": TokenType.DATETIME,
"INT8": TokenType.TINYINT, "FINAL": TokenType.FINAL,
"FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE,
"INT16": TokenType.SMALLINT, "INT16": TokenType.SMALLINT,
"INT32": TokenType.INT, "INT32": TokenType.INT,
"INT64": TokenType.BIGINT, "INT64": TokenType.BIGINT,
"FLOAT32": TokenType.FLOAT, "INT8": TokenType.TINYINT,
"FLOAT64": TokenType.DOUBLE,
"TUPLE": TokenType.STRUCT, "TUPLE": TokenType.STRUCT,
} }
@ -38,6 +39,10 @@ class ClickHouse(Dialect):
"MAP": parse_var_map, "MAP": parse_var_map,
} }
JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF}
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY}
def _parse_table(self, schema=False): def _parse_table(self, schema=False):
this = super()._parse_table(schema) this = super()._parse_table(schema)

View file

@ -289,19 +289,19 @@ def struct_extract_sql(self, expression):
return f"{this}.{struct_key}" return f"{this}.{struct_key}"
def var_map_sql(self, expression): def var_map_sql(self, expression, map_func_name="MAP"):
keys = expression.args["keys"] keys = expression.args["keys"]
values = expression.args["values"] values = expression.args["values"]
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
self.unsupported("Cannot convert array columns into map.") self.unsupported("Cannot convert array columns into map.")
return f"MAP({self.format_args(keys, values)})" return f"{map_func_name}({self.format_args(keys, values)})"
args = [] args = []
for key, value in zip(keys.expressions, values.expressions): for key, value in zip(keys.expressions, values.expressions):
args.append(self.sql(key)) args.append(self.sql(key))
args.append(self.sql(value)) args.append(self.sql(value))
return f"MAP({self.format_args(*args)})" return f"{map_func_name}({self.format_args(*args)})"
def format_time_lambda(exp_class, dialect, default=None): def format_time_lambda(exp_class, dialect, default=None):
@ -336,18 +336,13 @@ def create_with_partitions_sql(self, expression):
if has_schema and is_partitionable: if has_schema and is_partitionable:
expression = expression.copy() expression = expression.copy()
prop = expression.find(exp.PartitionedByProperty) prop = expression.find(exp.PartitionedByProperty)
value = prop and prop.args.get("value") this = prop and prop.this
if prop and not isinstance(value, exp.Schema): if prop and not isinstance(this, exp.Schema):
schema = expression.this schema = expression.this
columns = {v.name.upper() for v in value.expressions} columns = {v.name.upper() for v in this.expressions}
partitions = [col for col in schema.expressions if col.name.upper() in columns] partitions = [col for col in schema.expressions if col.name.upper() in columns]
schema.set( schema.set("expressions", [e for e in schema.expressions if e not in partitions])
"expressions", prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
[e for e in schema.expressions if e not in partitions],
)
prop.replace(
exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))
)
expression.set("this", schema) expression.set("this", schema)
return self.create_sql(expression) return self.create_sql(expression)

View file

@ -153,7 +153,7 @@ class Drill(Dialect):
exp.If: if_sql, exp.If: if_sql,
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}", exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}", exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Pivot: no_pivot_sql, exp.Pivot: no_pivot_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"), exp.RegexpLike: rename_func("REGEXP_MATCHES"),
exp.StrPosition: str_position_sql, exp.StrPosition: str_position_sql,

View file

@ -61,9 +61,7 @@ def _array_sort(self, expression):
def _property_sql(self, expression): def _property_sql(self, expression):
key = expression.name return f"'{expression.name}'={self.sql(expression, 'value')}"
value = self.sql(expression, "value")
return f"'{key}'={value}"
def _str_to_unix(self, expression): def _str_to_unix(self, expression):
@ -250,7 +248,7 @@ class Hive(Dialect):
TRANSFORMS = { TRANSFORMS = {
**generator.Generator.TRANSFORMS, **generator.Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore
exp.AnonymousProperty: _property_sql, exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql, exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayAgg: rename_func("COLLECT_LIST"), exp.ArrayAgg: rename_func("COLLECT_LIST"),
exp.ArrayConcat: rename_func("CONCAT"), exp.ArrayConcat: rename_func("CONCAT"),
@ -262,7 +260,7 @@ class Hive(Dialect):
exp.DateStrToDate: rename_func("TO_DATE"), exp.DateStrToDate: rename_func("TO_DATE"),
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)", exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})", exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
exp.FileFormatProperty: lambda self, e: f"STORED AS {e.text('value').upper()}", exp.FileFormatProperty: lambda self, e: f"STORED AS {e.name.upper()}",
exp.If: if_sql, exp.If: if_sql,
exp.Index: _index_sql, exp.Index: _index_sql,
exp.ILike: no_ilike_sql, exp.ILike: no_ilike_sql,
@ -285,7 +283,7 @@ class Hive(Dialect):
exp.StrToTime: _str_to_time, exp.StrToTime: _str_to_time,
exp.StrToUnix: _str_to_unix, exp.StrToUnix: _str_to_unix,
exp.StructExtract: struct_extract_sql, exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'value')}", exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
exp.TimeStrToDate: rename_func("TO_DATE"), exp.TimeStrToDate: rename_func("TO_DATE"),
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
@ -298,11 +296,11 @@ class Hive(Dialect):
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.format_args(e.this, _time_format(self, e))})", exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.format_args(e.this, _time_format(self, e))})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'value')}", exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
exp.NumberToStr: rename_func("FORMAT_NUMBER"), exp.NumberToStr: rename_func("FORMAT_NUMBER"),
} }
WITH_PROPERTIES = {exp.AnonymousProperty} WITH_PROPERTIES = {exp.Property}
ROOT_PROPERTIES = { ROOT_PROPERTIES = {
exp.PartitionedByProperty, exp.PartitionedByProperty,

View file

@ -453,6 +453,7 @@ class MySQL(Dialect):
exp.CharacterSetProperty, exp.CharacterSetProperty,
exp.CollateProperty, exp.CollateProperty,
exp.SchemaCommentProperty, exp.SchemaCommentProperty,
exp.LikeProperty,
} }
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set() WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()

View file

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from sqlglot import exp, generator, tokens, transforms from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import Dialect, no_ilike_sql from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func
from sqlglot.helper import csv from sqlglot.helper import csv
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
@ -37,6 +37,12 @@ class Oracle(Dialect):
"YYYY": "%Y", # 2015 "YYYY": "%Y", # 2015
} }
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DECODE": exp.Matches.from_arg_list,
}
class Generator(generator.Generator): class Generator(generator.Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING,
@ -58,6 +64,7 @@ class Oracle(Dialect):
**transforms.UNALIAS_GROUP, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore
exp.ILike: no_ilike_sql, exp.ILike: no_ilike_sql,
exp.Limit: _limit_sql, exp.Limit: _limit_sql,
exp.Matches: rename_func("DECODE"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",

View file

@ -74,6 +74,27 @@ def _trim_sql(self, expression):
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def _string_agg_sql(self, expression):
expression = expression.copy()
separator = expression.args.get("separator") or exp.Literal.string(",")
order = ""
this = expression.this
if isinstance(this, exp.Order):
if this.this:
this = this.this
this.pop()
order = self.sql(expression.this) # Order has a leading space
return f"STRING_AGG({self.format_args(this, separator)}{order})"
def _datatype_sql(self, expression):
if expression.this == exp.DataType.Type.ARRAY:
return f"{self.expressions(expression, flat=True)}[]"
return self.datatype_sql(expression)
def _auto_increment_to_serial(expression): def _auto_increment_to_serial(expression):
auto = expression.find(exp.AutoIncrementColumnConstraint) auto = expression.find(exp.AutoIncrementColumnConstraint)
@ -191,25 +212,27 @@ class Postgres(Dialect):
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"ALWAYS": TokenType.ALWAYS, "ALWAYS": TokenType.ALWAYS,
"BY DEFAULT": TokenType.BY_DEFAULT,
"IDENTITY": TokenType.IDENTITY,
"GENERATED": TokenType.GENERATED,
"DOUBLE PRECISION": TokenType.DOUBLE,
"BIGSERIAL": TokenType.BIGSERIAL,
"SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
"UUID": TokenType.UUID,
"TEMP": TokenType.TEMPORARY,
"BEGIN TRANSACTION": TokenType.BEGIN,
"BEGIN": TokenType.COMMAND, "BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"BIGSERIAL": TokenType.BIGSERIAL,
"BY DEFAULT": TokenType.BY_DEFAULT,
"COMMENT ON": TokenType.COMMAND, "COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND, "DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND, "DO": TokenType.COMMAND,
"DOUBLE PRECISION": TokenType.DOUBLE,
"GENERATED": TokenType.GENERATED,
"GRANT": TokenType.COMMAND,
"HSTORE": TokenType.HSTORE,
"IDENTITY": TokenType.IDENTITY,
"JSONB": TokenType.JSONB,
"REFRESH": TokenType.COMMAND, "REFRESH": TokenType.COMMAND,
"REINDEX": TokenType.COMMAND, "REINDEX": TokenType.COMMAND,
"RESET": TokenType.COMMAND, "RESET": TokenType.COMMAND,
"REVOKE": TokenType.COMMAND, "REVOKE": TokenType.COMMAND,
"GRANT": TokenType.COMMAND, "SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
"TEMP": TokenType.TEMPORARY,
"UUID": TokenType.UUID,
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES}, **{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES}, **{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
} }
@ -265,4 +288,7 @@ class Postgres(Dialect):
exp.Trim: _trim_sql, exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql, exp.TryCast: no_trycast_sql,
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
exp.DataType: _datatype_sql,
exp.GroupConcat: _string_agg_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
} }

View file

@ -171,16 +171,7 @@ class Presto(Dialect):
STRUCT_DELIMITER = ("(", ")") STRUCT_DELIMITER = ("(", ")")
ROOT_PROPERTIES = { ROOT_PROPERTIES = {exp.SchemaCommentProperty}
exp.SchemaCommentProperty,
}
WITH_PROPERTIES = {
exp.PartitionedByProperty,
exp.FileFormatProperty,
exp.AnonymousProperty,
exp.TableFormatProperty,
}
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING,
@ -231,7 +222,8 @@ class Presto(Dialect):
exp.StrToTime: _str_to_time_sql, exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.StructExtract: struct_extract_sql, exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT = '{e.text('value').upper()}'", exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.TimeStrToDate: _date_parse_sql, exp.TimeStrToDate: _date_parse_sql,
exp.TimeStrToTime: _date_parse_sql, exp.TimeStrToTime: _date_parse_sql,
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))", exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",

View file

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from sqlglot import exp from sqlglot import exp, transforms
from sqlglot.dialects.postgres import Postgres from sqlglot.dialects.postgres import Postgres
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
@ -18,12 +18,14 @@ class Redshift(Postgres):
KEYWORDS = { KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS, # type: ignore **Postgres.Tokenizer.KEYWORDS, # type: ignore
"COPY": TokenType.COMMAND,
"GEOMETRY": TokenType.GEOMETRY, "GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY, "GEOGRAPHY": TokenType.GEOGRAPHY,
"HLLSKETCH": TokenType.HLLSKETCH, "HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER, "SUPER": TokenType.SUPER,
"TIME": TokenType.TIMESTAMP, "TIME": TokenType.TIMESTAMP,
"TIMETZ": TokenType.TIMESTAMPTZ, "TIMETZ": TokenType.TIMESTAMPTZ,
"UNLOAD": TokenType.COMMAND,
"VARBYTE": TokenType.VARBINARY, "VARBYTE": TokenType.VARBINARY,
"SIMILAR TO": TokenType.SIMILAR_TO, "SIMILAR TO": TokenType.SIMILAR_TO,
} }
@ -35,3 +37,17 @@ class Redshift(Postgres):
exp.DataType.Type.VARBINARY: "VARBYTE", exp.DataType.Type.VARBINARY: "VARBYTE",
exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.INT: "INTEGER",
} }
ROOT_PROPERTIES = {
exp.DistKeyProperty,
exp.SortKeyProperty,
exp.DistStyleProperty,
}
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
}

View file

@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
format_time_lambda, format_time_lambda,
inline_array_sql, inline_array_sql,
rename_func, rename_func,
var_map_sql,
) )
from sqlglot.expressions import Literal from sqlglot.expressions import Literal
from sqlglot.helper import seq_get from sqlglot.helper import seq_get
@ -100,6 +101,14 @@ def _parse_date_part(self):
return self.expression(exp.Extract, this=this, expression=expression) return self.expression(exp.Extract, this=this, expression=expression)
def _datatype_sql(self, expression):
if expression.this == exp.DataType.Type.ARRAY:
return "ARRAY"
elif expression.this == exp.DataType.Type.MAP:
return "OBJECT"
return self.datatype_sql(expression)
class Snowflake(Dialect): class Snowflake(Dialect):
null_ordering = "nulls_are_large" null_ordering = "nulls_are_large"
time_format = "'yyyy-mm-dd hh24:mi:ss'" time_format = "'yyyy-mm-dd hh24:mi:ss'"
@ -142,6 +151,8 @@ class Snowflake(Dialect):
"TO_TIMESTAMP": _snowflake_to_timestamp, "TO_TIMESTAMP": _snowflake_to_timestamp,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list, "ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list,
"DECODE": exp.Matches.from_arg_list,
"OBJECT_CONSTRUCT": parser.parse_var_map,
} }
FUNCTION_PARSERS = { FUNCTION_PARSERS = {
@ -195,16 +206,20 @@ class Snowflake(Dialect):
TRANSFORMS = { TRANSFORMS = {
**generator.Generator.TRANSFORMS, **generator.Generator.TRANSFORMS,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.If: rename_func("IFF"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Array: inline_array_sql, exp.Array: inline_array_sql,
exp.StrPosition: rename_func("POSITION"), exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}", exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}", exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Matches: rename_func("DECODE"),
exp.StrPosition: rename_func("POSITION"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})", exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
exp.UnixToTime: _unix_to_time_sql,
} }
TYPE_MAPPING = { TYPE_MAPPING = {

View file

@ -98,7 +98,7 @@ class Spark(Hive):
TRANSFORMS = { TRANSFORMS = {
**Hive.Generator.TRANSFORMS, # type: ignore **Hive.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}", exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),

View file

@ -13,6 +13,23 @@ from sqlglot.dialects.dialect import (
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
# https://www.sqlite.org/lang_aggfunc.html#group_concat
def _group_concat_sql(self, expression):
this = expression.this
distinct = expression.find(exp.Distinct)
if distinct:
this = distinct.expressions[0]
distinct = "DISTINCT "
if isinstance(expression.this, exp.Order):
self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.")
if expression.this.this and not distinct:
this = expression.this.this
separator = expression.args.get("separator")
return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"
class SQLite(Dialect): class SQLite(Dialect):
class Tokenizer(tokens.Tokenizer): class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"] IDENTIFIERS = ['"', ("[", "]"), "`"]
@ -62,6 +79,7 @@ class SQLite(Dialect):
exp.Levenshtein: rename_func("EDITDIST3"), exp.Levenshtein: rename_func("EDITDIST3"),
exp.TableSample: no_tablesample_sql, exp.TableSample: no_tablesample_sql,
exp.TryCast: no_trycast_sql, exp.TryCast: no_trycast_sql,
exp.GroupConcat: _group_concat_sql,
} }
def transaction_sql(self, expression): def transaction_sql(self, expression):

View file

@ -17,6 +17,7 @@ FULL_FORMAT_TIME_MAPPING = {
"mm": "%B", "mm": "%B",
"m": "%B", "m": "%B",
} }
DATE_DELTA_INTERVAL = { DATE_DELTA_INTERVAL = {
"year": "year", "year": "year",
"yyyy": "year", "yyyy": "year",
@ -37,11 +38,12 @@ DATE_DELTA_INTERVAL = {
DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{1,2})") DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{1,2})")
# N = Numeric, C=Currency # N = Numeric, C=Currency
TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"} TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None): def _format_time_lambda(exp_class, full_format_mapping=None, default=None):
def _format_time(args): def _format_time(args):
return exp_class( return exp_class(
this=seq_get(args, 1), this=seq_get(args, 1),
@ -58,7 +60,7 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
return _format_time return _format_time
def parse_format(args): def _parse_format(args):
fmt = seq_get(args, 1) fmt = seq_get(args, 1)
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this) number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
if number_fmt: if number_fmt:
@ -78,7 +80,7 @@ def generate_date_delta_with_unit_sql(self, e):
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})" return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
def generate_format_sql(self, e): def _format_sql(self, e):
fmt = ( fmt = (
e.args["format"] e.args["format"]
if isinstance(e, exp.NumberToStr) if isinstance(e, exp.NumberToStr)
@ -87,6 +89,28 @@ def generate_format_sql(self, e):
return f"FORMAT({self.format_args(e.this, fmt)})" return f"FORMAT({self.format_args(e.this, fmt)})"
def _string_agg_sql(self, e):
e = e.copy()
this = e.this
distinct = e.find(exp.Distinct)
if distinct:
# exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression
self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.")
this = distinct.expressions[0]
distinct.pop()
order = ""
if isinstance(e.this, exp.Order):
if e.this.this:
this = e.this.this
e.this.this.pop()
order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space
separator = e.args.get("separator") or exp.Literal.string(",")
return f"STRING_AGG({self.format_args(this, separator)}){order}"
class TSQL(Dialect): class TSQL(Dialect):
null_ordering = "nulls_are_small" null_ordering = "nulls_are_small"
time_format = "'yyyy-mm-dd hh:mm:ss'" time_format = "'yyyy-mm-dd hh:mm:ss'"
@ -228,14 +252,14 @@ class TSQL(Dialect):
"ISNULL": exp.Coalesce.from_arg_list, "ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
"DATENAME": tsql_format_time_lambda(exp.TimeToStr, full_format_mapping=True), "DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
"DATEPART": tsql_format_time_lambda(exp.TimeToStr), "DATEPART": _format_time_lambda(exp.TimeToStr),
"GETDATE": exp.CurrentDate.from_arg_list, "GETDATE": exp.CurrentDate.from_arg_list,
"IIF": exp.If.from_arg_list, "IIF": exp.If.from_arg_list,
"LEN": exp.Length.from_arg_list, "LEN": exp.Length.from_arg_list,
"REPLICATE": exp.Repeat.from_arg_list, "REPLICATE": exp.Repeat.from_arg_list,
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list, "JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
"FORMAT": parse_format, "FORMAT": _parse_format,
} }
VAR_LENGTH_DATATYPES = { VAR_LENGTH_DATATYPES = {
@ -298,6 +322,7 @@ class TSQL(Dialect):
exp.DateDiff: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"), exp.CurrentDate: rename_func("GETDATE"),
exp.If: rename_func("IIF"), exp.If: rename_func("IIF"),
exp.NumberToStr: generate_format_sql, exp.NumberToStr: _format_sql,
exp.TimeToStr: generate_format_sql, exp.TimeToStr: _format_sql,
exp.GroupConcat: _string_agg_sql,
} }

View file

@ -22,7 +22,40 @@ class UnsupportedError(SqlglotError):
class ParseError(SqlglotError): class ParseError(SqlglotError):
pass def __init__(
self,
message: str,
errors: t.Optional[t.List[t.Dict[str, t.Any]]] = None,
):
super().__init__(message)
self.errors = errors or []
@classmethod
def new(
cls,
message: str,
description: t.Optional[str] = None,
line: t.Optional[int] = None,
col: t.Optional[int] = None,
start_context: t.Optional[str] = None,
highlight: t.Optional[str] = None,
end_context: t.Optional[str] = None,
into_expression: t.Optional[str] = None,
) -> ParseError:
return cls(
message,
[
{
"description": description,
"line": line,
"col": col,
"start_context": start_context,
"highlight": highlight,
"end_context": end_context,
"into_expression": into_expression,
}
],
)
class TokenError(SqlglotError): class TokenError(SqlglotError):
@ -41,9 +74,13 @@ class ExecuteError(SqlglotError):
pass pass
def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str: def concat_messages(errors: t.Sequence[t.Any], maximum: int) -> str:
msg = [str(e) for e in errors[:maximum]] msg = [str(e) for e in errors[:maximum]]
remaining = len(errors) - maximum remaining = len(errors) - maximum
if remaining > 0: if remaining > 0:
msg.append(f"... and {remaining} more") msg.append(f"... and {remaining} more")
return "\n\n".join(msg) return "\n\n".join(msg)
def merge_errors(errors: t.Sequence[ParseError]) -> t.List[t.Dict[str, t.Any]]:
return [e_dict for error in errors for e_dict in error.errors]

View file

@ -122,7 +122,6 @@ def interval(this, unit):
ENV = { ENV = {
"__builtins__": {},
"exp": exp, "exp": exp,
# aggs # aggs
"SUM": filter_nulls(sum), "SUM": filter_nulls(sum),

View file

@ -115,6 +115,9 @@ class PythonExecutor:
sink = self.table(context.columns) sink = self.table(context.columns)
for reader in table_iter: for reader in table_iter:
if len(sink) >= step.limit:
break
if condition and not context.eval(condition): if condition and not context.eval(condition):
continue continue
@ -123,9 +126,6 @@ class PythonExecutor:
else: else:
sink.append(reader.row) sink.append(reader.row)
if len(sink) >= step.limit:
break
return self.context({step.name: sink}) return self.context({step.name: sink})
def static(self): def static(self):
@ -288,7 +288,13 @@ class PythonExecutor:
end = 1 end = 1
length = len(context.table) length = len(context.table)
table = self.table(list(step.group) + step.aggregations) table = self.table(list(step.group) + step.aggregations)
condition = self.generate(step.condition)
def add_row():
if not condition or context.eval(condition):
table.append(group + context.eval_tuple(aggregations))
if length:
for i in range(length): for i in range(length):
context.set_index(i) context.set_index(i)
key = context.eval_tuple(group_by) key = context.eval_tuple(group_by)
@ -296,12 +302,17 @@ class PythonExecutor:
end += 1 end += 1
if key != group: if key != group:
context.set_range(start, end - 2) context.set_range(start, end - 2)
table.append(group + context.eval_tuple(aggregations)) add_row()
group = key group = key
start = end - 2 start = end - 2
if len(table.rows) >= step.limit:
break
if i == length - 1: if i == length - 1:
context.set_range(start, end - 1) context.set_range(start, end - 1)
table.append(group + context.eval_tuple(aggregations)) add_row()
elif step.limit > 0:
context.set_range(0, 0)
table.append(context.eval_tuple(group_by) + context.eval_tuple(aggregations))
context = self.context({step.name: table, **{name: table for name in context.tables}}) context = self.context({step.name: table, **{name: table for name in context.tables}})
@ -311,11 +322,9 @@ class PythonExecutor:
def sort(self, step, context): def sort(self, step, context):
projections = self.generate_tuple(step.projections) projections = self.generate_tuple(step.projections)
projection_columns = [p.alias_or_name for p in step.projections] projection_columns = [p.alias_or_name for p in step.projections]
all_columns = list(context.columns) + projection_columns all_columns = list(context.columns) + projection_columns
sink = self.table(all_columns) sink = self.table(all_columns)
for reader, ctx in context: for reader, ctx in context:
sink.append(reader.row + ctx.eval_tuple(projections)) sink.append(reader.row + ctx.eval_tuple(projections))
@ -401,8 +410,9 @@ class Python(Dialect):
exp.Boolean: lambda self, e: "True" if e.this else "False", exp.Boolean: lambda self, e: "True" if e.this else "False",
exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})", exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]", exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})", exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}", exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
exp.Is: lambda self, e: self.binary(e, "is"), exp.Is: lambda self, e: self.binary(e, "is"),
exp.Not: lambda self, e: f"not {self.sql(e.this)}", exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None", exp.Null: lambda *_: "None",

View file

@ -43,14 +43,14 @@ class Expression(metaclass=_Expression):
key = "Expression" key = "Expression"
arg_types = {"this": True} arg_types = {"this": True}
__slots__ = ("args", "parent", "arg_key", "type", "comment") __slots__ = ("args", "parent", "arg_key", "type", "comments")
def __init__(self, **args): def __init__(self, **args):
self.args = args self.args = args
self.parent = None self.parent = None
self.arg_key = None self.arg_key = None
self.type = None self.type = None
self.comment = None self.comments = None
for arg_key, value in self.args.items(): for arg_key, value in self.args.items():
self._set_parent(arg_key, value) self._set_parent(arg_key, value)
@ -88,19 +88,6 @@ class Expression(metaclass=_Expression):
return field.this return field.this
return "" return ""
def find_comment(self, key: str) -> str:
"""
Finds the comment that is attached to a specified child node.
Args:
key: the key of the target child node (e.g. "this", "expression", etc).
Returns:
The comment attached to the child node, or the empty string, if it doesn't exist.
"""
field = self.args.get(key)
return field.comment if isinstance(field, Expression) else ""
@property @property
def is_string(self): def is_string(self):
return isinstance(self, Literal) and self.args["is_string"] return isinstance(self, Literal) and self.args["is_string"]
@ -137,7 +124,7 @@ class Expression(metaclass=_Expression):
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
copy = self.__class__(**deepcopy(self.args)) copy = self.__class__(**deepcopy(self.args))
copy.comment = self.comment copy.comments = self.comments
copy.type = self.type copy.type = self.type
return copy return copy
@ -369,7 +356,7 @@ class Expression(metaclass=_Expression):
) )
for k, vs in self.args.items() for k, vs in self.args.items()
} }
args["comment"] = self.comment args["comments"] = self.comments
args["type"] = self.type args["type"] = self.type
args = {k: v for k, v in args.items() if v or not hide_missing} args = {k: v for k, v in args.items() if v or not hide_missing}
@ -767,7 +754,7 @@ class NotNullColumnConstraint(ColumnConstraintKind):
class PrimaryKeyColumnConstraint(ColumnConstraintKind): class PrimaryKeyColumnConstraint(ColumnConstraintKind):
pass arg_types = {"desc": False}
class UniqueColumnConstraint(ColumnConstraintKind): class UniqueColumnConstraint(ColumnConstraintKind):
@ -819,6 +806,12 @@ class Unique(Expression):
arg_types = {"expressions": True} arg_types = {"expressions": True}
# https://www.postgresql.org/docs/9.1/sql-selectinto.html
# https://docs.aws.amazon.com/redshift/latest/dg/r_SELECT_INTO.html#r_SELECT_INTO-examples
class Into(Expression):
arg_types = {"this": True, "temporary": False, "unlogged": False}
class From(Expression): class From(Expression):
arg_types = {"expressions": True} arg_types = {"expressions": True}
@ -1065,67 +1058,67 @@ class Property(Expression):
class TableFormatProperty(Property): class TableFormatProperty(Property):
pass arg_types = {"this": True}
class PartitionedByProperty(Property): class PartitionedByProperty(Property):
pass arg_types = {"this": True}
class FileFormatProperty(Property): class FileFormatProperty(Property):
pass arg_types = {"this": True}
class DistKeyProperty(Property): class DistKeyProperty(Property):
pass arg_types = {"this": True}
class SortKeyProperty(Property): class SortKeyProperty(Property):
pass arg_types = {"this": True, "compound": False}
class DistStyleProperty(Property): class DistStyleProperty(Property):
pass arg_types = {"this": True}
class LikeProperty(Property):
arg_types = {"this": True, "expressions": False}
class LocationProperty(Property): class LocationProperty(Property):
pass arg_types = {"this": True}
class EngineProperty(Property): class EngineProperty(Property):
pass arg_types = {"this": True}
class AutoIncrementProperty(Property): class AutoIncrementProperty(Property):
pass arg_types = {"this": True}
class CharacterSetProperty(Property): class CharacterSetProperty(Property):
arg_types = {"this": True, "value": True, "default": True} arg_types = {"this": True, "default": True}
class CollateProperty(Property): class CollateProperty(Property):
pass arg_types = {"this": True}
class SchemaCommentProperty(Property): class SchemaCommentProperty(Property):
pass arg_types = {"this": True}
class AnonymousProperty(Property):
pass
class ReturnsProperty(Property): class ReturnsProperty(Property):
arg_types = {"this": True, "value": True, "is_table": False} arg_types = {"this": True, "is_table": False}
class LanguageProperty(Property): class LanguageProperty(Property):
pass arg_types = {"this": True}
class ExecuteAsProperty(Property): class ExecuteAsProperty(Property):
pass arg_types = {"this": True}
class VolatilityProperty(Property): class VolatilityProperty(Property):
@ -1135,27 +1128,36 @@ class VolatilityProperty(Property):
class Properties(Expression): class Properties(Expression):
arg_types = {"expressions": True} arg_types = {"expressions": True}
PROPERTY_KEY_MAPPING = { NAME_TO_PROPERTY = {
"AUTO_INCREMENT": AutoIncrementProperty, "AUTO_INCREMENT": AutoIncrementProperty,
"CHARACTER_SET": CharacterSetProperty, "CHARACTER SET": CharacterSetProperty,
"COLLATE": CollateProperty, "COLLATE": CollateProperty,
"COMMENT": SchemaCommentProperty, "COMMENT": SchemaCommentProperty,
"ENGINE": EngineProperty,
"FORMAT": FileFormatProperty,
"LOCATION": LocationProperty,
"PARTITIONED_BY": PartitionedByProperty,
"TABLE_FORMAT": TableFormatProperty,
"DISTKEY": DistKeyProperty, "DISTKEY": DistKeyProperty,
"DISTSTYLE": DistStyleProperty, "DISTSTYLE": DistStyleProperty,
"ENGINE": EngineProperty,
"EXECUTE AS": ExecuteAsProperty,
"FORMAT": FileFormatProperty,
"LANGUAGE": LanguageProperty,
"LOCATION": LocationProperty,
"PARTITIONED_BY": PartitionedByProperty,
"RETURNS": ReturnsProperty,
"SORTKEY": SortKeyProperty, "SORTKEY": SortKeyProperty,
"TABLE_FORMAT": TableFormatProperty,
} }
PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()}
@classmethod @classmethod
def from_dict(cls, properties_dict) -> Properties: def from_dict(cls, properties_dict) -> Properties:
expressions = [] expressions = []
for key, value in properties_dict.items(): for key, value in properties_dict.items():
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty) property_cls = cls.NAME_TO_PROPERTY.get(key.upper())
expressions.append(property_cls(this=Literal.string(key), value=convert(value))) if property_cls:
expressions.append(property_cls(this=convert(value)))
else:
expressions.append(Property(this=Literal.string(key), value=convert(value)))
return cls(expressions=expressions) return cls(expressions=expressions)
@ -1383,6 +1385,7 @@ class Select(Subqueryable):
"expressions": False, "expressions": False,
"hint": False, "hint": False,
"distinct": False, "distinct": False,
"into": False,
"from": False, "from": False,
**QUERY_MODIFIERS, **QUERY_MODIFIERS,
} }
@ -2015,6 +2018,7 @@ class DataType(Expression):
DECIMAL = auto() DECIMAL = auto()
BOOLEAN = auto() BOOLEAN = auto()
JSON = auto() JSON = auto()
JSONB = auto()
INTERVAL = auto() INTERVAL = auto()
TIMESTAMP = auto() TIMESTAMP = auto()
TIMESTAMPTZ = auto() TIMESTAMPTZ = auto()
@ -2029,6 +2033,7 @@ class DataType(Expression):
STRUCT = auto() STRUCT = auto()
NULLABLE = auto() NULLABLE = auto()
HLLSKETCH = auto() HLLSKETCH = auto()
HSTORE = auto()
SUPER = auto() SUPER = auto()
SERIAL = auto() SERIAL = auto()
SMALLSERIAL = auto() SMALLSERIAL = auto()
@ -2109,7 +2114,7 @@ class Transaction(Command):
class Commit(Command): class Commit(Command):
arg_types = {} # type: ignore arg_types = {"chain": False}
class Rollback(Command): class Rollback(Command):
@ -2442,7 +2447,7 @@ class ArrayFilter(Func):
class ArraySize(Func): class ArraySize(Func):
pass arg_types = {"this": True, "expression": False}
class ArraySort(Func): class ArraySort(Func):
@ -2726,6 +2731,16 @@ class VarMap(Func):
is_var_len_args = True is_var_len_args = True
class Matches(Func):
"""Oracle/Snowflake decode.
https://docs.oracle.com/cd/B19306_01/server.102/b14200/functions040.htm
Pattern matching MATCHES(value, search1, result1, ...searchN, resultN, else)
"""
arg_types = {"this": True, "expressions": True}
is_var_len_args = True
class Max(AggFunc): class Max(AggFunc):
pass pass
@ -2785,6 +2800,10 @@ class Round(Func):
arg_types = {"this": True, "decimals": False} arg_types = {"this": True, "decimals": False}
class RowNumber(Func):
arg_types: t.Dict[str, t.Any] = {}
class SafeDivide(Func): class SafeDivide(Func):
arg_types = {"this": True, "expression": True} arg_types = {"this": True, "expression": True}

View file

@ -1,19 +1,16 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import re
import typing as t import typing as t
from sqlglot import exp from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
from sqlglot.helper import apply_index_offset, csv from sqlglot.helper import apply_index_offset, csv
from sqlglot.time import format_time from sqlglot.time import format_time
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
logger = logging.getLogger("sqlglot") logger = logging.getLogger("sqlglot")
NEWLINE_RE = re.compile("\r\n?|\n")
class Generator: class Generator:
""" """
@ -58,11 +55,11 @@ class Generator:
""" """
TRANSFORMS = { TRANSFORMS = {
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})", exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})", exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})",
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})", exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})", exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})",
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'this')}",
exp.LanguageProperty: lambda self, e: self.naked_property(e), exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: self.naked_property(e), exp.ReturnsProperty: lambda self, e: self.naked_property(e),
@ -97,16 +94,17 @@ class Generator:
exp.DistStyleProperty, exp.DistStyleProperty,
exp.DistKeyProperty, exp.DistKeyProperty,
exp.SortKeyProperty, exp.SortKeyProperty,
exp.LikeProperty,
} }
WITH_PROPERTIES = { WITH_PROPERTIES = {
exp.AnonymousProperty, exp.Property,
exp.FileFormatProperty, exp.FileFormatProperty,
exp.PartitionedByProperty, exp.PartitionedByProperty,
exp.TableFormatProperty, exp.TableFormatProperty,
} }
WITH_SEPARATED_COMMENTS = (exp.Select,) WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
__slots__ = ( __slots__ = (
"time_mapping", "time_mapping",
@ -211,7 +209,7 @@ class Generator:
for msg in self.unsupported_messages: for msg in self.unsupported_messages:
logger.warning(msg) logger.warning(msg)
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages: elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
raise UnsupportedError(concat_errors(self.unsupported_messages, self.max_unsupported)) raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported))
return sql return sql
@ -226,25 +224,24 @@ class Generator:
def seg(self, sql, sep=" "): def seg(self, sql, sep=" "):
return f"{self.sep(sep)}{sql}" return f"{self.sep(sep)}{sql}"
def maybe_comment(self, sql, expression, single_line=False): def pad_comment(self, comment):
comment = expression.comment if self._comments else None
if not comment:
return sql
comment = " " + comment if comment[0].strip() else comment comment = " " + comment if comment[0].strip() else comment
comment = comment + " " if comment[-1].strip() else comment comment = comment + " " if comment[-1].strip() else comment
return comment
def maybe_comment(self, sql, expression):
comments = expression.comments if self._comments else None
if not comments:
return sql
sep = "\n" if self.pretty else " "
comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments)
if isinstance(expression, self.WITH_SEPARATED_COMMENTS): if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
return f"/*{comment}*/{self.sep()}{sql}" return f"{comments}{self.sep()}{sql}"
if not self.pretty: return f"{sql} {comments}"
return f"{sql} /*{comment}*/"
if not NEWLINE_RE.search(comment):
return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
return f"/*{comment}*/\n{sql}" if sql else f" /*{comment}*/"
def wrap(self, expression): def wrap(self, expression):
this_sql = self.indent( this_sql = self.indent(
@ -387,8 +384,11 @@ class Generator:
def notnullcolumnconstraint_sql(self, _): def notnullcolumnconstraint_sql(self, _):
return "NOT NULL" return "NOT NULL"
def primarykeycolumnconstraint_sql(self, _): def primarykeycolumnconstraint_sql(self, expression):
return "PRIMARY KEY" desc = expression.args.get("desc")
if desc is not None:
return f"PRIMARY KEY{' DESC' if desc else ' ASC'}"
return f"PRIMARY KEY"
def uniquecolumnconstraint_sql(self, _): def uniquecolumnconstraint_sql(self, _):
return "UNIQUE" return "UNIQUE"
@ -546,36 +546,33 @@ class Generator:
def root_properties(self, properties): def root_properties(self, properties):
if properties.expressions: if properties.expressions:
return self.sep() + self.expressions( return self.sep() + self.expressions(properties, indent=False, sep=" ")
properties,
indent=False,
sep=" ",
)
return "" return ""
def properties(self, properties, prefix="", sep=", "): def properties(self, properties, prefix="", sep=", "):
if properties.expressions: if properties.expressions:
expressions = self.expressions( expressions = self.expressions(properties, sep=sep, indent=False)
properties,
sep=sep,
indent=False,
)
return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}" return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}"
return "" return ""
def with_properties(self, properties): def with_properties(self, properties):
return self.properties( return self.properties(properties, prefix="WITH")
properties,
prefix="WITH",
)
def property_sql(self, expression): def property_sql(self, expression):
if isinstance(expression.this, exp.Literal): property_cls = expression.__class__
key = expression.this.this if property_cls == exp.Property:
else: return f"{expression.name}={self.sql(expression, 'value')}"
key = expression.name
value = self.sql(expression, "value") property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls)
return f"{key}={value}" if not property_name:
self.unsupported(f"Unsupported property {property_name}")
return f"{property_name}={self.sql(expression, 'this')}"
def likeproperty_sql(self, expression):
options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions)
options = f" {options}" if options else ""
return f"LIKE {self.sql(expression, 'this')}{options}"
def insert_sql(self, expression): def insert_sql(self, expression):
overwrite = expression.args.get("overwrite") overwrite = expression.args.get("overwrite")
@ -700,6 +697,11 @@ class Generator:
def var_sql(self, expression): def var_sql(self, expression):
return self.sql(expression, "this") return self.sql(expression, "this")
def into_sql(self, expression):
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
unlogged = " UNLOGGED" if expression.args.get("unlogged") else ""
return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}"
def from_sql(self, expression): def from_sql(self, expression):
expressions = self.expressions(expression, flat=True) expressions = self.expressions(expression, flat=True)
return f"{self.seg('FROM')} {expressions}" return f"{self.seg('FROM')} {expressions}"
@ -883,6 +885,7 @@ class Generator:
sql = self.query_modifiers( sql = self.query_modifiers(
expression, expression,
f"SELECT{hint}{distinct}{expressions}", f"SELECT{hint}{distinct}{expressions}",
self.sql(expression, "into", comment=False),
self.sql(expression, "from", comment=False), self.sql(expression, "from", comment=False),
) )
return self.prepend_ctes(expression, sql) return self.prepend_ctes(expression, sql)
@ -1061,6 +1064,11 @@ class Generator:
else: else:
return f"TRIM({target})" return f"TRIM({target})"
def concat_sql(self, expression):
if len(expression.expressions) == 1:
return self.sql(expression.expressions[0])
return self.function_fallback_sql(expression)
def check_sql(self, expression): def check_sql(self, expression):
this = self.sql(expression, key="this") this = self.sql(expression, key="this")
return f"CHECK ({this})" return f"CHECK ({this})"
@ -1125,7 +1133,10 @@ class Generator:
return self.prepend_ctes(expression, sql) return self.prepend_ctes(expression, sql)
def neg_sql(self, expression): def neg_sql(self, expression):
return f"-{self.sql(expression, 'this')}" # This makes sure we don't convert "- - 5" to "--5", which is a comment
this_sql = self.sql(expression, "this")
sep = " " if this_sql[0] == "-" else ""
return f"-{sep}{this_sql}"
def not_sql(self, expression): def not_sql(self, expression):
return f"NOT {self.sql(expression, 'this')}" return f"NOT {self.sql(expression, 'this')}"
@ -1191,8 +1202,12 @@ class Generator:
def transaction_sql(self, *_): def transaction_sql(self, *_):
return "BEGIN" return "BEGIN"
def commit_sql(self, *_): def commit_sql(self, expression):
return "COMMIT" chain = expression.args.get("chain")
if chain is not None:
chain = " AND CHAIN" if chain else " AND NO CHAIN"
return f"COMMIT{chain or ''}"
def rollback_sql(self, expression): def rollback_sql(self, expression):
savepoint = expression.args.get("savepoint") savepoint = expression.args.get("savepoint")
@ -1334,15 +1349,15 @@ class Generator:
result_sqls = [] result_sqls = []
for i, e in enumerate(expressions): for i, e in enumerate(expressions):
sql = self.sql(e, comment=False) sql = self.sql(e, comment=False)
comment = self.maybe_comment("", e, single_line=True) comments = self.maybe_comment("", e)
if self.pretty: if self.pretty:
if self._leading_comma: if self._leading_comma:
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}") result_sqls.append(f"{sep if i > 0 else pad}{sql}{comments}")
else: else:
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}") result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}")
else: else:
result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}") result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}")
result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls) result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
return self.indent(result_sqls, skip_first=False) if indent else result_sqls return self.indent(result_sqls, skip_first=False) if indent else result_sqls
@ -1354,7 +1369,10 @@ class Generator:
return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}" return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}"
def naked_property(self, expression): def naked_property(self, expression):
return f"{expression.name} {self.sql(expression, 'value')}" property_name = exp.Properties.PROPERTY_TO_NAME.get(expression.__class__)
if not property_name:
self.unsupported(f"Unsupported property {expression.__class__.__name__}")
return f"{property_name} {self.sql(expression, 'this')}"
def set_operation(self, expression, op): def set_operation(self, expression, op):
this = self.sql(expression, "this") this = self.sql(expression, "this")

View file

@ -68,6 +68,9 @@ def eliminate_subqueries(expression):
for cte_scope in root.cte_scopes: for cte_scope in root.cte_scopes:
# Append all the new CTEs from this existing CTE # Append all the new CTEs from this existing CTE
for scope in cte_scope.traverse(): for scope in cte_scope.traverse():
if scope is cte_scope:
# Don't try to eliminate this CTE itself
continue
new_cte = _eliminate(scope, existing_ctes, taken) new_cte = _eliminate(scope, existing_ctes, taken)
if new_cte: if new_cte:
new_ctes.append(new_cte) new_ctes.append(new_cte)
@ -97,6 +100,9 @@ def _eliminate(scope, existing_ctes, taken):
if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF): if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
return _eliminate_derived_table(scope, existing_ctes, taken) return _eliminate_derived_table(scope, existing_ctes, taken)
if scope.is_cte:
return _eliminate_cte(scope, existing_ctes, taken)
def _eliminate_union(scope, existing_ctes, taken): def _eliminate_union(scope, existing_ctes, taken):
duplicate_cte_alias = existing_ctes.get(scope.expression) duplicate_cte_alias = existing_ctes.get(scope.expression)
@ -127,26 +133,61 @@ def _eliminate_union(scope, existing_ctes, taken):
def _eliminate_derived_table(scope, existing_ctes, taken): def _eliminate_derived_table(scope, existing_ctes, taken):
parent = scope.expression.parent
name, cte = _new_cte(scope, existing_ctes, taken)
table = exp.alias_(exp.table_(name), alias=parent.alias or name)
parent.replace(table)
return cte
def _eliminate_cte(scope, existing_ctes, taken):
parent = scope.expression.parent
name, cte = _new_cte(scope, existing_ctes, taken)
with_ = parent.parent
parent.pop()
if not with_.expressions:
with_.pop()
# Rename references to this CTE
for child_scope in scope.parent.traverse():
for table, source in child_scope.selected_sources.values():
if source is scope:
new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name)
table.replace(new_table)
return cte
def _new_cte(scope, existing_ctes, taken):
"""
Returns:
tuple of (name, cte)
where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance.
If this CTE duplicates an existing CTE, `cte` will be None.
"""
duplicate_cte_alias = existing_ctes.get(scope.expression) duplicate_cte_alias = existing_ctes.get(scope.expression)
parent = scope.expression.parent parent = scope.expression.parent
name = alias = parent.alias name = parent.alias
if not alias: if not name:
name = alias = find_new_name(taken=taken, base="cte") name = find_new_name(taken=taken, base="cte")
if duplicate_cte_alias: if duplicate_cte_alias:
name = duplicate_cte_alias name = duplicate_cte_alias
elif taken.get(alias): elif taken.get(name):
name = find_new_name(taken=taken, base=alias) name = find_new_name(taken=taken, base=name)
taken[name] = scope taken[name] = scope
table = exp.alias_(exp.table_(name), alias=alias)
parent.replace(table)
if not duplicate_cte_alias: if not duplicate_cte_alias:
existing_ctes[scope.expression] = name existing_ctes[scope.expression] = name
return exp.CTE( cte = exp.CTE(
this=scope.expression, this=scope.expression,
alias=exp.TableAlias(this=exp.to_identifier(name)), alias=exp.TableAlias(this=exp.to_identifier(name)),
) )
else:
cte = None
return name, cte

View file

@ -0,0 +1,92 @@
from sqlglot import exp
from sqlglot.helper import ensure_collection
def lower_identities(expression):
"""
Convert all unquoted identifiers to lower case.
Assuming the schema is all lower case, this essentially makes identifiers case-insensitive.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
>>> lower_identities(expression).sql()
'SELECT bar.a AS A FROM "Foo".bar'
Args:
expression (sqlglot.Expression): expression to quote
Returns:
sqlglot.Expression: quoted expression
"""
# We need to leave the output aliases unchanged, so the selects need special handling
_lower_selects(expression)
# These clauses can reference output aliases and also need special handling
_lower_order(expression)
_lower_having(expression)
# We've already handled these args, so don't traverse into them
traversed = {"expressions", "order", "having"}
if isinstance(expression, exp.Subquery):
# Root subquery, e.g. (SELECT A AS A FROM X) LIMIT 1
lower_identities(expression.this)
traversed |= {"this"}
if isinstance(expression, exp.Union):
# Union, e.g. SELECT A AS A FROM X UNION SELECT A AS A FROM X
lower_identities(expression.left)
lower_identities(expression.right)
traversed |= {"this", "expression"}
for k, v in expression.args.items():
if k in traversed:
continue
for child in ensure_collection(v):
if isinstance(child, exp.Expression):
child.transform(_lower, copy=False)
return expression
def _lower_selects(expression):
for e in expression.expressions:
# Leave output aliases as-is
e.unalias().transform(_lower, copy=False)
def _lower_order(expression):
order = expression.args.get("order")
if not order:
return
output_aliases = {e.alias for e in expression.expressions if isinstance(e, exp.Alias)}
for ordered in order.expressions:
# Don't lower references to output aliases
if not (
isinstance(ordered.this, exp.Column)
and not ordered.this.table
and ordered.this.name in output_aliases
):
ordered.transform(_lower, copy=False)
def _lower_having(expression):
having = expression.args.get("having")
if not having:
return
# Don't lower references to output aliases
for agg in having.find_all(exp.AggFunc):
agg.transform(_lower, copy=False)
def _lower(node):
if isinstance(node, exp.Identifier) and not node.quoted:
node.set("this", node.this.lower())
return node

View file

@ -6,6 +6,7 @@ from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.lower_identities import lower_identities
from sqlglot.optimizer.merge_subqueries import merge_subqueries from sqlglot.optimizer.merge_subqueries import merge_subqueries
from sqlglot.optimizer.normalize import normalize from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins from sqlglot.optimizer.optimize_joins import optimize_joins
@ -17,6 +18,7 @@ from sqlglot.optimizer.quote_identities import quote_identities
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
RULES = ( RULES = (
lower_identities,
qualify_tables, qualify_tables,
isolate_table_selects, isolate_table_selects,
qualify_columns, qualify_columns,

View file

@ -1,16 +1,15 @@
import itertools import itertools
from sqlglot import exp from sqlglot import exp
from sqlglot.optimizer.scope import traverse_scope from sqlglot.optimizer.scope import ScopeType, traverse_scope
def unnest_subqueries(expression): def unnest_subqueries(expression):
""" """
Rewrite sqlglot AST to convert some predicates with subqueries into joins. Rewrite sqlglot AST to convert some predicates with subqueries into joins.
Convert the subquery into a group by so it is not a many to many left join. Convert scalar subqueries into cross joins.
Unnesting can only occur if the subquery does not have LIMIT or OFFSET. Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
Unnesting non correlated subqueries only happens on IN statements or = ANY statements.
Example: Example:
>>> import sqlglot >>> import sqlglot
@ -29,21 +28,43 @@ def unnest_subqueries(expression):
for scope in traverse_scope(expression): for scope in traverse_scope(expression):
select = scope.expression select = scope.expression
parent = select.parent_select parent = select.parent_select
if not parent:
continue
if scope.external_columns: if scope.external_columns:
decorrelate(select, parent, scope.external_columns, sequence) decorrelate(select, parent, scope.external_columns, sequence)
else: elif scope.scope_type == ScopeType.SUBQUERY:
unnest(select, parent, sequence) unnest(select, parent, sequence)
return expression return expression
def unnest(select, parent_select, sequence): def unnest(select, parent_select, sequence):
predicate = select.find_ancestor(exp.In, exp.Any) if len(select.selects) > 1:
return
predicate = select.find_ancestor(exp.Condition)
alias = _alias(sequence)
if not predicate or parent_select is not predicate.parent_select: if not predicate or parent_select is not predicate.parent_select:
return return
if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset): # this subquery returns a scalar and can just be converted to a cross join
if not isinstance(predicate, (exp.In, exp.Any)):
having = predicate.find_ancestor(exp.Having)
column = exp.column(select.selects[0].alias_or_name, alias)
if having and having.parent_select is parent_select:
column = exp.Max(this=column)
_replace(select.parent, column)
parent_select.join(
select,
join_type="CROSS",
join_alias=alias,
copy=False,
)
return
if select.find(exp.Limit, exp.Offset):
return return
if isinstance(predicate, exp.Any): if isinstance(predicate, exp.Any):
@ -54,7 +75,6 @@ def unnest(select, parent_select, sequence):
column = _other_operand(predicate) column = _other_operand(predicate)
value = select.selects[0] value = select.selects[0]
alias = _alias(sequence)
on = exp.condition(f'{column} = "{alias}"."{value.alias}"') on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
_replace(predicate, f"NOT {on.right} IS NULL") _replace(predicate, f"NOT {on.right} IS NULL")

View file

@ -4,7 +4,7 @@ import logging
import typing as t import typing as t
from sqlglot import exp from sqlglot import exp
from sqlglot.errors import ErrorLevel, ParseError, concat_errors from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
from sqlglot.helper import apply_index_offset, ensure_collection, seq_get from sqlglot.helper import apply_index_offset, ensure_collection, seq_get
from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import in_trie, new_trie from sqlglot.trie import in_trie, new_trie
@ -104,6 +104,7 @@ class Parser(metaclass=_Parser):
TokenType.BINARY, TokenType.BINARY,
TokenType.VARBINARY, TokenType.VARBINARY,
TokenType.JSON, TokenType.JSON,
TokenType.JSONB,
TokenType.INTERVAL, TokenType.INTERVAL,
TokenType.TIMESTAMP, TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ, TokenType.TIMESTAMPTZ,
@ -115,6 +116,7 @@ class Parser(metaclass=_Parser):
TokenType.GEOGRAPHY, TokenType.GEOGRAPHY,
TokenType.GEOMETRY, TokenType.GEOMETRY,
TokenType.HLLSKETCH, TokenType.HLLSKETCH,
TokenType.HSTORE,
TokenType.SUPER, TokenType.SUPER,
TokenType.SERIAL, TokenType.SERIAL,
TokenType.SMALLSERIAL, TokenType.SMALLSERIAL,
@ -153,6 +155,7 @@ class Parser(metaclass=_Parser):
TokenType.COLLATE, TokenType.COLLATE,
TokenType.COMMAND, TokenType.COMMAND,
TokenType.COMMIT, TokenType.COMMIT,
TokenType.COMPOUND,
TokenType.CONSTRAINT, TokenType.CONSTRAINT,
TokenType.CURRENT_TIME, TokenType.CURRENT_TIME,
TokenType.DEFAULT, TokenType.DEFAULT,
@ -194,6 +197,7 @@ class Parser(metaclass=_Parser):
TokenType.RANGE, TokenType.RANGE,
TokenType.REFERENCES, TokenType.REFERENCES,
TokenType.RETURNS, TokenType.RETURNS,
TokenType.ROW,
TokenType.ROWS, TokenType.ROWS,
TokenType.SCHEMA, TokenType.SCHEMA,
TokenType.SCHEMA_COMMENT, TokenType.SCHEMA_COMMENT,
@ -213,6 +217,7 @@ class Parser(metaclass=_Parser):
TokenType.TRUE, TokenType.TRUE,
TokenType.UNBOUNDED, TokenType.UNBOUNDED,
TokenType.UNIQUE, TokenType.UNIQUE,
TokenType.UNLOGGED,
TokenType.UNPIVOT, TokenType.UNPIVOT,
TokenType.PROPERTIES, TokenType.PROPERTIES,
TokenType.PROCEDURE, TokenType.PROCEDURE,
@ -400,9 +405,17 @@ class Parser(metaclass=_Parser):
TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()), TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
TokenType.BEGIN: lambda self: self._parse_transaction(), TokenType.BEGIN: lambda self: self._parse_transaction(),
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.END: lambda self: self._parse_commit_or_rollback(),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
} }
UNARY_PARSERS = {
TokenType.PLUS: lambda self: self._parse_unary(), # Unary + is handled as a no-op
TokenType.NOT: lambda self: self.expression(exp.Not, this=self._parse_equality()),
TokenType.TILDA: lambda self: self.expression(exp.BitwiseNot, this=self._parse_unary()),
TokenType.DASH: lambda self: self.expression(exp.Neg, this=self._parse_unary()),
}
PRIMARY_PARSERS = { PRIMARY_PARSERS = {
TokenType.STRING: lambda self, token: self.expression( TokenType.STRING: lambda self, token: self.expression(
exp.Literal, this=token.text, is_string=True exp.Literal, this=token.text, is_string=True
@ -446,19 +459,20 @@ class Parser(metaclass=_Parser):
} }
PROPERTY_PARSERS = { PROPERTY_PARSERS = {
TokenType.AUTO_INCREMENT: lambda self: self._parse_auto_increment(), TokenType.AUTO_INCREMENT: lambda self: self._parse_property_assignment(
TokenType.CHARACTER_SET: lambda self: self._parse_character_set(), exp.AutoIncrementProperty
TokenType.LOCATION: lambda self: self.expression(
exp.LocationProperty,
this=exp.Literal.string("LOCATION"),
value=self._parse_string(),
), ),
TokenType.CHARACTER_SET: lambda self: self._parse_character_set(),
TokenType.LOCATION: lambda self: self._parse_property_assignment(exp.LocationProperty),
TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(), TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(),
TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(), TokenType.SCHEMA_COMMENT: lambda self: self._parse_property_assignment(
TokenType.STORED: lambda self: self._parse_stored(), exp.SchemaCommentProperty
),
TokenType.STORED: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
TokenType.DISTKEY: lambda self: self._parse_distkey(), TokenType.DISTKEY: lambda self: self._parse_distkey(),
TokenType.DISTSTYLE: lambda self: self._parse_diststyle(), TokenType.DISTSTYLE: lambda self: self._parse_property_assignment(exp.DistStyleProperty),
TokenType.SORTKEY: lambda self: self._parse_sortkey(), TokenType.SORTKEY: lambda self: self._parse_sortkey(),
TokenType.LIKE: lambda self: self._parse_create_like(),
TokenType.RETURNS: lambda self: self._parse_returns(), TokenType.RETURNS: lambda self: self._parse_returns(),
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty), TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
@ -468,7 +482,7 @@ class Parser(metaclass=_Parser):
), ),
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty), TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty), TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
TokenType.EXECUTE: lambda self: self._parse_execute_as(), TokenType.EXECUTE: lambda self: self._parse_property_assignment(exp.ExecuteAsProperty),
TokenType.DETERMINISTIC: lambda self: self.expression( TokenType.DETERMINISTIC: lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
), ),
@ -489,6 +503,7 @@ class Parser(metaclass=_Parser):
), ),
TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(), TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(),
TokenType.UNIQUE: lambda self: self._parse_unique(), TokenType.UNIQUE: lambda self: self._parse_unique(),
TokenType.LIKE: lambda self: self._parse_create_like(),
} }
NO_PAREN_FUNCTION_PARSERS = { NO_PAREN_FUNCTION_PARSERS = {
@ -505,6 +520,7 @@ class Parser(metaclass=_Parser):
"TRIM": lambda self: self._parse_trim(), "TRIM": lambda self: self._parse_trim(),
"CAST": lambda self: self._parse_cast(self.STRICT_CAST), "CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"TRY_CAST": lambda self: self._parse_cast(False), "TRY_CAST": lambda self: self._parse_cast(False),
"STRING_AGG": lambda self: self._parse_string_agg(),
} }
QUERY_MODIFIER_PARSERS = { QUERY_MODIFIER_PARSERS = {
@ -556,7 +572,7 @@ class Parser(metaclass=_Parser):
"_curr", "_curr",
"_next", "_next",
"_prev", "_prev",
"_prev_comment", "_prev_comments",
"_show_trie", "_show_trie",
"_set_trie", "_set_trie",
) )
@ -589,7 +605,7 @@ class Parser(metaclass=_Parser):
self._curr = None self._curr = None
self._next = None self._next = None
self._prev = None self._prev = None
self._prev_comment = None self._prev_comments = None
def parse(self, raw_tokens, sql=None): def parse(self, raw_tokens, sql=None):
""" """
@ -608,6 +624,7 @@ class Parser(metaclass=_Parser):
) )
def parse_into(self, expression_types, raw_tokens, sql=None): def parse_into(self, expression_types, raw_tokens, sql=None):
errors = []
for expression_type in ensure_collection(expression_types): for expression_type in ensure_collection(expression_types):
parser = self.EXPRESSION_PARSERS.get(expression_type) parser = self.EXPRESSION_PARSERS.get(expression_type)
if not parser: if not parser:
@ -615,8 +632,12 @@ class Parser(metaclass=_Parser):
try: try:
return self._parse(parser, raw_tokens, sql) return self._parse(parser, raw_tokens, sql)
except ParseError as e: except ParseError as e:
error = e e.errors[0]["into_expression"] = expression_type
raise ParseError(f"Failed to parse into {expression_types}") from error errors.append(e)
raise ParseError(
f"Failed to parse into {expression_types}",
errors=merge_errors(errors),
) from errors[-1]
def _parse(self, parse_method, raw_tokens, sql=None): def _parse(self, parse_method, raw_tokens, sql=None):
self.reset() self.reset()
@ -650,7 +671,10 @@ class Parser(metaclass=_Parser):
for error in self.errors: for error in self.errors:
logger.error(str(error)) logger.error(str(error))
elif self.error_level == ErrorLevel.RAISE and self.errors: elif self.error_level == ErrorLevel.RAISE and self.errors:
raise ParseError(concat_errors(self.errors, self.max_errors)) raise ParseError(
concat_messages(self.errors, self.max_errors),
errors=merge_errors(self.errors),
)
def raise_error(self, message, token=None): def raise_error(self, message, token=None):
token = token or self._curr or self._prev or Token.string("") token = token or self._curr or self._prev or Token.string("")
@ -659,19 +683,27 @@ class Parser(metaclass=_Parser):
start_context = self.sql[max(start - self.error_message_context, 0) : start] start_context = self.sql[max(start - self.error_message_context, 0) : start]
highlight = self.sql[start:end] highlight = self.sql[start:end]
end_context = self.sql[end : end + self.error_message_context] end_context = self.sql[end : end + self.error_message_context]
error = ParseError( error = ParseError.new(
f"{message}. Line {token.line}, Col: {token.col}.\n" f"{message}. Line {token.line}, Col: {token.col}.\n"
f" {start_context}\033[4m{highlight}\033[0m{end_context}" f" {start_context}\033[4m{highlight}\033[0m{end_context}",
description=message,
line=token.line,
col=token.col,
start_context=start_context,
highlight=highlight,
end_context=end_context,
) )
if self.error_level == ErrorLevel.IMMEDIATE: if self.error_level == ErrorLevel.IMMEDIATE:
raise error raise error
self.errors.append(error) self.errors.append(error)
def expression(self, exp_class, **kwargs): def expression(self, exp_class, comments=None, **kwargs):
instance = exp_class(**kwargs) instance = exp_class(**kwargs)
if self._prev_comment: if self._prev_comments:
instance.comment = self._prev_comment instance.comments = self._prev_comments
self._prev_comment = None self._prev_comments = None
if comments:
instance.comments = comments
self.validate_expression(instance) self.validate_expression(instance)
return instance return instance
@ -714,10 +746,10 @@ class Parser(metaclass=_Parser):
self._next = seq_get(self._tokens, self._index + 1) self._next = seq_get(self._tokens, self._index + 1)
if self._index > 0: if self._index > 0:
self._prev = self._tokens[self._index - 1] self._prev = self._tokens[self._index - 1]
self._prev_comment = self._prev.comment self._prev_comments = self._prev.comments
else: else:
self._prev = None self._prev = None
self._prev_comment = None self._prev_comments = None
def _retreat(self, index): def _retreat(self, index):
self._advance(index - self._index) self._advance(index - self._index)
@ -768,7 +800,7 @@ class Parser(metaclass=_Parser):
) )
def _parse_create(self): def _parse_create(self):
replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE) replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
temporary = self._match(TokenType.TEMPORARY) temporary = self._match(TokenType.TEMPORARY)
transient = self._match(TokenType.TRANSIENT) transient = self._match(TokenType.TRANSIENT)
unique = self._match(TokenType.UNIQUE) unique = self._match(TokenType.UNIQUE)
@ -822,97 +854,57 @@ class Parser(metaclass=_Parser):
def _parse_property(self): def _parse_property(self):
if self._match_set(self.PROPERTY_PARSERS): if self._match_set(self.PROPERTY_PARSERS):
return self.PROPERTY_PARSERS[self._prev.token_type](self) return self.PROPERTY_PARSERS[self._prev.token_type](self)
if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET): if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET):
return self._parse_character_set(True) return self._parse_character_set(True)
if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False): if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
key = self._parse_var().this return self._parse_sortkey(compound=True)
self._match(TokenType.EQ)
return self.expression( if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False):
exp.AnonymousProperty, key = self._parse_var()
this=exp.Literal.string(key), self._match(TokenType.EQ)
value=self._parse_column(), return self.expression(exp.Property, this=key, value=self._parse_column())
)
return None return None
def _parse_property_assignment(self, exp_class): def _parse_property_assignment(self, exp_class):
prop = self._prev.text
self._match(TokenType.EQ) self._match(TokenType.EQ)
return self.expression(exp_class, this=prop, value=self._parse_var_or_string()) self._match(TokenType.ALIAS)
return self.expression(exp_class, this=self._parse_var_or_string() or self._parse_number())
def _parse_partitioned_by(self): def _parse_partitioned_by(self):
self._match(TokenType.EQ) self._match(TokenType.EQ)
return self.expression( return self.expression(
exp.PartitionedByProperty, exp.PartitionedByProperty,
this=exp.Literal.string("PARTITIONED_BY"), this=self._parse_schema() or self._parse_bracket(self._parse_field()),
value=self._parse_schema() or self._parse_bracket(self._parse_field()),
)
def _parse_stored(self):
self._match(TokenType.ALIAS)
self._match(TokenType.EQ)
return self.expression(
exp.FileFormatProperty,
this=exp.Literal.string("FORMAT"),
value=exp.Literal.string(self._parse_var_or_string().name),
) )
def _parse_distkey(self): def _parse_distkey(self):
self._match_l_paren() return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_var))
this = exp.Literal.string("DISTKEY")
value = exp.Literal.string(self._parse_var().name)
self._match_r_paren()
return self.expression(
exp.DistKeyProperty,
this=this,
value=value,
)
def _parse_sortkey(self): def _parse_create_like(self):
self._match_l_paren() table = self._parse_table(schema=True)
this = exp.Literal.string("SORTKEY") options = []
value = exp.Literal.string(self._parse_var().name) while self._match_texts(("INCLUDING", "EXCLUDING")):
self._match_r_paren() options.append(
return self.expression( self.expression(
exp.SortKeyProperty, exp.Property,
this=this, this=self._prev.text.upper(),
value=value, value=exp.Var(this=self._parse_id_var().this.upper()),
) )
def _parse_diststyle(self):
this = exp.Literal.string("DISTSTYLE")
value = exp.Literal.string(self._parse_var().name)
return self.expression(
exp.DistStyleProperty,
this=this,
value=value,
) )
return self.expression(exp.LikeProperty, this=table, expressions=options)
def _parse_auto_increment(self): def _parse_sortkey(self, compound=False):
self._match(TokenType.EQ)
return self.expression( return self.expression(
exp.AutoIncrementProperty, exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_var), compound=compound
this=exp.Literal.string("AUTO_INCREMENT"),
value=self._parse_number(),
)
def _parse_schema_comment(self):
self._match(TokenType.EQ)
return self.expression(
exp.SchemaCommentProperty,
this=exp.Literal.string("COMMENT"),
value=self._parse_string(),
) )
def _parse_character_set(self, default=False): def _parse_character_set(self, default=False):
self._match(TokenType.EQ) self._match(TokenType.EQ)
return self.expression( return self.expression(
exp.CharacterSetProperty, exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default
this=exp.Literal.string("CHARACTER_SET"),
value=self._parse_var_or_string(),
default=default,
) )
def _parse_returns(self): def _parse_returns(self):
@ -931,20 +923,7 @@ class Parser(metaclass=_Parser):
else: else:
value = self._parse_types() value = self._parse_types()
return self.expression( return self.expression(exp.ReturnsProperty, this=value, is_table=is_table)
exp.ReturnsProperty,
this=exp.Literal.string("RETURNS"),
value=value,
is_table=is_table,
)
def _parse_execute_as(self):
self._match(TokenType.ALIAS)
return self.expression(
exp.ExecuteAsProperty,
this=exp.Literal.string("EXECUTE AS"),
value=self._parse_var(),
)
def _parse_properties(self): def _parse_properties(self):
properties = [] properties = []
@ -956,7 +935,7 @@ class Parser(metaclass=_Parser):
properties.extend( properties.extend(
self._parse_wrapped_csv( self._parse_wrapped_csv(
lambda: self.expression( lambda: self.expression(
exp.AnonymousProperty, exp.Property,
this=self._parse_string(), this=self._parse_string(),
value=self._match(TokenType.EQ) and self._parse_string(), value=self._match(TokenType.EQ) and self._parse_string(),
) )
@ -1076,7 +1055,12 @@ class Parser(metaclass=_Parser):
options = [] options = []
if self._match(TokenType.OPTIONS): if self._match(TokenType.OPTIONS):
options = self._parse_wrapped_csv(self._parse_string, sep=TokenType.EQ) self._match_l_paren()
k = self._parse_string()
self._match(TokenType.EQ)
v = self._parse_string()
options = [k, v]
self._match_r_paren()
self._match(TokenType.ALIAS) self._match(TokenType.ALIAS)
return self.expression( return self.expression(
@ -1116,7 +1100,7 @@ class Parser(metaclass=_Parser):
self.raise_error(f"{this.key} does not support CTE") self.raise_error(f"{this.key} does not support CTE")
this = cte this = cte
elif self._match(TokenType.SELECT): elif self._match(TokenType.SELECT):
comment = self._prev_comment comments = self._prev_comments
hint = self._parse_hint() hint = self._parse_hint()
all_ = self._match(TokenType.ALL) all_ = self._match(TokenType.ALL)
@ -1141,10 +1125,16 @@ class Parser(metaclass=_Parser):
expressions=expressions, expressions=expressions,
limit=limit, limit=limit,
) )
this.comment = comment this.comments = comments
into = self._parse_into()
if into:
this.set("into", into)
from_ = self._parse_from() from_ = self._parse_from()
if from_: if from_:
this.set("from", from_) this.set("from", from_)
self._parse_query_modifiers(this) self._parse_query_modifiers(this)
elif (table or nested) and self._match(TokenType.L_PAREN): elif (table or nested) and self._match(TokenType.L_PAREN):
this = self._parse_table() if table else self._parse_select(nested=True) this = self._parse_table() if table else self._parse_select(nested=True)
@ -1248,11 +1238,24 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Hint, expressions=hints) return self.expression(exp.Hint, expressions=hints)
return None return None
def _parse_into(self):
if not self._match(TokenType.INTO):
return None
temp = self._match(TokenType.TEMPORARY)
unlogged = self._match(TokenType.UNLOGGED)
self._match(TokenType.TABLE)
return self.expression(
exp.Into, this=self._parse_table(schema=True), temporary=temp, unlogged=unlogged
)
def _parse_from(self): def _parse_from(self):
if not self._match(TokenType.FROM): if not self._match(TokenType.FROM):
return None return None
return self.expression(
return self.expression(exp.From, expressions=self._parse_csv(self._parse_table)) exp.From, comments=self._prev_comments, expressions=self._parse_csv(self._parse_table)
)
def _parse_lateral(self): def _parse_lateral(self):
outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY) outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY)
@ -1515,7 +1518,9 @@ class Parser(metaclass=_Parser):
def _parse_where(self, skip_where_token=False): def _parse_where(self, skip_where_token=False):
if not skip_where_token and not self._match(TokenType.WHERE): if not skip_where_token and not self._match(TokenType.WHERE):
return None return None
return self.expression(exp.Where, this=self._parse_conjunction()) return self.expression(
exp.Where, comments=self._prev_comments, this=self._parse_conjunction()
)
def _parse_group(self, skip_group_by_token=False): def _parse_group(self, skip_group_by_token=False):
if not skip_group_by_token and not self._match(TokenType.GROUP_BY): if not skip_group_by_token and not self._match(TokenType.GROUP_BY):
@ -1737,12 +1742,8 @@ class Parser(metaclass=_Parser):
return self._parse_tokens(self._parse_unary, self.FACTOR) return self._parse_tokens(self._parse_unary, self.FACTOR)
def _parse_unary(self): def _parse_unary(self):
if self._match(TokenType.NOT): if self._match_set(self.UNARY_PARSERS):
return self.expression(exp.Not, this=self._parse_equality()) return self.UNARY_PARSERS[self._prev.token_type](self)
if self._match(TokenType.TILDA):
return self.expression(exp.BitwiseNot, this=self._parse_unary())
if self._match(TokenType.DASH):
return self.expression(exp.Neg, this=self._parse_unary())
return self._parse_at_time_zone(self._parse_type()) return self._parse_at_time_zone(self._parse_type())
def _parse_type(self): def _parse_type(self):
@ -1775,17 +1776,6 @@ class Parser(metaclass=_Parser):
expressions = None expressions = None
maybe_func = False maybe_func = False
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
return exp.DataType(
this=exp.DataType.Type.ARRAY,
expressions=[exp.DataType.build(type_token.value)],
nested=True,
)
if self._match(TokenType.L_BRACKET):
self._retreat(index)
return None
if self._match(TokenType.L_PAREN): if self._match(TokenType.L_PAREN):
if is_struct: if is_struct:
expressions = self._parse_csv(self._parse_struct_kwargs) expressions = self._parse_csv(self._parse_struct_kwargs)
@ -1801,6 +1791,17 @@ class Parser(metaclass=_Parser):
self._match_r_paren() self._match_r_paren()
maybe_func = True maybe_func = True
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
return exp.DataType(
this=exp.DataType.Type.ARRAY,
expressions=[exp.DataType.build(type_token.value, expressions=expressions)],
nested=True,
)
if self._match(TokenType.L_BRACKET):
self._retreat(index)
return None
if nested and self._match(TokenType.LT): if nested and self._match(TokenType.LT):
if is_struct: if is_struct:
expressions = self._parse_csv(self._parse_struct_kwargs) expressions = self._parse_csv(self._parse_struct_kwargs)
@ -1904,7 +1905,7 @@ class Parser(metaclass=_Parser):
return exp.Literal.number(f"0.{self._prev.text}") return exp.Literal.number(f"0.{self._prev.text}")
if self._match(TokenType.L_PAREN): if self._match(TokenType.L_PAREN):
comment = self._prev_comment comments = self._prev_comments
query = self._parse_select() query = self._parse_select()
if query: if query:
@ -1924,8 +1925,8 @@ class Parser(metaclass=_Parser):
this = self.expression(exp.Tuple, expressions=expressions) this = self.expression(exp.Tuple, expressions=expressions)
else: else:
this = self.expression(exp.Paren, this=this) this = self.expression(exp.Paren, this=this)
if comment: if comments:
this.comment = comment this.comments = comments
return this return this
return None return None
@ -2098,7 +2099,10 @@ class Parser(metaclass=_Parser):
elif self._match(TokenType.SCHEMA_COMMENT): elif self._match(TokenType.SCHEMA_COMMENT):
kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string()) kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string())
elif self._match(TokenType.PRIMARY_KEY): elif self._match(TokenType.PRIMARY_KEY):
kind = exp.PrimaryKeyColumnConstraint() desc = None
if self._match(TokenType.ASC) or self._match(TokenType.DESC):
desc = self._prev.token_type == TokenType.DESC
kind = exp.PrimaryKeyColumnConstraint(desc=desc)
elif self._match(TokenType.UNIQUE): elif self._match(TokenType.UNIQUE):
kind = exp.UniqueColumnConstraint() kind = exp.UniqueColumnConstraint()
elif self._match(TokenType.GENERATED): elif self._match(TokenType.GENERATED):
@ -2189,7 +2193,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.R_BRACKET): if not self._match(TokenType.R_BRACKET):
self.raise_error("Expected ]") self.raise_error("Expected ]")
this.comment = self._prev_comment this.comments = self._prev_comments
return self._parse_bracket(this) return self._parse_bracket(this)
def _parse_case(self): def _parse_case(self):
@ -2256,6 +2260,33 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
def _parse_string_agg(self):
if self._match(TokenType.DISTINCT):
args = self._parse_csv(self._parse_conjunction)
expression = self.expression(exp.Distinct, expressions=[seq_get(args, 0)])
else:
args = self._parse_csv(self._parse_conjunction)
expression = seq_get(args, 0)
index = self._index
if not self._match(TokenType.R_PAREN):
# postgres: STRING_AGG([DISTINCT] expression, separator [ORDER BY expression1 {ASC | DESC} [, ...]])
order = self._parse_order(this=expression)
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
# Checks if we can parse an order clause: WITHIN GROUP (ORDER BY <order_by_expression_list> [ASC | DESC]).
# This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that
# the STRING_AGG call is parsed like in MySQL / SQLite and can thus be transpiled more easily to them.
if not self._match(TokenType.WITHIN_GROUP):
self._retreat(index)
this = exp.GroupConcat.from_arg_list(args)
self.validate_expression(this, args)
return this
self._match_l_paren() # The corresponding match_r_paren will be called in parse_function (caller)
order = self._parse_order(this=expression)
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
def _parse_convert(self, strict): def _parse_convert(self, strict):
this = self._parse_column() this = self._parse_column()
if self._match(TokenType.USING): if self._match(TokenType.USING):
@ -2511,8 +2542,8 @@ class Parser(metaclass=_Parser):
items = [parse_result] if parse_result is not None else [] items = [parse_result] if parse_result is not None else []
while self._match(sep): while self._match(sep):
if parse_result and self._prev_comment is not None: if parse_result and self._prev_comments:
parse_result.comment = self._prev_comment parse_result.comments = self._prev_comments
parse_result = parse_method() parse_result = parse_method()
if parse_result is not None: if parse_result is not None:
@ -2525,7 +2556,10 @@ class Parser(metaclass=_Parser):
while self._match_set(expressions): while self._match_set(expressions):
this = self.expression( this = self.expression(
expressions[self._prev.token_type], this=this, expression=parse_method() expressions[self._prev.token_type],
this=this,
comments=self._prev_comments,
expression=parse_method(),
) )
return this return this
@ -2566,6 +2600,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Transaction, this=this, modes=modes) return self.expression(exp.Transaction, this=this, modes=modes)
def _parse_commit_or_rollback(self): def _parse_commit_or_rollback(self):
chain = None
savepoint = None savepoint = None
is_rollback = self._prev.token_type == TokenType.ROLLBACK is_rollback = self._prev.token_type == TokenType.ROLLBACK
@ -2575,9 +2610,13 @@ class Parser(metaclass=_Parser):
self._match_text_seq("SAVEPOINT") self._match_text_seq("SAVEPOINT")
savepoint = self._parse_id_var() savepoint = self._parse_id_var()
if self._match(TokenType.AND):
chain = not self._match_text_seq("NO")
self._match_text_seq("CHAIN")
if is_rollback: if is_rollback:
return self.expression(exp.Rollback, savepoint=savepoint) return self.expression(exp.Rollback, savepoint=savepoint)
return self.expression(exp.Commit) return self.expression(exp.Commit, chain=chain)
def _parse_show(self): def _parse_show(self):
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
@ -2651,14 +2690,14 @@ class Parser(metaclass=_Parser):
def _match_l_paren(self, expression=None): def _match_l_paren(self, expression=None):
if not self._match(TokenType.L_PAREN): if not self._match(TokenType.L_PAREN):
self.raise_error("Expecting (") self.raise_error("Expecting (")
if expression and self._prev_comment: if expression and self._prev_comments:
expression.comment = self._prev_comment expression.comments = self._prev_comments
def _match_r_paren(self, expression=None): def _match_r_paren(self, expression=None):
if not self._match(TokenType.R_PAREN): if not self._match(TokenType.R_PAREN):
self.raise_error("Expecting )") self.raise_error("Expecting )")
if expression and self._prev_comment: if expression and self._prev_comments:
expression.comment = self._prev_comment expression.comments = self._prev_comments
def _match_texts(self, texts): def _match_texts(self, texts):
if self._curr and self._curr.text.upper() in texts: if self._curr and self._curr.text.upper() in texts:

View file

@ -130,18 +130,20 @@ class Step:
aggregations = [] aggregations = []
sequence = itertools.count() sequence = itertools.count()
for e in expression.expressions: def extract_agg_operands(expression):
aggregation = e.find(exp.AggFunc) for agg in expression.find_all(exp.AggFunc):
for operand in agg.unnest_operands():
if aggregation:
projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
aggregations.append(e)
for operand in aggregation.unnest_operands():
if isinstance(operand, exp.Column): if isinstance(operand, exp.Column):
continue continue
if operand not in operands: if operand not in operands:
operands[operand] = f"_a_{next(sequence)}" operands[operand] = f"_a_{next(sequence)}"
operand.replace(exp.column(operands[operand], quoted=True)) operand.replace(exp.column(operands[operand], quoted=True))
for e in expression.expressions:
if e.find(exp.AggFunc):
projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
aggregations.append(e)
extract_agg_operands(e)
else: else:
projections.append(e) projections.append(e)
@ -156,6 +158,13 @@ class Step:
aggregate = Aggregate() aggregate = Aggregate()
aggregate.source = step.name aggregate.source = step.name
aggregate.name = step.name aggregate.name = step.name
having = expression.args.get("having")
if having:
extract_agg_operands(having)
aggregate.condition = having.this
aggregate.operands = tuple( aggregate.operands = tuple(
alias(operand, alias_) for operand, alias_ in operands.items() alias(operand, alias_) for operand, alias_ in operands.items()
) )
@ -172,11 +181,6 @@ class Step:
aggregate.add_dependency(step) aggregate.add_dependency(step)
step = aggregate step = aggregate
having = expression.args.get("having")
if having:
step.condition = having.this
order = expression.args.get("order") order = expression.args.get("order")
if order: if order:
@ -188,6 +192,17 @@ class Step:
step.projections = projections step.projections = projections
if isinstance(expression, exp.Select) and expression.args.get("distinct"):
distinct = Aggregate()
distinct.source = step.name
distinct.name = step.name
distinct.group = {
e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name)
for e in projections or expression.expressions
}
distinct.add_dependency(step)
step = distinct
limit = expression.args.get("limit") limit = expression.args.get("limit")
if limit: if limit:
@ -231,6 +246,9 @@ class Step:
if self.condition: if self.condition:
lines.append(f"{nested}Condition: {self.condition.sql()}") lines.append(f"{nested}Condition: {self.condition.sql()}")
if self.limit is not math.inf:
lines.append(f"{nested}Limit: {self.limit}")
if self.dependencies: if self.dependencies:
lines.append(f"{nested}Dependencies:") lines.append(f"{nested}Dependencies:")
for dependency in self.dependencies: for dependency in self.dependencies:
@ -258,12 +276,7 @@ class Scan(Step):
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
) -> Step: ) -> Step:
table = expression table = expression
alias_ = expression.alias alias_ = expression.alias_or_name
if not alias_:
raise UnsupportedError(
"Tables/Subqueries must be aliased. Run it through the optimizer"
)
if isinstance(expression, exp.Subquery): if isinstance(expression, exp.Subquery):
table = expression.this table = expression.this
@ -338,6 +351,9 @@ class Aggregate(Step):
lines.append(f"{indent}Group:") lines.append(f"{indent}Group:")
for expression in self.group.values(): for expression in self.group.values():
lines.append(f"{indent} - {expression.sql()}") lines.append(f"{indent} - {expression.sql()}")
if self.condition:
lines.append(f"{indent}Having:")
lines.append(f"{indent} - {self.condition.sql()}")
if self.operands: if self.operands:
lines.append(f"{indent}Operands:") lines.append(f"{indent}Operands:")
for expression in self.operands: for expression in self.operands:

View file

@ -81,6 +81,7 @@ class TokenType(AutoName):
BINARY = auto() BINARY = auto()
VARBINARY = auto() VARBINARY = auto()
JSON = auto() JSON = auto()
JSONB = auto()
TIMESTAMP = auto() TIMESTAMP = auto()
TIMESTAMPTZ = auto() TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto() TIMESTAMPLTZ = auto()
@ -91,6 +92,7 @@ class TokenType(AutoName):
NULLABLE = auto() NULLABLE = auto()
GEOMETRY = auto() GEOMETRY = auto()
HLLSKETCH = auto() HLLSKETCH = auto()
HSTORE = auto()
SUPER = auto() SUPER = auto()
SERIAL = auto() SERIAL = auto()
SMALLSERIAL = auto() SMALLSERIAL = auto()
@ -113,6 +115,7 @@ class TokenType(AutoName):
APPLY = auto() APPLY = auto()
ARRAY = auto() ARRAY = auto()
ASC = auto() ASC = auto()
ASOF = auto()
AT_TIME_ZONE = auto() AT_TIME_ZONE = auto()
AUTO_INCREMENT = auto() AUTO_INCREMENT = auto()
BEGIN = auto() BEGIN = auto()
@ -130,6 +133,7 @@ class TokenType(AutoName):
COMMAND = auto() COMMAND = auto()
COMMENT = auto() COMMENT = auto()
COMMIT = auto() COMMIT = auto()
COMPOUND = auto()
CONSTRAINT = auto() CONSTRAINT = auto()
CREATE = auto() CREATE = auto()
CROSS = auto() CROSS = auto()
@ -271,6 +275,7 @@ class TokenType(AutoName):
UNBOUNDED = auto() UNBOUNDED = auto()
UNCACHE = auto() UNCACHE = auto()
UNION = auto() UNION = auto()
UNLOGGED = auto()
UNNEST = auto() UNNEST = auto()
UNPIVOT = auto() UNPIVOT = auto()
UPDATE = auto() UPDATE = auto()
@ -291,7 +296,7 @@ class TokenType(AutoName):
class Token: class Token:
__slots__ = ("token_type", "text", "line", "col", "comment") __slots__ = ("token_type", "text", "line", "col", "comments")
@classmethod @classmethod
def number(cls, number: int) -> Token: def number(cls, number: int) -> Token:
@ -319,13 +324,13 @@ class Token:
text: str, text: str,
line: int = 1, line: int = 1,
col: int = 1, col: int = 1,
comment: t.Optional[str] = None, comments: t.List[str] = [],
) -> None: ) -> None:
self.token_type = token_type self.token_type = token_type
self.text = text self.text = text
self.line = line self.line = line
self.col = max(col - len(text), 1) self.col = max(col - len(text), 1)
self.comment = comment self.comments = comments
def __repr__(self) -> str: def __repr__(self) -> str:
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__) attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
@ -452,6 +457,7 @@ class Tokenizer(metaclass=_Tokenizer):
"COLLATE": TokenType.COLLATE, "COLLATE": TokenType.COLLATE,
"COMMENT": TokenType.SCHEMA_COMMENT, "COMMENT": TokenType.SCHEMA_COMMENT,
"COMMIT": TokenType.COMMIT, "COMMIT": TokenType.COMMIT,
"COMPOUND": TokenType.COMPOUND,
"CONSTRAINT": TokenType.CONSTRAINT, "CONSTRAINT": TokenType.CONSTRAINT,
"CREATE": TokenType.CREATE, "CREATE": TokenType.CREATE,
"CROSS": TokenType.CROSS, "CROSS": TokenType.CROSS,
@ -582,8 +588,9 @@ class Tokenizer(metaclass=_Tokenizer):
"TRAILING": TokenType.TRAILING, "TRAILING": TokenType.TRAILING,
"UNBOUNDED": TokenType.UNBOUNDED, "UNBOUNDED": TokenType.UNBOUNDED,
"UNION": TokenType.UNION, "UNION": TokenType.UNION,
"UNPIVOT": TokenType.UNPIVOT, "UNLOGGED": TokenType.UNLOGGED,
"UNNEST": TokenType.UNNEST, "UNNEST": TokenType.UNNEST,
"UNPIVOT": TokenType.UNPIVOT,
"UPDATE": TokenType.UPDATE, "UPDATE": TokenType.UPDATE,
"USE": TokenType.USE, "USE": TokenType.USE,
"USING": TokenType.USING, "USING": TokenType.USING,
@ -686,12 +693,12 @@ class Tokenizer(metaclass=_Tokenizer):
"_current", "_current",
"_line", "_line",
"_col", "_col",
"_comment", "_comments",
"_char", "_char",
"_end", "_end",
"_peek", "_peek",
"_prev_token_line", "_prev_token_line",
"_prev_token_comment", "_prev_token_comments",
"_prev_token_type", "_prev_token_type",
"_replace_backslash", "_replace_backslash",
) )
@ -708,13 +715,13 @@ class Tokenizer(metaclass=_Tokenizer):
self._current = 0 self._current = 0
self._line = 1 self._line = 1
self._col = 1 self._col = 1
self._comment = None self._comments: t.List[str] = []
self._char = None self._char = None
self._end = None self._end = None
self._peek = None self._peek = None
self._prev_token_line = -1 self._prev_token_line = -1
self._prev_token_comment = None self._prev_token_comments: t.List[str] = []
self._prev_token_type = None self._prev_token_type = None
def tokenize(self, sql: str) -> t.List[Token]: def tokenize(self, sql: str) -> t.List[Token]:
@ -767,7 +774,7 @@ class Tokenizer(metaclass=_Tokenizer):
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None: def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
self._prev_token_line = self._line self._prev_token_line = self._line
self._prev_token_comment = self._comment self._prev_token_comments = self._comments
self._prev_token_type = token_type # type: ignore self._prev_token_type = token_type # type: ignore
self.tokens.append( self.tokens.append(
Token( Token(
@ -775,10 +782,10 @@ class Tokenizer(metaclass=_Tokenizer):
self._text if text is None else text, self._text if text is None else text,
self._line, self._line,
self._col, self._col,
self._comment, self._comments,
) )
) )
self._comment = None self._comments = []
if token_type in self.COMMANDS and ( if token_type in self.COMMANDS and (
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
@ -857,22 +864,18 @@ class Tokenizer(metaclass=_Tokenizer):
while not self._end and self._chars(comment_end_size) != comment_end: while not self._end and self._chars(comment_end_size) != comment_end:
self._advance() self._advance()
self._comment = self._text[comment_start_size : -comment_end_size + 1] # type: ignore self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore
self._advance(comment_end_size - 1) self._advance(comment_end_size - 1)
else: else:
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore
self._advance() self._advance()
self._comment = self._text[comment_start_size:] # type: ignore self._comments.append(self._text[comment_start_size:]) # type: ignore
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. If both
# types of comment can be attached to a token, the trailing one is discarded in favour of the leading one.
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding.
# Multiple consecutive comments are preserved by appending them to the current comments list.
if comment_start_line == self._prev_token_line: if comment_start_line == self._prev_token_line:
if self._prev_token_comment is None: self.tokens[-1].comments.extend(self._comments)
self.tokens[-1].comment = self._comment self._comments = []
self._prev_token_comment = self._comment
self._comment = None
return True return True

View file

@ -2,6 +2,8 @@ from __future__ import annotations
import typing as t import typing as t
from sqlglot.helper import find_new_name
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from sqlglot.generator import Generator from sqlglot.generator import Generator
@ -43,6 +45,43 @@ def unalias_group(expression: exp.Expression) -> exp.Expression:
return expression return expression
def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
"""
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Args:
expression: the expression that will be transformed.
Returns:
The transformed expression.
"""
if (
isinstance(expression, exp.Select)
and expression.args.get("distinct")
and expression.args["distinct"].args.get("on")
and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
):
distinct_cols = [e.copy() for e in expression.args["distinct"].args["on"].expressions]
outer_selects = [e.copy() for e in expression.expressions]
nested = expression.copy()
nested.args["distinct"].pop()
row_number = find_new_name(expression.named_selects, "_row_number")
window = exp.Window(
this=exp.RowNumber(),
partition_by=distinct_cols,
)
order = nested.args.get("order")
if order:
window.set("order", order.copy())
order.pop()
window = exp.alias_(window, row_number)
nested.select(window, copy=False)
return exp.select(*outer_selects).from_(nested.subquery()).where(f'"{row_number}" = 1')
return expression
def preprocess( def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
to_sql: t.Callable[[Generator, exp.Expression], str], to_sql: t.Callable[[Generator, exp.Expression], str],
@ -81,3 +120,4 @@ def delegate(attr: str) -> t.Callable:
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))} UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))}

View file

@ -1276,7 +1276,7 @@ class TestFunctions(unittest.TestCase):
col = SF.concat(SF.col("cola"), SF.col("colb")) col = SF.concat(SF.col("cola"), SF.col("colb"))
self.assertEqual("CONCAT(cola, colb)", col.sql()) self.assertEqual("CONCAT(cola, colb)", col.sql())
col_single = SF.concat("cola") col_single = SF.concat("cola")
self.assertEqual("CONCAT(cola)", col_single.sql()) self.assertEqual("cola", col_single.sql())
def test_array_position(self): def test_array_position(self):
col_str = SF.array_position("cola", SF.col("colb")) col_str = SF.array_position("cola", SF.col("colb"))

View file

@ -10,6 +10,10 @@ class TestClickhouse(Validator):
self.validate_identity("SELECT * FROM x AS y FINAL") self.validate_identity("SELECT * FROM x AS y FINAL")
self.validate_identity("'a' IN mapKeys(map('a', 1, 'b', 2))") self.validate_identity("'a' IN mapKeys(map('a', 1, 'b', 2))")
self.validate_identity("CAST((1, 2) AS Tuple(a Int8, b Int16))") self.validate_identity("CAST((1, 2) AS Tuple(a Int8, b Int16))")
self.validate_identity("SELECT * FROM foo LEFT ANY JOIN bla")
self.validate_identity("SELECT * FROM foo LEFT ASOF JOIN bla")
self.validate_identity("SELECT * FROM foo ASOF JOIN bla")
self.validate_identity("SELECT * FROM foo ANY JOIN bla")
self.validate_all( self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",

View file

@ -997,6 +997,13 @@ class TestDialect(Validator):
"spark": "CONCAT_WS('-', x)", "spark": "CONCAT_WS('-', x)",
}, },
) )
self.validate_all(
"CONCAT(a)",
write={
"mysql": "a",
"tsql": "a",
},
)
self.validate_all( self.validate_all(
"IF(x > 1, 1, 0)", "IF(x > 1, 1, 0)",
write={ write={
@ -1263,8 +1270,8 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"""/* comment1 */ """/* comment1 */
SELECT SELECT
x, -- comment2 x, /* comment2 */
y -- comment3""", y /* comment3 */""",
read={ read={
"mysql": """SELECT # comment1 "mysql": """SELECT # comment1
x, # comment2 x, # comment2

View file

@ -89,6 +89,8 @@ class TestDuckDB(Validator):
"presto": "CAST(COL AS ARRAY(BIGINT))", "presto": "CAST(COL AS ARRAY(BIGINT))",
"hive": "CAST(COL AS ARRAY<BIGINT>)", "hive": "CAST(COL AS ARRAY<BIGINT>)",
"spark": "CAST(COL AS ARRAY<LONG>)", "spark": "CAST(COL AS ARRAY<LONG>)",
"postgres": "CAST(COL AS BIGINT[])",
"snowflake": "CAST(COL AS ARRAY)",
}, },
) )
@ -104,6 +106,10 @@ class TestDuckDB(Validator):
"spark": "ARRAY(0, 1, 2)", "spark": "ARRAY(0, 1, 2)",
}, },
) )
self.validate_all(
"SELECT ARRAY_LENGTH([0], 1) AS x",
write={"duckdb": "SELECT ARRAY_LENGTH(LIST_VALUE(0), 1) AS x"},
)
self.validate_all( self.validate_all(
"REGEXP_MATCHES(x, y)", "REGEXP_MATCHES(x, y)",
write={ write={

View file

@ -139,7 +139,7 @@ class TestHive(Validator):
"CREATE TABLE test STORED AS parquet TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1", "CREATE TABLE test STORED AS parquet TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1",
write={ write={
"duckdb": "CREATE TABLE test AS SELECT 1", "duckdb": "CREATE TABLE test AS SELECT 1",
"presto": "CREATE TABLE test WITH (FORMAT='parquet', x='1', Z='2') AS SELECT 1", "presto": "CREATE TABLE test WITH (FORMAT='PARQUET', x='1', Z='2') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1", "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1",
"spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1", "spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1",
}, },
@ -459,6 +459,7 @@ class TestHive(Validator):
"hive": "MAP(a, b, c, d)", "hive": "MAP(a, b, c, d)",
"presto": "MAP(ARRAY[a, c], ARRAY[b, d])", "presto": "MAP(ARRAY[a, c], ARRAY[b, d])",
"spark": "MAP(a, b, c, d)", "spark": "MAP(a, b, c, d)",
"snowflake": "OBJECT_CONSTRUCT(a, b, c, d)",
}, },
write={ write={
"": "MAP(ARRAY(a, c), ARRAY(b, d))", "": "MAP(ARRAY(a, c), ARRAY(b, d))",
@ -467,6 +468,7 @@ class TestHive(Validator):
"presto": "MAP(ARRAY[a, c], ARRAY[b, d])", "presto": "MAP(ARRAY[a, c], ARRAY[b, d])",
"hive": "MAP(a, b, c, d)", "hive": "MAP(a, b, c, d)",
"spark": "MAP(a, b, c, d)", "spark": "MAP(a, b, c, d)",
"snowflake": "OBJECT_CONSTRUCT(a, b, c, d)",
}, },
) )
self.validate_all( self.validate_all(
@ -476,6 +478,7 @@ class TestHive(Validator):
"presto": "MAP(ARRAY[a], ARRAY[b])", "presto": "MAP(ARRAY[a], ARRAY[b])",
"hive": "MAP(a, b)", "hive": "MAP(a, b)",
"spark": "MAP(a, b)", "spark": "MAP(a, b)",
"snowflake": "OBJECT_CONSTRUCT(a, b)",
}, },
) )
self.validate_all( self.validate_all(

View file

@ -23,6 +23,8 @@ class TestMySQL(Validator):
self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')") self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')")
self.validate_identity("@@GLOBAL.max_connections") self.validate_identity("@@GLOBAL.max_connections")
self.validate_identity("CREATE TABLE A LIKE B")
# SET Commands # SET Commands
self.validate_identity("SET @var_name = expr") self.validate_identity("SET @var_name = expr")
self.validate_identity("SET @name = 43") self.validate_identity("SET @name = 43")
@ -177,14 +179,27 @@ class TestMySQL(Validator):
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC)", "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
write={ write={
"mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR ',')", "mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR ',')",
"sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)", "sqlite": "GROUP_CONCAT(DISTINCT x)",
"tsql": "STRING_AGG(x, ',') WITHIN GROUP (ORDER BY y DESC)",
"postgres": "STRING_AGG(DISTINCT x, ',' ORDER BY y DESC NULLS LAST)",
},
)
self.validate_all(
"GROUP_CONCAT(x ORDER BY y SEPARATOR z)",
write={
"mysql": "GROUP_CONCAT(x ORDER BY y SEPARATOR z)",
"sqlite": "GROUP_CONCAT(x, z)",
"tsql": "STRING_AGG(x, z) WITHIN GROUP (ORDER BY y)",
"postgres": "STRING_AGG(x, z ORDER BY y NULLS FIRST)",
}, },
) )
self.validate_all( self.validate_all(
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')", "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')",
write={ write={
"mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')", "mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')",
"sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC, '')", "sqlite": "GROUP_CONCAT(DISTINCT x, '')",
"tsql": "STRING_AGG(x, '') WITHIN GROUP (ORDER BY y DESC)",
"postgres": "STRING_AGG(DISTINCT x, '' ORDER BY y DESC NULLS LAST)",
}, },
) )
self.validate_identity( self.validate_identity(

View file

@ -6,6 +6,9 @@ class TestPostgres(Validator):
dialect = "postgres" dialect = "postgres"
def test_ddl(self): def test_ddl(self):
self.validate_identity("CREATE TABLE test (foo HSTORE)")
self.validate_identity("CREATE TABLE test (foo JSONB)")
self.validate_identity("CREATE TABLE test (foo VARCHAR(64)[])")
self.validate_all( self.validate_all(
"CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)", "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)",
write={ write={
@ -60,6 +63,12 @@ class TestPostgres(Validator):
) )
def test_postgres(self): def test_postgres(self):
self.validate_identity("SELECT ARRAY[1, 2, 3]")
self.validate_identity("SELECT ARRAY_LENGTH(ARRAY[1, 2, 3], 1)")
self.validate_identity("STRING_AGG(x, y)")
self.validate_identity("STRING_AGG(x, ',' ORDER BY y)")
self.validate_identity("STRING_AGG(x, ',' ORDER BY y DESC)")
self.validate_identity("STRING_AGG(DISTINCT x, ',' ORDER BY y DESC)")
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END") self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END")
self.validate_identity( self.validate_identity(
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END" "SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END"
@ -86,6 +95,14 @@ class TestPostgres(Validator):
self.validate_identity("SELECT e'\\xDEADBEEF'") self.validate_identity("SELECT e'\\xDEADBEEF'")
self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)") self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
self.validate_all(
"END WORK AND NO CHAIN",
write={"postgres": "COMMIT AND NO CHAIN"},
)
self.validate_all(
"END AND CHAIN",
write={"postgres": "COMMIT AND CHAIN"},
)
self.validate_all( self.validate_all(
"CREATE TABLE x (a UUID, b BYTEA)", "CREATE TABLE x (a UUID, b BYTEA)",
write={ write={
@ -95,6 +112,10 @@ class TestPostgres(Validator):
"spark": "CREATE TABLE x (a UUID, b BINARY)", "spark": "CREATE TABLE x (a UUID, b BINARY)",
}, },
) )
self.validate_identity(
"CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)"
)
self.validate_all( self.validate_all(
"SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS 1 PRECEDING)", "SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS 1 PRECEDING)",
write={ write={

View file

@ -13,6 +13,7 @@ class TestPresto(Validator):
"duckdb": "CAST(a AS INT[])", "duckdb": "CAST(a AS INT[])",
"presto": "CAST(a AS ARRAY(INTEGER))", "presto": "CAST(a AS ARRAY(INTEGER))",
"spark": "CAST(a AS ARRAY<INT>)", "spark": "CAST(a AS ARRAY<INT>)",
"snowflake": "CAST(a AS ARRAY)",
}, },
) )
self.validate_all( self.validate_all(
@ -31,6 +32,7 @@ class TestPresto(Validator):
"duckdb": "CAST(LIST_VALUE(1, 2) AS BIGINT[])", "duckdb": "CAST(LIST_VALUE(1, 2) AS BIGINT[])",
"presto": "CAST(ARRAY[1, 2] AS ARRAY(BIGINT))", "presto": "CAST(ARRAY[1, 2] AS ARRAY(BIGINT))",
"spark": "CAST(ARRAY(1, 2) AS ARRAY<LONG>)", "spark": "CAST(ARRAY(1, 2) AS ARRAY<LONG>)",
"snowflake": "CAST([1, 2] AS ARRAY)",
}, },
) )
self.validate_all( self.validate_all(
@ -41,6 +43,7 @@ class TestPresto(Validator):
"presto": "CAST(MAP(ARRAY[1], ARRAY[1]) AS MAP(INTEGER, INTEGER))", "presto": "CAST(MAP(ARRAY[1], ARRAY[1]) AS MAP(INTEGER, INTEGER))",
"hive": "CAST(MAP(1, 1) AS MAP<INT, INT>)", "hive": "CAST(MAP(1, 1) AS MAP<INT, INT>)",
"spark": "CAST(MAP_FROM_ARRAYS(ARRAY(1), ARRAY(1)) AS MAP<INT, INT>)", "spark": "CAST(MAP_FROM_ARRAYS(ARRAY(1), ARRAY(1)) AS MAP<INT, INT>)",
"snowflake": "CAST(OBJECT_CONSTRUCT(1, 1) AS OBJECT)",
}, },
) )
self.validate_all( self.validate_all(
@ -51,6 +54,7 @@ class TestPresto(Validator):
"presto": "CAST(MAP(ARRAY['a', 'b', 'c'], ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]]) AS MAP(VARCHAR, ARRAY(INTEGER)))", "presto": "CAST(MAP(ARRAY['a', 'b', 'c'], ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]]) AS MAP(VARCHAR, ARRAY(INTEGER)))",
"hive": "CAST(MAP('a', ARRAY(1), 'b', ARRAY(2), 'c', ARRAY(3)) AS MAP<STRING, ARRAY<INT>>)", "hive": "CAST(MAP('a', ARRAY(1), 'b', ARRAY(2), 'c', ARRAY(3)) AS MAP<STRING, ARRAY<INT>>)",
"spark": "CAST(MAP_FROM_ARRAYS(ARRAY('a', 'b', 'c'), ARRAY(ARRAY(1), ARRAY(2), ARRAY(3))) AS MAP<STRING, ARRAY<INT>>)", "spark": "CAST(MAP_FROM_ARRAYS(ARRAY('a', 'b', 'c'), ARRAY(ARRAY(1), ARRAY(2), ARRAY(3))) AS MAP<STRING, ARRAY<INT>>)",
"snowflake": "CAST(OBJECT_CONSTRUCT('a', [1], 'b', [2], 'c', [3]) AS OBJECT)",
}, },
) )
self.validate_all( self.validate_all(
@ -393,6 +397,7 @@ class TestPresto(Validator):
write={ write={
"hive": UnsupportedError, "hive": UnsupportedError,
"spark": "MAP_FROM_ARRAYS(a, b)", "spark": "MAP_FROM_ARRAYS(a, b)",
"snowflake": UnsupportedError,
}, },
) )
self.validate_all( self.validate_all(
@ -401,6 +406,7 @@ class TestPresto(Validator):
"hive": "MAP(a, c, b, d)", "hive": "MAP(a, c, b, d)",
"presto": "MAP(ARRAY[a, b], ARRAY[c, d])", "presto": "MAP(ARRAY[a, b], ARRAY[c, d])",
"spark": "MAP_FROM_ARRAYS(ARRAY(a, b), ARRAY(c, d))", "spark": "MAP_FROM_ARRAYS(ARRAY(a, b), ARRAY(c, d))",
"snowflake": "OBJECT_CONSTRUCT(a, c, b, d)",
}, },
) )
self.validate_all( self.validate_all(
@ -409,6 +415,7 @@ class TestPresto(Validator):
"hive": "MAP('a', 'b')", "hive": "MAP('a', 'b')",
"presto": "MAP(ARRAY['a'], ARRAY['b'])", "presto": "MAP(ARRAY['a'], ARRAY['b'])",
"spark": "MAP_FROM_ARRAYS(ARRAY('a'), ARRAY('b'))", "spark": "MAP_FROM_ARRAYS(ARRAY('a'), ARRAY('b'))",
"snowflake": "OBJECT_CONSTRUCT('a', 'b')",
}, },
) )
self.validate_all( self.validate_all(

View file

@ -50,6 +50,12 @@ class TestRedshift(Validator):
"redshift": 'SELECT tablename, "column" FROM pg_table_def WHERE "column" LIKE \'%start\\\\_%\' LIMIT 5' "redshift": 'SELECT tablename, "column" FROM pg_table_def WHERE "column" LIKE \'%start\\\\_%\' LIMIT 5'
}, },
) )
self.validate_all(
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
write={
"redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1',
},
)
def test_identity(self): def test_identity(self):
self.validate_identity("CAST('bla' AS SUPER)") self.validate_identity("CAST('bla' AS SUPER)")
@ -64,3 +70,13 @@ class TestRedshift(Validator):
self.validate_identity( self.validate_identity(
"SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'" "SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'"
) )
self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE AUTO")
self.validate_identity(
"CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid)"
)
self.validate_identity(
"COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'"
)
self.validate_identity(
"UNLOAD ('select * from venue') TO 's3://mybucket/unload/' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'"
)

View file

@ -172,13 +172,28 @@ class TestSnowflake(Validator):
self.validate_all( self.validate_all(
"trim(date_column, 'UTC')", "trim(date_column, 'UTC')",
write={ write={
"bigquery": "TRIM(date_column, 'UTC')",
"snowflake": "TRIM(date_column, 'UTC')", "snowflake": "TRIM(date_column, 'UTC')",
"postgres": "TRIM('UTC' FROM date_column)", "postgres": "TRIM('UTC' FROM date_column)",
}, },
) )
self.validate_all( self.validate_all(
"trim(date_column)", "trim(date_column)",
write={"snowflake": "TRIM(date_column)"}, write={
"snowflake": "TRIM(date_column)",
"bigquery": "TRIM(date_column)",
},
)
self.validate_all(
"DECODE(x, a, b, c, d)",
read={
"": "MATCHES(x, a, b, c, d)",
},
write={
"": "MATCHES(x, a, b, c, d)",
"oracle": "DECODE(x, a, b, c, d)",
"snowflake": "DECODE(x, a, b, c, d)",
},
) )
def test_null_treatment(self): def test_null_treatment(self):
@ -370,7 +385,8 @@ class TestSnowflake(Validator):
) )
self.validate_all( self.validate_all(
r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""} r"""SELECT * FROM TABLE(?)""",
write={"snowflake": r"""SELECT * FROM TABLE(?)"""},
) )
self.validate_all( self.validate_all(

View file

@ -32,13 +32,14 @@ class TestSpark(Validator):
"presto": "CREATE TABLE db.example_table (col_a ARRAY(INTEGER), col_b ARRAY(ARRAY(INTEGER)))", "presto": "CREATE TABLE db.example_table (col_a ARRAY(INTEGER), col_b ARRAY(ARRAY(INTEGER)))",
"hive": "CREATE TABLE db.example_table (col_a ARRAY<INT>, col_b ARRAY<ARRAY<INT>>)", "hive": "CREATE TABLE db.example_table (col_a ARRAY<INT>, col_b ARRAY<ARRAY<INT>>)",
"spark": "CREATE TABLE db.example_table (col_a ARRAY<INT>, col_b ARRAY<ARRAY<INT>>)", "spark": "CREATE TABLE db.example_table (col_a ARRAY<INT>, col_b ARRAY<ARRAY<INT>>)",
"snowflake": "CREATE TABLE db.example_table (col_a ARRAY, col_b ARRAY)",
}, },
) )
self.validate_all( self.validate_all(
"CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'",
write={ write={
"duckdb": "CREATE TABLE x", "duckdb": "CREATE TABLE x",
"presto": "CREATE TABLE x WITH (TABLE_FORMAT = 'ICEBERG', PARTITIONED_BY=ARRAY['MONTHS'])", "presto": "CREATE TABLE x WITH (TABLE_FORMAT='ICEBERG', PARTITIONED_BY=ARRAY['MONTHS'])",
"hive": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", "hive": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'",
"spark": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", "spark": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'",
}, },
@ -94,6 +95,13 @@ TBLPROPERTIES (
pretty=True, pretty=True,
) )
self.validate_all(
"CACHE TABLE testCache OPTIONS ('storageLevel' 'DISK_ONLY') SELECT * FROM testData",
write={
"spark": "CACHE TABLE testCache OPTIONS('storageLevel' = 'DISK_ONLY') AS SELECT * FROM testData"
},
)
def test_to_date(self): def test_to_date(self):
self.validate_all( self.validate_all(
"TO_DATE(x, 'yyyy-MM-dd')", "TO_DATE(x, 'yyyy-MM-dd')",
@ -271,6 +279,7 @@ TBLPROPERTIES (
"presto": "MAP(ARRAY[1], c)", "presto": "MAP(ARRAY[1], c)",
"hive": "MAP(ARRAY(1), c)", "hive": "MAP(ARRAY(1), c)",
"spark": "MAP_FROM_ARRAYS(ARRAY(1), c)", "spark": "MAP_FROM_ARRAYS(ARRAY(1), c)",
"snowflake": "OBJECT_CONSTRUCT([1], c)",
}, },
) )
self.validate_all( self.validate_all(

View file

@ -5,6 +5,10 @@ class TestSQLite(Validator):
dialect = "sqlite" dialect = "sqlite"
def test_ddl(self): def test_ddl(self):
self.validate_all(
"CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)",
write={"sqlite": "CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)"},
)
self.validate_all( self.validate_all(
""" """
CREATE TABLE "Track" CREATE TABLE "Track"

View file

@ -17,7 +17,6 @@ class TestTSQL(Validator):
"spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo", "spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo",
}, },
) )
self.validate_all( self.validate_all(
"CONVERT(INT, CONVERT(NUMERIC, '444.75'))", "CONVERT(INT, CONVERT(NUMERIC, '444.75'))",
write={ write={
@ -25,6 +24,33 @@ class TestTSQL(Validator):
"tsql": "CAST(CAST('444.75' AS NUMERIC) AS INTEGER)", "tsql": "CAST(CAST('444.75' AS NUMERIC) AS INTEGER)",
}, },
) )
self.validate_all(
"STRING_AGG(x, y) WITHIN GROUP (ORDER BY z DESC)",
write={
"tsql": "STRING_AGG(x, y) WITHIN GROUP (ORDER BY z DESC)",
"mysql": "GROUP_CONCAT(x ORDER BY z DESC SEPARATOR y)",
"sqlite": "GROUP_CONCAT(x, y)",
"postgres": "STRING_AGG(x, y ORDER BY z DESC NULLS LAST)",
},
)
self.validate_all(
"STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z ASC)",
write={
"tsql": "STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z)",
"mysql": "GROUP_CONCAT(x ORDER BY z SEPARATOR '|')",
"sqlite": "GROUP_CONCAT(x, '|')",
"postgres": "STRING_AGG(x, '|' ORDER BY z NULLS FIRST)",
},
)
self.validate_all(
"STRING_AGG(x, '|')",
write={
"tsql": "STRING_AGG(x, '|')",
"mysql": "GROUP_CONCAT(x SEPARATOR '|')",
"sqlite": "GROUP_CONCAT(x, '|')",
"postgres": "STRING_AGG(x, '|')",
},
)
def test_types(self): def test_types(self):
self.validate_identity("CAST(x AS XML)") self.validate_identity("CAST(x AS XML)")

View file

@ -34,6 +34,7 @@ x >> 1
x >> 1 | 1 & 1 ^ 1 x >> 1 | 1 & 1 ^ 1
x || y x || y
1 - -1 1 - -1
- -5
dec.x + y dec.x + y
a.filter a.filter
a.b.c a.b.c
@ -438,6 +439,7 @@ SELECT student, score FROM tests CROSS JOIN UNNEST(scores) AS t(a, b)
SELECT student, score FROM tests CROSS JOIN UNNEST(scores) WITH ORDINALITY AS t(a, b) SELECT student, score FROM tests CROSS JOIN UNNEST(scores) WITH ORDINALITY AS t(a, b)
SELECT student, score FROM tests CROSS JOIN UNNEST(x.scores) AS t(score) SELECT student, score FROM tests CROSS JOIN UNNEST(x.scores) AS t(score)
SELECT student, score FROM tests CROSS JOIN UNNEST(ARRAY(x.scores)) AS t(score) SELECT student, score FROM tests CROSS JOIN UNNEST(ARRAY(x.scores)) AS t(score)
CREATE TABLE foo (id INT PRIMARY KEY ASC)
CREATE TABLE a.b AS SELECT 1 CREATE TABLE a.b AS SELECT 1
CREATE TABLE a.b AS SELECT a FROM a.c CREATE TABLE a.b AS SELECT a FROM a.c
CREATE TABLE IF NOT EXISTS x AS SELECT a FROM d CREATE TABLE IF NOT EXISTS x AS SELECT a FROM d
@ -579,6 +581,7 @@ SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)
SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl) SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)
SELECT CAST(x AS INT) /* comment */ FROM foo SELECT CAST(x AS INT) /* comment */ FROM foo
SELECT a /* x */, b /* x */ SELECT a /* x */, b /* x */
SELECT a /* x */ /* y */ /* z */, b /* k */ /* m */
SELECT * FROM foo /* x */, bla /* x */ SELECT * FROM foo /* x */, bla /* x */
SELECT 1 /* comment */ + 1 SELECT 1 /* comment */ + 1
SELECT 1 /* c1 */ + 2 /* c2 */ SELECT 1 /* c1 */ + 2 /* c2 */
@ -588,3 +591,7 @@ SELECT x FROM a.b.c /* x */, e.f.g /* x */
SELECT FOO(x /* c */) /* FOO */, b /* b */ SELECT FOO(x /* c */) /* FOO */, b /* b */
SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */ SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */
SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b' SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b'
SELECT x AS INTO FROM bla
SELECT * INTO newevent FROM event
SELECT * INTO TEMPORARY newevent FROM event
SELECT * INTO UNLOGGED newevent FROM event

View file

@ -77,3 +77,15 @@ WITH x_2 AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT x.id FROM x
-- Existing duplicate CTE -- Existing duplicate CTE
WITH y AS (SELECT a FROM x) SELECT a FROM (SELECT a FROM x) AS y JOIN y AS z; WITH y AS (SELECT a FROM x) SELECT a FROM (SELECT a FROM x) AS y JOIN y AS z;
WITH y AS (SELECT a FROM x) SELECT a FROM y AS y JOIN y AS z; WITH y AS (SELECT a FROM x) SELECT a FROM y AS y JOIN y AS z;
-- Nested CTE
WITH cte1 AS (SELECT a FROM x) SELECT a FROM (WITH cte2 AS (SELECT a FROM cte1) SELECT a FROM cte2);
WITH cte1 AS (SELECT a FROM x), cte2 AS (SELECT a FROM cte1), cte AS (SELECT a FROM cte2 AS cte2) SELECT a FROM cte AS cte;
-- Nested CTE inside CTE
WITH cte1 AS (WITH cte2 AS (SELECT a FROM x) SELECT t.a FROM cte2 AS t) SELECT a FROM cte1;
WITH cte2 AS (SELECT a FROM x), cte1 AS (SELECT t.a FROM cte2 AS t) SELECT a FROM cte1;
-- Duplicate CTE nested in CTE
WITH cte1 AS (SELECT a FROM x), cte2 AS (WITH cte3 AS (SELECT a FROM x) SELECT a FROM cte3) SELECT a FROM cte2;
WITH cte1 AS (SELECT a FROM x), cte2 AS (SELECT a FROM cte1 AS cte3) SELECT a FROM cte2;

View file

@ -0,0 +1,41 @@
SELECT a FROM x;
SELECT a FROM x;
SELECT "A" FROM "X";
SELECT "A" FROM "X";
SELECT a AS A FROM x;
SELECT a AS A FROM x;
SELECT * FROM x;
SELECT * FROM x;
SELECT A FROM x;
SELECT a FROM x;
SELECT a FROM X;
SELECT a FROM x;
SELECT A AS A FROM (SELECT a AS A FROM x);
SELECT a AS A FROM (SELECT a AS a FROM x);
SELECT a AS B FROM x ORDER BY B;
SELECT a AS B FROM x ORDER BY B;
SELECT A FROM x ORDER BY A;
SELECT a FROM x ORDER BY a;
SELECT A AS B FROM X GROUP BY A HAVING SUM(B) > 0;
SELECT a AS B FROM x GROUP BY a HAVING SUM(b) > 0;
SELECT A AS B, SUM(B) AS C FROM X GROUP BY A HAVING C > 0;
SELECT a AS B, SUM(b) AS C FROM x GROUP BY a HAVING C > 0;
SELECT A FROM X UNION SELECT A FROM X;
SELECT a FROM x UNION SELECT a FROM x;
SELECT A AS A FROM X UNION SELECT A AS A FROM X;
SELECT a AS A FROM x UNION SELECT a AS A FROM x;
(SELECT A AS A FROM X);
(SELECT a AS A FROM x);

View file

@ -276,3 +276,18 @@ SELECT /*+ COALESCE(3),
FROM `x` AS `x` FROM `x` AS `x`
JOIN `y` AS `y` JOIN `y` AS `y`
ON `x`.`b` = `y`.`b`; ON `x`.`b` = `y`.`b`;
WITH cte1 AS (
WITH cte2 AS (
SELECT a, b FROM x
)
SELECT a1
FROM (
WITH cte3 AS (SELECT 1)
SELECT a AS a1, b AS b1 FROM cte2
)
)
SELECT a1 FROM cte1;
SELECT
"x"."a" AS "a1"
FROM "x" AS "x";

View file

@ -274,6 +274,15 @@ TRUE;
-(-1); -(-1);
1; 1;
- -+1;
1;
+-1;
-1;
++1;
1;
0.06 - 0.01; 0.06 - 0.01;
0.05; 0.05;

View file

@ -666,19 +666,7 @@ WITH "supplier_2" AS (
FROM "nation" AS "nation" FROM "nation" AS "nation"
WHERE WHERE
"nation"."n_name" = 'GERMANY' "nation"."n_name" = 'GERMANY'
) ), "_u_0" AS (
SELECT
"partsupp"."ps_partkey" AS "ps_partkey",
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value"
FROM "partsupp" AS "partsupp"
JOIN "supplier_2" AS "supplier"
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
JOIN "nation_2" AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
GROUP BY
"partsupp"."ps_partkey"
HAVING
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > (
SELECT SELECT
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0" SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0"
FROM "partsupp" AS "partsupp" FROM "partsupp" AS "partsupp"
@ -686,7 +674,20 @@ HAVING
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
JOIN "nation_2" AS "nation" JOIN "nation_2" AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey" ON "supplier"."s_nationkey" = "nation"."n_nationkey"
) )
SELECT
"partsupp"."ps_partkey" AS "ps_partkey",
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value"
FROM "partsupp" AS "partsupp"
CROSS JOIN "_u_0" AS "_u_0"
JOIN "supplier_2" AS "supplier"
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
JOIN "nation_2" AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
GROUP BY
"partsupp"."ps_partkey"
HAVING
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > MAX("_u_0"."_col_0")
ORDER BY ORDER BY
"value" DESC; "value" DESC;
@ -880,6 +881,10 @@ WITH "revenue" AS (
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1996-01-01' AS DATE) AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1996-01-01' AS DATE)
GROUP BY GROUP BY
"lineitem"."l_suppkey" "lineitem"."l_suppkey"
), "_u_0" AS (
SELECT
MAX("revenue"."total_revenue") AS "_col_0"
FROM "revenue"
) )
SELECT SELECT
"supplier"."s_suppkey" AS "s_suppkey", "supplier"."s_suppkey" AS "s_suppkey",
@ -889,12 +894,9 @@ SELECT
"revenue"."total_revenue" AS "total_revenue" "revenue"."total_revenue" AS "total_revenue"
FROM "supplier" AS "supplier" FROM "supplier" AS "supplier"
JOIN "revenue" JOIN "revenue"
ON "revenue"."total_revenue" = ( ON "supplier"."s_suppkey" = "revenue"."supplier_no"
SELECT JOIN "_u_0" AS "_u_0"
MAX("revenue"."total_revenue") AS "_col_0" ON "revenue"."total_revenue" = "_u_0"."_col_0"
FROM "revenue"
)
AND "supplier"."s_suppkey" = "revenue"."supplier_no"
ORDER BY ORDER BY
"s_suppkey"; "s_suppkey";
@ -1395,7 +1397,14 @@ order by
cntrycode; cntrycode;
WITH "_u_0" AS ( WITH "_u_0" AS (
SELECT SELECT
"orders"."o_custkey" AS "_u_1" AVG("customer"."c_acctbal") AS "_col_0"
FROM "customer" AS "customer"
WHERE
"customer"."c_acctbal" > 0.00
AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
), "_u_1" AS (
SELECT
"orders"."o_custkey" AS "_u_2"
FROM "orders" AS "orders" FROM "orders" AS "orders"
GROUP BY GROUP BY
"orders"."o_custkey" "orders"."o_custkey"
@ -1405,18 +1414,12 @@ SELECT
COUNT(*) AS "numcust", COUNT(*) AS "numcust",
SUM("customer"."c_acctbal") AS "totacctbal" SUM("customer"."c_acctbal") AS "totacctbal"
FROM "customer" AS "customer" FROM "customer" AS "customer"
LEFT JOIN "_u_0" AS "_u_0" JOIN "_u_0" AS "_u_0"
ON "_u_0"."_u_1" = "customer"."c_custkey" ON "customer"."c_acctbal" > "_u_0"."_col_0"
LEFT JOIN "_u_1" AS "_u_1"
ON "_u_1"."_u_2" = "customer"."c_custkey"
WHERE WHERE
"_u_0"."_u_1" IS NULL "_u_1"."_u_2" IS NULL
AND "customer"."c_acctbal" > (
SELECT
AVG("customer"."c_acctbal") AS "_col_0"
FROM "customer" AS "customer"
WHERE
"customer"."c_acctbal" > 0.00
AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
)
AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17') AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
GROUP BY GROUP BY
SUBSTRING("customer"."c_phone", 1, 2) SUBSTRING("customer"."c_phone", 1, 2)

View file

@ -1,10 +1,12 @@
--SELECT x.a > (SELECT SUM(y.a) AS b FROM y) FROM x;
-------------------------------------- --------------------------------------
-- Unnest Subqueries -- Unnest Subqueries
-------------------------------------- --------------------------------------
SELECT * SELECT *
FROM x AS x FROM x AS x
WHERE WHERE
x.a IN (SELECT y.a AS a FROM y) x.a = (SELECT SUM(y.a) AS a FROM y)
AND x.a IN (SELECT y.a AS a FROM y)
AND x.a IN (SELECT y.b AS b FROM y) AND x.a IN (SELECT y.b AS b FROM y)
AND x.a = ANY (SELECT y.a AS a FROM y) AND x.a = ANY (SELECT y.a AS a FROM y)
AND x.a = (SELECT SUM(y.b) AS b FROM y WHERE x.a = y.a) AND x.a = (SELECT SUM(y.b) AS b FROM y WHERE x.a = y.a)
@ -24,62 +26,57 @@ WHERE
SELECT SELECT
* *
FROM x AS x FROM x AS x
CROSS JOIN (
SELECT
SUM(y.a) AS a
FROM y
) AS "_u_0"
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
y.a AS a y.a AS a
FROM y FROM y
GROUP BY GROUP BY
y.a y.a
) AS "_u_0" ) AS "_u_1"
ON x.a = "_u_0"."a" ON x.a = "_u_1"."a"
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
y.b AS b y.b AS b
FROM y FROM y
GROUP BY GROUP BY
y.b y.b
) AS "_u_1" ) AS "_u_2"
ON x.a = "_u_1"."b" ON x.a = "_u_2"."b"
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
y.a AS a y.a AS a
FROM y FROM y
GROUP BY GROUP BY
y.a y.a
) AS "_u_2"
ON x.a = "_u_2"."a"
LEFT JOIN (
SELECT
SUM(y.b) AS b,
y.a AS _u_4
FROM y
WHERE
TRUE
GROUP BY
y.a
) AS "_u_3" ) AS "_u_3"
ON x.a = "_u_3"."_u_4" ON x.a = "_u_3"."a"
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
SUM(y.b) AS b, SUM(y.b) AS b,
y.a AS _u_6 y.a AS _u_5
FROM y FROM y
WHERE WHERE
TRUE TRUE
GROUP BY GROUP BY
y.a y.a
) AS "_u_5" ) AS "_u_4"
ON x.a = "_u_5"."_u_6" ON x.a = "_u_4"."_u_5"
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
y.a AS a SUM(y.b) AS b,
y.a AS _u_7
FROM y FROM y
WHERE WHERE
TRUE TRUE
GROUP BY GROUP BY
y.a y.a
) AS "_u_7" ) AS "_u_6"
ON "_u_7".a = x.a ON x.a = "_u_6"."_u_7"
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
y.a AS a y.a AS a
@ -90,29 +87,39 @@ LEFT JOIN (
y.a y.a
) AS "_u_8" ) AS "_u_8"
ON "_u_8".a = x.a ON "_u_8".a = x.a
LEFT JOIN (
SELECT
y.a AS a
FROM y
WHERE
TRUE
GROUP BY
y.a
) AS "_u_9"
ON "_u_9".a = x.a
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
ARRAY_AGG(y.a) AS a, ARRAY_AGG(y.a) AS a,
y.b AS _u_10 y.b AS _u_11
FROM y FROM y
WHERE WHERE
TRUE TRUE
GROUP BY GROUP BY
y.b y.b
) AS "_u_9" ) AS "_u_10"
ON "_u_9"."_u_10" = x.a ON "_u_10"."_u_11" = x.a
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
SUM(y.a) AS a, SUM(y.a) AS a,
y.a AS _u_12, y.a AS _u_13,
ARRAY_AGG(y.b) AS _u_13 ARRAY_AGG(y.b) AS _u_14
FROM y FROM y
WHERE WHERE
TRUE AND TRUE AND TRUE TRUE AND TRUE AND TRUE
GROUP BY GROUP BY
y.a y.a
) AS "_u_11" ) AS "_u_12"
ON "_u_11"."_u_12" = x.a AND "_u_11"."_u_12" = x.b ON "_u_12"."_u_13" = x.a AND "_u_12"."_u_13" = x.b
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
y.a AS a y.a AS a
@ -121,37 +128,38 @@ LEFT JOIN (
TRUE TRUE
GROUP BY GROUP BY
y.a y.a
) AS "_u_14" ) AS "_u_15"
ON x.a = "_u_14".a ON x.a = "_u_15".a
WHERE WHERE
NOT "_u_0"."a" IS NULL x.a = "_u_0".a
AND NOT "_u_1"."b" IS NULL AND NOT "_u_1"."a" IS NULL
AND NOT "_u_2"."a" IS NULL AND NOT "_u_2"."b" IS NULL
AND NOT "_u_3"."a" IS NULL
AND ( AND (
x.a = "_u_3".b AND NOT "_u_3"."_u_4" IS NULL x.a = "_u_4".b AND NOT "_u_4"."_u_5" IS NULL
) )
AND ( AND (
x.a > "_u_5".b AND NOT "_u_5"."_u_6" IS NULL x.a > "_u_6".b AND NOT "_u_6"."_u_7" IS NULL
) )
AND ( AND (
None = "_u_7".a AND NOT "_u_7".a IS NULL None = "_u_8".a AND NOT "_u_8".a IS NULL
) )
AND NOT ( AND NOT (
x.a = "_u_8".a AND NOT "_u_8".a IS NULL x.a = "_u_9".a AND NOT "_u_9".a IS NULL
) )
AND ( AND (
ARRAY_ANY("_u_9".a, _x -> _x = x.a) AND NOT "_u_9"."_u_10" IS NULL ARRAY_ANY("_u_10".a, _x -> _x = x.a) AND NOT "_u_10"."_u_11" IS NULL
) )
AND ( AND (
( (
( (
x.a < "_u_11".a AND NOT "_u_11"."_u_12" IS NULL x.a < "_u_12".a AND NOT "_u_12"."_u_13" IS NULL
) AND NOT "_u_11"."_u_12" IS NULL ) AND NOT "_u_12"."_u_13" IS NULL
) )
AND ARRAY_ANY("_u_11"."_u_13", "_x" -> "_x" <> x.d) AND ARRAY_ANY("_u_12"."_u_14", "_x" -> "_x" <> x.d)
) )
AND ( AND (
NOT "_u_14".a IS NULL AND NOT "_u_14".a IS NULL NOT "_u_15".a IS NULL AND NOT "_u_15".a IS NULL
) )
AND x.a IN ( AND x.a IN (
SELECT SELECT

View file

@ -68,13 +68,13 @@ class TestExecutor(unittest.TestCase):
def test_execute_tpch(self): def test_execute_tpch(self):
def to_csv(expression): def to_csv(expression):
if isinstance(expression, exp.Table): if isinstance(expression, exp.Table) and expression.name not in ("revenue"):
return parse_one( return parse_one(
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}" f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}"
) )
return expression return expression
for i, (sql, _) in enumerate(self.sqls[0:7]): for i, (sql, _) in enumerate(self.sqls[0:16]):
with self.subTest(f"tpch-h {i + 1}"): with self.subTest(f"tpch-h {i + 1}"):
a = self.cached_execute(sql) a = self.cached_execute(sql)
sql = parse_one(sql).transform(to_csv).sql(pretty=True) sql = parse_one(sql).transform(to_csv).sql(pretty=True)
@ -165,6 +165,39 @@ class TestExecutor(unittest.TestCase):
["a"], ["a"],
[("a",)], [("a",)],
), ),
(
"SELECT DISTINCT a FROM (SELECT 1 AS a UNION ALL SELECT 1 AS a)",
["a"],
[(1,)],
),
(
"SELECT DISTINCT a, SUM(b) AS b "
"FROM (SELECT 'a' AS a, 1 AS b UNION ALL SELECT 'a' AS a, 2 AS b UNION ALL SELECT 'b' AS a, 1 AS b) "
"GROUP BY a "
"LIMIT 1",
["a", "b"],
[("a", 3)],
),
(
"SELECT COUNT(1) AS a FROM (SELECT 1)",
["a"],
[(1,)],
),
(
"SELECT COUNT(1) AS a FROM (SELECT 1) LIMIT 0",
["a"],
[],
),
(
"SELECT a FROM x GROUP BY a LIMIT 0",
["a"],
[],
),
(
"SELECT a FROM x LIMIT 0",
["a"],
[],
),
]: ]:
with self.subTest(sql): with self.subTest(sql):
result = execute(sql, schema=schema, tables=tables) result = execute(sql, schema=schema, tables=tables)
@ -346,6 +379,28 @@ class TestExecutor(unittest.TestCase):
], ],
) )
def test_execute_subqueries(self):
tables = {
"table": [
{"a": 1, "b": 1},
{"a": 2, "b": 2},
],
}
self.assertEqual(
execute(
"""
SELECT *
FROM table
WHERE a = (SELECT MAX(a) FROM table)
""",
tables=tables,
).rows,
[
(2, 2),
],
)
def test_table_depth_mismatch(self): def test_table_depth_mismatch(self):
tables = {"table": []} tables = {"table": []}
schema = {"db": {"table": {"col": "VARCHAR"}}} schema = {"db": {"table": {"col": "VARCHAR"}}}
@ -401,6 +456,7 @@ class TestExecutor(unittest.TestCase):
("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]), ("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]), ("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]), ("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]),
("SELECT SUM(x) FROM (SELECT 1 AS x WHERE FALSE)", ["_col_0"], [(0,)]),
]: ]:
result = execute(sql) result = execute(sql)
self.assertEqual(result.columns, tuple(cols)) self.assertEqual(result.columns, tuple(cols))
@ -462,7 +518,18 @@ class TestExecutor(unittest.TestCase):
("IF(false, 1, 0)", 0), ("IF(false, 1, 0)", 0),
("CASE WHEN 0 = 1 THEN 'foo' ELSE 'bar' END", "bar"), ("CASE WHEN 0 = 1 THEN 'foo' ELSE 'bar' END", "bar"),
("CAST('2022-01-01' AS DATE) + INTERVAL '1' DAY", date(2022, 1, 2)), ("CAST('2022-01-01' AS DATE) + INTERVAL '1' DAY", date(2022, 1, 2)),
("1 IN (1, 2, 3)", True),
("1 IN (2, 3)", False),
("NULL IS NULL", True),
("NULL IS NOT NULL", False),
("NULL = NULL", None),
("NULL <> NULL", None),
]: ]:
with self.subTest(sql): with self.subTest(sql):
result = execute(f"SELECT {sql}") result = execute(f"SELECT {sql}")
self.assertEqual(result.rows, [(expected,)]) self.assertEqual(result.rows, [(expected,)])
def test_case_sensitivity(self):
result = execute("SELECT A AS A FROM X", tables={"x": [{"a": 1}]})
self.assertEqual(result.columns, ("A",))
self.assertEqual(result.rows, [(1,)])

View file

@ -525,24 +525,14 @@ class TestExpressions(unittest.TestCase):
), ),
exp.Properties( exp.Properties(
expressions=[ expressions=[
exp.FileFormatProperty( exp.FileFormatProperty(this=exp.Literal.string("parquet")),
this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet")
),
exp.PartitionedByProperty( exp.PartitionedByProperty(
this=exp.Literal.string("PARTITIONED_BY"), this=exp.Tuple(expressions=[exp.to_identifier("a"), exp.to_identifier("b")])
value=exp.Tuple(
expressions=[exp.to_identifier("a"), exp.to_identifier("b")]
), ),
), exp.Property(this=exp.Literal.string("custom"), value=exp.Literal.number(1)),
exp.AnonymousProperty( exp.TableFormatProperty(this=exp.to_identifier("test_format")),
this=exp.Literal.string("custom"), value=exp.Literal.number(1) exp.EngineProperty(this=exp.null()),
), exp.CollateProperty(this=exp.true()),
exp.TableFormatProperty(
this=exp.Literal.string("TABLE_FORMAT"),
value=exp.to_identifier("test_format"),
),
exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.null()),
exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.true()),
] ]
), ),
) )
@ -609,9 +599,9 @@ FROM foo""",
"""SELECT """SELECT
a, a,
b AS B, b AS B,
c, -- comment c, /* comment */
d AS D, -- another comment d AS D, /* another comment */
CAST(x AS INT) -- final comment CAST(x AS INT) /* final comment */
FROM foo""", FROM foo""",
) )

View file

@ -85,9 +85,8 @@ class TestOptimizer(unittest.TestCase):
if leave_tables_isolated is not None: if leave_tables_isolated is not None:
func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated) func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated)
optimized = func(parse_one(sql, read=dialect), **func_kwargs)
with self.subTest(title): with self.subTest(title):
optimized = func(parse_one(sql, read=dialect), **func_kwargs)
self.assertEqual( self.assertEqual(
expected, expected,
optimized.sql(pretty=pretty, dialect=dialect), optimized.sql(pretty=pretty, dialect=dialect),
@ -168,6 +167,9 @@ class TestOptimizer(unittest.TestCase):
def test_quote_identities(self): def test_quote_identities(self):
self.check_file("quote_identities", optimizer.quote_identities.quote_identities) self.check_file("quote_identities", optimizer.quote_identities.quote_identities)
def test_lower_identities(self):
self.check_file("lower_identities", optimizer.lower_identities.lower_identities)
def test_pushdown_projection(self): def test_pushdown_projection(self):
def pushdown_projections(expression, **kwargs): def pushdown_projections(expression, **kwargs):
expression = optimizer.qualify_tables.qualify_tables(expression) expression = optimizer.qualify_tables.qualify_tables(expression)

View file

@ -15,6 +15,51 @@ class TestParser(unittest.TestCase):
self.assertIsInstance(parse_one("int", into=exp.DataType), exp.DataType) self.assertIsInstance(parse_one("int", into=exp.DataType), exp.DataType)
self.assertIsInstance(parse_one("array<int>", into=exp.DataType), exp.DataType) self.assertIsInstance(parse_one("array<int>", into=exp.DataType), exp.DataType)
def test_parse_into_error(self):
expected_message = "Failed to parse into [<class 'sqlglot.expressions.From'>]"
expected_errors = [
{
"description": "Invalid expression / Unexpected token",
"line": 1,
"col": 1,
"start_context": "",
"highlight": "SELECT",
"end_context": " 1;",
"into_expression": exp.From,
}
]
with self.assertRaises(ParseError) as ctx:
parse_one("SELECT 1;", "sqlite", [exp.From])
self.assertEqual(str(ctx.exception), expected_message)
self.assertEqual(ctx.exception.errors, expected_errors)
def test_parse_into_errors(self):
expected_message = "Failed to parse into [<class 'sqlglot.expressions.From'>, <class 'sqlglot.expressions.Join'>]"
expected_errors = [
{
"description": "Invalid expression / Unexpected token",
"line": 1,
"col": 1,
"start_context": "",
"highlight": "SELECT",
"end_context": " 1;",
"into_expression": exp.From,
},
{
"description": "Invalid expression / Unexpected token",
"line": 1,
"col": 1,
"start_context": "",
"highlight": "SELECT",
"end_context": " 1;",
"into_expression": exp.Join,
},
]
with self.assertRaises(ParseError) as ctx:
parse_one("SELECT 1;", "sqlite", [exp.From, exp.Join])
self.assertEqual(str(ctx.exception), expected_message)
self.assertEqual(ctx.exception.errors, expected_errors)
def test_column(self): def test_column(self):
columns = parse_one("select a, ARRAY[1] b, case when 1 then 1 end").find_all(exp.Column) columns = parse_one("select a, ARRAY[1] b, case when 1 then 1 end").find_all(exp.Column)
assert len(list(columns)) == 1 assert len(list(columns)) == 1
@ -24,6 +69,9 @@ class TestParser(unittest.TestCase):
def test_float(self): def test_float(self):
self.assertEqual(parse_one(".2"), parse_one("0.2")) self.assertEqual(parse_one(".2"), parse_one("0.2"))
def test_unary_plus(self):
self.assertEqual(parse_one("+15"), exp.Literal.number(15))
def test_table(self): def test_table(self):
tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)] tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)]
self.assertEqual(tables, ["a", "b.c", "d"]) self.assertEqual(tables, ["a", "b.c", "d"])
@ -157,8 +205,9 @@ class TestParser(unittest.TestCase):
def test_comments(self): def test_comments(self):
expression = parse_one( expression = parse_one(
""" """
--comment1 --comment1.1
SELECT /* this won't be used */ --comment1.2
SELECT /*comment1.3*/
a, --comment2 a, --comment2
b as B, --comment3:testing b as B, --comment3:testing
"test--annotation", "test--annotation",
@ -169,13 +218,13 @@ class TestParser(unittest.TestCase):
""" """
) )
self.assertEqual(expression.comment, "comment1") self.assertEqual(expression.comments, ["comment1.1", "comment1.2", "comment1.3"])
self.assertEqual(expression.expressions[0].comment, "comment2") self.assertEqual(expression.expressions[0].comments, ["comment2"])
self.assertEqual(expression.expressions[1].comment, "comment3:testing") self.assertEqual(expression.expressions[1].comments, ["comment3:testing"])
self.assertEqual(expression.expressions[2].comment, None) self.assertEqual(expression.expressions[2].comments, None)
self.assertEqual(expression.expressions[3].comment, "comment4 --foo") self.assertEqual(expression.expressions[3].comments, ["comment4 --foo"])
self.assertEqual(expression.expressions[4].comment, "") self.assertEqual(expression.expressions[4].comments, [""])
self.assertEqual(expression.expressions[5].comment, " space") self.assertEqual(expression.expressions[5].comments, [" space"])
def test_type_literals(self): def test_type_literals(self):
self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)")) self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)"))

View file

@ -7,13 +7,13 @@ class TestTokens(unittest.TestCase):
def test_comment_attachment(self): def test_comment_attachment(self):
tokenizer = Tokenizer() tokenizer = Tokenizer()
sql_comment = [ sql_comment = [
("/*comment*/ foo", "comment"), ("/*comment*/ foo", ["comment"]),
("/*comment*/ foo --test", "comment"), ("/*comment*/ foo --test", ["comment", "test"]),
("--comment\nfoo --test", "comment"), ("--comment\nfoo --test", ["comment", "test"]),
("foo --comment", "comment"), ("foo --comment", ["comment"]),
("foo", None), ("foo", []),
("foo /*comment 1*/ /*comment 2*/", "comment 1"), ("foo /*comment 1*/ /*comment 2*/", ["comment 1", "comment 2"]),
] ]
for sql, comment in sql_comment: for sql, comment in sql_comment:
self.assertEqual(tokenizer.tokenize(sql)[0].comment, comment) self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment)

View file

@ -1,7 +1,7 @@
import unittest import unittest
from sqlglot import parse_one from sqlglot import parse_one
from sqlglot.transforms import unalias_group from sqlglot.transforms import eliminate_distinct_on, unalias_group
class TestTime(unittest.TestCase): class TestTime(unittest.TestCase):
@ -35,3 +35,30 @@ class TestTime(unittest.TestCase):
"SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date", "SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date",
"SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date", "SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date",
) )
def test_eliminate_distinct_on(self):
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1',
)
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (a) a, b FROM x",
'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a) AS "_row_number" FROM x) WHERE "_row_number" = 1',
)
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (a, b) a, b FROM x ORDER BY c DESC",
'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1',
)
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT a, b FROM x ORDER BY c DESC",
"SELECT DISTINCT a, b FROM x ORDER BY c DESC",
)
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC",
'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS "_row_number_2" FROM x) WHERE "_row_number_2" = 1',
)

View file

@ -26,6 +26,7 @@ class TestTranspile(unittest.TestCase):
) )
self.assertEqual(transpile("SELECT 1 current_date")[0], "SELECT 1 AS current_date") self.assertEqual(transpile("SELECT 1 current_date")[0], "SELECT 1 AS current_date")
self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime") self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime")
self.assertEqual(transpile("SELECT 1 row")[0], "SELECT 1 AS row")
for key in ("union", "filter", "over", "from", "join"): for key in ("union", "filter", "over", "from", "join"):
with self.subTest(f"alias {key}"): with self.subTest(f"alias {key}"):
@ -38,6 +39,11 @@ class TestTranspile(unittest.TestCase):
def test_asc(self): def test_asc(self):
self.validate("SELECT x FROM y ORDER BY x ASC", "SELECT x FROM y ORDER BY x") self.validate("SELECT x FROM y ORDER BY x ASC", "SELECT x FROM y ORDER BY x")
def test_unary(self):
self.validate("+++1", "1")
self.validate("+-1", "-1")
self.validate("+- - -1", "- - -1")
def test_paren(self): def test_paren(self):
with self.assertRaises(ParseError): with self.assertRaises(ParseError):
transpile("1 + (2 + 3") transpile("1 + (2 + 3")
@ -58,7 +64,7 @@ class TestTranspile(unittest.TestCase):
) )
self.validate( self.validate(
"SELECT FOO, /*x*/\nBAR, /*y*/\nBAZ", "SELECT FOO, /*x*/\nBAR, /*y*/\nBAZ",
"SELECT\n FOO -- x\n , BAR -- y\n , BAZ", "SELECT\n FOO /* x */\n , BAR /* y */\n , BAZ",
leading_comma=True, leading_comma=True,
pretty=True, pretty=True,
) )
@ -78,7 +84,8 @@ class TestTranspile(unittest.TestCase):
def test_comments(self): def test_comments(self):
self.validate("SELECT */*comment*/", "SELECT * /* comment */") self.validate("SELECT */*comment*/", "SELECT * /* comment */")
self.validate( self.validate(
"SELECT * FROM table /*comment 1*/ /*comment 2*/", "SELECT * FROM table /* comment 1 */" "SELECT * FROM table /*comment 1*/ /*comment 2*/",
"SELECT * FROM table /* comment 1 */ /* comment 2 */",
) )
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */") self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */")
self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo") self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo")
@ -112,6 +119,53 @@ class TestTranspile(unittest.TestCase):
) )
self.validate( self.validate(
""" """
-- comment 1
-- comment 2
-- comment 3
SELECT * FROM foo
""",
"/* comment 1 */ /* comment 2 */ /* comment 3 */ SELECT * FROM foo",
)
self.validate(
"""
-- comment 1
-- comment 2
-- comment 3
SELECT * FROM foo""",
"""/* comment 1 */
/* comment 2 */
/* comment 3 */
SELECT
*
FROM foo""",
pretty=True,
)
self.validate(
"""
SELECT * FROM tbl /*line1
line2
line3*/ /*another comment*/ where 1=1 -- comment at the end""",
"""SELECT * FROM tbl /* line1
line2
line3 */ /* another comment */ WHERE 1 = 1 /* comment at the end */""",
)
self.validate(
"""
SELECT * FROM tbl /*line1
line2
line3*/ /*another comment*/ where 1=1 -- comment at the end""",
"""SELECT
*
FROM tbl /* line1
line2
line3 */
/* another comment */
WHERE
1 = 1 /* comment at the end */""",
pretty=True,
)
self.validate(
"""
/* multi /* multi
line line
comment comment
@ -130,8 +184,8 @@ class TestTranspile(unittest.TestCase):
*/ */
SELECT SELECT
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */, tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
CAST(x AS INT), -- comment 3 CAST(x AS INT), /* comment 3 */
y -- comment 4 y /* comment 4 */
FROM bar /* comment 5 */, tbl /* comment 6 */""", FROM bar /* comment 5 */, tbl /* comment 6 */""",
read="mysql", read="mysql",
pretty=True, pretty=True,
@ -364,33 +418,79 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""",
@mock.patch("sqlglot.parser.logger") @mock.patch("sqlglot.parser.logger")
def test_error_level(self, logger): def test_error_level(self, logger):
invalid = "x + 1. (" invalid = "x + 1. ("
errors = [ expected_messages = [
"Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Aliases'>. Line 1, Col: 8.\n x + 1. \033[4m(\033[0m", "Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Aliases'>. Line 1, Col: 8.\n x + 1. \033[4m(\033[0m",
"Expecting ). Line 1, Col: 8.\n x + 1. \033[4m(\033[0m", "Expecting ). Line 1, Col: 8.\n x + 1. \033[4m(\033[0m",
] ]
expected_errors = [
{
"description": "Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Aliases'>",
"line": 1,
"col": 8,
"start_context": "x + 1. ",
"highlight": "(",
"end_context": "",
"into_expression": None,
},
{
"description": "Expecting )",
"line": 1,
"col": 8,
"start_context": "x + 1. ",
"highlight": "(",
"end_context": "",
"into_expression": None,
},
]
transpile(invalid, error_level=ErrorLevel.WARN) transpile(invalid, error_level=ErrorLevel.WARN)
for error in errors: for error in expected_messages:
assert_logger_contains(error, logger) assert_logger_contains(error, logger)
with self.assertRaises(ParseError) as ctx: with self.assertRaises(ParseError) as ctx:
transpile(invalid, error_level=ErrorLevel.IMMEDIATE) transpile(invalid, error_level=ErrorLevel.IMMEDIATE)
self.assertEqual(str(ctx.exception), errors[0]) self.assertEqual(str(ctx.exception), expected_messages[0])
self.assertEqual(ctx.exception.errors[0], expected_errors[0])
with self.assertRaises(ParseError) as ctx: with self.assertRaises(ParseError) as ctx:
transpile(invalid, error_level=ErrorLevel.RAISE) transpile(invalid, error_level=ErrorLevel.RAISE)
self.assertEqual(str(ctx.exception), "\n\n".join(errors)) self.assertEqual(str(ctx.exception), "\n\n".join(expected_messages))
self.assertEqual(ctx.exception.errors, expected_errors)
more_than_max_errors = "((((" more_than_max_errors = "(((("
expected = ( expected_messages = (
"Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" "Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
"Required keyword: 'this' missing for <class 'sqlglot.expressions.Paren'>. Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" "Required keyword: 'this' missing for <class 'sqlglot.expressions.Paren'>. Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
"Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" "Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
"... and 2 more" "... and 2 more"
) )
expected_errors = [
{
"description": "Expecting )",
"line": 1,
"col": 4,
"start_context": "(((",
"highlight": "(",
"end_context": "",
"into_expression": None,
},
{
"description": "Required keyword: 'this' missing for <class 'sqlglot.expressions.Paren'>",
"line": 1,
"col": 4,
"start_context": "(((",
"highlight": "(",
"end_context": "",
"into_expression": None,
},
]
# Also expect three trailing structured errors that match the first
expected_errors += [expected_errors[0]] * 3
with self.assertRaises(ParseError) as ctx: with self.assertRaises(ParseError) as ctx:
transpile(more_than_max_errors, error_level=ErrorLevel.RAISE) transpile(more_than_max_errors, error_level=ErrorLevel.RAISE)
self.assertEqual(str(ctx.exception), expected) self.assertEqual(str(ctx.exception), expected_messages)
self.assertEqual(ctx.exception.errors, expected_errors)
@mock.patch("sqlglot.generator.logger") @mock.patch("sqlglot.generator.logger")
def test_unsupported_level(self, logger): def test_unsupported_level(self, logger):