1
0
Fork 0

Adding upstream version 10.1.3.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:55:11 +01:00
parent 87cdb8246e
commit b7601057ad
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
57 changed files with 1542 additions and 529 deletions

View file

@ -1,6 +1,47 @@
Changelog
=========
v10.1.0
------
Changes:
- Breaking: [refactored](https://github.com/tobymao/sqlglot/commit/6b0da1e1a2b5d6bdf7b5b918400456422d30a1d4) the way SQL comments are handled. Before at most one comment could be attached to an expression, now multiple comments may be stored in a list.
- Breaking: [refactored](https://github.com/tobymao/sqlglot/commit/be332d10404f36b43ea6ece956a73bf451348641) the way properties are represented and parsed. The argument `this` now stores a property's attributes instead of its name.
- New: added structured ParseError properties.
- New: the executor now handles set operations.
- New: sqlglot can [now execute SQL queries](https://github.com/tobymao/sqlglot/commit/62d3496e761a4f38dfa61af793062690923dce74) using python objects.
- New: added support for the [Drill dialect](https://github.com/tobymao/sqlglot/commit/543eca314546e0bd42f97c354807b4e398ab36ec).
- New: added a `canonicalize` method which leverages existing type information for an expression to apply various transformations to it.
- New: TRIM function support for Snowflake and Bigquery.
- New: added support for SQLite primary key ordering constraints (ASC, DESC).
- New: added support for Redshift DISTKEY / SORTKEY / DISTSTYLE properties.
- New: added support for SET TRANSACTION MySQL statements.
- New: added `null`, `true`, `false` helper methods to avoid using singleton expressions.
- Improvement: allow multiple aggregations in an expression.
- Improvement: execution of left / right joins.
- Improvement: execution of aggregations without the GROUP BY clause.
- Improvement: static query execution (e.g. SELECT 1, SELECT CONCAT('a', 'b') AS x, etc).
- Improvement: include a rule for type inference in the optimizer.
- Improvement: transaction, commit expressions parsed [at finer granularity](https://github.com/tobymao/sqlglot/commit/148282e710fd79512bb7d32e6e519d631df8115d).
v10.0.0
------

View file

@ -25,6 +25,7 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/
* [AST Introspection](#ast-introspection)
* [AST Diff](#ast-diff)
* [Custom Dialects](#custom-dialects)
* [SQL Execution](#sql-execution)
* [Benchmarks](#benchmarks)
* [Optional Dependencies](#optional-dependencies)
@ -147,9 +148,9 @@ print(sqlglot.transpile(sql, read='mysql', pretty=True)[0])
*/
SELECT
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
CAST(x AS INT), -- comment 3
y -- comment 4
FROM bar /* comment 5 */, tbl /* comment 6*/
CAST(x AS INT), /* comment 3 */
y /* comment 4 */
FROM bar /* comment 5 */, tbl /* comment 6 */
```
@ -189,6 +190,28 @@ sqlglot.errors.ParseError: Expecting ). Line 1, Col: 13.
~~~~
```
Structured syntax errors are accessible for programmatic use:
```python
import sqlglot
try:
sqlglot.transpile("SELECT foo( FROM bar")
except sqlglot.errors.ParseError as e:
print(e.errors)
```
Output:
```python
[{
'description': 'Expecting )',
'line': 1,
'col': 13,
'start_context': 'SELECT foo( ',
'highlight': 'FROM',
'end_context': ' bar'
}]
```
### Unsupported Errors
Presto `APPROX_DISTINCT` supports the accuracy argument which is not supported in Hive:
@ -372,6 +395,53 @@ print(Dialect["custom"])
<class '__main__.Custom'>
```
### SQL Execution
One can even interpret SQL queries using SQLGlot, where the tables are represented as Python dictionaries. Although the engine is not very fast (it's not supposed to be) and is in a relatively early stage of development, it can be useful for unit testing and running SQL natively across Python objects. Additionally, the foundation can be easily integrated with fast compute kernels (arrow, pandas). Below is an example showcasing the execution of a SELECT expression that involves aggregations and JOINs:
```python
from sqlglot.executor import execute
tables = {
"sushi": [
{"id": 1, "price": 1.0},
{"id": 2, "price": 2.0},
{"id": 3, "price": 3.0},
],
"order_items": [
{"sushi_id": 1, "order_id": 1},
{"sushi_id": 1, "order_id": 1},
{"sushi_id": 2, "order_id": 1},
{"sushi_id": 3, "order_id": 2},
],
"orders": [
{"id": 1, "user_id": 1},
{"id": 2, "user_id": 2},
],
}
execute(
"""
SELECT
o.user_id,
SUM(s.price) AS price
FROM orders o
JOIN order_items i
ON o.id = i.order_id
JOIN sushi s
ON i.sushi_id = s.id
GROUP BY o.user_id
""",
tables=tables
)
```
```python
user_id price
1 4.0
2 3.0
```
## Benchmarks
[Benchmarks](benchmarks) run on Python 3.10.5 in seconds.

View file

@ -30,7 +30,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "10.0.8"
__version__ = "10.1.3"
pretty = False

View file

@ -56,12 +56,12 @@ def _derived_table_values_to_unnest(self, expression):
def _returnsproperty_sql(self, expression):
value = expression.args.get("value")
if isinstance(value, exp.Schema):
value = f"{value.this} <{self.expressions(value)}>"
this = expression.this
if isinstance(this, exp.Schema):
this = f"{this.this} <{self.expressions(this)}>"
else:
value = self.sql(value)
return f"RETURNS {value}"
this = self.sql(this)
return f"RETURNS {this}"
def _create_sql(self, expression):
@ -142,6 +142,11 @@ class BigQuery(Dialect):
),
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
}
FUNCTION_PARSERS.pop("TRIM")
NO_PAREN_FUNCTIONS = {
**parser.Parser.NO_PAREN_FUNCTIONS,
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
@ -174,6 +179,7 @@ class BigQuery(Dialect):
exp.Values: _derived_table_values_to_unnest,
exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql,
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
@ -200,9 +206,7 @@ class BigQuery(Dialect):
exp.VolatilityProperty,
}
WITH_PROPERTIES = {
exp.AnonymousProperty,
}
WITH_PROPERTIES = {exp.Property}
EXPLICIT_UNION = True

View file

@ -21,14 +21,15 @@ class ClickHouse(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"FINAL": TokenType.FINAL,
"ASOF": TokenType.ASOF,
"DATETIME64": TokenType.DATETIME,
"INT8": TokenType.TINYINT,
"FINAL": TokenType.FINAL,
"FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE,
"INT16": TokenType.SMALLINT,
"INT32": TokenType.INT,
"INT64": TokenType.BIGINT,
"FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE,
"INT8": TokenType.TINYINT,
"TUPLE": TokenType.STRUCT,
}
@ -38,6 +39,10 @@ class ClickHouse(Dialect):
"MAP": parse_var_map,
}
JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF}
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY}
def _parse_table(self, schema=False):
this = super()._parse_table(schema)

View file

@ -289,19 +289,19 @@ def struct_extract_sql(self, expression):
return f"{this}.{struct_key}"
def var_map_sql(self, expression):
def var_map_sql(self, expression, map_func_name="MAP"):
keys = expression.args["keys"]
values = expression.args["values"]
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
self.unsupported("Cannot convert array columns into map.")
return f"MAP({self.format_args(keys, values)})"
return f"{map_func_name}({self.format_args(keys, values)})"
args = []
for key, value in zip(keys.expressions, values.expressions):
args.append(self.sql(key))
args.append(self.sql(value))
return f"MAP({self.format_args(*args)})"
return f"{map_func_name}({self.format_args(*args)})"
def format_time_lambda(exp_class, dialect, default=None):
@ -336,18 +336,13 @@ def create_with_partitions_sql(self, expression):
if has_schema and is_partitionable:
expression = expression.copy()
prop = expression.find(exp.PartitionedByProperty)
value = prop and prop.args.get("value")
if prop and not isinstance(value, exp.Schema):
this = prop and prop.this
if prop and not isinstance(this, exp.Schema):
schema = expression.this
columns = {v.name.upper() for v in value.expressions}
columns = {v.name.upper() for v in this.expressions}
partitions = [col for col in schema.expressions if col.name.upper() in columns]
schema.set(
"expressions",
[e for e in schema.expressions if e not in partitions],
)
prop.replace(
exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))
)
schema.set("expressions", [e for e in schema.expressions if e not in partitions])
prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
expression.set("this", schema)
return self.create_sql(expression)

View file

@ -153,7 +153,7 @@ class Drill(Dialect):
exp.If: if_sql,
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Pivot: no_pivot_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
exp.StrPosition: str_position_sql,

View file

@ -61,9 +61,7 @@ def _array_sort(self, expression):
def _property_sql(self, expression):
key = expression.name
value = self.sql(expression, "value")
return f"'{key}'={value}"
return f"'{expression.name}'={self.sql(expression, 'value')}"
def _str_to_unix(self, expression):
@ -250,7 +248,7 @@ class Hive(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP, # type: ignore
exp.AnonymousProperty: _property_sql,
exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayAgg: rename_func("COLLECT_LIST"),
exp.ArrayConcat: rename_func("CONCAT"),
@ -262,7 +260,7 @@ class Hive(Dialect):
exp.DateStrToDate: rename_func("TO_DATE"),
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
exp.FileFormatProperty: lambda self, e: f"STORED AS {e.text('value').upper()}",
exp.FileFormatProperty: lambda self, e: f"STORED AS {e.name.upper()}",
exp.If: if_sql,
exp.Index: _index_sql,
exp.ILike: no_ilike_sql,
@ -285,7 +283,7 @@ class Hive(Dialect):
exp.StrToTime: _str_to_time,
exp.StrToUnix: _str_to_unix,
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'value')}",
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
@ -298,11 +296,11 @@ class Hive(Dialect):
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.format_args(e.this, _time_format(self, e))})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'value')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
}
WITH_PROPERTIES = {exp.AnonymousProperty}
WITH_PROPERTIES = {exp.Property}
ROOT_PROPERTIES = {
exp.PartitionedByProperty,

View file

@ -453,6 +453,7 @@ class MySQL(Dialect):
exp.CharacterSetProperty,
exp.CollateProperty,
exp.SchemaCommentProperty,
exp.LikeProperty,
}
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()

View file

@ -1,7 +1,7 @@
from __future__ import annotations
from sqlglot import exp, generator, tokens, transforms
from sqlglot.dialects.dialect import Dialect, no_ilike_sql
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func
from sqlglot.helper import csv
from sqlglot.tokens import TokenType
@ -37,6 +37,12 @@ class Oracle(Dialect):
"YYYY": "%Y", # 2015
}
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DECODE": exp.Matches.from_arg_list,
}
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@ -58,6 +64,7 @@ class Oracle(Dialect):
**transforms.UNALIAS_GROUP, # type: ignore
exp.ILike: no_ilike_sql,
exp.Limit: _limit_sql,
exp.Matches: rename_func("DECODE"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",

View file

@ -74,6 +74,27 @@ def _trim_sql(self, expression):
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def _string_agg_sql(self, expression):
expression = expression.copy()
separator = expression.args.get("separator") or exp.Literal.string(",")
order = ""
this = expression.this
if isinstance(this, exp.Order):
if this.this:
this = this.this
this.pop()
order = self.sql(expression.this) # Order has a leading space
return f"STRING_AGG({self.format_args(this, separator)}{order})"
def _datatype_sql(self, expression):
if expression.this == exp.DataType.Type.ARRAY:
return f"{self.expressions(expression, flat=True)}[]"
return self.datatype_sql(expression)
def _auto_increment_to_serial(expression):
auto = expression.find(exp.AutoIncrementColumnConstraint)
@ -191,25 +212,27 @@ class Postgres(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ALWAYS": TokenType.ALWAYS,
"BY DEFAULT": TokenType.BY_DEFAULT,
"IDENTITY": TokenType.IDENTITY,
"GENERATED": TokenType.GENERATED,
"DOUBLE PRECISION": TokenType.DOUBLE,
"BIGSERIAL": TokenType.BIGSERIAL,
"SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
"UUID": TokenType.UUID,
"TEMP": TokenType.TEMPORARY,
"BEGIN TRANSACTION": TokenType.BEGIN,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"BIGSERIAL": TokenType.BIGSERIAL,
"BY DEFAULT": TokenType.BY_DEFAULT,
"COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
"DOUBLE PRECISION": TokenType.DOUBLE,
"GENERATED": TokenType.GENERATED,
"GRANT": TokenType.COMMAND,
"HSTORE": TokenType.HSTORE,
"IDENTITY": TokenType.IDENTITY,
"JSONB": TokenType.JSONB,
"REFRESH": TokenType.COMMAND,
"REINDEX": TokenType.COMMAND,
"RESET": TokenType.COMMAND,
"REVOKE": TokenType.COMMAND,
"GRANT": TokenType.COMMAND,
"SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
"TEMP": TokenType.TEMPORARY,
"UUID": TokenType.UUID,
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
}
@ -265,4 +288,7 @@ class Postgres(Dialect):
exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
exp.DataType: _datatype_sql,
exp.GroupConcat: _string_agg_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
}

View file

@ -171,16 +171,7 @@ class Presto(Dialect):
STRUCT_DELIMITER = ("(", ")")
ROOT_PROPERTIES = {
exp.SchemaCommentProperty,
}
WITH_PROPERTIES = {
exp.PartitionedByProperty,
exp.FileFormatProperty,
exp.AnonymousProperty,
exp.TableFormatProperty,
}
ROOT_PROPERTIES = {exp.SchemaCommentProperty}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@ -231,7 +222,8 @@ class Presto(Dialect):
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT = '{e.text('value').upper()}'",
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.TimeStrToDate: _date_parse_sql,
exp.TimeStrToTime: _date_parse_sql,
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",

View file

@ -1,6 +1,6 @@
from __future__ import annotations
from sqlglot import exp
from sqlglot import exp, transforms
from sqlglot.dialects.postgres import Postgres
from sqlglot.tokens import TokenType
@ -18,12 +18,14 @@ class Redshift(Postgres):
KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS, # type: ignore
"COPY": TokenType.COMMAND,
"GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER,
"TIME": TokenType.TIMESTAMP,
"TIMETZ": TokenType.TIMESTAMPTZ,
"UNLOAD": TokenType.COMMAND,
"VARBYTE": TokenType.VARBINARY,
"SIMILAR TO": TokenType.SIMILAR_TO,
}
@ -35,3 +37,17 @@ class Redshift(Postgres):
exp.DataType.Type.VARBINARY: "VARBYTE",
exp.DataType.Type.INT: "INTEGER",
}
ROOT_PROPERTIES = {
exp.DistKeyProperty,
exp.SortKeyProperty,
exp.DistStyleProperty,
}
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
}

View file

@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
format_time_lambda,
inline_array_sql,
rename_func,
var_map_sql,
)
from sqlglot.expressions import Literal
from sqlglot.helper import seq_get
@ -100,6 +101,14 @@ def _parse_date_part(self):
return self.expression(exp.Extract, this=this, expression=expression)
def _datatype_sql(self, expression):
if expression.this == exp.DataType.Type.ARRAY:
return "ARRAY"
elif expression.this == exp.DataType.Type.MAP:
return "OBJECT"
return self.datatype_sql(expression)
class Snowflake(Dialect):
null_ordering = "nulls_are_large"
time_format = "'yyyy-mm-dd hh24:mi:ss'"
@ -142,6 +151,8 @@ class Snowflake(Dialect):
"TO_TIMESTAMP": _snowflake_to_timestamp,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
"DECODE": exp.Matches.from_arg_list,
"OBJECT_CONSTRUCT": parser.parse_var_map,
}
FUNCTION_PARSERS = {
@ -195,16 +206,20 @@ class Snowflake(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.If: rename_func("IFF"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Array: inline_array_sql,
exp.StrPosition: rename_func("POSITION"),
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Matches: rename_func("DECODE"),
exp.StrPosition: rename_func("POSITION"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
exp.UnixToTime: _unix_to_time_sql,
}
TYPE_MAPPING = {

View file

@ -98,7 +98,7 @@ class Spark(Hive):
TRANSFORMS = {
**Hive.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),

View file

@ -13,6 +13,23 @@ from sqlglot.dialects.dialect import (
from sqlglot.tokens import TokenType
# https://www.sqlite.org/lang_aggfunc.html#group_concat
def _group_concat_sql(self, expression):
this = expression.this
distinct = expression.find(exp.Distinct)
if distinct:
this = distinct.expressions[0]
distinct = "DISTINCT "
if isinstance(expression.this, exp.Order):
self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.")
if expression.this.this and not distinct:
this = expression.this.this
separator = expression.args.get("separator")
return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"
class SQLite(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
@ -62,6 +79,7 @@ class SQLite(Dialect):
exp.Levenshtein: rename_func("EDITDIST3"),
exp.TableSample: no_tablesample_sql,
exp.TryCast: no_trycast_sql,
exp.GroupConcat: _group_concat_sql,
}
def transaction_sql(self, expression):

View file

@ -17,6 +17,7 @@ FULL_FORMAT_TIME_MAPPING = {
"mm": "%B",
"m": "%B",
}
DATE_DELTA_INTERVAL = {
"year": "year",
"yyyy": "year",
@ -37,11 +38,12 @@ DATE_DELTA_INTERVAL = {
DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{1,2})")
# N = Numeric, C=Currency
TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
def _format_time_lambda(exp_class, full_format_mapping=None, default=None):
def _format_time(args):
return exp_class(
this=seq_get(args, 1),
@ -58,7 +60,7 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
return _format_time
def parse_format(args):
def _parse_format(args):
fmt = seq_get(args, 1)
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
if number_fmt:
@ -78,7 +80,7 @@ def generate_date_delta_with_unit_sql(self, e):
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
def generate_format_sql(self, e):
def _format_sql(self, e):
fmt = (
e.args["format"]
if isinstance(e, exp.NumberToStr)
@ -87,6 +89,28 @@ def generate_format_sql(self, e):
return f"FORMAT({self.format_args(e.this, fmt)})"
def _string_agg_sql(self, e):
e = e.copy()
this = e.this
distinct = e.find(exp.Distinct)
if distinct:
# exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression
self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.")
this = distinct.expressions[0]
distinct.pop()
order = ""
if isinstance(e.this, exp.Order):
if e.this.this:
this = e.this.this
e.this.this.pop()
order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space
separator = e.args.get("separator") or exp.Literal.string(",")
return f"STRING_AGG({self.format_args(this, separator)}){order}"
class TSQL(Dialect):
null_ordering = "nulls_are_small"
time_format = "'yyyy-mm-dd hh:mm:ss'"
@ -228,14 +252,14 @@ class TSQL(Dialect):
"ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
"DATENAME": tsql_format_time_lambda(exp.TimeToStr, full_format_mapping=True),
"DATEPART": tsql_format_time_lambda(exp.TimeToStr),
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
"DATEPART": _format_time_lambda(exp.TimeToStr),
"GETDATE": exp.CurrentDate.from_arg_list,
"IIF": exp.If.from_arg_list,
"LEN": exp.Length.from_arg_list,
"REPLICATE": exp.Repeat.from_arg_list,
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
"FORMAT": parse_format,
"FORMAT": _parse_format,
}
VAR_LENGTH_DATATYPES = {
@ -298,6 +322,7 @@ class TSQL(Dialect):
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
exp.If: rename_func("IIF"),
exp.NumberToStr: generate_format_sql,
exp.TimeToStr: generate_format_sql,
exp.NumberToStr: _format_sql,
exp.TimeToStr: _format_sql,
exp.GroupConcat: _string_agg_sql,
}

View file

@ -22,7 +22,40 @@ class UnsupportedError(SqlglotError):
class ParseError(SqlglotError):
pass
def __init__(
self,
message: str,
errors: t.Optional[t.List[t.Dict[str, t.Any]]] = None,
):
super().__init__(message)
self.errors = errors or []
@classmethod
def new(
cls,
message: str,
description: t.Optional[str] = None,
line: t.Optional[int] = None,
col: t.Optional[int] = None,
start_context: t.Optional[str] = None,
highlight: t.Optional[str] = None,
end_context: t.Optional[str] = None,
into_expression: t.Optional[str] = None,
) -> ParseError:
return cls(
message,
[
{
"description": description,
"line": line,
"col": col,
"start_context": start_context,
"highlight": highlight,
"end_context": end_context,
"into_expression": into_expression,
}
],
)
class TokenError(SqlglotError):
@ -41,9 +74,13 @@ class ExecuteError(SqlglotError):
pass
def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str:
def concat_messages(errors: t.Sequence[t.Any], maximum: int) -> str:
msg = [str(e) for e in errors[:maximum]]
remaining = len(errors) - maximum
if remaining > 0:
msg.append(f"... and {remaining} more")
return "\n\n".join(msg)
def merge_errors(errors: t.Sequence[ParseError]) -> t.List[t.Dict[str, t.Any]]:
return [e_dict for error in errors for e_dict in error.errors]

View file

@ -122,7 +122,6 @@ def interval(this, unit):
ENV = {
"__builtins__": {},
"exp": exp,
# aggs
"SUM": filter_nulls(sum),

View file

@ -115,6 +115,9 @@ class PythonExecutor:
sink = self.table(context.columns)
for reader in table_iter:
if len(sink) >= step.limit:
break
if condition and not context.eval(condition):
continue
@ -123,9 +126,6 @@ class PythonExecutor:
else:
sink.append(reader.row)
if len(sink) >= step.limit:
break
return self.context({step.name: sink})
def static(self):
@ -288,7 +288,13 @@ class PythonExecutor:
end = 1
length = len(context.table)
table = self.table(list(step.group) + step.aggregations)
condition = self.generate(step.condition)
def add_row():
if not condition or context.eval(condition):
table.append(group + context.eval_tuple(aggregations))
if length:
for i in range(length):
context.set_index(i)
key = context.eval_tuple(group_by)
@ -296,12 +302,17 @@ class PythonExecutor:
end += 1
if key != group:
context.set_range(start, end - 2)
table.append(group + context.eval_tuple(aggregations))
add_row()
group = key
start = end - 2
if len(table.rows) >= step.limit:
break
if i == length - 1:
context.set_range(start, end - 1)
table.append(group + context.eval_tuple(aggregations))
add_row()
elif step.limit > 0:
context.set_range(0, 0)
table.append(context.eval_tuple(group_by) + context.eval_tuple(aggregations))
context = self.context({step.name: table, **{name: table for name in context.tables}})
@ -311,11 +322,9 @@ class PythonExecutor:
def sort(self, step, context):
projections = self.generate_tuple(step.projections)
projection_columns = [p.alias_or_name for p in step.projections]
all_columns = list(context.columns) + projection_columns
sink = self.table(all_columns)
for reader, ctx in context:
sink.append(reader.row + ctx.eval_tuple(projections))
@ -401,8 +410,9 @@ class Python(Dialect):
exp.Boolean: lambda self, e: "True" if e.this else "False",
exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
exp.Is: lambda self, e: self.binary(e, "is"),
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",

View file

@ -43,14 +43,14 @@ class Expression(metaclass=_Expression):
key = "Expression"
arg_types = {"this": True}
__slots__ = ("args", "parent", "arg_key", "type", "comment")
__slots__ = ("args", "parent", "arg_key", "type", "comments")
def __init__(self, **args):
self.args = args
self.parent = None
self.arg_key = None
self.type = None
self.comment = None
self.comments = None
for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
@ -88,19 +88,6 @@ class Expression(metaclass=_Expression):
return field.this
return ""
def find_comment(self, key: str) -> str:
"""
Finds the comment that is attached to a specified child node.
Args:
key: the key of the target child node (e.g. "this", "expression", etc).
Returns:
The comment attached to the child node, or the empty string, if it doesn't exist.
"""
field = self.args.get(key)
return field.comment if isinstance(field, Expression) else ""
@property
def is_string(self):
return isinstance(self, Literal) and self.args["is_string"]
@ -137,7 +124,7 @@ class Expression(metaclass=_Expression):
def __deepcopy__(self, memo):
copy = self.__class__(**deepcopy(self.args))
copy.comment = self.comment
copy.comments = self.comments
copy.type = self.type
return copy
@ -369,7 +356,7 @@ class Expression(metaclass=_Expression):
)
for k, vs in self.args.items()
}
args["comment"] = self.comment
args["comments"] = self.comments
args["type"] = self.type
args = {k: v for k, v in args.items() if v or not hide_missing}
@ -767,7 +754,7 @@ class NotNullColumnConstraint(ColumnConstraintKind):
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
pass
arg_types = {"desc": False}
class UniqueColumnConstraint(ColumnConstraintKind):
@ -819,6 +806,12 @@ class Unique(Expression):
arg_types = {"expressions": True}
# https://www.postgresql.org/docs/9.1/sql-selectinto.html
# https://docs.aws.amazon.com/redshift/latest/dg/r_SELECT_INTO.html#r_SELECT_INTO-examples
class Into(Expression):
arg_types = {"this": True, "temporary": False, "unlogged": False}
class From(Expression):
arg_types = {"expressions": True}
@ -1065,67 +1058,67 @@ class Property(Expression):
class TableFormatProperty(Property):
pass
arg_types = {"this": True}
class PartitionedByProperty(Property):
pass
arg_types = {"this": True}
class FileFormatProperty(Property):
pass
arg_types = {"this": True}
class DistKeyProperty(Property):
pass
arg_types = {"this": True}
class SortKeyProperty(Property):
pass
arg_types = {"this": True, "compound": False}
class DistStyleProperty(Property):
pass
arg_types = {"this": True}
class LikeProperty(Property):
arg_types = {"this": True, "expressions": False}
class LocationProperty(Property):
pass
arg_types = {"this": True}
class EngineProperty(Property):
pass
arg_types = {"this": True}
class AutoIncrementProperty(Property):
pass
arg_types = {"this": True}
class CharacterSetProperty(Property):
arg_types = {"this": True, "value": True, "default": True}
arg_types = {"this": True, "default": True}
class CollateProperty(Property):
pass
arg_types = {"this": True}
class SchemaCommentProperty(Property):
pass
class AnonymousProperty(Property):
pass
arg_types = {"this": True}
class ReturnsProperty(Property):
arg_types = {"this": True, "value": True, "is_table": False}
arg_types = {"this": True, "is_table": False}
class LanguageProperty(Property):
pass
arg_types = {"this": True}
class ExecuteAsProperty(Property):
pass
arg_types = {"this": True}
class VolatilityProperty(Property):
@ -1135,27 +1128,36 @@ class VolatilityProperty(Property):
class Properties(Expression):
arg_types = {"expressions": True}
PROPERTY_KEY_MAPPING = {
NAME_TO_PROPERTY = {
"AUTO_INCREMENT": AutoIncrementProperty,
"CHARACTER_SET": CharacterSetProperty,
"CHARACTER SET": CharacterSetProperty,
"COLLATE": CollateProperty,
"COMMENT": SchemaCommentProperty,
"ENGINE": EngineProperty,
"FORMAT": FileFormatProperty,
"LOCATION": LocationProperty,
"PARTITIONED_BY": PartitionedByProperty,
"TABLE_FORMAT": TableFormatProperty,
"DISTKEY": DistKeyProperty,
"DISTSTYLE": DistStyleProperty,
"ENGINE": EngineProperty,
"EXECUTE AS": ExecuteAsProperty,
"FORMAT": FileFormatProperty,
"LANGUAGE": LanguageProperty,
"LOCATION": LocationProperty,
"PARTITIONED_BY": PartitionedByProperty,
"RETURNS": ReturnsProperty,
"SORTKEY": SortKeyProperty,
"TABLE_FORMAT": TableFormatProperty,
}
PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()}
@classmethod
def from_dict(cls, properties_dict) -> Properties:
expressions = []
for key, value in properties_dict.items():
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
expressions.append(property_cls(this=Literal.string(key), value=convert(value)))
property_cls = cls.NAME_TO_PROPERTY.get(key.upper())
if property_cls:
expressions.append(property_cls(this=convert(value)))
else:
expressions.append(Property(this=Literal.string(key), value=convert(value)))
return cls(expressions=expressions)
@ -1383,6 +1385,7 @@ class Select(Subqueryable):
"expressions": False,
"hint": False,
"distinct": False,
"into": False,
"from": False,
**QUERY_MODIFIERS,
}
@ -2015,6 +2018,7 @@ class DataType(Expression):
DECIMAL = auto()
BOOLEAN = auto()
JSON = auto()
JSONB = auto()
INTERVAL = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
@ -2029,6 +2033,7 @@ class DataType(Expression):
STRUCT = auto()
NULLABLE = auto()
HLLSKETCH = auto()
HSTORE = auto()
SUPER = auto()
SERIAL = auto()
SMALLSERIAL = auto()
@ -2109,7 +2114,7 @@ class Transaction(Command):
class Commit(Command):
arg_types = {} # type: ignore
arg_types = {"chain": False}
class Rollback(Command):
@ -2442,7 +2447,7 @@ class ArrayFilter(Func):
class ArraySize(Func):
pass
arg_types = {"this": True, "expression": False}
class ArraySort(Func):
@ -2726,6 +2731,16 @@ class VarMap(Func):
is_var_len_args = True
class Matches(Func):
"""Oracle/Snowflake decode.
https://docs.oracle.com/cd/B19306_01/server.102/b14200/functions040.htm
Pattern matching MATCHES(value, search1, result1, ...searchN, resultN, else)
"""
arg_types = {"this": True, "expressions": True}
is_var_len_args = True
class Max(AggFunc):
pass
@ -2785,6 +2800,10 @@ class Round(Func):
arg_types = {"this": True, "decimals": False}
class RowNumber(Func):
arg_types: t.Dict[str, t.Any] = {}
class SafeDivide(Func):
arg_types = {"this": True, "expression": True}

View file

@ -1,19 +1,16 @@
from __future__ import annotations
import logging
import re
import typing as t
from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
from sqlglot.helper import apply_index_offset, csv
from sqlglot.time import format_time
from sqlglot.tokens import TokenType
logger = logging.getLogger("sqlglot")
NEWLINE_RE = re.compile("\r\n?|\n")
class Generator:
"""
@ -58,11 +55,11 @@ class Generator:
"""
TRANSFORMS = {
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})",
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})",
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'this')}",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
@ -97,16 +94,17 @@ class Generator:
exp.DistStyleProperty,
exp.DistKeyProperty,
exp.SortKeyProperty,
exp.LikeProperty,
}
WITH_PROPERTIES = {
exp.AnonymousProperty,
exp.Property,
exp.FileFormatProperty,
exp.PartitionedByProperty,
exp.TableFormatProperty,
}
WITH_SEPARATED_COMMENTS = (exp.Select,)
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
__slots__ = (
"time_mapping",
@ -211,7 +209,7 @@ class Generator:
for msg in self.unsupported_messages:
logger.warning(msg)
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
raise UnsupportedError(concat_errors(self.unsupported_messages, self.max_unsupported))
raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported))
return sql
@ -226,25 +224,24 @@ class Generator:
def seg(self, sql, sep=" "):
return f"{self.sep(sep)}{sql}"
def maybe_comment(self, sql, expression, single_line=False):
comment = expression.comment if self._comments else None
if not comment:
return sql
def pad_comment(self, comment):
comment = " " + comment if comment[0].strip() else comment
comment = comment + " " if comment[-1].strip() else comment
return comment
def maybe_comment(self, sql, expression):
comments = expression.comments if self._comments else None
if not comments:
return sql
sep = "\n" if self.pretty else " "
comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments)
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
return f"/*{comment}*/{self.sep()}{sql}"
return f"{comments}{self.sep()}{sql}"
if not self.pretty:
return f"{sql} /*{comment}*/"
if not NEWLINE_RE.search(comment):
return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
return f"/*{comment}*/\n{sql}" if sql else f" /*{comment}*/"
return f"{sql} {comments}"
def wrap(self, expression):
this_sql = self.indent(
@ -387,8 +384,11 @@ class Generator:
def notnullcolumnconstraint_sql(self, _):
return "NOT NULL"
def primarykeycolumnconstraint_sql(self, _):
return "PRIMARY KEY"
def primarykeycolumnconstraint_sql(self, expression):
desc = expression.args.get("desc")
if desc is not None:
return f"PRIMARY KEY{' DESC' if desc else ' ASC'}"
return f"PRIMARY KEY"
def uniquecolumnconstraint_sql(self, _):
return "UNIQUE"
@ -546,36 +546,33 @@ class Generator:
def root_properties(self, properties):
if properties.expressions:
return self.sep() + self.expressions(
properties,
indent=False,
sep=" ",
)
return self.sep() + self.expressions(properties, indent=False, sep=" ")
return ""
def properties(self, properties, prefix="", sep=", "):
if properties.expressions:
expressions = self.expressions(
properties,
sep=sep,
indent=False,
)
expressions = self.expressions(properties, sep=sep, indent=False)
return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}"
return ""
def with_properties(self, properties):
return self.properties(
properties,
prefix="WITH",
)
return self.properties(properties, prefix="WITH")
def property_sql(self, expression):
if isinstance(expression.this, exp.Literal):
key = expression.this.this
else:
key = expression.name
value = self.sql(expression, "value")
return f"{key}={value}"
property_cls = expression.__class__
if property_cls == exp.Property:
return f"{expression.name}={self.sql(expression, 'value')}"
property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls)
if not property_name:
self.unsupported(f"Unsupported property {property_name}")
return f"{property_name}={self.sql(expression, 'this')}"
def likeproperty_sql(self, expression):
options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions)
options = f" {options}" if options else ""
return f"LIKE {self.sql(expression, 'this')}{options}"
def insert_sql(self, expression):
overwrite = expression.args.get("overwrite")
@ -700,6 +697,11 @@ class Generator:
def var_sql(self, expression):
return self.sql(expression, "this")
def into_sql(self, expression):
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
unlogged = " UNLOGGED" if expression.args.get("unlogged") else ""
return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}"
def from_sql(self, expression):
expressions = self.expressions(expression, flat=True)
return f"{self.seg('FROM')} {expressions}"
@ -883,6 +885,7 @@ class Generator:
sql = self.query_modifiers(
expression,
f"SELECT{hint}{distinct}{expressions}",
self.sql(expression, "into", comment=False),
self.sql(expression, "from", comment=False),
)
return self.prepend_ctes(expression, sql)
@ -1061,6 +1064,11 @@ class Generator:
else:
return f"TRIM({target})"
def concat_sql(self, expression):
if len(expression.expressions) == 1:
return self.sql(expression.expressions[0])
return self.function_fallback_sql(expression)
def check_sql(self, expression):
this = self.sql(expression, key="this")
return f"CHECK ({this})"
@ -1125,7 +1133,10 @@ class Generator:
return self.prepend_ctes(expression, sql)
def neg_sql(self, expression):
return f"-{self.sql(expression, 'this')}"
# This makes sure we don't convert "- - 5" to "--5", which is a comment
this_sql = self.sql(expression, "this")
sep = " " if this_sql[0] == "-" else ""
return f"-{sep}{this_sql}"
def not_sql(self, expression):
return f"NOT {self.sql(expression, 'this')}"
@ -1191,8 +1202,12 @@ class Generator:
def transaction_sql(self, *_):
return "BEGIN"
def commit_sql(self, *_):
return "COMMIT"
def commit_sql(self, expression):
chain = expression.args.get("chain")
if chain is not None:
chain = " AND CHAIN" if chain else " AND NO CHAIN"
return f"COMMIT{chain or ''}"
def rollback_sql(self, expression):
savepoint = expression.args.get("savepoint")
@ -1334,15 +1349,15 @@ class Generator:
result_sqls = []
for i, e in enumerate(expressions):
sql = self.sql(e, comment=False)
comment = self.maybe_comment("", e, single_line=True)
comments = self.maybe_comment("", e)
if self.pretty:
if self._leading_comma:
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}")
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comments}")
else:
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}")
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}")
else:
result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}")
result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}")
result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
return self.indent(result_sqls, skip_first=False) if indent else result_sqls
@ -1354,7 +1369,10 @@ class Generator:
return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}"
def naked_property(self, expression):
return f"{expression.name} {self.sql(expression, 'value')}"
property_name = exp.Properties.PROPERTY_TO_NAME.get(expression.__class__)
if not property_name:
self.unsupported(f"Unsupported property {expression.__class__.__name__}")
return f"{property_name} {self.sql(expression, 'this')}"
def set_operation(self, expression, op):
this = self.sql(expression, "this")

View file

@ -68,6 +68,9 @@ def eliminate_subqueries(expression):
for cte_scope in root.cte_scopes:
# Append all the new CTEs from this existing CTE
for scope in cte_scope.traverse():
if scope is cte_scope:
# Don't try to eliminate this CTE itself
continue
new_cte = _eliminate(scope, existing_ctes, taken)
if new_cte:
new_ctes.append(new_cte)
@ -97,6 +100,9 @@ def _eliminate(scope, existing_ctes, taken):
if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
return _eliminate_derived_table(scope, existing_ctes, taken)
if scope.is_cte:
return _eliminate_cte(scope, existing_ctes, taken)
def _eliminate_union(scope, existing_ctes, taken):
duplicate_cte_alias = existing_ctes.get(scope.expression)
@ -127,26 +133,61 @@ def _eliminate_union(scope, existing_ctes, taken):
def _eliminate_derived_table(scope, existing_ctes, taken):
parent = scope.expression.parent
name, cte = _new_cte(scope, existing_ctes, taken)
table = exp.alias_(exp.table_(name), alias=parent.alias or name)
parent.replace(table)
return cte
def _eliminate_cte(scope, existing_ctes, taken):
parent = scope.expression.parent
name, cte = _new_cte(scope, existing_ctes, taken)
with_ = parent.parent
parent.pop()
if not with_.expressions:
with_.pop()
# Rename references to this CTE
for child_scope in scope.parent.traverse():
for table, source in child_scope.selected_sources.values():
if source is scope:
new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name)
table.replace(new_table)
return cte
def _new_cte(scope, existing_ctes, taken):
"""
Returns:
tuple of (name, cte)
where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance.
If this CTE duplicates an existing CTE, `cte` will be None.
"""
duplicate_cte_alias = existing_ctes.get(scope.expression)
parent = scope.expression.parent
name = alias = parent.alias
name = parent.alias
if not alias:
name = alias = find_new_name(taken=taken, base="cte")
if not name:
name = find_new_name(taken=taken, base="cte")
if duplicate_cte_alias:
name = duplicate_cte_alias
elif taken.get(alias):
name = find_new_name(taken=taken, base=alias)
elif taken.get(name):
name = find_new_name(taken=taken, base=name)
taken[name] = scope
table = exp.alias_(exp.table_(name), alias=alias)
parent.replace(table)
if not duplicate_cte_alias:
existing_ctes[scope.expression] = name
return exp.CTE(
cte = exp.CTE(
this=scope.expression,
alias=exp.TableAlias(this=exp.to_identifier(name)),
)
else:
cte = None
return name, cte

View file

@ -0,0 +1,92 @@
from sqlglot import exp
from sqlglot.helper import ensure_collection
def lower_identities(expression):
"""
Convert all unquoted identifiers to lower case.
Assuming the schema is all lower case, this essentially makes identifiers case-insensitive.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
>>> lower_identities(expression).sql()
'SELECT bar.a AS A FROM "Foo".bar'
Args:
expression (sqlglot.Expression): expression to quote
Returns:
sqlglot.Expression: quoted expression
"""
# We need to leave the output aliases unchanged, so the selects need special handling
_lower_selects(expression)
# These clauses can reference output aliases and also need special handling
_lower_order(expression)
_lower_having(expression)
# We've already handled these args, so don't traverse into them
traversed = {"expressions", "order", "having"}
if isinstance(expression, exp.Subquery):
# Root subquery, e.g. (SELECT A AS A FROM X) LIMIT 1
lower_identities(expression.this)
traversed |= {"this"}
if isinstance(expression, exp.Union):
# Union, e.g. SELECT A AS A FROM X UNION SELECT A AS A FROM X
lower_identities(expression.left)
lower_identities(expression.right)
traversed |= {"this", "expression"}
for k, v in expression.args.items():
if k in traversed:
continue
for child in ensure_collection(v):
if isinstance(child, exp.Expression):
child.transform(_lower, copy=False)
return expression
def _lower_selects(expression):
for e in expression.expressions:
# Leave output aliases as-is
e.unalias().transform(_lower, copy=False)
def _lower_order(expression):
order = expression.args.get("order")
if not order:
return
output_aliases = {e.alias for e in expression.expressions if isinstance(e, exp.Alias)}
for ordered in order.expressions:
# Don't lower references to output aliases
if not (
isinstance(ordered.this, exp.Column)
and not ordered.this.table
and ordered.this.name in output_aliases
):
ordered.transform(_lower, copy=False)
def _lower_having(expression):
having = expression.args.get("having")
if not having:
return
# Don't lower references to output aliases
for agg in having.find_all(exp.AggFunc):
agg.transform(_lower, copy=False)
def _lower(node):
if isinstance(node, exp.Identifier) and not node.quoted:
node.set("this", node.this.lower())
return node

View file

@ -6,6 +6,7 @@ from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.lower_identities import lower_identities
from sqlglot.optimizer.merge_subqueries import merge_subqueries
from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
@ -17,6 +18,7 @@ from sqlglot.optimizer.quote_identities import quote_identities
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
RULES = (
lower_identities,
qualify_tables,
isolate_table_selects,
qualify_columns,

View file

@ -1,16 +1,15 @@
import itertools
from sqlglot import exp
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.optimizer.scope import ScopeType, traverse_scope
def unnest_subqueries(expression):
"""
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
Convert the subquery into a group by so it is not a many to many left join.
Unnesting can only occur if the subquery does not have LIMIT or OFFSET.
Unnesting non correlated subqueries only happens on IN statements or = ANY statements.
Convert scalar subqueries into cross joins.
Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
Example:
>>> import sqlglot
@ -29,21 +28,43 @@ def unnest_subqueries(expression):
for scope in traverse_scope(expression):
select = scope.expression
parent = select.parent_select
if not parent:
continue
if scope.external_columns:
decorrelate(select, parent, scope.external_columns, sequence)
else:
elif scope.scope_type == ScopeType.SUBQUERY:
unnest(select, parent, sequence)
return expression
def unnest(select, parent_select, sequence):
predicate = select.find_ancestor(exp.In, exp.Any)
if len(select.selects) > 1:
return
predicate = select.find_ancestor(exp.Condition)
alias = _alias(sequence)
if not predicate or parent_select is not predicate.parent_select:
return
if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset):
# this subquery returns a scalar and can just be converted to a cross join
if not isinstance(predicate, (exp.In, exp.Any)):
having = predicate.find_ancestor(exp.Having)
column = exp.column(select.selects[0].alias_or_name, alias)
if having and having.parent_select is parent_select:
column = exp.Max(this=column)
_replace(select.parent, column)
parent_select.join(
select,
join_type="CROSS",
join_alias=alias,
copy=False,
)
return
if select.find(exp.Limit, exp.Offset):
return
if isinstance(predicate, exp.Any):
@ -54,7 +75,6 @@ def unnest(select, parent_select, sequence):
column = _other_operand(predicate)
value = select.selects[0]
alias = _alias(sequence)
on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
_replace(predicate, f"NOT {on.right} IS NULL")

View file

@ -4,7 +4,7 @@ import logging
import typing as t
from sqlglot import exp
from sqlglot.errors import ErrorLevel, ParseError, concat_errors
from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
from sqlglot.helper import apply_index_offset, ensure_collection, seq_get
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import in_trie, new_trie
@ -104,6 +104,7 @@ class Parser(metaclass=_Parser):
TokenType.BINARY,
TokenType.VARBINARY,
TokenType.JSON,
TokenType.JSONB,
TokenType.INTERVAL,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
@ -115,6 +116,7 @@ class Parser(metaclass=_Parser):
TokenType.GEOGRAPHY,
TokenType.GEOMETRY,
TokenType.HLLSKETCH,
TokenType.HSTORE,
TokenType.SUPER,
TokenType.SERIAL,
TokenType.SMALLSERIAL,
@ -153,6 +155,7 @@ class Parser(metaclass=_Parser):
TokenType.COLLATE,
TokenType.COMMAND,
TokenType.COMMIT,
TokenType.COMPOUND,
TokenType.CONSTRAINT,
TokenType.CURRENT_TIME,
TokenType.DEFAULT,
@ -194,6 +197,7 @@ class Parser(metaclass=_Parser):
TokenType.RANGE,
TokenType.REFERENCES,
TokenType.RETURNS,
TokenType.ROW,
TokenType.ROWS,
TokenType.SCHEMA,
TokenType.SCHEMA_COMMENT,
@ -213,6 +217,7 @@ class Parser(metaclass=_Parser):
TokenType.TRUE,
TokenType.UNBOUNDED,
TokenType.UNIQUE,
TokenType.UNLOGGED,
TokenType.UNPIVOT,
TokenType.PROPERTIES,
TokenType.PROCEDURE,
@ -400,9 +405,17 @@ class Parser(metaclass=_Parser):
TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
TokenType.BEGIN: lambda self: self._parse_transaction(),
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.END: lambda self: self._parse_commit_or_rollback(),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
}
UNARY_PARSERS = {
TokenType.PLUS: lambda self: self._parse_unary(), # Unary + is handled as a no-op
TokenType.NOT: lambda self: self.expression(exp.Not, this=self._parse_equality()),
TokenType.TILDA: lambda self: self.expression(exp.BitwiseNot, this=self._parse_unary()),
TokenType.DASH: lambda self: self.expression(exp.Neg, this=self._parse_unary()),
}
PRIMARY_PARSERS = {
TokenType.STRING: lambda self, token: self.expression(
exp.Literal, this=token.text, is_string=True
@ -446,19 +459,20 @@ class Parser(metaclass=_Parser):
}
PROPERTY_PARSERS = {
TokenType.AUTO_INCREMENT: lambda self: self._parse_auto_increment(),
TokenType.CHARACTER_SET: lambda self: self._parse_character_set(),
TokenType.LOCATION: lambda self: self.expression(
exp.LocationProperty,
this=exp.Literal.string("LOCATION"),
value=self._parse_string(),
TokenType.AUTO_INCREMENT: lambda self: self._parse_property_assignment(
exp.AutoIncrementProperty
),
TokenType.CHARACTER_SET: lambda self: self._parse_character_set(),
TokenType.LOCATION: lambda self: self._parse_property_assignment(exp.LocationProperty),
TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(),
TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(),
TokenType.STORED: lambda self: self._parse_stored(),
TokenType.SCHEMA_COMMENT: lambda self: self._parse_property_assignment(
exp.SchemaCommentProperty
),
TokenType.STORED: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
TokenType.DISTKEY: lambda self: self._parse_distkey(),
TokenType.DISTSTYLE: lambda self: self._parse_diststyle(),
TokenType.DISTSTYLE: lambda self: self._parse_property_assignment(exp.DistStyleProperty),
TokenType.SORTKEY: lambda self: self._parse_sortkey(),
TokenType.LIKE: lambda self: self._parse_create_like(),
TokenType.RETURNS: lambda self: self._parse_returns(),
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
@ -468,7 +482,7 @@ class Parser(metaclass=_Parser):
),
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
TokenType.EXECUTE: lambda self: self._parse_execute_as(),
TokenType.EXECUTE: lambda self: self._parse_property_assignment(exp.ExecuteAsProperty),
TokenType.DETERMINISTIC: lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
),
@ -489,6 +503,7 @@ class Parser(metaclass=_Parser):
),
TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(),
TokenType.UNIQUE: lambda self: self._parse_unique(),
TokenType.LIKE: lambda self: self._parse_create_like(),
}
NO_PAREN_FUNCTION_PARSERS = {
@ -505,6 +520,7 @@ class Parser(metaclass=_Parser):
"TRIM": lambda self: self._parse_trim(),
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"TRY_CAST": lambda self: self._parse_cast(False),
"STRING_AGG": lambda self: self._parse_string_agg(),
}
QUERY_MODIFIER_PARSERS = {
@ -556,7 +572,7 @@ class Parser(metaclass=_Parser):
"_curr",
"_next",
"_prev",
"_prev_comment",
"_prev_comments",
"_show_trie",
"_set_trie",
)
@ -589,7 +605,7 @@ class Parser(metaclass=_Parser):
self._curr = None
self._next = None
self._prev = None
self._prev_comment = None
self._prev_comments = None
def parse(self, raw_tokens, sql=None):
"""
@ -608,6 +624,7 @@ class Parser(metaclass=_Parser):
)
def parse_into(self, expression_types, raw_tokens, sql=None):
errors = []
for expression_type in ensure_collection(expression_types):
parser = self.EXPRESSION_PARSERS.get(expression_type)
if not parser:
@ -615,8 +632,12 @@ class Parser(metaclass=_Parser):
try:
return self._parse(parser, raw_tokens, sql)
except ParseError as e:
error = e
raise ParseError(f"Failed to parse into {expression_types}") from error
e.errors[0]["into_expression"] = expression_type
errors.append(e)
raise ParseError(
f"Failed to parse into {expression_types}",
errors=merge_errors(errors),
) from errors[-1]
def _parse(self, parse_method, raw_tokens, sql=None):
self.reset()
@ -650,7 +671,10 @@ class Parser(metaclass=_Parser):
for error in self.errors:
logger.error(str(error))
elif self.error_level == ErrorLevel.RAISE and self.errors:
raise ParseError(concat_errors(self.errors, self.max_errors))
raise ParseError(
concat_messages(self.errors, self.max_errors),
errors=merge_errors(self.errors),
)
def raise_error(self, message, token=None):
token = token or self._curr or self._prev or Token.string("")
@ -659,19 +683,27 @@ class Parser(metaclass=_Parser):
start_context = self.sql[max(start - self.error_message_context, 0) : start]
highlight = self.sql[start:end]
end_context = self.sql[end : end + self.error_message_context]
error = ParseError(
error = ParseError.new(
f"{message}. Line {token.line}, Col: {token.col}.\n"
f" {start_context}\033[4m{highlight}\033[0m{end_context}"
f" {start_context}\033[4m{highlight}\033[0m{end_context}",
description=message,
line=token.line,
col=token.col,
start_context=start_context,
highlight=highlight,
end_context=end_context,
)
if self.error_level == ErrorLevel.IMMEDIATE:
raise error
self.errors.append(error)
def expression(self, exp_class, **kwargs):
def expression(self, exp_class, comments=None, **kwargs):
instance = exp_class(**kwargs)
if self._prev_comment:
instance.comment = self._prev_comment
self._prev_comment = None
if self._prev_comments:
instance.comments = self._prev_comments
self._prev_comments = None
if comments:
instance.comments = comments
self.validate_expression(instance)
return instance
@ -714,10 +746,10 @@ class Parser(metaclass=_Parser):
self._next = seq_get(self._tokens, self._index + 1)
if self._index > 0:
self._prev = self._tokens[self._index - 1]
self._prev_comment = self._prev.comment
self._prev_comments = self._prev.comments
else:
self._prev = None
self._prev_comment = None
self._prev_comments = None
def _retreat(self, index):
self._advance(index - self._index)
@ -768,7 +800,7 @@ class Parser(metaclass=_Parser):
)
def _parse_create(self):
replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
temporary = self._match(TokenType.TEMPORARY)
transient = self._match(TokenType.TRANSIENT)
unique = self._match(TokenType.UNIQUE)
@ -822,97 +854,57 @@ class Parser(metaclass=_Parser):
def _parse_property(self):
if self._match_set(self.PROPERTY_PARSERS):
return self.PROPERTY_PARSERS[self._prev.token_type](self)
if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET):
return self._parse_character_set(True)
if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False):
key = self._parse_var().this
self._match(TokenType.EQ)
if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
return self._parse_sortkey(compound=True)
return self.expression(
exp.AnonymousProperty,
this=exp.Literal.string(key),
value=self._parse_column(),
)
if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False):
key = self._parse_var()
self._match(TokenType.EQ)
return self.expression(exp.Property, this=key, value=self._parse_column())
return None
def _parse_property_assignment(self, exp_class):
prop = self._prev.text
self._match(TokenType.EQ)
return self.expression(exp_class, this=prop, value=self._parse_var_or_string())
self._match(TokenType.ALIAS)
return self.expression(exp_class, this=self._parse_var_or_string() or self._parse_number())
def _parse_partitioned_by(self):
self._match(TokenType.EQ)
return self.expression(
exp.PartitionedByProperty,
this=exp.Literal.string("PARTITIONED_BY"),
value=self._parse_schema() or self._parse_bracket(self._parse_field()),
)
def _parse_stored(self):
self._match(TokenType.ALIAS)
self._match(TokenType.EQ)
return self.expression(
exp.FileFormatProperty,
this=exp.Literal.string("FORMAT"),
value=exp.Literal.string(self._parse_var_or_string().name),
this=self._parse_schema() or self._parse_bracket(self._parse_field()),
)
def _parse_distkey(self):
self._match_l_paren()
this = exp.Literal.string("DISTKEY")
value = exp.Literal.string(self._parse_var().name)
self._match_r_paren()
return self.expression(
exp.DistKeyProperty,
this=this,
value=value,
)
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_var))
def _parse_sortkey(self):
self._match_l_paren()
this = exp.Literal.string("SORTKEY")
value = exp.Literal.string(self._parse_var().name)
self._match_r_paren()
return self.expression(
exp.SortKeyProperty,
this=this,
value=value,
def _parse_create_like(self):
table = self._parse_table(schema=True)
options = []
while self._match_texts(("INCLUDING", "EXCLUDING")):
options.append(
self.expression(
exp.Property,
this=self._prev.text.upper(),
value=exp.Var(this=self._parse_id_var().this.upper()),
)
def _parse_diststyle(self):
this = exp.Literal.string("DISTSTYLE")
value = exp.Literal.string(self._parse_var().name)
return self.expression(
exp.DistStyleProperty,
this=this,
value=value,
)
return self.expression(exp.LikeProperty, this=table, expressions=options)
def _parse_auto_increment(self):
self._match(TokenType.EQ)
def _parse_sortkey(self, compound=False):
return self.expression(
exp.AutoIncrementProperty,
this=exp.Literal.string("AUTO_INCREMENT"),
value=self._parse_number(),
)
def _parse_schema_comment(self):
self._match(TokenType.EQ)
return self.expression(
exp.SchemaCommentProperty,
this=exp.Literal.string("COMMENT"),
value=self._parse_string(),
exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_var), compound=compound
)
def _parse_character_set(self, default=False):
self._match(TokenType.EQ)
return self.expression(
exp.CharacterSetProperty,
this=exp.Literal.string("CHARACTER_SET"),
value=self._parse_var_or_string(),
default=default,
exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default
)
def _parse_returns(self):
@ -931,20 +923,7 @@ class Parser(metaclass=_Parser):
else:
value = self._parse_types()
return self.expression(
exp.ReturnsProperty,
this=exp.Literal.string("RETURNS"),
value=value,
is_table=is_table,
)
def _parse_execute_as(self):
self._match(TokenType.ALIAS)
return self.expression(
exp.ExecuteAsProperty,
this=exp.Literal.string("EXECUTE AS"),
value=self._parse_var(),
)
return self.expression(exp.ReturnsProperty, this=value, is_table=is_table)
def _parse_properties(self):
properties = []
@ -956,7 +935,7 @@ class Parser(metaclass=_Parser):
properties.extend(
self._parse_wrapped_csv(
lambda: self.expression(
exp.AnonymousProperty,
exp.Property,
this=self._parse_string(),
value=self._match(TokenType.EQ) and self._parse_string(),
)
@ -1076,7 +1055,12 @@ class Parser(metaclass=_Parser):
options = []
if self._match(TokenType.OPTIONS):
options = self._parse_wrapped_csv(self._parse_string, sep=TokenType.EQ)
self._match_l_paren()
k = self._parse_string()
self._match(TokenType.EQ)
v = self._parse_string()
options = [k, v]
self._match_r_paren()
self._match(TokenType.ALIAS)
return self.expression(
@ -1116,7 +1100,7 @@ class Parser(metaclass=_Parser):
self.raise_error(f"{this.key} does not support CTE")
this = cte
elif self._match(TokenType.SELECT):
comment = self._prev_comment
comments = self._prev_comments
hint = self._parse_hint()
all_ = self._match(TokenType.ALL)
@ -1141,10 +1125,16 @@ class Parser(metaclass=_Parser):
expressions=expressions,
limit=limit,
)
this.comment = comment
this.comments = comments
into = self._parse_into()
if into:
this.set("into", into)
from_ = self._parse_from()
if from_:
this.set("from", from_)
self._parse_query_modifiers(this)
elif (table or nested) and self._match(TokenType.L_PAREN):
this = self._parse_table() if table else self._parse_select(nested=True)
@ -1248,11 +1238,24 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Hint, expressions=hints)
return None
def _parse_into(self):
if not self._match(TokenType.INTO):
return None
temp = self._match(TokenType.TEMPORARY)
unlogged = self._match(TokenType.UNLOGGED)
self._match(TokenType.TABLE)
return self.expression(
exp.Into, this=self._parse_table(schema=True), temporary=temp, unlogged=unlogged
)
def _parse_from(self):
if not self._match(TokenType.FROM):
return None
return self.expression(exp.From, expressions=self._parse_csv(self._parse_table))
return self.expression(
exp.From, comments=self._prev_comments, expressions=self._parse_csv(self._parse_table)
)
def _parse_lateral(self):
outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY)
@ -1515,7 +1518,9 @@ class Parser(metaclass=_Parser):
def _parse_where(self, skip_where_token=False):
if not skip_where_token and not self._match(TokenType.WHERE):
return None
return self.expression(exp.Where, this=self._parse_conjunction())
return self.expression(
exp.Where, comments=self._prev_comments, this=self._parse_conjunction()
)
def _parse_group(self, skip_group_by_token=False):
if not skip_group_by_token and not self._match(TokenType.GROUP_BY):
@ -1737,12 +1742,8 @@ class Parser(metaclass=_Parser):
return self._parse_tokens(self._parse_unary, self.FACTOR)
def _parse_unary(self):
if self._match(TokenType.NOT):
return self.expression(exp.Not, this=self._parse_equality())
if self._match(TokenType.TILDA):
return self.expression(exp.BitwiseNot, this=self._parse_unary())
if self._match(TokenType.DASH):
return self.expression(exp.Neg, this=self._parse_unary())
if self._match_set(self.UNARY_PARSERS):
return self.UNARY_PARSERS[self._prev.token_type](self)
return self._parse_at_time_zone(self._parse_type())
def _parse_type(self):
@ -1775,17 +1776,6 @@ class Parser(metaclass=_Parser):
expressions = None
maybe_func = False
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
return exp.DataType(
this=exp.DataType.Type.ARRAY,
expressions=[exp.DataType.build(type_token.value)],
nested=True,
)
if self._match(TokenType.L_BRACKET):
self._retreat(index)
return None
if self._match(TokenType.L_PAREN):
if is_struct:
expressions = self._parse_csv(self._parse_struct_kwargs)
@ -1801,6 +1791,17 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
maybe_func = True
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
return exp.DataType(
this=exp.DataType.Type.ARRAY,
expressions=[exp.DataType.build(type_token.value, expressions=expressions)],
nested=True,
)
if self._match(TokenType.L_BRACKET):
self._retreat(index)
return None
if nested and self._match(TokenType.LT):
if is_struct:
expressions = self._parse_csv(self._parse_struct_kwargs)
@ -1904,7 +1905,7 @@ class Parser(metaclass=_Parser):
return exp.Literal.number(f"0.{self._prev.text}")
if self._match(TokenType.L_PAREN):
comment = self._prev_comment
comments = self._prev_comments
query = self._parse_select()
if query:
@ -1924,8 +1925,8 @@ class Parser(metaclass=_Parser):
this = self.expression(exp.Tuple, expressions=expressions)
else:
this = self.expression(exp.Paren, this=this)
if comment:
this.comment = comment
if comments:
this.comments = comments
return this
return None
@ -2098,7 +2099,10 @@ class Parser(metaclass=_Parser):
elif self._match(TokenType.SCHEMA_COMMENT):
kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string())
elif self._match(TokenType.PRIMARY_KEY):
kind = exp.PrimaryKeyColumnConstraint()
desc = None
if self._match(TokenType.ASC) or self._match(TokenType.DESC):
desc = self._prev.token_type == TokenType.DESC
kind = exp.PrimaryKeyColumnConstraint(desc=desc)
elif self._match(TokenType.UNIQUE):
kind = exp.UniqueColumnConstraint()
elif self._match(TokenType.GENERATED):
@ -2189,7 +2193,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.R_BRACKET):
self.raise_error("Expected ]")
this.comment = self._prev_comment
this.comments = self._prev_comments
return self._parse_bracket(this)
def _parse_case(self):
@ -2256,6 +2260,33 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
def _parse_string_agg(self):
if self._match(TokenType.DISTINCT):
args = self._parse_csv(self._parse_conjunction)
expression = self.expression(exp.Distinct, expressions=[seq_get(args, 0)])
else:
args = self._parse_csv(self._parse_conjunction)
expression = seq_get(args, 0)
index = self._index
if not self._match(TokenType.R_PAREN):
# postgres: STRING_AGG([DISTINCT] expression, separator [ORDER BY expression1 {ASC | DESC} [, ...]])
order = self._parse_order(this=expression)
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
# Checks if we can parse an order clause: WITHIN GROUP (ORDER BY <order_by_expression_list> [ASC | DESC]).
# This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that
# the STRING_AGG call is parsed like in MySQL / SQLite and can thus be transpiled more easily to them.
if not self._match(TokenType.WITHIN_GROUP):
self._retreat(index)
this = exp.GroupConcat.from_arg_list(args)
self.validate_expression(this, args)
return this
self._match_l_paren() # The corresponding match_r_paren will be called in parse_function (caller)
order = self._parse_order(this=expression)
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
def _parse_convert(self, strict):
this = self._parse_column()
if self._match(TokenType.USING):
@ -2511,8 +2542,8 @@ class Parser(metaclass=_Parser):
items = [parse_result] if parse_result is not None else []
while self._match(sep):
if parse_result and self._prev_comment is not None:
parse_result.comment = self._prev_comment
if parse_result and self._prev_comments:
parse_result.comments = self._prev_comments
parse_result = parse_method()
if parse_result is not None:
@ -2525,7 +2556,10 @@ class Parser(metaclass=_Parser):
while self._match_set(expressions):
this = self.expression(
expressions[self._prev.token_type], this=this, expression=parse_method()
expressions[self._prev.token_type],
this=this,
comments=self._prev_comments,
expression=parse_method(),
)
return this
@ -2566,6 +2600,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Transaction, this=this, modes=modes)
def _parse_commit_or_rollback(self):
chain = None
savepoint = None
is_rollback = self._prev.token_type == TokenType.ROLLBACK
@ -2575,9 +2610,13 @@ class Parser(metaclass=_Parser):
self._match_text_seq("SAVEPOINT")
savepoint = self._parse_id_var()
if self._match(TokenType.AND):
chain = not self._match_text_seq("NO")
self._match_text_seq("CHAIN")
if is_rollback:
return self.expression(exp.Rollback, savepoint=savepoint)
return self.expression(exp.Commit)
return self.expression(exp.Commit, chain=chain)
def _parse_show(self):
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
@ -2651,14 +2690,14 @@ class Parser(metaclass=_Parser):
def _match_l_paren(self, expression=None):
if not self._match(TokenType.L_PAREN):
self.raise_error("Expecting (")
if expression and self._prev_comment:
expression.comment = self._prev_comment
if expression and self._prev_comments:
expression.comments = self._prev_comments
def _match_r_paren(self, expression=None):
if not self._match(TokenType.R_PAREN):
self.raise_error("Expecting )")
if expression and self._prev_comment:
expression.comment = self._prev_comment
if expression and self._prev_comments:
expression.comments = self._prev_comments
def _match_texts(self, texts):
if self._curr and self._curr.text.upper() in texts:

View file

@ -130,18 +130,20 @@ class Step:
aggregations = []
sequence = itertools.count()
for e in expression.expressions:
aggregation = e.find(exp.AggFunc)
if aggregation:
projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
aggregations.append(e)
for operand in aggregation.unnest_operands():
def extract_agg_operands(expression):
for agg in expression.find_all(exp.AggFunc):
for operand in agg.unnest_operands():
if isinstance(operand, exp.Column):
continue
if operand not in operands:
operands[operand] = f"_a_{next(sequence)}"
operand.replace(exp.column(operands[operand], quoted=True))
for e in expression.expressions:
if e.find(exp.AggFunc):
projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
aggregations.append(e)
extract_agg_operands(e)
else:
projections.append(e)
@ -156,6 +158,13 @@ class Step:
aggregate = Aggregate()
aggregate.source = step.name
aggregate.name = step.name
having = expression.args.get("having")
if having:
extract_agg_operands(having)
aggregate.condition = having.this
aggregate.operands = tuple(
alias(operand, alias_) for operand, alias_ in operands.items()
)
@ -172,11 +181,6 @@ class Step:
aggregate.add_dependency(step)
step = aggregate
having = expression.args.get("having")
if having:
step.condition = having.this
order = expression.args.get("order")
if order:
@ -188,6 +192,17 @@ class Step:
step.projections = projections
if isinstance(expression, exp.Select) and expression.args.get("distinct"):
distinct = Aggregate()
distinct.source = step.name
distinct.name = step.name
distinct.group = {
e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name)
for e in projections or expression.expressions
}
distinct.add_dependency(step)
step = distinct
limit = expression.args.get("limit")
if limit:
@ -231,6 +246,9 @@ class Step:
if self.condition:
lines.append(f"{nested}Condition: {self.condition.sql()}")
if self.limit is not math.inf:
lines.append(f"{nested}Limit: {self.limit}")
if self.dependencies:
lines.append(f"{nested}Dependencies:")
for dependency in self.dependencies:
@ -258,12 +276,7 @@ class Scan(Step):
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
) -> Step:
table = expression
alias_ = expression.alias
if not alias_:
raise UnsupportedError(
"Tables/Subqueries must be aliased. Run it through the optimizer"
)
alias_ = expression.alias_or_name
if isinstance(expression, exp.Subquery):
table = expression.this
@ -338,6 +351,9 @@ class Aggregate(Step):
lines.append(f"{indent}Group:")
for expression in self.group.values():
lines.append(f"{indent} - {expression.sql()}")
if self.condition:
lines.append(f"{indent}Having:")
lines.append(f"{indent} - {self.condition.sql()}")
if self.operands:
lines.append(f"{indent}Operands:")
for expression in self.operands:

View file

@ -81,6 +81,7 @@ class TokenType(AutoName):
BINARY = auto()
VARBINARY = auto()
JSON = auto()
JSONB = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
@ -91,6 +92,7 @@ class TokenType(AutoName):
NULLABLE = auto()
GEOMETRY = auto()
HLLSKETCH = auto()
HSTORE = auto()
SUPER = auto()
SERIAL = auto()
SMALLSERIAL = auto()
@ -113,6 +115,7 @@ class TokenType(AutoName):
APPLY = auto()
ARRAY = auto()
ASC = auto()
ASOF = auto()
AT_TIME_ZONE = auto()
AUTO_INCREMENT = auto()
BEGIN = auto()
@ -130,6 +133,7 @@ class TokenType(AutoName):
COMMAND = auto()
COMMENT = auto()
COMMIT = auto()
COMPOUND = auto()
CONSTRAINT = auto()
CREATE = auto()
CROSS = auto()
@ -271,6 +275,7 @@ class TokenType(AutoName):
UNBOUNDED = auto()
UNCACHE = auto()
UNION = auto()
UNLOGGED = auto()
UNNEST = auto()
UNPIVOT = auto()
UPDATE = auto()
@ -291,7 +296,7 @@ class TokenType(AutoName):
class Token:
__slots__ = ("token_type", "text", "line", "col", "comment")
__slots__ = ("token_type", "text", "line", "col", "comments")
@classmethod
def number(cls, number: int) -> Token:
@ -319,13 +324,13 @@ class Token:
text: str,
line: int = 1,
col: int = 1,
comment: t.Optional[str] = None,
comments: t.List[str] = [],
) -> None:
self.token_type = token_type
self.text = text
self.line = line
self.col = max(col - len(text), 1)
self.comment = comment
self.comments = comments
def __repr__(self) -> str:
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
@ -452,6 +457,7 @@ class Tokenizer(metaclass=_Tokenizer):
"COLLATE": TokenType.COLLATE,
"COMMENT": TokenType.SCHEMA_COMMENT,
"COMMIT": TokenType.COMMIT,
"COMPOUND": TokenType.COMPOUND,
"CONSTRAINT": TokenType.CONSTRAINT,
"CREATE": TokenType.CREATE,
"CROSS": TokenType.CROSS,
@ -582,8 +588,9 @@ class Tokenizer(metaclass=_Tokenizer):
"TRAILING": TokenType.TRAILING,
"UNBOUNDED": TokenType.UNBOUNDED,
"UNION": TokenType.UNION,
"UNPIVOT": TokenType.UNPIVOT,
"UNLOGGED": TokenType.UNLOGGED,
"UNNEST": TokenType.UNNEST,
"UNPIVOT": TokenType.UNPIVOT,
"UPDATE": TokenType.UPDATE,
"USE": TokenType.USE,
"USING": TokenType.USING,
@ -686,12 +693,12 @@ class Tokenizer(metaclass=_Tokenizer):
"_current",
"_line",
"_col",
"_comment",
"_comments",
"_char",
"_end",
"_peek",
"_prev_token_line",
"_prev_token_comment",
"_prev_token_comments",
"_prev_token_type",
"_replace_backslash",
)
@ -708,13 +715,13 @@ class Tokenizer(metaclass=_Tokenizer):
self._current = 0
self._line = 1
self._col = 1
self._comment = None
self._comments: t.List[str] = []
self._char = None
self._end = None
self._peek = None
self._prev_token_line = -1
self._prev_token_comment = None
self._prev_token_comments: t.List[str] = []
self._prev_token_type = None
def tokenize(self, sql: str) -> t.List[Token]:
@ -767,7 +774,7 @@ class Tokenizer(metaclass=_Tokenizer):
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
self._prev_token_line = self._line
self._prev_token_comment = self._comment
self._prev_token_comments = self._comments
self._prev_token_type = token_type # type: ignore
self.tokens.append(
Token(
@ -775,10 +782,10 @@ class Tokenizer(metaclass=_Tokenizer):
self._text if text is None else text,
self._line,
self._col,
self._comment,
self._comments,
)
)
self._comment = None
self._comments = []
if token_type in self.COMMANDS and (
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
@ -857,22 +864,18 @@ class Tokenizer(metaclass=_Tokenizer):
while not self._end and self._chars(comment_end_size) != comment_end:
self._advance()
self._comment = self._text[comment_start_size : -comment_end_size + 1] # type: ignore
self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore
self._advance(comment_end_size - 1)
else:
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore
self._advance()
self._comment = self._text[comment_start_size:] # type: ignore
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. If both
# types of comment can be attached to a token, the trailing one is discarded in favour of the leading one.
self._comments.append(self._text[comment_start_size:]) # type: ignore
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding.
# Multiple consecutive comments are preserved by appending them to the current comments list.
if comment_start_line == self._prev_token_line:
if self._prev_token_comment is None:
self.tokens[-1].comment = self._comment
self._prev_token_comment = self._comment
self._comment = None
self.tokens[-1].comments.extend(self._comments)
self._comments = []
return True

View file

@ -2,6 +2,8 @@ from __future__ import annotations
import typing as t
from sqlglot.helper import find_new_name
if t.TYPE_CHECKING:
from sqlglot.generator import Generator
@ -43,6 +45,43 @@ def unalias_group(expression: exp.Expression) -> exp.Expression:
return expression
def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
"""
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Args:
expression: the expression that will be transformed.
Returns:
The transformed expression.
"""
if (
isinstance(expression, exp.Select)
and expression.args.get("distinct")
and expression.args["distinct"].args.get("on")
and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
):
distinct_cols = [e.copy() for e in expression.args["distinct"].args["on"].expressions]
outer_selects = [e.copy() for e in expression.expressions]
nested = expression.copy()
nested.args["distinct"].pop()
row_number = find_new_name(expression.named_selects, "_row_number")
window = exp.Window(
this=exp.RowNumber(),
partition_by=distinct_cols,
)
order = nested.args.get("order")
if order:
window.set("order", order.copy())
order.pop()
window = exp.alias_(window, row_number)
nested.select(window, copy=False)
return exp.select(*outer_selects).from_(nested.subquery()).where(f'"{row_number}" = 1')
return expression
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
to_sql: t.Callable[[Generator, exp.Expression], str],
@ -81,3 +120,4 @@ def delegate(attr: str) -> t.Callable:
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))}

View file

@ -1276,7 +1276,7 @@ class TestFunctions(unittest.TestCase):
col = SF.concat(SF.col("cola"), SF.col("colb"))
self.assertEqual("CONCAT(cola, colb)", col.sql())
col_single = SF.concat("cola")
self.assertEqual("CONCAT(cola)", col_single.sql())
self.assertEqual("cola", col_single.sql())
def test_array_position(self):
col_str = SF.array_position("cola", SF.col("colb"))

View file

@ -10,6 +10,10 @@ class TestClickhouse(Validator):
self.validate_identity("SELECT * FROM x AS y FINAL")
self.validate_identity("'a' IN mapKeys(map('a', 1, 'b', 2))")
self.validate_identity("CAST((1, 2) AS Tuple(a Int8, b Int16))")
self.validate_identity("SELECT * FROM foo LEFT ANY JOIN bla")
self.validate_identity("SELECT * FROM foo LEFT ASOF JOIN bla")
self.validate_identity("SELECT * FROM foo ASOF JOIN bla")
self.validate_identity("SELECT * FROM foo ANY JOIN bla")
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",

View file

@ -997,6 +997,13 @@ class TestDialect(Validator):
"spark": "CONCAT_WS('-', x)",
},
)
self.validate_all(
"CONCAT(a)",
write={
"mysql": "a",
"tsql": "a",
},
)
self.validate_all(
"IF(x > 1, 1, 0)",
write={
@ -1263,8 +1270,8 @@ class TestDialect(Validator):
self.validate_all(
"""/* comment1 */
SELECT
x, -- comment2
y -- comment3""",
x, /* comment2 */
y /* comment3 */""",
read={
"mysql": """SELECT # comment1
x, # comment2

View file

@ -89,6 +89,8 @@ class TestDuckDB(Validator):
"presto": "CAST(COL AS ARRAY(BIGINT))",
"hive": "CAST(COL AS ARRAY<BIGINT>)",
"spark": "CAST(COL AS ARRAY<LONG>)",
"postgres": "CAST(COL AS BIGINT[])",
"snowflake": "CAST(COL AS ARRAY)",
},
)
@ -104,6 +106,10 @@ class TestDuckDB(Validator):
"spark": "ARRAY(0, 1, 2)",
},
)
self.validate_all(
"SELECT ARRAY_LENGTH([0], 1) AS x",
write={"duckdb": "SELECT ARRAY_LENGTH(LIST_VALUE(0), 1) AS x"},
)
self.validate_all(
"REGEXP_MATCHES(x, y)",
write={

View file

@ -139,7 +139,7 @@ class TestHive(Validator):
"CREATE TABLE test STORED AS parquet TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1",
write={
"duckdb": "CREATE TABLE test AS SELECT 1",
"presto": "CREATE TABLE test WITH (FORMAT='parquet', x='1', Z='2') AS SELECT 1",
"presto": "CREATE TABLE test WITH (FORMAT='PARQUET', x='1', Z='2') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1",
"spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1",
},
@ -459,6 +459,7 @@ class TestHive(Validator):
"hive": "MAP(a, b, c, d)",
"presto": "MAP(ARRAY[a, c], ARRAY[b, d])",
"spark": "MAP(a, b, c, d)",
"snowflake": "OBJECT_CONSTRUCT(a, b, c, d)",
},
write={
"": "MAP(ARRAY(a, c), ARRAY(b, d))",
@ -467,6 +468,7 @@ class TestHive(Validator):
"presto": "MAP(ARRAY[a, c], ARRAY[b, d])",
"hive": "MAP(a, b, c, d)",
"spark": "MAP(a, b, c, d)",
"snowflake": "OBJECT_CONSTRUCT(a, b, c, d)",
},
)
self.validate_all(
@ -476,6 +478,7 @@ class TestHive(Validator):
"presto": "MAP(ARRAY[a], ARRAY[b])",
"hive": "MAP(a, b)",
"spark": "MAP(a, b)",
"snowflake": "OBJECT_CONSTRUCT(a, b)",
},
)
self.validate_all(

View file

@ -23,6 +23,8 @@ class TestMySQL(Validator):
self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')")
self.validate_identity("@@GLOBAL.max_connections")
self.validate_identity("CREATE TABLE A LIKE B")
# SET Commands
self.validate_identity("SET @var_name = expr")
self.validate_identity("SET @name = 43")
@ -177,14 +179,27 @@ class TestMySQL(Validator):
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
write={
"mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR ',')",
"sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
"sqlite": "GROUP_CONCAT(DISTINCT x)",
"tsql": "STRING_AGG(x, ',') WITHIN GROUP (ORDER BY y DESC)",
"postgres": "STRING_AGG(DISTINCT x, ',' ORDER BY y DESC NULLS LAST)",
},
)
self.validate_all(
"GROUP_CONCAT(x ORDER BY y SEPARATOR z)",
write={
"mysql": "GROUP_CONCAT(x ORDER BY y SEPARATOR z)",
"sqlite": "GROUP_CONCAT(x, z)",
"tsql": "STRING_AGG(x, z) WITHIN GROUP (ORDER BY y)",
"postgres": "STRING_AGG(x, z ORDER BY y NULLS FIRST)",
},
)
self.validate_all(
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')",
write={
"mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')",
"sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC, '')",
"sqlite": "GROUP_CONCAT(DISTINCT x, '')",
"tsql": "STRING_AGG(x, '') WITHIN GROUP (ORDER BY y DESC)",
"postgres": "STRING_AGG(DISTINCT x, '' ORDER BY y DESC NULLS LAST)",
},
)
self.validate_identity(

View file

@ -6,6 +6,9 @@ class TestPostgres(Validator):
dialect = "postgres"
def test_ddl(self):
self.validate_identity("CREATE TABLE test (foo HSTORE)")
self.validate_identity("CREATE TABLE test (foo JSONB)")
self.validate_identity("CREATE TABLE test (foo VARCHAR(64)[])")
self.validate_all(
"CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)",
write={
@ -60,6 +63,12 @@ class TestPostgres(Validator):
)
def test_postgres(self):
self.validate_identity("SELECT ARRAY[1, 2, 3]")
self.validate_identity("SELECT ARRAY_LENGTH(ARRAY[1, 2, 3], 1)")
self.validate_identity("STRING_AGG(x, y)")
self.validate_identity("STRING_AGG(x, ',' ORDER BY y)")
self.validate_identity("STRING_AGG(x, ',' ORDER BY y DESC)")
self.validate_identity("STRING_AGG(DISTINCT x, ',' ORDER BY y DESC)")
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END")
self.validate_identity(
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END"
@ -86,6 +95,14 @@ class TestPostgres(Validator):
self.validate_identity("SELECT e'\\xDEADBEEF'")
self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
self.validate_all(
"END WORK AND NO CHAIN",
write={"postgres": "COMMIT AND NO CHAIN"},
)
self.validate_all(
"END AND CHAIN",
write={"postgres": "COMMIT AND CHAIN"},
)
self.validate_all(
"CREATE TABLE x (a UUID, b BYTEA)",
write={
@ -95,6 +112,10 @@ class TestPostgres(Validator):
"spark": "CREATE TABLE x (a UUID, b BINARY)",
},
)
self.validate_identity(
"CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)"
)
self.validate_all(
"SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS 1 PRECEDING)",
write={

View file

@ -13,6 +13,7 @@ class TestPresto(Validator):
"duckdb": "CAST(a AS INT[])",
"presto": "CAST(a AS ARRAY(INTEGER))",
"spark": "CAST(a AS ARRAY<INT>)",
"snowflake": "CAST(a AS ARRAY)",
},
)
self.validate_all(
@ -31,6 +32,7 @@ class TestPresto(Validator):
"duckdb": "CAST(LIST_VALUE(1, 2) AS BIGINT[])",
"presto": "CAST(ARRAY[1, 2] AS ARRAY(BIGINT))",
"spark": "CAST(ARRAY(1, 2) AS ARRAY<LONG>)",
"snowflake": "CAST([1, 2] AS ARRAY)",
},
)
self.validate_all(
@ -41,6 +43,7 @@ class TestPresto(Validator):
"presto": "CAST(MAP(ARRAY[1], ARRAY[1]) AS MAP(INTEGER, INTEGER))",
"hive": "CAST(MAP(1, 1) AS MAP<INT, INT>)",
"spark": "CAST(MAP_FROM_ARRAYS(ARRAY(1), ARRAY(1)) AS MAP<INT, INT>)",
"snowflake": "CAST(OBJECT_CONSTRUCT(1, 1) AS OBJECT)",
},
)
self.validate_all(
@ -51,6 +54,7 @@ class TestPresto(Validator):
"presto": "CAST(MAP(ARRAY['a', 'b', 'c'], ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]]) AS MAP(VARCHAR, ARRAY(INTEGER)))",
"hive": "CAST(MAP('a', ARRAY(1), 'b', ARRAY(2), 'c', ARRAY(3)) AS MAP<STRING, ARRAY<INT>>)",
"spark": "CAST(MAP_FROM_ARRAYS(ARRAY('a', 'b', 'c'), ARRAY(ARRAY(1), ARRAY(2), ARRAY(3))) AS MAP<STRING, ARRAY<INT>>)",
"snowflake": "CAST(OBJECT_CONSTRUCT('a', [1], 'b', [2], 'c', [3]) AS OBJECT)",
},
)
self.validate_all(
@ -393,6 +397,7 @@ class TestPresto(Validator):
write={
"hive": UnsupportedError,
"spark": "MAP_FROM_ARRAYS(a, b)",
"snowflake": UnsupportedError,
},
)
self.validate_all(
@ -401,6 +406,7 @@ class TestPresto(Validator):
"hive": "MAP(a, c, b, d)",
"presto": "MAP(ARRAY[a, b], ARRAY[c, d])",
"spark": "MAP_FROM_ARRAYS(ARRAY(a, b), ARRAY(c, d))",
"snowflake": "OBJECT_CONSTRUCT(a, c, b, d)",
},
)
self.validate_all(
@ -409,6 +415,7 @@ class TestPresto(Validator):
"hive": "MAP('a', 'b')",
"presto": "MAP(ARRAY['a'], ARRAY['b'])",
"spark": "MAP_FROM_ARRAYS(ARRAY('a'), ARRAY('b'))",
"snowflake": "OBJECT_CONSTRUCT('a', 'b')",
},
)
self.validate_all(

View file

@ -50,6 +50,12 @@ class TestRedshift(Validator):
"redshift": 'SELECT tablename, "column" FROM pg_table_def WHERE "column" LIKE \'%start\\\\_%\' LIMIT 5'
},
)
self.validate_all(
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
write={
"redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1',
},
)
def test_identity(self):
self.validate_identity("CAST('bla' AS SUPER)")
@ -64,3 +70,13 @@ class TestRedshift(Validator):
self.validate_identity(
"SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'"
)
self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE AUTO")
self.validate_identity(
"CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid)"
)
self.validate_identity(
"COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'"
)
self.validate_identity(
"UNLOAD ('select * from venue') TO 's3://mybucket/unload/' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'"
)

View file

@ -172,13 +172,28 @@ class TestSnowflake(Validator):
self.validate_all(
"trim(date_column, 'UTC')",
write={
"bigquery": "TRIM(date_column, 'UTC')",
"snowflake": "TRIM(date_column, 'UTC')",
"postgres": "TRIM('UTC' FROM date_column)",
},
)
self.validate_all(
"trim(date_column)",
write={"snowflake": "TRIM(date_column)"},
write={
"snowflake": "TRIM(date_column)",
"bigquery": "TRIM(date_column)",
},
)
self.validate_all(
"DECODE(x, a, b, c, d)",
read={
"": "MATCHES(x, a, b, c, d)",
},
write={
"": "MATCHES(x, a, b, c, d)",
"oracle": "DECODE(x, a, b, c, d)",
"snowflake": "DECODE(x, a, b, c, d)",
},
)
def test_null_treatment(self):
@ -370,7 +385,8 @@ class TestSnowflake(Validator):
)
self.validate_all(
r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""}
r"""SELECT * FROM TABLE(?)""",
write={"snowflake": r"""SELECT * FROM TABLE(?)"""},
)
self.validate_all(

View file

@ -32,13 +32,14 @@ class TestSpark(Validator):
"presto": "CREATE TABLE db.example_table (col_a ARRAY(INTEGER), col_b ARRAY(ARRAY(INTEGER)))",
"hive": "CREATE TABLE db.example_table (col_a ARRAY<INT>, col_b ARRAY<ARRAY<INT>>)",
"spark": "CREATE TABLE db.example_table (col_a ARRAY<INT>, col_b ARRAY<ARRAY<INT>>)",
"snowflake": "CREATE TABLE db.example_table (col_a ARRAY, col_b ARRAY)",
},
)
self.validate_all(
"CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'",
write={
"duckdb": "CREATE TABLE x",
"presto": "CREATE TABLE x WITH (TABLE_FORMAT = 'ICEBERG', PARTITIONED_BY=ARRAY['MONTHS'])",
"presto": "CREATE TABLE x WITH (TABLE_FORMAT='ICEBERG', PARTITIONED_BY=ARRAY['MONTHS'])",
"hive": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'",
"spark": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'",
},
@ -94,6 +95,13 @@ TBLPROPERTIES (
pretty=True,
)
self.validate_all(
"CACHE TABLE testCache OPTIONS ('storageLevel' 'DISK_ONLY') SELECT * FROM testData",
write={
"spark": "CACHE TABLE testCache OPTIONS('storageLevel' = 'DISK_ONLY') AS SELECT * FROM testData"
},
)
def test_to_date(self):
self.validate_all(
"TO_DATE(x, 'yyyy-MM-dd')",
@ -271,6 +279,7 @@ TBLPROPERTIES (
"presto": "MAP(ARRAY[1], c)",
"hive": "MAP(ARRAY(1), c)",
"spark": "MAP_FROM_ARRAYS(ARRAY(1), c)",
"snowflake": "OBJECT_CONSTRUCT([1], c)",
},
)
self.validate_all(

View file

@ -5,6 +5,10 @@ class TestSQLite(Validator):
dialect = "sqlite"
def test_ddl(self):
self.validate_all(
"CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)",
write={"sqlite": "CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)"},
)
self.validate_all(
"""
CREATE TABLE "Track"

View file

@ -17,7 +17,6 @@ class TestTSQL(Validator):
"spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo",
},
)
self.validate_all(
"CONVERT(INT, CONVERT(NUMERIC, '444.75'))",
write={
@ -25,6 +24,33 @@ class TestTSQL(Validator):
"tsql": "CAST(CAST('444.75' AS NUMERIC) AS INTEGER)",
},
)
self.validate_all(
"STRING_AGG(x, y) WITHIN GROUP (ORDER BY z DESC)",
write={
"tsql": "STRING_AGG(x, y) WITHIN GROUP (ORDER BY z DESC)",
"mysql": "GROUP_CONCAT(x ORDER BY z DESC SEPARATOR y)",
"sqlite": "GROUP_CONCAT(x, y)",
"postgres": "STRING_AGG(x, y ORDER BY z DESC NULLS LAST)",
},
)
self.validate_all(
"STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z ASC)",
write={
"tsql": "STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z)",
"mysql": "GROUP_CONCAT(x ORDER BY z SEPARATOR '|')",
"sqlite": "GROUP_CONCAT(x, '|')",
"postgres": "STRING_AGG(x, '|' ORDER BY z NULLS FIRST)",
},
)
self.validate_all(
"STRING_AGG(x, '|')",
write={
"tsql": "STRING_AGG(x, '|')",
"mysql": "GROUP_CONCAT(x SEPARATOR '|')",
"sqlite": "GROUP_CONCAT(x, '|')",
"postgres": "STRING_AGG(x, '|')",
},
)
def test_types(self):
self.validate_identity("CAST(x AS XML)")

View file

@ -34,6 +34,7 @@ x >> 1
x >> 1 | 1 & 1 ^ 1
x || y
1 - -1
- -5
dec.x + y
a.filter
a.b.c
@ -438,6 +439,7 @@ SELECT student, score FROM tests CROSS JOIN UNNEST(scores) AS t(a, b)
SELECT student, score FROM tests CROSS JOIN UNNEST(scores) WITH ORDINALITY AS t(a, b)
SELECT student, score FROM tests CROSS JOIN UNNEST(x.scores) AS t(score)
SELECT student, score FROM tests CROSS JOIN UNNEST(ARRAY(x.scores)) AS t(score)
CREATE TABLE foo (id INT PRIMARY KEY ASC)
CREATE TABLE a.b AS SELECT 1
CREATE TABLE a.b AS SELECT a FROM a.c
CREATE TABLE IF NOT EXISTS x AS SELECT a FROM d
@ -579,6 +581,7 @@ SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)
SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)
SELECT CAST(x AS INT) /* comment */ FROM foo
SELECT a /* x */, b /* x */
SELECT a /* x */ /* y */ /* z */, b /* k */ /* m */
SELECT * FROM foo /* x */, bla /* x */
SELECT 1 /* comment */ + 1
SELECT 1 /* c1 */ + 2 /* c2 */
@ -588,3 +591,7 @@ SELECT x FROM a.b.c /* x */, e.f.g /* x */
SELECT FOO(x /* c */) /* FOO */, b /* b */
SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */
SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b'
SELECT x AS INTO FROM bla
SELECT * INTO newevent FROM event
SELECT * INTO TEMPORARY newevent FROM event
SELECT * INTO UNLOGGED newevent FROM event

View file

@ -77,3 +77,15 @@ WITH x_2 AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT x.id FROM x
-- Existing duplicate CTE
WITH y AS (SELECT a FROM x) SELECT a FROM (SELECT a FROM x) AS y JOIN y AS z;
WITH y AS (SELECT a FROM x) SELECT a FROM y AS y JOIN y AS z;
-- Nested CTE
WITH cte1 AS (SELECT a FROM x) SELECT a FROM (WITH cte2 AS (SELECT a FROM cte1) SELECT a FROM cte2);
WITH cte1 AS (SELECT a FROM x), cte2 AS (SELECT a FROM cte1), cte AS (SELECT a FROM cte2 AS cte2) SELECT a FROM cte AS cte;
-- Nested CTE inside CTE
WITH cte1 AS (WITH cte2 AS (SELECT a FROM x) SELECT t.a FROM cte2 AS t) SELECT a FROM cte1;
WITH cte2 AS (SELECT a FROM x), cte1 AS (SELECT t.a FROM cte2 AS t) SELECT a FROM cte1;
-- Duplicate CTE nested in CTE
WITH cte1 AS (SELECT a FROM x), cte2 AS (WITH cte3 AS (SELECT a FROM x) SELECT a FROM cte3) SELECT a FROM cte2;
WITH cte1 AS (SELECT a FROM x), cte2 AS (SELECT a FROM cte1 AS cte3) SELECT a FROM cte2;

View file

@ -0,0 +1,41 @@
SELECT a FROM x;
SELECT a FROM x;
SELECT "A" FROM "X";
SELECT "A" FROM "X";
SELECT a AS A FROM x;
SELECT a AS A FROM x;
SELECT * FROM x;
SELECT * FROM x;
SELECT A FROM x;
SELECT a FROM x;
SELECT a FROM X;
SELECT a FROM x;
SELECT A AS A FROM (SELECT a AS A FROM x);
SELECT a AS A FROM (SELECT a AS a FROM x);
SELECT a AS B FROM x ORDER BY B;
SELECT a AS B FROM x ORDER BY B;
SELECT A FROM x ORDER BY A;
SELECT a FROM x ORDER BY a;
SELECT A AS B FROM X GROUP BY A HAVING SUM(B) > 0;
SELECT a AS B FROM x GROUP BY a HAVING SUM(b) > 0;
SELECT A AS B, SUM(B) AS C FROM X GROUP BY A HAVING C > 0;
SELECT a AS B, SUM(b) AS C FROM x GROUP BY a HAVING C > 0;
SELECT A FROM X UNION SELECT A FROM X;
SELECT a FROM x UNION SELECT a FROM x;
SELECT A AS A FROM X UNION SELECT A AS A FROM X;
SELECT a AS A FROM x UNION SELECT a AS A FROM x;
(SELECT A AS A FROM X);
(SELECT a AS A FROM x);

View file

@ -276,3 +276,18 @@ SELECT /*+ COALESCE(3),
FROM `x` AS `x`
JOIN `y` AS `y`
ON `x`.`b` = `y`.`b`;
WITH cte1 AS (
WITH cte2 AS (
SELECT a, b FROM x
)
SELECT a1
FROM (
WITH cte3 AS (SELECT 1)
SELECT a AS a1, b AS b1 FROM cte2
)
)
SELECT a1 FROM cte1;
SELECT
"x"."a" AS "a1"
FROM "x" AS "x";

View file

@ -274,6 +274,15 @@ TRUE;
-(-1);
1;
- -+1;
1;
+-1;
-1;
++1;
1;
0.06 - 0.01;
0.05;

View file

@ -666,19 +666,7 @@ WITH "supplier_2" AS (
FROM "nation" AS "nation"
WHERE
"nation"."n_name" = 'GERMANY'
)
SELECT
"partsupp"."ps_partkey" AS "ps_partkey",
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value"
FROM "partsupp" AS "partsupp"
JOIN "supplier_2" AS "supplier"
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
JOIN "nation_2" AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
GROUP BY
"partsupp"."ps_partkey"
HAVING
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > (
), "_u_0" AS (
SELECT
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0"
FROM "partsupp" AS "partsupp"
@ -686,7 +674,20 @@ HAVING
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
JOIN "nation_2" AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
)
)
SELECT
"partsupp"."ps_partkey" AS "ps_partkey",
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value"
FROM "partsupp" AS "partsupp"
CROSS JOIN "_u_0" AS "_u_0"
JOIN "supplier_2" AS "supplier"
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
JOIN "nation_2" AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
GROUP BY
"partsupp"."ps_partkey"
HAVING
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > MAX("_u_0"."_col_0")
ORDER BY
"value" DESC;
@ -880,6 +881,10 @@ WITH "revenue" AS (
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1996-01-01' AS DATE)
GROUP BY
"lineitem"."l_suppkey"
), "_u_0" AS (
SELECT
MAX("revenue"."total_revenue") AS "_col_0"
FROM "revenue"
)
SELECT
"supplier"."s_suppkey" AS "s_suppkey",
@ -889,12 +894,9 @@ SELECT
"revenue"."total_revenue" AS "total_revenue"
FROM "supplier" AS "supplier"
JOIN "revenue"
ON "revenue"."total_revenue" = (
SELECT
MAX("revenue"."total_revenue") AS "_col_0"
FROM "revenue"
)
AND "supplier"."s_suppkey" = "revenue"."supplier_no"
ON "supplier"."s_suppkey" = "revenue"."supplier_no"
JOIN "_u_0" AS "_u_0"
ON "revenue"."total_revenue" = "_u_0"."_col_0"
ORDER BY
"s_suppkey";
@ -1395,7 +1397,14 @@ order by
cntrycode;
WITH "_u_0" AS (
SELECT
"orders"."o_custkey" AS "_u_1"
AVG("customer"."c_acctbal") AS "_col_0"
FROM "customer" AS "customer"
WHERE
"customer"."c_acctbal" > 0.00
AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
), "_u_1" AS (
SELECT
"orders"."o_custkey" AS "_u_2"
FROM "orders" AS "orders"
GROUP BY
"orders"."o_custkey"
@ -1405,18 +1414,12 @@ SELECT
COUNT(*) AS "numcust",
SUM("customer"."c_acctbal") AS "totacctbal"
FROM "customer" AS "customer"
LEFT JOIN "_u_0" AS "_u_0"
ON "_u_0"."_u_1" = "customer"."c_custkey"
JOIN "_u_0" AS "_u_0"
ON "customer"."c_acctbal" > "_u_0"."_col_0"
LEFT JOIN "_u_1" AS "_u_1"
ON "_u_1"."_u_2" = "customer"."c_custkey"
WHERE
"_u_0"."_u_1" IS NULL
AND "customer"."c_acctbal" > (
SELECT
AVG("customer"."c_acctbal") AS "_col_0"
FROM "customer" AS "customer"
WHERE
"customer"."c_acctbal" > 0.00
AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
)
"_u_1"."_u_2" IS NULL
AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
GROUP BY
SUBSTRING("customer"."c_phone", 1, 2)

View file

@ -1,10 +1,12 @@
--SELECT x.a > (SELECT SUM(y.a) AS b FROM y) FROM x;
--------------------------------------
-- Unnest Subqueries
--------------------------------------
SELECT *
FROM x AS x
WHERE
x.a IN (SELECT y.a AS a FROM y)
x.a = (SELECT SUM(y.a) AS a FROM y)
AND x.a IN (SELECT y.a AS a FROM y)
AND x.a IN (SELECT y.b AS b FROM y)
AND x.a = ANY (SELECT y.a AS a FROM y)
AND x.a = (SELECT SUM(y.b) AS b FROM y WHERE x.a = y.a)
@ -24,62 +26,57 @@ WHERE
SELECT
*
FROM x AS x
CROSS JOIN (
SELECT
SUM(y.a) AS a
FROM y
) AS "_u_0"
LEFT JOIN (
SELECT
y.a AS a
FROM y
GROUP BY
y.a
) AS "_u_0"
ON x.a = "_u_0"."a"
) AS "_u_1"
ON x.a = "_u_1"."a"
LEFT JOIN (
SELECT
y.b AS b
FROM y
GROUP BY
y.b
) AS "_u_1"
ON x.a = "_u_1"."b"
) AS "_u_2"
ON x.a = "_u_2"."b"
LEFT JOIN (
SELECT
y.a AS a
FROM y
GROUP BY
y.a
) AS "_u_2"
ON x.a = "_u_2"."a"
LEFT JOIN (
SELECT
SUM(y.b) AS b,
y.a AS _u_4
FROM y
WHERE
TRUE
GROUP BY
y.a
) AS "_u_3"
ON x.a = "_u_3"."_u_4"
ON x.a = "_u_3"."a"
LEFT JOIN (
SELECT
SUM(y.b) AS b,
y.a AS _u_6
y.a AS _u_5
FROM y
WHERE
TRUE
GROUP BY
y.a
) AS "_u_5"
ON x.a = "_u_5"."_u_6"
) AS "_u_4"
ON x.a = "_u_4"."_u_5"
LEFT JOIN (
SELECT
y.a AS a
SUM(y.b) AS b,
y.a AS _u_7
FROM y
WHERE
TRUE
GROUP BY
y.a
) AS "_u_7"
ON "_u_7".a = x.a
) AS "_u_6"
ON x.a = "_u_6"."_u_7"
LEFT JOIN (
SELECT
y.a AS a
@ -90,29 +87,39 @@ LEFT JOIN (
y.a
) AS "_u_8"
ON "_u_8".a = x.a
LEFT JOIN (
SELECT
y.a AS a
FROM y
WHERE
TRUE
GROUP BY
y.a
) AS "_u_9"
ON "_u_9".a = x.a
LEFT JOIN (
SELECT
ARRAY_AGG(y.a) AS a,
y.b AS _u_10
y.b AS _u_11
FROM y
WHERE
TRUE
GROUP BY
y.b
) AS "_u_9"
ON "_u_9"."_u_10" = x.a
) AS "_u_10"
ON "_u_10"."_u_11" = x.a
LEFT JOIN (
SELECT
SUM(y.a) AS a,
y.a AS _u_12,
ARRAY_AGG(y.b) AS _u_13
y.a AS _u_13,
ARRAY_AGG(y.b) AS _u_14
FROM y
WHERE
TRUE AND TRUE AND TRUE
GROUP BY
y.a
) AS "_u_11"
ON "_u_11"."_u_12" = x.a AND "_u_11"."_u_12" = x.b
) AS "_u_12"
ON "_u_12"."_u_13" = x.a AND "_u_12"."_u_13" = x.b
LEFT JOIN (
SELECT
y.a AS a
@ -121,37 +128,38 @@ LEFT JOIN (
TRUE
GROUP BY
y.a
) AS "_u_14"
ON x.a = "_u_14".a
) AS "_u_15"
ON x.a = "_u_15".a
WHERE
NOT "_u_0"."a" IS NULL
AND NOT "_u_1"."b" IS NULL
AND NOT "_u_2"."a" IS NULL
x.a = "_u_0".a
AND NOT "_u_1"."a" IS NULL
AND NOT "_u_2"."b" IS NULL
AND NOT "_u_3"."a" IS NULL
AND (
x.a = "_u_3".b AND NOT "_u_3"."_u_4" IS NULL
x.a = "_u_4".b AND NOT "_u_4"."_u_5" IS NULL
)
AND (
x.a > "_u_5".b AND NOT "_u_5"."_u_6" IS NULL
x.a > "_u_6".b AND NOT "_u_6"."_u_7" IS NULL
)
AND (
None = "_u_7".a AND NOT "_u_7".a IS NULL
None = "_u_8".a AND NOT "_u_8".a IS NULL
)
AND NOT (
x.a = "_u_8".a AND NOT "_u_8".a IS NULL
x.a = "_u_9".a AND NOT "_u_9".a IS NULL
)
AND (
ARRAY_ANY("_u_9".a, _x -> _x = x.a) AND NOT "_u_9"."_u_10" IS NULL
ARRAY_ANY("_u_10".a, _x -> _x = x.a) AND NOT "_u_10"."_u_11" IS NULL
)
AND (
(
(
x.a < "_u_11".a AND NOT "_u_11"."_u_12" IS NULL
) AND NOT "_u_11"."_u_12" IS NULL
x.a < "_u_12".a AND NOT "_u_12"."_u_13" IS NULL
) AND NOT "_u_12"."_u_13" IS NULL
)
AND ARRAY_ANY("_u_11"."_u_13", "_x" -> "_x" <> x.d)
AND ARRAY_ANY("_u_12"."_u_14", "_x" -> "_x" <> x.d)
)
AND (
NOT "_u_14".a IS NULL AND NOT "_u_14".a IS NULL
NOT "_u_15".a IS NULL AND NOT "_u_15".a IS NULL
)
AND x.a IN (
SELECT

View file

@ -68,13 +68,13 @@ class TestExecutor(unittest.TestCase):
def test_execute_tpch(self):
def to_csv(expression):
if isinstance(expression, exp.Table):
if isinstance(expression, exp.Table) and expression.name not in ("revenue"):
return parse_one(
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}"
)
return expression
for i, (sql, _) in enumerate(self.sqls[0:7]):
for i, (sql, _) in enumerate(self.sqls[0:16]):
with self.subTest(f"tpch-h {i + 1}"):
a = self.cached_execute(sql)
sql = parse_one(sql).transform(to_csv).sql(pretty=True)
@ -165,6 +165,39 @@ class TestExecutor(unittest.TestCase):
["a"],
[("a",)],
),
(
"SELECT DISTINCT a FROM (SELECT 1 AS a UNION ALL SELECT 1 AS a)",
["a"],
[(1,)],
),
(
"SELECT DISTINCT a, SUM(b) AS b "
"FROM (SELECT 'a' AS a, 1 AS b UNION ALL SELECT 'a' AS a, 2 AS b UNION ALL SELECT 'b' AS a, 1 AS b) "
"GROUP BY a "
"LIMIT 1",
["a", "b"],
[("a", 3)],
),
(
"SELECT COUNT(1) AS a FROM (SELECT 1)",
["a"],
[(1,)],
),
(
"SELECT COUNT(1) AS a FROM (SELECT 1) LIMIT 0",
["a"],
[],
),
(
"SELECT a FROM x GROUP BY a LIMIT 0",
["a"],
[],
),
(
"SELECT a FROM x LIMIT 0",
["a"],
[],
),
]:
with self.subTest(sql):
result = execute(sql, schema=schema, tables=tables)
@ -346,6 +379,28 @@ class TestExecutor(unittest.TestCase):
],
)
def test_execute_subqueries(self):
tables = {
"table": [
{"a": 1, "b": 1},
{"a": 2, "b": 2},
],
}
self.assertEqual(
execute(
"""
SELECT *
FROM table
WHERE a = (SELECT MAX(a) FROM table)
""",
tables=tables,
).rows,
[
(2, 2),
],
)
def test_table_depth_mismatch(self):
tables = {"table": []}
schema = {"db": {"table": {"col": "VARCHAR"}}}
@ -401,6 +456,7 @@ class TestExecutor(unittest.TestCase):
("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]),
("SELECT SUM(x) FROM (SELECT 1 AS x WHERE FALSE)", ["_col_0"], [(0,)]),
]:
result = execute(sql)
self.assertEqual(result.columns, tuple(cols))
@ -462,7 +518,18 @@ class TestExecutor(unittest.TestCase):
("IF(false, 1, 0)", 0),
("CASE WHEN 0 = 1 THEN 'foo' ELSE 'bar' END", "bar"),
("CAST('2022-01-01' AS DATE) + INTERVAL '1' DAY", date(2022, 1, 2)),
("1 IN (1, 2, 3)", True),
("1 IN (2, 3)", False),
("NULL IS NULL", True),
("NULL IS NOT NULL", False),
("NULL = NULL", None),
("NULL <> NULL", None),
]:
with self.subTest(sql):
result = execute(f"SELECT {sql}")
self.assertEqual(result.rows, [(expected,)])
def test_case_sensitivity(self):
result = execute("SELECT A AS A FROM X", tables={"x": [{"a": 1}]})
self.assertEqual(result.columns, ("A",))
self.assertEqual(result.rows, [(1,)])

View file

@ -525,24 +525,14 @@ class TestExpressions(unittest.TestCase):
),
exp.Properties(
expressions=[
exp.FileFormatProperty(
this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet")
),
exp.FileFormatProperty(this=exp.Literal.string("parquet")),
exp.PartitionedByProperty(
this=exp.Literal.string("PARTITIONED_BY"),
value=exp.Tuple(
expressions=[exp.to_identifier("a"), exp.to_identifier("b")]
this=exp.Tuple(expressions=[exp.to_identifier("a"), exp.to_identifier("b")])
),
),
exp.AnonymousProperty(
this=exp.Literal.string("custom"), value=exp.Literal.number(1)
),
exp.TableFormatProperty(
this=exp.Literal.string("TABLE_FORMAT"),
value=exp.to_identifier("test_format"),
),
exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.null()),
exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.true()),
exp.Property(this=exp.Literal.string("custom"), value=exp.Literal.number(1)),
exp.TableFormatProperty(this=exp.to_identifier("test_format")),
exp.EngineProperty(this=exp.null()),
exp.CollateProperty(this=exp.true()),
]
),
)
@ -609,9 +599,9 @@ FROM foo""",
"""SELECT
a,
b AS B,
c, -- comment
d AS D, -- another comment
CAST(x AS INT) -- final comment
c, /* comment */
d AS D, /* another comment */
CAST(x AS INT) /* final comment */
FROM foo""",
)

View file

@ -85,9 +85,8 @@ class TestOptimizer(unittest.TestCase):
if leave_tables_isolated is not None:
func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated)
optimized = func(parse_one(sql, read=dialect), **func_kwargs)
with self.subTest(title):
optimized = func(parse_one(sql, read=dialect), **func_kwargs)
self.assertEqual(
expected,
optimized.sql(pretty=pretty, dialect=dialect),
@ -168,6 +167,9 @@ class TestOptimizer(unittest.TestCase):
def test_quote_identities(self):
self.check_file("quote_identities", optimizer.quote_identities.quote_identities)
def test_lower_identities(self):
self.check_file("lower_identities", optimizer.lower_identities.lower_identities)
def test_pushdown_projection(self):
def pushdown_projections(expression, **kwargs):
expression = optimizer.qualify_tables.qualify_tables(expression)

View file

@ -15,6 +15,51 @@ class TestParser(unittest.TestCase):
self.assertIsInstance(parse_one("int", into=exp.DataType), exp.DataType)
self.assertIsInstance(parse_one("array<int>", into=exp.DataType), exp.DataType)
def test_parse_into_error(self):
expected_message = "Failed to parse into [<class 'sqlglot.expressions.From'>]"
expected_errors = [
{
"description": "Invalid expression / Unexpected token",
"line": 1,
"col": 1,
"start_context": "",
"highlight": "SELECT",
"end_context": " 1;",
"into_expression": exp.From,
}
]
with self.assertRaises(ParseError) as ctx:
parse_one("SELECT 1;", "sqlite", [exp.From])
self.assertEqual(str(ctx.exception), expected_message)
self.assertEqual(ctx.exception.errors, expected_errors)
def test_parse_into_errors(self):
expected_message = "Failed to parse into [<class 'sqlglot.expressions.From'>, <class 'sqlglot.expressions.Join'>]"
expected_errors = [
{
"description": "Invalid expression / Unexpected token",
"line": 1,
"col": 1,
"start_context": "",
"highlight": "SELECT",
"end_context": " 1;",
"into_expression": exp.From,
},
{
"description": "Invalid expression / Unexpected token",
"line": 1,
"col": 1,
"start_context": "",
"highlight": "SELECT",
"end_context": " 1;",
"into_expression": exp.Join,
},
]
with self.assertRaises(ParseError) as ctx:
parse_one("SELECT 1;", "sqlite", [exp.From, exp.Join])
self.assertEqual(str(ctx.exception), expected_message)
self.assertEqual(ctx.exception.errors, expected_errors)
def test_column(self):
columns = parse_one("select a, ARRAY[1] b, case when 1 then 1 end").find_all(exp.Column)
assert len(list(columns)) == 1
@ -24,6 +69,9 @@ class TestParser(unittest.TestCase):
def test_float(self):
self.assertEqual(parse_one(".2"), parse_one("0.2"))
def test_unary_plus(self):
self.assertEqual(parse_one("+15"), exp.Literal.number(15))
def test_table(self):
tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)]
self.assertEqual(tables, ["a", "b.c", "d"])
@ -157,8 +205,9 @@ class TestParser(unittest.TestCase):
def test_comments(self):
expression = parse_one(
"""
--comment1
SELECT /* this won't be used */
--comment1.1
--comment1.2
SELECT /*comment1.3*/
a, --comment2
b as B, --comment3:testing
"test--annotation",
@ -169,13 +218,13 @@ class TestParser(unittest.TestCase):
"""
)
self.assertEqual(expression.comment, "comment1")
self.assertEqual(expression.expressions[0].comment, "comment2")
self.assertEqual(expression.expressions[1].comment, "comment3:testing")
self.assertEqual(expression.expressions[2].comment, None)
self.assertEqual(expression.expressions[3].comment, "comment4 --foo")
self.assertEqual(expression.expressions[4].comment, "")
self.assertEqual(expression.expressions[5].comment, " space")
self.assertEqual(expression.comments, ["comment1.1", "comment1.2", "comment1.3"])
self.assertEqual(expression.expressions[0].comments, ["comment2"])
self.assertEqual(expression.expressions[1].comments, ["comment3:testing"])
self.assertEqual(expression.expressions[2].comments, None)
self.assertEqual(expression.expressions[3].comments, ["comment4 --foo"])
self.assertEqual(expression.expressions[4].comments, [""])
self.assertEqual(expression.expressions[5].comments, [" space"])
def test_type_literals(self):
self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)"))

View file

@ -7,13 +7,13 @@ class TestTokens(unittest.TestCase):
def test_comment_attachment(self):
tokenizer = Tokenizer()
sql_comment = [
("/*comment*/ foo", "comment"),
("/*comment*/ foo --test", "comment"),
("--comment\nfoo --test", "comment"),
("foo --comment", "comment"),
("foo", None),
("foo /*comment 1*/ /*comment 2*/", "comment 1"),
("/*comment*/ foo", ["comment"]),
("/*comment*/ foo --test", ["comment", "test"]),
("--comment\nfoo --test", ["comment", "test"]),
("foo --comment", ["comment"]),
("foo", []),
("foo /*comment 1*/ /*comment 2*/", ["comment 1", "comment 2"]),
]
for sql, comment in sql_comment:
self.assertEqual(tokenizer.tokenize(sql)[0].comment, comment)
self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment)

View file

@ -1,7 +1,7 @@
import unittest
from sqlglot import parse_one
from sqlglot.transforms import unalias_group
from sqlglot.transforms import eliminate_distinct_on, unalias_group
class TestTime(unittest.TestCase):
@ -35,3 +35,30 @@ class TestTime(unittest.TestCase):
"SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date",
"SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date",
)
def test_eliminate_distinct_on(self):
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1',
)
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (a) a, b FROM x",
'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a) AS "_row_number" FROM x) WHERE "_row_number" = 1',
)
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (a, b) a, b FROM x ORDER BY c DESC",
'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1',
)
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT a, b FROM x ORDER BY c DESC",
"SELECT DISTINCT a, b FROM x ORDER BY c DESC",
)
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC",
'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS "_row_number_2" FROM x) WHERE "_row_number_2" = 1',
)

View file

@ -26,6 +26,7 @@ class TestTranspile(unittest.TestCase):
)
self.assertEqual(transpile("SELECT 1 current_date")[0], "SELECT 1 AS current_date")
self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime")
self.assertEqual(transpile("SELECT 1 row")[0], "SELECT 1 AS row")
for key in ("union", "filter", "over", "from", "join"):
with self.subTest(f"alias {key}"):
@ -38,6 +39,11 @@ class TestTranspile(unittest.TestCase):
def test_asc(self):
self.validate("SELECT x FROM y ORDER BY x ASC", "SELECT x FROM y ORDER BY x")
def test_unary(self):
self.validate("+++1", "1")
self.validate("+-1", "-1")
self.validate("+- - -1", "- - -1")
def test_paren(self):
with self.assertRaises(ParseError):
transpile("1 + (2 + 3")
@ -58,7 +64,7 @@ class TestTranspile(unittest.TestCase):
)
self.validate(
"SELECT FOO, /*x*/\nBAR, /*y*/\nBAZ",
"SELECT\n FOO -- x\n , BAR -- y\n , BAZ",
"SELECT\n FOO /* x */\n , BAR /* y */\n , BAZ",
leading_comma=True,
pretty=True,
)
@ -78,7 +84,8 @@ class TestTranspile(unittest.TestCase):
def test_comments(self):
self.validate("SELECT */*comment*/", "SELECT * /* comment */")
self.validate(
"SELECT * FROM table /*comment 1*/ /*comment 2*/", "SELECT * FROM table /* comment 1 */"
"SELECT * FROM table /*comment 1*/ /*comment 2*/",
"SELECT * FROM table /* comment 1 */ /* comment 2 */",
)
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */")
self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo")
@ -112,6 +119,53 @@ class TestTranspile(unittest.TestCase):
)
self.validate(
"""
-- comment 1
-- comment 2
-- comment 3
SELECT * FROM foo
""",
"/* comment 1 */ /* comment 2 */ /* comment 3 */ SELECT * FROM foo",
)
self.validate(
"""
-- comment 1
-- comment 2
-- comment 3
SELECT * FROM foo""",
"""/* comment 1 */
/* comment 2 */
/* comment 3 */
SELECT
*
FROM foo""",
pretty=True,
)
self.validate(
"""
SELECT * FROM tbl /*line1
line2
line3*/ /*another comment*/ where 1=1 -- comment at the end""",
"""SELECT * FROM tbl /* line1
line2
line3 */ /* another comment */ WHERE 1 = 1 /* comment at the end */""",
)
self.validate(
"""
SELECT * FROM tbl /*line1
line2
line3*/ /*another comment*/ where 1=1 -- comment at the end""",
"""SELECT
*
FROM tbl /* line1
line2
line3 */
/* another comment */
WHERE
1 = 1 /* comment at the end */""",
pretty=True,
)
self.validate(
"""
/* multi
line
comment
@ -130,8 +184,8 @@ class TestTranspile(unittest.TestCase):
*/
SELECT
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
CAST(x AS INT), -- comment 3
y -- comment 4
CAST(x AS INT), /* comment 3 */
y /* comment 4 */
FROM bar /* comment 5 */, tbl /* comment 6 */""",
read="mysql",
pretty=True,
@ -364,33 +418,79 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""",
@mock.patch("sqlglot.parser.logger")
def test_error_level(self, logger):
invalid = "x + 1. ("
errors = [
expected_messages = [
"Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Aliases'>. Line 1, Col: 8.\n x + 1. \033[4m(\033[0m",
"Expecting ). Line 1, Col: 8.\n x + 1. \033[4m(\033[0m",
]
expected_errors = [
{
"description": "Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Aliases'>",
"line": 1,
"col": 8,
"start_context": "x + 1. ",
"highlight": "(",
"end_context": "",
"into_expression": None,
},
{
"description": "Expecting )",
"line": 1,
"col": 8,
"start_context": "x + 1. ",
"highlight": "(",
"end_context": "",
"into_expression": None,
},
]
transpile(invalid, error_level=ErrorLevel.WARN)
for error in errors:
for error in expected_messages:
assert_logger_contains(error, logger)
with self.assertRaises(ParseError) as ctx:
transpile(invalid, error_level=ErrorLevel.IMMEDIATE)
self.assertEqual(str(ctx.exception), errors[0])
self.assertEqual(str(ctx.exception), expected_messages[0])
self.assertEqual(ctx.exception.errors[0], expected_errors[0])
with self.assertRaises(ParseError) as ctx:
transpile(invalid, error_level=ErrorLevel.RAISE)
self.assertEqual(str(ctx.exception), "\n\n".join(errors))
self.assertEqual(str(ctx.exception), "\n\n".join(expected_messages))
self.assertEqual(ctx.exception.errors, expected_errors)
more_than_max_errors = "(((("
expected = (
expected_messages = (
"Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
"Required keyword: 'this' missing for <class 'sqlglot.expressions.Paren'>. Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
"Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
"... and 2 more"
)
expected_errors = [
{
"description": "Expecting )",
"line": 1,
"col": 4,
"start_context": "(((",
"highlight": "(",
"end_context": "",
"into_expression": None,
},
{
"description": "Required keyword: 'this' missing for <class 'sqlglot.expressions.Paren'>",
"line": 1,
"col": 4,
"start_context": "(((",
"highlight": "(",
"end_context": "",
"into_expression": None,
},
]
# Also expect three trailing structured errors that match the first
expected_errors += [expected_errors[0]] * 3
with self.assertRaises(ParseError) as ctx:
transpile(more_than_max_errors, error_level=ErrorLevel.RAISE)
self.assertEqual(str(ctx.exception), expected)
self.assertEqual(str(ctx.exception), expected_messages)
self.assertEqual(ctx.exception.errors, expected_errors)
@mock.patch("sqlglot.generator.logger")
def test_unsupported_level(self, logger):