Merging upstream version 10.1.3.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
582b160275
commit
a5128ea109
57 changed files with 1542 additions and 529 deletions
41
CHANGELOG.md
41
CHANGELOG.md
|
@ -1,6 +1,47 @@
|
||||||
Changelog
|
Changelog
|
||||||
=========
|
=========
|
||||||
|
|
||||||
|
v10.1.0
|
||||||
|
------
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
- Breaking: [refactored](https://github.com/tobymao/sqlglot/commit/6b0da1e1a2b5d6bdf7b5b918400456422d30a1d4) the way SQL comments are handled. Before at most one comment could be attached to an expression, now multiple comments may be stored in a list.
|
||||||
|
|
||||||
|
- Breaking: [refactored](https://github.com/tobymao/sqlglot/commit/be332d10404f36b43ea6ece956a73bf451348641) the way properties are represented and parsed. The argument `this` now stores a property's attributes instead of its name.
|
||||||
|
|
||||||
|
- New: added structured ParseError properties.
|
||||||
|
|
||||||
|
- New: the executor now handles set operations.
|
||||||
|
|
||||||
|
- New: sqlglot can [now execute SQL queries](https://github.com/tobymao/sqlglot/commit/62d3496e761a4f38dfa61af793062690923dce74) using python objects.
|
||||||
|
|
||||||
|
- New: added support for the [Drill dialect](https://github.com/tobymao/sqlglot/commit/543eca314546e0bd42f97c354807b4e398ab36ec).
|
||||||
|
|
||||||
|
- New: added a `canonicalize` method which leverages existing type information for an expression to apply various transformations to it.
|
||||||
|
|
||||||
|
- New: TRIM function support for Snowflake and Bigquery.
|
||||||
|
|
||||||
|
- New: added support for SQLite primary key ordering constraints (ASC, DESC).
|
||||||
|
|
||||||
|
- New: added support for Redshift DISTKEY / SORTKEY / DISTSTYLE properties.
|
||||||
|
|
||||||
|
- New: added support for SET TRANSACTION MySQL statements.
|
||||||
|
|
||||||
|
- New: added `null`, `true`, `false` helper methods to avoid using singleton expressions.
|
||||||
|
|
||||||
|
- Improvement: allow multiple aggregations in an expression.
|
||||||
|
|
||||||
|
- Improvement: execution of left / right joins.
|
||||||
|
|
||||||
|
- Improvement: execution of aggregations without the GROUP BY clause.
|
||||||
|
|
||||||
|
- Improvement: static query execution (e.g. SELECT 1, SELECT CONCAT('a', 'b') AS x, etc).
|
||||||
|
|
||||||
|
- Improvement: include a rule for type inference in the optimizer.
|
||||||
|
|
||||||
|
- Improvement: transaction, commit expressions parsed [at finer granularity](https://github.com/tobymao/sqlglot/commit/148282e710fd79512bb7d32e6e519d631df8115d).
|
||||||
|
|
||||||
v10.0.0
|
v10.0.0
|
||||||
------
|
------
|
||||||
|
|
||||||
|
|
76
README.md
76
README.md
|
@ -25,6 +25,7 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/
|
||||||
* [AST Introspection](#ast-introspection)
|
* [AST Introspection](#ast-introspection)
|
||||||
* [AST Diff](#ast-diff)
|
* [AST Diff](#ast-diff)
|
||||||
* [Custom Dialects](#custom-dialects)
|
* [Custom Dialects](#custom-dialects)
|
||||||
|
* [SQL Execution](#sql-execution)
|
||||||
* [Benchmarks](#benchmarks)
|
* [Benchmarks](#benchmarks)
|
||||||
* [Optional Dependencies](#optional-dependencies)
|
* [Optional Dependencies](#optional-dependencies)
|
||||||
|
|
||||||
|
@ -147,9 +148,9 @@ print(sqlglot.transpile(sql, read='mysql', pretty=True)[0])
|
||||||
*/
|
*/
|
||||||
SELECT
|
SELECT
|
||||||
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
|
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
|
||||||
CAST(x AS INT), -- comment 3
|
CAST(x AS INT), /* comment 3 */
|
||||||
y -- comment 4
|
y /* comment 4 */
|
||||||
FROM bar /* comment 5 */, tbl /* comment 6*/
|
FROM bar /* comment 5 */, tbl /* comment 6 */
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
@ -189,6 +190,28 @@ sqlglot.errors.ParseError: Expecting ). Line 1, Col: 13.
|
||||||
~~~~
|
~~~~
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Structured syntax errors are accessible for programmatic use:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import sqlglot
|
||||||
|
try:
|
||||||
|
sqlglot.transpile("SELECT foo( FROM bar")
|
||||||
|
except sqlglot.errors.ParseError as e:
|
||||||
|
print(e.errors)
|
||||||
|
```
|
||||||
|
|
||||||
|
Output:
|
||||||
|
```python
|
||||||
|
[{
|
||||||
|
'description': 'Expecting )',
|
||||||
|
'line': 1,
|
||||||
|
'col': 13,
|
||||||
|
'start_context': 'SELECT foo( ',
|
||||||
|
'highlight': 'FROM',
|
||||||
|
'end_context': ' bar'
|
||||||
|
}]
|
||||||
|
```
|
||||||
|
|
||||||
### Unsupported Errors
|
### Unsupported Errors
|
||||||
|
|
||||||
Presto `APPROX_DISTINCT` supports the accuracy argument which is not supported in Hive:
|
Presto `APPROX_DISTINCT` supports the accuracy argument which is not supported in Hive:
|
||||||
|
@ -372,6 +395,53 @@ print(Dialect["custom"])
|
||||||
<class '__main__.Custom'>
|
<class '__main__.Custom'>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### SQL Execution
|
||||||
|
|
||||||
|
One can even interpret SQL queries using SQLGlot, where the tables are represented as Python dictionaries. Although the engine is not very fast (it's not supposed to be) and is in a relatively early stage of development, it can be useful for unit testing and running SQL natively across Python objects. Additionally, the foundation can be easily integrated with fast compute kernels (arrow, pandas). Below is an example showcasing the execution of a SELECT expression that involves aggregations and JOINs:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqlglot.executor import execute
|
||||||
|
|
||||||
|
tables = {
|
||||||
|
"sushi": [
|
||||||
|
{"id": 1, "price": 1.0},
|
||||||
|
{"id": 2, "price": 2.0},
|
||||||
|
{"id": 3, "price": 3.0},
|
||||||
|
],
|
||||||
|
"order_items": [
|
||||||
|
{"sushi_id": 1, "order_id": 1},
|
||||||
|
{"sushi_id": 1, "order_id": 1},
|
||||||
|
{"sushi_id": 2, "order_id": 1},
|
||||||
|
{"sushi_id": 3, "order_id": 2},
|
||||||
|
],
|
||||||
|
"orders": [
|
||||||
|
{"id": 1, "user_id": 1},
|
||||||
|
{"id": 2, "user_id": 2},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
execute(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
o.user_id,
|
||||||
|
SUM(s.price) AS price
|
||||||
|
FROM orders o
|
||||||
|
JOIN order_items i
|
||||||
|
ON o.id = i.order_id
|
||||||
|
JOIN sushi s
|
||||||
|
ON i.sushi_id = s.id
|
||||||
|
GROUP BY o.user_id
|
||||||
|
""",
|
||||||
|
tables=tables
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
user_id price
|
||||||
|
1 4.0
|
||||||
|
2 3.0
|
||||||
|
```
|
||||||
|
|
||||||
## Benchmarks
|
## Benchmarks
|
||||||
|
|
||||||
[Benchmarks](benchmarks) run on Python 3.10.5 in seconds.
|
[Benchmarks](benchmarks) run on Python 3.10.5 in seconds.
|
||||||
|
|
|
@ -30,7 +30,7 @@ from sqlglot.parser import Parser
|
||||||
from sqlglot.schema import MappingSchema
|
from sqlglot.schema import MappingSchema
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
from sqlglot.tokens import Tokenizer, TokenType
|
||||||
|
|
||||||
__version__ = "10.0.8"
|
__version__ = "10.1.3"
|
||||||
|
|
||||||
pretty = False
|
pretty = False
|
||||||
|
|
||||||
|
|
|
@ -56,12 +56,12 @@ def _derived_table_values_to_unnest(self, expression):
|
||||||
|
|
||||||
|
|
||||||
def _returnsproperty_sql(self, expression):
|
def _returnsproperty_sql(self, expression):
|
||||||
value = expression.args.get("value")
|
this = expression.this
|
||||||
if isinstance(value, exp.Schema):
|
if isinstance(this, exp.Schema):
|
||||||
value = f"{value.this} <{self.expressions(value)}>"
|
this = f"{this.this} <{self.expressions(this)}>"
|
||||||
else:
|
else:
|
||||||
value = self.sql(value)
|
this = self.sql(this)
|
||||||
return f"RETURNS {value}"
|
return f"RETURNS {this}"
|
||||||
|
|
||||||
|
|
||||||
def _create_sql(self, expression):
|
def _create_sql(self, expression):
|
||||||
|
@ -142,6 +142,11 @@ class BigQuery(Dialect):
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FUNCTION_PARSERS = {
|
||||||
|
**parser.Parser.FUNCTION_PARSERS,
|
||||||
|
}
|
||||||
|
FUNCTION_PARSERS.pop("TRIM")
|
||||||
|
|
||||||
NO_PAREN_FUNCTIONS = {
|
NO_PAREN_FUNCTIONS = {
|
||||||
**parser.Parser.NO_PAREN_FUNCTIONS,
|
**parser.Parser.NO_PAREN_FUNCTIONS,
|
||||||
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
|
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
|
||||||
|
@ -174,6 +179,7 @@ class BigQuery(Dialect):
|
||||||
exp.Values: _derived_table_values_to_unnest,
|
exp.Values: _derived_table_values_to_unnest,
|
||||||
exp.ReturnsProperty: _returnsproperty_sql,
|
exp.ReturnsProperty: _returnsproperty_sql,
|
||||||
exp.Create: _create_sql,
|
exp.Create: _create_sql,
|
||||||
|
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
|
||||||
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
|
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
|
||||||
if e.name == "IMMUTABLE"
|
if e.name == "IMMUTABLE"
|
||||||
else "NOT DETERMINISTIC",
|
else "NOT DETERMINISTIC",
|
||||||
|
@ -200,9 +206,7 @@ class BigQuery(Dialect):
|
||||||
exp.VolatilityProperty,
|
exp.VolatilityProperty,
|
||||||
}
|
}
|
||||||
|
|
||||||
WITH_PROPERTIES = {
|
WITH_PROPERTIES = {exp.Property}
|
||||||
exp.AnonymousProperty,
|
|
||||||
}
|
|
||||||
|
|
||||||
EXPLICIT_UNION = True
|
EXPLICIT_UNION = True
|
||||||
|
|
||||||
|
|
|
@ -21,14 +21,15 @@ class ClickHouse(Dialect):
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**tokens.Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"FINAL": TokenType.FINAL,
|
"ASOF": TokenType.ASOF,
|
||||||
"DATETIME64": TokenType.DATETIME,
|
"DATETIME64": TokenType.DATETIME,
|
||||||
"INT8": TokenType.TINYINT,
|
"FINAL": TokenType.FINAL,
|
||||||
|
"FLOAT32": TokenType.FLOAT,
|
||||||
|
"FLOAT64": TokenType.DOUBLE,
|
||||||
"INT16": TokenType.SMALLINT,
|
"INT16": TokenType.SMALLINT,
|
||||||
"INT32": TokenType.INT,
|
"INT32": TokenType.INT,
|
||||||
"INT64": TokenType.BIGINT,
|
"INT64": TokenType.BIGINT,
|
||||||
"FLOAT32": TokenType.FLOAT,
|
"INT8": TokenType.TINYINT,
|
||||||
"FLOAT64": TokenType.DOUBLE,
|
|
||||||
"TUPLE": TokenType.STRUCT,
|
"TUPLE": TokenType.STRUCT,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,6 +39,10 @@ class ClickHouse(Dialect):
|
||||||
"MAP": parse_var_map,
|
"MAP": parse_var_map,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF}
|
||||||
|
|
||||||
|
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY}
|
||||||
|
|
||||||
def _parse_table(self, schema=False):
|
def _parse_table(self, schema=False):
|
||||||
this = super()._parse_table(schema)
|
this = super()._parse_table(schema)
|
||||||
|
|
||||||
|
|
|
@ -289,19 +289,19 @@ def struct_extract_sql(self, expression):
|
||||||
return f"{this}.{struct_key}"
|
return f"{this}.{struct_key}"
|
||||||
|
|
||||||
|
|
||||||
def var_map_sql(self, expression):
|
def var_map_sql(self, expression, map_func_name="MAP"):
|
||||||
keys = expression.args["keys"]
|
keys = expression.args["keys"]
|
||||||
values = expression.args["values"]
|
values = expression.args["values"]
|
||||||
|
|
||||||
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
|
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
|
||||||
self.unsupported("Cannot convert array columns into map.")
|
self.unsupported("Cannot convert array columns into map.")
|
||||||
return f"MAP({self.format_args(keys, values)})"
|
return f"{map_func_name}({self.format_args(keys, values)})"
|
||||||
|
|
||||||
args = []
|
args = []
|
||||||
for key, value in zip(keys.expressions, values.expressions):
|
for key, value in zip(keys.expressions, values.expressions):
|
||||||
args.append(self.sql(key))
|
args.append(self.sql(key))
|
||||||
args.append(self.sql(value))
|
args.append(self.sql(value))
|
||||||
return f"MAP({self.format_args(*args)})"
|
return f"{map_func_name}({self.format_args(*args)})"
|
||||||
|
|
||||||
|
|
||||||
def format_time_lambda(exp_class, dialect, default=None):
|
def format_time_lambda(exp_class, dialect, default=None):
|
||||||
|
@ -336,18 +336,13 @@ def create_with_partitions_sql(self, expression):
|
||||||
if has_schema and is_partitionable:
|
if has_schema and is_partitionable:
|
||||||
expression = expression.copy()
|
expression = expression.copy()
|
||||||
prop = expression.find(exp.PartitionedByProperty)
|
prop = expression.find(exp.PartitionedByProperty)
|
||||||
value = prop and prop.args.get("value")
|
this = prop and prop.this
|
||||||
if prop and not isinstance(value, exp.Schema):
|
if prop and not isinstance(this, exp.Schema):
|
||||||
schema = expression.this
|
schema = expression.this
|
||||||
columns = {v.name.upper() for v in value.expressions}
|
columns = {v.name.upper() for v in this.expressions}
|
||||||
partitions = [col for col in schema.expressions if col.name.upper() in columns]
|
partitions = [col for col in schema.expressions if col.name.upper() in columns]
|
||||||
schema.set(
|
schema.set("expressions", [e for e in schema.expressions if e not in partitions])
|
||||||
"expressions",
|
prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
|
||||||
[e for e in schema.expressions if e not in partitions],
|
|
||||||
)
|
|
||||||
prop.replace(
|
|
||||||
exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))
|
|
||||||
)
|
|
||||||
expression.set("this", schema)
|
expression.set("this", schema)
|
||||||
|
|
||||||
return self.create_sql(expression)
|
return self.create_sql(expression)
|
||||||
|
|
|
@ -153,7 +153,7 @@ class Drill(Dialect):
|
||||||
exp.If: if_sql,
|
exp.If: if_sql,
|
||||||
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
|
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
|
||||||
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
|
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
|
||||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
|
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||||
exp.Pivot: no_pivot_sql,
|
exp.Pivot: no_pivot_sql,
|
||||||
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
|
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
|
||||||
exp.StrPosition: str_position_sql,
|
exp.StrPosition: str_position_sql,
|
||||||
|
|
|
@ -61,9 +61,7 @@ def _array_sort(self, expression):
|
||||||
|
|
||||||
|
|
||||||
def _property_sql(self, expression):
|
def _property_sql(self, expression):
|
||||||
key = expression.name
|
return f"'{expression.name}'={self.sql(expression, 'value')}"
|
||||||
value = self.sql(expression, "value")
|
|
||||||
return f"'{key}'={value}"
|
|
||||||
|
|
||||||
|
|
||||||
def _str_to_unix(self, expression):
|
def _str_to_unix(self, expression):
|
||||||
|
@ -250,7 +248,7 @@ class Hive(Dialect):
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**generator.Generator.TRANSFORMS,
|
**generator.Generator.TRANSFORMS,
|
||||||
**transforms.UNALIAS_GROUP, # type: ignore
|
**transforms.UNALIAS_GROUP, # type: ignore
|
||||||
exp.AnonymousProperty: _property_sql,
|
exp.Property: _property_sql,
|
||||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||||
exp.ArrayAgg: rename_func("COLLECT_LIST"),
|
exp.ArrayAgg: rename_func("COLLECT_LIST"),
|
||||||
exp.ArrayConcat: rename_func("CONCAT"),
|
exp.ArrayConcat: rename_func("CONCAT"),
|
||||||
|
@ -262,7 +260,7 @@ class Hive(Dialect):
|
||||||
exp.DateStrToDate: rename_func("TO_DATE"),
|
exp.DateStrToDate: rename_func("TO_DATE"),
|
||||||
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
|
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
|
||||||
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
|
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
|
||||||
exp.FileFormatProperty: lambda self, e: f"STORED AS {e.text('value').upper()}",
|
exp.FileFormatProperty: lambda self, e: f"STORED AS {e.name.upper()}",
|
||||||
exp.If: if_sql,
|
exp.If: if_sql,
|
||||||
exp.Index: _index_sql,
|
exp.Index: _index_sql,
|
||||||
exp.ILike: no_ilike_sql,
|
exp.ILike: no_ilike_sql,
|
||||||
|
@ -285,7 +283,7 @@ class Hive(Dialect):
|
||||||
exp.StrToTime: _str_to_time,
|
exp.StrToTime: _str_to_time,
|
||||||
exp.StrToUnix: _str_to_unix,
|
exp.StrToUnix: _str_to_unix,
|
||||||
exp.StructExtract: struct_extract_sql,
|
exp.StructExtract: struct_extract_sql,
|
||||||
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'value')}",
|
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
|
||||||
exp.TimeStrToDate: rename_func("TO_DATE"),
|
exp.TimeStrToDate: rename_func("TO_DATE"),
|
||||||
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
|
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
|
||||||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||||
|
@ -298,11 +296,11 @@ class Hive(Dialect):
|
||||||
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.format_args(e.this, _time_format(self, e))})",
|
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.format_args(e.this, _time_format(self, e))})",
|
||||||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||||
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
|
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
|
||||||
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'value')}",
|
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
|
||||||
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
|
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
|
||||||
}
|
}
|
||||||
|
|
||||||
WITH_PROPERTIES = {exp.AnonymousProperty}
|
WITH_PROPERTIES = {exp.Property}
|
||||||
|
|
||||||
ROOT_PROPERTIES = {
|
ROOT_PROPERTIES = {
|
||||||
exp.PartitionedByProperty,
|
exp.PartitionedByProperty,
|
||||||
|
|
|
@ -453,6 +453,7 @@ class MySQL(Dialect):
|
||||||
exp.CharacterSetProperty,
|
exp.CharacterSetProperty,
|
||||||
exp.CollateProperty,
|
exp.CollateProperty,
|
||||||
exp.SchemaCommentProperty,
|
exp.SchemaCommentProperty,
|
||||||
|
exp.LikeProperty,
|
||||||
}
|
}
|
||||||
|
|
||||||
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
|
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from sqlglot import exp, generator, tokens, transforms
|
from sqlglot import exp, generator, parser, tokens, transforms
|
||||||
from sqlglot.dialects.dialect import Dialect, no_ilike_sql
|
from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func
|
||||||
from sqlglot.helper import csv
|
from sqlglot.helper import csv
|
||||||
from sqlglot.tokens import TokenType
|
from sqlglot.tokens import TokenType
|
||||||
|
|
||||||
|
@ -37,6 +37,12 @@ class Oracle(Dialect):
|
||||||
"YYYY": "%Y", # 2015
|
"YYYY": "%Y", # 2015
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class Parser(parser.Parser):
|
||||||
|
FUNCTIONS = {
|
||||||
|
**parser.Parser.FUNCTIONS,
|
||||||
|
"DECODE": exp.Matches.from_arg_list,
|
||||||
|
}
|
||||||
|
|
||||||
class Generator(generator.Generator):
|
class Generator(generator.Generator):
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**generator.Generator.TYPE_MAPPING,
|
**generator.Generator.TYPE_MAPPING,
|
||||||
|
@ -58,6 +64,7 @@ class Oracle(Dialect):
|
||||||
**transforms.UNALIAS_GROUP, # type: ignore
|
**transforms.UNALIAS_GROUP, # type: ignore
|
||||||
exp.ILike: no_ilike_sql,
|
exp.ILike: no_ilike_sql,
|
||||||
exp.Limit: _limit_sql,
|
exp.Limit: _limit_sql,
|
||||||
|
exp.Matches: rename_func("DECODE"),
|
||||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||||
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
|
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
|
||||||
|
|
|
@ -74,6 +74,27 @@ def _trim_sql(self, expression):
|
||||||
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
|
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
|
||||||
|
|
||||||
|
|
||||||
|
def _string_agg_sql(self, expression):
|
||||||
|
expression = expression.copy()
|
||||||
|
separator = expression.args.get("separator") or exp.Literal.string(",")
|
||||||
|
|
||||||
|
order = ""
|
||||||
|
this = expression.this
|
||||||
|
if isinstance(this, exp.Order):
|
||||||
|
if this.this:
|
||||||
|
this = this.this
|
||||||
|
this.pop()
|
||||||
|
order = self.sql(expression.this) # Order has a leading space
|
||||||
|
|
||||||
|
return f"STRING_AGG({self.format_args(this, separator)}{order})"
|
||||||
|
|
||||||
|
|
||||||
|
def _datatype_sql(self, expression):
|
||||||
|
if expression.this == exp.DataType.Type.ARRAY:
|
||||||
|
return f"{self.expressions(expression, flat=True)}[]"
|
||||||
|
return self.datatype_sql(expression)
|
||||||
|
|
||||||
|
|
||||||
def _auto_increment_to_serial(expression):
|
def _auto_increment_to_serial(expression):
|
||||||
auto = expression.find(exp.AutoIncrementColumnConstraint)
|
auto = expression.find(exp.AutoIncrementColumnConstraint)
|
||||||
|
|
||||||
|
@ -191,25 +212,27 @@ class Postgres(Dialect):
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**tokens.Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"ALWAYS": TokenType.ALWAYS,
|
"ALWAYS": TokenType.ALWAYS,
|
||||||
"BY DEFAULT": TokenType.BY_DEFAULT,
|
|
||||||
"IDENTITY": TokenType.IDENTITY,
|
|
||||||
"GENERATED": TokenType.GENERATED,
|
|
||||||
"DOUBLE PRECISION": TokenType.DOUBLE,
|
|
||||||
"BIGSERIAL": TokenType.BIGSERIAL,
|
|
||||||
"SERIAL": TokenType.SERIAL,
|
|
||||||
"SMALLSERIAL": TokenType.SMALLSERIAL,
|
|
||||||
"UUID": TokenType.UUID,
|
|
||||||
"TEMP": TokenType.TEMPORARY,
|
|
||||||
"BEGIN TRANSACTION": TokenType.BEGIN,
|
|
||||||
"BEGIN": TokenType.COMMAND,
|
"BEGIN": TokenType.COMMAND,
|
||||||
|
"BEGIN TRANSACTION": TokenType.BEGIN,
|
||||||
|
"BIGSERIAL": TokenType.BIGSERIAL,
|
||||||
|
"BY DEFAULT": TokenType.BY_DEFAULT,
|
||||||
"COMMENT ON": TokenType.COMMAND,
|
"COMMENT ON": TokenType.COMMAND,
|
||||||
"DECLARE": TokenType.COMMAND,
|
"DECLARE": TokenType.COMMAND,
|
||||||
"DO": TokenType.COMMAND,
|
"DO": TokenType.COMMAND,
|
||||||
|
"DOUBLE PRECISION": TokenType.DOUBLE,
|
||||||
|
"GENERATED": TokenType.GENERATED,
|
||||||
|
"GRANT": TokenType.COMMAND,
|
||||||
|
"HSTORE": TokenType.HSTORE,
|
||||||
|
"IDENTITY": TokenType.IDENTITY,
|
||||||
|
"JSONB": TokenType.JSONB,
|
||||||
"REFRESH": TokenType.COMMAND,
|
"REFRESH": TokenType.COMMAND,
|
||||||
"REINDEX": TokenType.COMMAND,
|
"REINDEX": TokenType.COMMAND,
|
||||||
"RESET": TokenType.COMMAND,
|
"RESET": TokenType.COMMAND,
|
||||||
"REVOKE": TokenType.COMMAND,
|
"REVOKE": TokenType.COMMAND,
|
||||||
"GRANT": TokenType.COMMAND,
|
"SERIAL": TokenType.SERIAL,
|
||||||
|
"SMALLSERIAL": TokenType.SMALLSERIAL,
|
||||||
|
"TEMP": TokenType.TEMPORARY,
|
||||||
|
"UUID": TokenType.UUID,
|
||||||
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
|
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
|
||||||
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
|
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
|
||||||
}
|
}
|
||||||
|
@ -265,4 +288,7 @@ class Postgres(Dialect):
|
||||||
exp.Trim: _trim_sql,
|
exp.Trim: _trim_sql,
|
||||||
exp.TryCast: no_trycast_sql,
|
exp.TryCast: no_trycast_sql,
|
||||||
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
|
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
|
||||||
|
exp.DataType: _datatype_sql,
|
||||||
|
exp.GroupConcat: _string_agg_sql,
|
||||||
|
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
|
||||||
}
|
}
|
||||||
|
|
|
@ -171,16 +171,7 @@ class Presto(Dialect):
|
||||||
|
|
||||||
STRUCT_DELIMITER = ("(", ")")
|
STRUCT_DELIMITER = ("(", ")")
|
||||||
|
|
||||||
ROOT_PROPERTIES = {
|
ROOT_PROPERTIES = {exp.SchemaCommentProperty}
|
||||||
exp.SchemaCommentProperty,
|
|
||||||
}
|
|
||||||
|
|
||||||
WITH_PROPERTIES = {
|
|
||||||
exp.PartitionedByProperty,
|
|
||||||
exp.FileFormatProperty,
|
|
||||||
exp.AnonymousProperty,
|
|
||||||
exp.TableFormatProperty,
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**generator.Generator.TYPE_MAPPING,
|
**generator.Generator.TYPE_MAPPING,
|
||||||
|
@ -231,7 +222,8 @@ class Presto(Dialect):
|
||||||
exp.StrToTime: _str_to_time_sql,
|
exp.StrToTime: _str_to_time_sql,
|
||||||
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
|
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
|
||||||
exp.StructExtract: struct_extract_sql,
|
exp.StructExtract: struct_extract_sql,
|
||||||
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT = '{e.text('value').upper()}'",
|
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
|
||||||
|
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
|
||||||
exp.TimeStrToDate: _date_parse_sql,
|
exp.TimeStrToDate: _date_parse_sql,
|
||||||
exp.TimeStrToTime: _date_parse_sql,
|
exp.TimeStrToTime: _date_parse_sql,
|
||||||
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",
|
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from sqlglot import exp
|
from sqlglot import exp, transforms
|
||||||
from sqlglot.dialects.postgres import Postgres
|
from sqlglot.dialects.postgres import Postgres
|
||||||
from sqlglot.tokens import TokenType
|
from sqlglot.tokens import TokenType
|
||||||
|
|
||||||
|
@ -18,12 +18,14 @@ class Redshift(Postgres):
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**Postgres.Tokenizer.KEYWORDS, # type: ignore
|
**Postgres.Tokenizer.KEYWORDS, # type: ignore
|
||||||
|
"COPY": TokenType.COMMAND,
|
||||||
"GEOMETRY": TokenType.GEOMETRY,
|
"GEOMETRY": TokenType.GEOMETRY,
|
||||||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||||
"HLLSKETCH": TokenType.HLLSKETCH,
|
"HLLSKETCH": TokenType.HLLSKETCH,
|
||||||
"SUPER": TokenType.SUPER,
|
"SUPER": TokenType.SUPER,
|
||||||
"TIME": TokenType.TIMESTAMP,
|
"TIME": TokenType.TIMESTAMP,
|
||||||
"TIMETZ": TokenType.TIMESTAMPTZ,
|
"TIMETZ": TokenType.TIMESTAMPTZ,
|
||||||
|
"UNLOAD": TokenType.COMMAND,
|
||||||
"VARBYTE": TokenType.VARBINARY,
|
"VARBYTE": TokenType.VARBINARY,
|
||||||
"SIMILAR TO": TokenType.SIMILAR_TO,
|
"SIMILAR TO": TokenType.SIMILAR_TO,
|
||||||
}
|
}
|
||||||
|
@ -35,3 +37,17 @@ class Redshift(Postgres):
|
||||||
exp.DataType.Type.VARBINARY: "VARBYTE",
|
exp.DataType.Type.VARBINARY: "VARBYTE",
|
||||||
exp.DataType.Type.INT: "INTEGER",
|
exp.DataType.Type.INT: "INTEGER",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ROOT_PROPERTIES = {
|
||||||
|
exp.DistKeyProperty,
|
||||||
|
exp.SortKeyProperty,
|
||||||
|
exp.DistStyleProperty,
|
||||||
|
}
|
||||||
|
|
||||||
|
TRANSFORMS = {
|
||||||
|
**Postgres.Generator.TRANSFORMS, # type: ignore
|
||||||
|
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
|
||||||
|
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
|
||||||
|
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||||
|
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
|
||||||
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
|
||||||
format_time_lambda,
|
format_time_lambda,
|
||||||
inline_array_sql,
|
inline_array_sql,
|
||||||
rename_func,
|
rename_func,
|
||||||
|
var_map_sql,
|
||||||
)
|
)
|
||||||
from sqlglot.expressions import Literal
|
from sqlglot.expressions import Literal
|
||||||
from sqlglot.helper import seq_get
|
from sqlglot.helper import seq_get
|
||||||
|
@ -100,6 +101,14 @@ def _parse_date_part(self):
|
||||||
return self.expression(exp.Extract, this=this, expression=expression)
|
return self.expression(exp.Extract, this=this, expression=expression)
|
||||||
|
|
||||||
|
|
||||||
|
def _datatype_sql(self, expression):
|
||||||
|
if expression.this == exp.DataType.Type.ARRAY:
|
||||||
|
return "ARRAY"
|
||||||
|
elif expression.this == exp.DataType.Type.MAP:
|
||||||
|
return "OBJECT"
|
||||||
|
return self.datatype_sql(expression)
|
||||||
|
|
||||||
|
|
||||||
class Snowflake(Dialect):
|
class Snowflake(Dialect):
|
||||||
null_ordering = "nulls_are_large"
|
null_ordering = "nulls_are_large"
|
||||||
time_format = "'yyyy-mm-dd hh24:mi:ss'"
|
time_format = "'yyyy-mm-dd hh24:mi:ss'"
|
||||||
|
@ -142,6 +151,8 @@ class Snowflake(Dialect):
|
||||||
"TO_TIMESTAMP": _snowflake_to_timestamp,
|
"TO_TIMESTAMP": _snowflake_to_timestamp,
|
||||||
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
|
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
|
||||||
"RLIKE": exp.RegexpLike.from_arg_list,
|
"RLIKE": exp.RegexpLike.from_arg_list,
|
||||||
|
"DECODE": exp.Matches.from_arg_list,
|
||||||
|
"OBJECT_CONSTRUCT": parser.parse_var_map,
|
||||||
}
|
}
|
||||||
|
|
||||||
FUNCTION_PARSERS = {
|
FUNCTION_PARSERS = {
|
||||||
|
@ -195,16 +206,20 @@ class Snowflake(Dialect):
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**generator.Generator.TRANSFORMS,
|
**generator.Generator.TRANSFORMS,
|
||||||
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
|
||||||
exp.If: rename_func("IFF"),
|
|
||||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
|
||||||
exp.UnixToTime: _unix_to_time_sql,
|
|
||||||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
|
||||||
exp.Array: inline_array_sql,
|
exp.Array: inline_array_sql,
|
||||||
exp.StrPosition: rename_func("POSITION"),
|
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
||||||
|
exp.DataType: _datatype_sql,
|
||||||
|
exp.If: rename_func("IFF"),
|
||||||
|
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||||
|
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||||
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
|
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
|
||||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
|
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||||
|
exp.Matches: rename_func("DECODE"),
|
||||||
|
exp.StrPosition: rename_func("POSITION"),
|
||||||
|
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||||
|
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||||
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
|
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
|
||||||
|
exp.UnixToTime: _unix_to_time_sql,
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
|
|
|
@ -98,7 +98,7 @@ class Spark(Hive):
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**Hive.Generator.TRANSFORMS, # type: ignore
|
**Hive.Generator.TRANSFORMS, # type: ignore
|
||||||
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
||||||
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
|
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
|
||||||
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||||
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
|
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
|
||||||
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
|
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
|
||||||
|
|
|
@ -13,6 +13,23 @@ from sqlglot.dialects.dialect import (
|
||||||
from sqlglot.tokens import TokenType
|
from sqlglot.tokens import TokenType
|
||||||
|
|
||||||
|
|
||||||
|
# https://www.sqlite.org/lang_aggfunc.html#group_concat
|
||||||
|
def _group_concat_sql(self, expression):
|
||||||
|
this = expression.this
|
||||||
|
distinct = expression.find(exp.Distinct)
|
||||||
|
if distinct:
|
||||||
|
this = distinct.expressions[0]
|
||||||
|
distinct = "DISTINCT "
|
||||||
|
|
||||||
|
if isinstance(expression.this, exp.Order):
|
||||||
|
self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.")
|
||||||
|
if expression.this.this and not distinct:
|
||||||
|
this = expression.this.this
|
||||||
|
|
||||||
|
separator = expression.args.get("separator")
|
||||||
|
return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"
|
||||||
|
|
||||||
|
|
||||||
class SQLite(Dialect):
|
class SQLite(Dialect):
|
||||||
class Tokenizer(tokens.Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
IDENTIFIERS = ['"', ("[", "]"), "`"]
|
IDENTIFIERS = ['"', ("[", "]"), "`"]
|
||||||
|
@ -62,6 +79,7 @@ class SQLite(Dialect):
|
||||||
exp.Levenshtein: rename_func("EDITDIST3"),
|
exp.Levenshtein: rename_func("EDITDIST3"),
|
||||||
exp.TableSample: no_tablesample_sql,
|
exp.TableSample: no_tablesample_sql,
|
||||||
exp.TryCast: no_trycast_sql,
|
exp.TryCast: no_trycast_sql,
|
||||||
|
exp.GroupConcat: _group_concat_sql,
|
||||||
}
|
}
|
||||||
|
|
||||||
def transaction_sql(self, expression):
|
def transaction_sql(self, expression):
|
||||||
|
|
|
@ -17,6 +17,7 @@ FULL_FORMAT_TIME_MAPPING = {
|
||||||
"mm": "%B",
|
"mm": "%B",
|
||||||
"m": "%B",
|
"m": "%B",
|
||||||
}
|
}
|
||||||
|
|
||||||
DATE_DELTA_INTERVAL = {
|
DATE_DELTA_INTERVAL = {
|
||||||
"year": "year",
|
"year": "year",
|
||||||
"yyyy": "year",
|
"yyyy": "year",
|
||||||
|
@ -37,11 +38,12 @@ DATE_DELTA_INTERVAL = {
|
||||||
|
|
||||||
|
|
||||||
DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{1,2})")
|
DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{1,2})")
|
||||||
|
|
||||||
# N = Numeric, C=Currency
|
# N = Numeric, C=Currency
|
||||||
TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
|
TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
|
||||||
|
|
||||||
|
|
||||||
def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
|
def _format_time_lambda(exp_class, full_format_mapping=None, default=None):
|
||||||
def _format_time(args):
|
def _format_time(args):
|
||||||
return exp_class(
|
return exp_class(
|
||||||
this=seq_get(args, 1),
|
this=seq_get(args, 1),
|
||||||
|
@ -58,7 +60,7 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
|
||||||
return _format_time
|
return _format_time
|
||||||
|
|
||||||
|
|
||||||
def parse_format(args):
|
def _parse_format(args):
|
||||||
fmt = seq_get(args, 1)
|
fmt = seq_get(args, 1)
|
||||||
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
|
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
|
||||||
if number_fmt:
|
if number_fmt:
|
||||||
|
@ -78,7 +80,7 @@ def generate_date_delta_with_unit_sql(self, e):
|
||||||
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
|
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
|
||||||
|
|
||||||
|
|
||||||
def generate_format_sql(self, e):
|
def _format_sql(self, e):
|
||||||
fmt = (
|
fmt = (
|
||||||
e.args["format"]
|
e.args["format"]
|
||||||
if isinstance(e, exp.NumberToStr)
|
if isinstance(e, exp.NumberToStr)
|
||||||
|
@ -87,6 +89,28 @@ def generate_format_sql(self, e):
|
||||||
return f"FORMAT({self.format_args(e.this, fmt)})"
|
return f"FORMAT({self.format_args(e.this, fmt)})"
|
||||||
|
|
||||||
|
|
||||||
|
def _string_agg_sql(self, e):
|
||||||
|
e = e.copy()
|
||||||
|
|
||||||
|
this = e.this
|
||||||
|
distinct = e.find(exp.Distinct)
|
||||||
|
if distinct:
|
||||||
|
# exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression
|
||||||
|
self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.")
|
||||||
|
this = distinct.expressions[0]
|
||||||
|
distinct.pop()
|
||||||
|
|
||||||
|
order = ""
|
||||||
|
if isinstance(e.this, exp.Order):
|
||||||
|
if e.this.this:
|
||||||
|
this = e.this.this
|
||||||
|
e.this.this.pop()
|
||||||
|
order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space
|
||||||
|
|
||||||
|
separator = e.args.get("separator") or exp.Literal.string(",")
|
||||||
|
return f"STRING_AGG({self.format_args(this, separator)}){order}"
|
||||||
|
|
||||||
|
|
||||||
class TSQL(Dialect):
|
class TSQL(Dialect):
|
||||||
null_ordering = "nulls_are_small"
|
null_ordering = "nulls_are_small"
|
||||||
time_format = "'yyyy-mm-dd hh:mm:ss'"
|
time_format = "'yyyy-mm-dd hh:mm:ss'"
|
||||||
|
@ -228,14 +252,14 @@ class TSQL(Dialect):
|
||||||
"ISNULL": exp.Coalesce.from_arg_list,
|
"ISNULL": exp.Coalesce.from_arg_list,
|
||||||
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
|
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
|
||||||
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
|
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
|
||||||
"DATENAME": tsql_format_time_lambda(exp.TimeToStr, full_format_mapping=True),
|
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
|
||||||
"DATEPART": tsql_format_time_lambda(exp.TimeToStr),
|
"DATEPART": _format_time_lambda(exp.TimeToStr),
|
||||||
"GETDATE": exp.CurrentDate.from_arg_list,
|
"GETDATE": exp.CurrentDate.from_arg_list,
|
||||||
"IIF": exp.If.from_arg_list,
|
"IIF": exp.If.from_arg_list,
|
||||||
"LEN": exp.Length.from_arg_list,
|
"LEN": exp.Length.from_arg_list,
|
||||||
"REPLICATE": exp.Repeat.from_arg_list,
|
"REPLICATE": exp.Repeat.from_arg_list,
|
||||||
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
|
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
|
||||||
"FORMAT": parse_format,
|
"FORMAT": _parse_format,
|
||||||
}
|
}
|
||||||
|
|
||||||
VAR_LENGTH_DATATYPES = {
|
VAR_LENGTH_DATATYPES = {
|
||||||
|
@ -298,6 +322,7 @@ class TSQL(Dialect):
|
||||||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||||
exp.CurrentDate: rename_func("GETDATE"),
|
exp.CurrentDate: rename_func("GETDATE"),
|
||||||
exp.If: rename_func("IIF"),
|
exp.If: rename_func("IIF"),
|
||||||
exp.NumberToStr: generate_format_sql,
|
exp.NumberToStr: _format_sql,
|
||||||
exp.TimeToStr: generate_format_sql,
|
exp.TimeToStr: _format_sql,
|
||||||
|
exp.GroupConcat: _string_agg_sql,
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,40 @@ class UnsupportedError(SqlglotError):
|
||||||
|
|
||||||
|
|
||||||
class ParseError(SqlglotError):
|
class ParseError(SqlglotError):
|
||||||
pass
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
errors: t.Optional[t.List[t.Dict[str, t.Any]]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(message)
|
||||||
|
self.errors = errors or []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new(
|
||||||
|
cls,
|
||||||
|
message: str,
|
||||||
|
description: t.Optional[str] = None,
|
||||||
|
line: t.Optional[int] = None,
|
||||||
|
col: t.Optional[int] = None,
|
||||||
|
start_context: t.Optional[str] = None,
|
||||||
|
highlight: t.Optional[str] = None,
|
||||||
|
end_context: t.Optional[str] = None,
|
||||||
|
into_expression: t.Optional[str] = None,
|
||||||
|
) -> ParseError:
|
||||||
|
return cls(
|
||||||
|
message,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"description": description,
|
||||||
|
"line": line,
|
||||||
|
"col": col,
|
||||||
|
"start_context": start_context,
|
||||||
|
"highlight": highlight,
|
||||||
|
"end_context": end_context,
|
||||||
|
"into_expression": into_expression,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TokenError(SqlglotError):
|
class TokenError(SqlglotError):
|
||||||
|
@ -41,9 +74,13 @@ class ExecuteError(SqlglotError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str:
|
def concat_messages(errors: t.Sequence[t.Any], maximum: int) -> str:
|
||||||
msg = [str(e) for e in errors[:maximum]]
|
msg = [str(e) for e in errors[:maximum]]
|
||||||
remaining = len(errors) - maximum
|
remaining = len(errors) - maximum
|
||||||
if remaining > 0:
|
if remaining > 0:
|
||||||
msg.append(f"... and {remaining} more")
|
msg.append(f"... and {remaining} more")
|
||||||
return "\n\n".join(msg)
|
return "\n\n".join(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_errors(errors: t.Sequence[ParseError]) -> t.List[t.Dict[str, t.Any]]:
|
||||||
|
return [e_dict for error in errors for e_dict in error.errors]
|
||||||
|
|
|
@ -122,7 +122,6 @@ def interval(this, unit):
|
||||||
|
|
||||||
|
|
||||||
ENV = {
|
ENV = {
|
||||||
"__builtins__": {},
|
|
||||||
"exp": exp,
|
"exp": exp,
|
||||||
# aggs
|
# aggs
|
||||||
"SUM": filter_nulls(sum),
|
"SUM": filter_nulls(sum),
|
||||||
|
|
|
@ -115,6 +115,9 @@ class PythonExecutor:
|
||||||
sink = self.table(context.columns)
|
sink = self.table(context.columns)
|
||||||
|
|
||||||
for reader in table_iter:
|
for reader in table_iter:
|
||||||
|
if len(sink) >= step.limit:
|
||||||
|
break
|
||||||
|
|
||||||
if condition and not context.eval(condition):
|
if condition and not context.eval(condition):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -123,9 +126,6 @@ class PythonExecutor:
|
||||||
else:
|
else:
|
||||||
sink.append(reader.row)
|
sink.append(reader.row)
|
||||||
|
|
||||||
if len(sink) >= step.limit:
|
|
||||||
break
|
|
||||||
|
|
||||||
return self.context({step.name: sink})
|
return self.context({step.name: sink})
|
||||||
|
|
||||||
def static(self):
|
def static(self):
|
||||||
|
@ -288,7 +288,13 @@ class PythonExecutor:
|
||||||
end = 1
|
end = 1
|
||||||
length = len(context.table)
|
length = len(context.table)
|
||||||
table = self.table(list(step.group) + step.aggregations)
|
table = self.table(list(step.group) + step.aggregations)
|
||||||
|
condition = self.generate(step.condition)
|
||||||
|
|
||||||
|
def add_row():
|
||||||
|
if not condition or context.eval(condition):
|
||||||
|
table.append(group + context.eval_tuple(aggregations))
|
||||||
|
|
||||||
|
if length:
|
||||||
for i in range(length):
|
for i in range(length):
|
||||||
context.set_index(i)
|
context.set_index(i)
|
||||||
key = context.eval_tuple(group_by)
|
key = context.eval_tuple(group_by)
|
||||||
|
@ -296,12 +302,17 @@ class PythonExecutor:
|
||||||
end += 1
|
end += 1
|
||||||
if key != group:
|
if key != group:
|
||||||
context.set_range(start, end - 2)
|
context.set_range(start, end - 2)
|
||||||
table.append(group + context.eval_tuple(aggregations))
|
add_row()
|
||||||
group = key
|
group = key
|
||||||
start = end - 2
|
start = end - 2
|
||||||
|
if len(table.rows) >= step.limit:
|
||||||
|
break
|
||||||
if i == length - 1:
|
if i == length - 1:
|
||||||
context.set_range(start, end - 1)
|
context.set_range(start, end - 1)
|
||||||
table.append(group + context.eval_tuple(aggregations))
|
add_row()
|
||||||
|
elif step.limit > 0:
|
||||||
|
context.set_range(0, 0)
|
||||||
|
table.append(context.eval_tuple(group_by) + context.eval_tuple(aggregations))
|
||||||
|
|
||||||
context = self.context({step.name: table, **{name: table for name in context.tables}})
|
context = self.context({step.name: table, **{name: table for name in context.tables}})
|
||||||
|
|
||||||
|
@ -311,11 +322,9 @@ class PythonExecutor:
|
||||||
|
|
||||||
def sort(self, step, context):
|
def sort(self, step, context):
|
||||||
projections = self.generate_tuple(step.projections)
|
projections = self.generate_tuple(step.projections)
|
||||||
|
|
||||||
projection_columns = [p.alias_or_name for p in step.projections]
|
projection_columns = [p.alias_or_name for p in step.projections]
|
||||||
all_columns = list(context.columns) + projection_columns
|
all_columns = list(context.columns) + projection_columns
|
||||||
sink = self.table(all_columns)
|
sink = self.table(all_columns)
|
||||||
|
|
||||||
for reader, ctx in context:
|
for reader, ctx in context:
|
||||||
sink.append(reader.row + ctx.eval_tuple(projections))
|
sink.append(reader.row + ctx.eval_tuple(projections))
|
||||||
|
|
||||||
|
@ -401,8 +410,9 @@ class Python(Dialect):
|
||||||
exp.Boolean: lambda self, e: "True" if e.this else "False",
|
exp.Boolean: lambda self, e: "True" if e.this else "False",
|
||||||
exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
|
exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
|
||||||
exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
|
exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
|
||||||
|
exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
|
||||||
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
|
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
|
||||||
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}",
|
exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
|
||||||
exp.Is: lambda self, e: self.binary(e, "is"),
|
exp.Is: lambda self, e: self.binary(e, "is"),
|
||||||
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
|
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
|
||||||
exp.Null: lambda *_: "None",
|
exp.Null: lambda *_: "None",
|
||||||
|
|
|
@ -43,14 +43,14 @@ class Expression(metaclass=_Expression):
|
||||||
|
|
||||||
key = "Expression"
|
key = "Expression"
|
||||||
arg_types = {"this": True}
|
arg_types = {"this": True}
|
||||||
__slots__ = ("args", "parent", "arg_key", "type", "comment")
|
__slots__ = ("args", "parent", "arg_key", "type", "comments")
|
||||||
|
|
||||||
def __init__(self, **args):
|
def __init__(self, **args):
|
||||||
self.args = args
|
self.args = args
|
||||||
self.parent = None
|
self.parent = None
|
||||||
self.arg_key = None
|
self.arg_key = None
|
||||||
self.type = None
|
self.type = None
|
||||||
self.comment = None
|
self.comments = None
|
||||||
|
|
||||||
for arg_key, value in self.args.items():
|
for arg_key, value in self.args.items():
|
||||||
self._set_parent(arg_key, value)
|
self._set_parent(arg_key, value)
|
||||||
|
@ -88,19 +88,6 @@ class Expression(metaclass=_Expression):
|
||||||
return field.this
|
return field.this
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def find_comment(self, key: str) -> str:
|
|
||||||
"""
|
|
||||||
Finds the comment that is attached to a specified child node.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: the key of the target child node (e.g. "this", "expression", etc).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The comment attached to the child node, or the empty string, if it doesn't exist.
|
|
||||||
"""
|
|
||||||
field = self.args.get(key)
|
|
||||||
return field.comment if isinstance(field, Expression) else ""
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_string(self):
|
def is_string(self):
|
||||||
return isinstance(self, Literal) and self.args["is_string"]
|
return isinstance(self, Literal) and self.args["is_string"]
|
||||||
|
@ -137,7 +124,7 @@ class Expression(metaclass=_Expression):
|
||||||
|
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, memo):
|
||||||
copy = self.__class__(**deepcopy(self.args))
|
copy = self.__class__(**deepcopy(self.args))
|
||||||
copy.comment = self.comment
|
copy.comments = self.comments
|
||||||
copy.type = self.type
|
copy.type = self.type
|
||||||
return copy
|
return copy
|
||||||
|
|
||||||
|
@ -369,7 +356,7 @@ class Expression(metaclass=_Expression):
|
||||||
)
|
)
|
||||||
for k, vs in self.args.items()
|
for k, vs in self.args.items()
|
||||||
}
|
}
|
||||||
args["comment"] = self.comment
|
args["comments"] = self.comments
|
||||||
args["type"] = self.type
|
args["type"] = self.type
|
||||||
args = {k: v for k, v in args.items() if v or not hide_missing}
|
args = {k: v for k, v in args.items() if v or not hide_missing}
|
||||||
|
|
||||||
|
@ -767,7 +754,7 @@ class NotNullColumnConstraint(ColumnConstraintKind):
|
||||||
|
|
||||||
|
|
||||||
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
|
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
|
||||||
pass
|
arg_types = {"desc": False}
|
||||||
|
|
||||||
|
|
||||||
class UniqueColumnConstraint(ColumnConstraintKind):
|
class UniqueColumnConstraint(ColumnConstraintKind):
|
||||||
|
@ -819,6 +806,12 @@ class Unique(Expression):
|
||||||
arg_types = {"expressions": True}
|
arg_types = {"expressions": True}
|
||||||
|
|
||||||
|
|
||||||
|
# https://www.postgresql.org/docs/9.1/sql-selectinto.html
|
||||||
|
# https://docs.aws.amazon.com/redshift/latest/dg/r_SELECT_INTO.html#r_SELECT_INTO-examples
|
||||||
|
class Into(Expression):
|
||||||
|
arg_types = {"this": True, "temporary": False, "unlogged": False}
|
||||||
|
|
||||||
|
|
||||||
class From(Expression):
|
class From(Expression):
|
||||||
arg_types = {"expressions": True}
|
arg_types = {"expressions": True}
|
||||||
|
|
||||||
|
@ -1065,67 +1058,67 @@ class Property(Expression):
|
||||||
|
|
||||||
|
|
||||||
class TableFormatProperty(Property):
|
class TableFormatProperty(Property):
|
||||||
pass
|
arg_types = {"this": True}
|
||||||
|
|
||||||
|
|
||||||
class PartitionedByProperty(Property):
|
class PartitionedByProperty(Property):
|
||||||
pass
|
arg_types = {"this": True}
|
||||||
|
|
||||||
|
|
||||||
class FileFormatProperty(Property):
|
class FileFormatProperty(Property):
|
||||||
pass
|
arg_types = {"this": True}
|
||||||
|
|
||||||
|
|
||||||
class DistKeyProperty(Property):
|
class DistKeyProperty(Property):
|
||||||
pass
|
arg_types = {"this": True}
|
||||||
|
|
||||||
|
|
||||||
class SortKeyProperty(Property):
|
class SortKeyProperty(Property):
|
||||||
pass
|
arg_types = {"this": True, "compound": False}
|
||||||
|
|
||||||
|
|
||||||
class DistStyleProperty(Property):
|
class DistStyleProperty(Property):
|
||||||
pass
|
arg_types = {"this": True}
|
||||||
|
|
||||||
|
|
||||||
|
class LikeProperty(Property):
|
||||||
|
arg_types = {"this": True, "expressions": False}
|
||||||
|
|
||||||
|
|
||||||
class LocationProperty(Property):
|
class LocationProperty(Property):
|
||||||
pass
|
arg_types = {"this": True}
|
||||||
|
|
||||||
|
|
||||||
class EngineProperty(Property):
|
class EngineProperty(Property):
|
||||||
pass
|
arg_types = {"this": True}
|
||||||
|
|
||||||
|
|
||||||
class AutoIncrementProperty(Property):
|
class AutoIncrementProperty(Property):
|
||||||
pass
|
arg_types = {"this": True}
|
||||||
|
|
||||||
|
|
||||||
class CharacterSetProperty(Property):
|
class CharacterSetProperty(Property):
|
||||||
arg_types = {"this": True, "value": True, "default": True}
|
arg_types = {"this": True, "default": True}
|
||||||
|
|
||||||
|
|
||||||
class CollateProperty(Property):
|
class CollateProperty(Property):
|
||||||
pass
|
arg_types = {"this": True}
|
||||||
|
|
||||||
|
|
||||||
class SchemaCommentProperty(Property):
|
class SchemaCommentProperty(Property):
|
||||||
pass
|
arg_types = {"this": True}
|
||||||
|
|
||||||
|
|
||||||
class AnonymousProperty(Property):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ReturnsProperty(Property):
|
class ReturnsProperty(Property):
|
||||||
arg_types = {"this": True, "value": True, "is_table": False}
|
arg_types = {"this": True, "is_table": False}
|
||||||
|
|
||||||
|
|
||||||
class LanguageProperty(Property):
|
class LanguageProperty(Property):
|
||||||
pass
|
arg_types = {"this": True}
|
||||||
|
|
||||||
|
|
||||||
class ExecuteAsProperty(Property):
|
class ExecuteAsProperty(Property):
|
||||||
pass
|
arg_types = {"this": True}
|
||||||
|
|
||||||
|
|
||||||
class VolatilityProperty(Property):
|
class VolatilityProperty(Property):
|
||||||
|
@ -1135,27 +1128,36 @@ class VolatilityProperty(Property):
|
||||||
class Properties(Expression):
|
class Properties(Expression):
|
||||||
arg_types = {"expressions": True}
|
arg_types = {"expressions": True}
|
||||||
|
|
||||||
PROPERTY_KEY_MAPPING = {
|
NAME_TO_PROPERTY = {
|
||||||
"AUTO_INCREMENT": AutoIncrementProperty,
|
"AUTO_INCREMENT": AutoIncrementProperty,
|
||||||
"CHARACTER_SET": CharacterSetProperty,
|
"CHARACTER SET": CharacterSetProperty,
|
||||||
"COLLATE": CollateProperty,
|
"COLLATE": CollateProperty,
|
||||||
"COMMENT": SchemaCommentProperty,
|
"COMMENT": SchemaCommentProperty,
|
||||||
"ENGINE": EngineProperty,
|
|
||||||
"FORMAT": FileFormatProperty,
|
|
||||||
"LOCATION": LocationProperty,
|
|
||||||
"PARTITIONED_BY": PartitionedByProperty,
|
|
||||||
"TABLE_FORMAT": TableFormatProperty,
|
|
||||||
"DISTKEY": DistKeyProperty,
|
"DISTKEY": DistKeyProperty,
|
||||||
"DISTSTYLE": DistStyleProperty,
|
"DISTSTYLE": DistStyleProperty,
|
||||||
|
"ENGINE": EngineProperty,
|
||||||
|
"EXECUTE AS": ExecuteAsProperty,
|
||||||
|
"FORMAT": FileFormatProperty,
|
||||||
|
"LANGUAGE": LanguageProperty,
|
||||||
|
"LOCATION": LocationProperty,
|
||||||
|
"PARTITIONED_BY": PartitionedByProperty,
|
||||||
|
"RETURNS": ReturnsProperty,
|
||||||
"SORTKEY": SortKeyProperty,
|
"SORTKEY": SortKeyProperty,
|
||||||
|
"TABLE_FORMAT": TableFormatProperty,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, properties_dict) -> Properties:
|
def from_dict(cls, properties_dict) -> Properties:
|
||||||
expressions = []
|
expressions = []
|
||||||
for key, value in properties_dict.items():
|
for key, value in properties_dict.items():
|
||||||
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
|
property_cls = cls.NAME_TO_PROPERTY.get(key.upper())
|
||||||
expressions.append(property_cls(this=Literal.string(key), value=convert(value)))
|
if property_cls:
|
||||||
|
expressions.append(property_cls(this=convert(value)))
|
||||||
|
else:
|
||||||
|
expressions.append(Property(this=Literal.string(key), value=convert(value)))
|
||||||
|
|
||||||
return cls(expressions=expressions)
|
return cls(expressions=expressions)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1383,6 +1385,7 @@ class Select(Subqueryable):
|
||||||
"expressions": False,
|
"expressions": False,
|
||||||
"hint": False,
|
"hint": False,
|
||||||
"distinct": False,
|
"distinct": False,
|
||||||
|
"into": False,
|
||||||
"from": False,
|
"from": False,
|
||||||
**QUERY_MODIFIERS,
|
**QUERY_MODIFIERS,
|
||||||
}
|
}
|
||||||
|
@ -2015,6 +2018,7 @@ class DataType(Expression):
|
||||||
DECIMAL = auto()
|
DECIMAL = auto()
|
||||||
BOOLEAN = auto()
|
BOOLEAN = auto()
|
||||||
JSON = auto()
|
JSON = auto()
|
||||||
|
JSONB = auto()
|
||||||
INTERVAL = auto()
|
INTERVAL = auto()
|
||||||
TIMESTAMP = auto()
|
TIMESTAMP = auto()
|
||||||
TIMESTAMPTZ = auto()
|
TIMESTAMPTZ = auto()
|
||||||
|
@ -2029,6 +2033,7 @@ class DataType(Expression):
|
||||||
STRUCT = auto()
|
STRUCT = auto()
|
||||||
NULLABLE = auto()
|
NULLABLE = auto()
|
||||||
HLLSKETCH = auto()
|
HLLSKETCH = auto()
|
||||||
|
HSTORE = auto()
|
||||||
SUPER = auto()
|
SUPER = auto()
|
||||||
SERIAL = auto()
|
SERIAL = auto()
|
||||||
SMALLSERIAL = auto()
|
SMALLSERIAL = auto()
|
||||||
|
@ -2109,7 +2114,7 @@ class Transaction(Command):
|
||||||
|
|
||||||
|
|
||||||
class Commit(Command):
|
class Commit(Command):
|
||||||
arg_types = {} # type: ignore
|
arg_types = {"chain": False}
|
||||||
|
|
||||||
|
|
||||||
class Rollback(Command):
|
class Rollback(Command):
|
||||||
|
@ -2442,7 +2447,7 @@ class ArrayFilter(Func):
|
||||||
|
|
||||||
|
|
||||||
class ArraySize(Func):
|
class ArraySize(Func):
|
||||||
pass
|
arg_types = {"this": True, "expression": False}
|
||||||
|
|
||||||
|
|
||||||
class ArraySort(Func):
|
class ArraySort(Func):
|
||||||
|
@ -2726,6 +2731,16 @@ class VarMap(Func):
|
||||||
is_var_len_args = True
|
is_var_len_args = True
|
||||||
|
|
||||||
|
|
||||||
|
class Matches(Func):
|
||||||
|
"""Oracle/Snowflake decode.
|
||||||
|
https://docs.oracle.com/cd/B19306_01/server.102/b14200/functions040.htm
|
||||||
|
Pattern matching MATCHES(value, search1, result1, ...searchN, resultN, else)
|
||||||
|
"""
|
||||||
|
|
||||||
|
arg_types = {"this": True, "expressions": True}
|
||||||
|
is_var_len_args = True
|
||||||
|
|
||||||
|
|
||||||
class Max(AggFunc):
|
class Max(AggFunc):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -2785,6 +2800,10 @@ class Round(Func):
|
||||||
arg_types = {"this": True, "decimals": False}
|
arg_types = {"this": True, "decimals": False}
|
||||||
|
|
||||||
|
|
||||||
|
class RowNumber(Func):
|
||||||
|
arg_types: t.Dict[str, t.Any] = {}
|
||||||
|
|
||||||
|
|
||||||
class SafeDivide(Func):
|
class SafeDivide(Func):
|
||||||
arg_types = {"this": True, "expression": True}
|
arg_types = {"this": True, "expression": True}
|
||||||
|
|
||||||
|
|
|
@ -1,19 +1,16 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
from sqlglot import exp
|
from sqlglot import exp
|
||||||
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors
|
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
|
||||||
from sqlglot.helper import apply_index_offset, csv
|
from sqlglot.helper import apply_index_offset, csv
|
||||||
from sqlglot.time import format_time
|
from sqlglot.time import format_time
|
||||||
from sqlglot.tokens import TokenType
|
from sqlglot.tokens import TokenType
|
||||||
|
|
||||||
logger = logging.getLogger("sqlglot")
|
logger = logging.getLogger("sqlglot")
|
||||||
|
|
||||||
NEWLINE_RE = re.compile("\r\n?|\n")
|
|
||||||
|
|
||||||
|
|
||||||
class Generator:
|
class Generator:
|
||||||
"""
|
"""
|
||||||
|
@ -58,11 +55,11 @@ class Generator:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
|
|
||||||
exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
|
exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
|
||||||
exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})",
|
exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})",
|
||||||
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
|
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
|
||||||
exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})",
|
exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})",
|
||||||
|
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'this')}",
|
||||||
exp.LanguageProperty: lambda self, e: self.naked_property(e),
|
exp.LanguageProperty: lambda self, e: self.naked_property(e),
|
||||||
exp.LocationProperty: lambda self, e: self.naked_property(e),
|
exp.LocationProperty: lambda self, e: self.naked_property(e),
|
||||||
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
|
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
|
||||||
|
@ -97,16 +94,17 @@ class Generator:
|
||||||
exp.DistStyleProperty,
|
exp.DistStyleProperty,
|
||||||
exp.DistKeyProperty,
|
exp.DistKeyProperty,
|
||||||
exp.SortKeyProperty,
|
exp.SortKeyProperty,
|
||||||
|
exp.LikeProperty,
|
||||||
}
|
}
|
||||||
|
|
||||||
WITH_PROPERTIES = {
|
WITH_PROPERTIES = {
|
||||||
exp.AnonymousProperty,
|
exp.Property,
|
||||||
exp.FileFormatProperty,
|
exp.FileFormatProperty,
|
||||||
exp.PartitionedByProperty,
|
exp.PartitionedByProperty,
|
||||||
exp.TableFormatProperty,
|
exp.TableFormatProperty,
|
||||||
}
|
}
|
||||||
|
|
||||||
WITH_SEPARATED_COMMENTS = (exp.Select,)
|
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
|
||||||
|
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
"time_mapping",
|
"time_mapping",
|
||||||
|
@ -211,7 +209,7 @@ class Generator:
|
||||||
for msg in self.unsupported_messages:
|
for msg in self.unsupported_messages:
|
||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
|
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
|
||||||
raise UnsupportedError(concat_errors(self.unsupported_messages, self.max_unsupported))
|
raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported))
|
||||||
|
|
||||||
return sql
|
return sql
|
||||||
|
|
||||||
|
@ -226,25 +224,24 @@ class Generator:
|
||||||
def seg(self, sql, sep=" "):
|
def seg(self, sql, sep=" "):
|
||||||
return f"{self.sep(sep)}{sql}"
|
return f"{self.sep(sep)}{sql}"
|
||||||
|
|
||||||
def maybe_comment(self, sql, expression, single_line=False):
|
def pad_comment(self, comment):
|
||||||
comment = expression.comment if self._comments else None
|
|
||||||
|
|
||||||
if not comment:
|
|
||||||
return sql
|
|
||||||
|
|
||||||
comment = " " + comment if comment[0].strip() else comment
|
comment = " " + comment if comment[0].strip() else comment
|
||||||
comment = comment + " " if comment[-1].strip() else comment
|
comment = comment + " " if comment[-1].strip() else comment
|
||||||
|
return comment
|
||||||
|
|
||||||
|
def maybe_comment(self, sql, expression):
|
||||||
|
comments = expression.comments if self._comments else None
|
||||||
|
|
||||||
|
if not comments:
|
||||||
|
return sql
|
||||||
|
|
||||||
|
sep = "\n" if self.pretty else " "
|
||||||
|
comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments)
|
||||||
|
|
||||||
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
|
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
|
||||||
return f"/*{comment}*/{self.sep()}{sql}"
|
return f"{comments}{self.sep()}{sql}"
|
||||||
|
|
||||||
if not self.pretty:
|
return f"{sql} {comments}"
|
||||||
return f"{sql} /*{comment}*/"
|
|
||||||
|
|
||||||
if not NEWLINE_RE.search(comment):
|
|
||||||
return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
|
|
||||||
|
|
||||||
return f"/*{comment}*/\n{sql}" if sql else f" /*{comment}*/"
|
|
||||||
|
|
||||||
def wrap(self, expression):
|
def wrap(self, expression):
|
||||||
this_sql = self.indent(
|
this_sql = self.indent(
|
||||||
|
@ -387,8 +384,11 @@ class Generator:
|
||||||
def notnullcolumnconstraint_sql(self, _):
|
def notnullcolumnconstraint_sql(self, _):
|
||||||
return "NOT NULL"
|
return "NOT NULL"
|
||||||
|
|
||||||
def primarykeycolumnconstraint_sql(self, _):
|
def primarykeycolumnconstraint_sql(self, expression):
|
||||||
return "PRIMARY KEY"
|
desc = expression.args.get("desc")
|
||||||
|
if desc is not None:
|
||||||
|
return f"PRIMARY KEY{' DESC' if desc else ' ASC'}"
|
||||||
|
return f"PRIMARY KEY"
|
||||||
|
|
||||||
def uniquecolumnconstraint_sql(self, _):
|
def uniquecolumnconstraint_sql(self, _):
|
||||||
return "UNIQUE"
|
return "UNIQUE"
|
||||||
|
@ -546,36 +546,33 @@ class Generator:
|
||||||
|
|
||||||
def root_properties(self, properties):
|
def root_properties(self, properties):
|
||||||
if properties.expressions:
|
if properties.expressions:
|
||||||
return self.sep() + self.expressions(
|
return self.sep() + self.expressions(properties, indent=False, sep=" ")
|
||||||
properties,
|
|
||||||
indent=False,
|
|
||||||
sep=" ",
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def properties(self, properties, prefix="", sep=", "):
|
def properties(self, properties, prefix="", sep=", "):
|
||||||
if properties.expressions:
|
if properties.expressions:
|
||||||
expressions = self.expressions(
|
expressions = self.expressions(properties, sep=sep, indent=False)
|
||||||
properties,
|
|
||||||
sep=sep,
|
|
||||||
indent=False,
|
|
||||||
)
|
|
||||||
return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}"
|
return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}"
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def with_properties(self, properties):
|
def with_properties(self, properties):
|
||||||
return self.properties(
|
return self.properties(properties, prefix="WITH")
|
||||||
properties,
|
|
||||||
prefix="WITH",
|
|
||||||
)
|
|
||||||
|
|
||||||
def property_sql(self, expression):
|
def property_sql(self, expression):
|
||||||
if isinstance(expression.this, exp.Literal):
|
property_cls = expression.__class__
|
||||||
key = expression.this.this
|
if property_cls == exp.Property:
|
||||||
else:
|
return f"{expression.name}={self.sql(expression, 'value')}"
|
||||||
key = expression.name
|
|
||||||
value = self.sql(expression, "value")
|
property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls)
|
||||||
return f"{key}={value}"
|
if not property_name:
|
||||||
|
self.unsupported(f"Unsupported property {property_name}")
|
||||||
|
|
||||||
|
return f"{property_name}={self.sql(expression, 'this')}"
|
||||||
|
|
||||||
|
def likeproperty_sql(self, expression):
|
||||||
|
options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions)
|
||||||
|
options = f" {options}" if options else ""
|
||||||
|
return f"LIKE {self.sql(expression, 'this')}{options}"
|
||||||
|
|
||||||
def insert_sql(self, expression):
|
def insert_sql(self, expression):
|
||||||
overwrite = expression.args.get("overwrite")
|
overwrite = expression.args.get("overwrite")
|
||||||
|
@ -700,6 +697,11 @@ class Generator:
|
||||||
def var_sql(self, expression):
|
def var_sql(self, expression):
|
||||||
return self.sql(expression, "this")
|
return self.sql(expression, "this")
|
||||||
|
|
||||||
|
def into_sql(self, expression):
|
||||||
|
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
|
||||||
|
unlogged = " UNLOGGED" if expression.args.get("unlogged") else ""
|
||||||
|
return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}"
|
||||||
|
|
||||||
def from_sql(self, expression):
|
def from_sql(self, expression):
|
||||||
expressions = self.expressions(expression, flat=True)
|
expressions = self.expressions(expression, flat=True)
|
||||||
return f"{self.seg('FROM')} {expressions}"
|
return f"{self.seg('FROM')} {expressions}"
|
||||||
|
@ -883,6 +885,7 @@ class Generator:
|
||||||
sql = self.query_modifiers(
|
sql = self.query_modifiers(
|
||||||
expression,
|
expression,
|
||||||
f"SELECT{hint}{distinct}{expressions}",
|
f"SELECT{hint}{distinct}{expressions}",
|
||||||
|
self.sql(expression, "into", comment=False),
|
||||||
self.sql(expression, "from", comment=False),
|
self.sql(expression, "from", comment=False),
|
||||||
)
|
)
|
||||||
return self.prepend_ctes(expression, sql)
|
return self.prepend_ctes(expression, sql)
|
||||||
|
@ -1061,6 +1064,11 @@ class Generator:
|
||||||
else:
|
else:
|
||||||
return f"TRIM({target})"
|
return f"TRIM({target})"
|
||||||
|
|
||||||
|
def concat_sql(self, expression):
|
||||||
|
if len(expression.expressions) == 1:
|
||||||
|
return self.sql(expression.expressions[0])
|
||||||
|
return self.function_fallback_sql(expression)
|
||||||
|
|
||||||
def check_sql(self, expression):
|
def check_sql(self, expression):
|
||||||
this = self.sql(expression, key="this")
|
this = self.sql(expression, key="this")
|
||||||
return f"CHECK ({this})"
|
return f"CHECK ({this})"
|
||||||
|
@ -1125,7 +1133,10 @@ class Generator:
|
||||||
return self.prepend_ctes(expression, sql)
|
return self.prepend_ctes(expression, sql)
|
||||||
|
|
||||||
def neg_sql(self, expression):
|
def neg_sql(self, expression):
|
||||||
return f"-{self.sql(expression, 'this')}"
|
# This makes sure we don't convert "- - 5" to "--5", which is a comment
|
||||||
|
this_sql = self.sql(expression, "this")
|
||||||
|
sep = " " if this_sql[0] == "-" else ""
|
||||||
|
return f"-{sep}{this_sql}"
|
||||||
|
|
||||||
def not_sql(self, expression):
|
def not_sql(self, expression):
|
||||||
return f"NOT {self.sql(expression, 'this')}"
|
return f"NOT {self.sql(expression, 'this')}"
|
||||||
|
@ -1191,8 +1202,12 @@ class Generator:
|
||||||
def transaction_sql(self, *_):
|
def transaction_sql(self, *_):
|
||||||
return "BEGIN"
|
return "BEGIN"
|
||||||
|
|
||||||
def commit_sql(self, *_):
|
def commit_sql(self, expression):
|
||||||
return "COMMIT"
|
chain = expression.args.get("chain")
|
||||||
|
if chain is not None:
|
||||||
|
chain = " AND CHAIN" if chain else " AND NO CHAIN"
|
||||||
|
|
||||||
|
return f"COMMIT{chain or ''}"
|
||||||
|
|
||||||
def rollback_sql(self, expression):
|
def rollback_sql(self, expression):
|
||||||
savepoint = expression.args.get("savepoint")
|
savepoint = expression.args.get("savepoint")
|
||||||
|
@ -1334,15 +1349,15 @@ class Generator:
|
||||||
result_sqls = []
|
result_sqls = []
|
||||||
for i, e in enumerate(expressions):
|
for i, e in enumerate(expressions):
|
||||||
sql = self.sql(e, comment=False)
|
sql = self.sql(e, comment=False)
|
||||||
comment = self.maybe_comment("", e, single_line=True)
|
comments = self.maybe_comment("", e)
|
||||||
|
|
||||||
if self.pretty:
|
if self.pretty:
|
||||||
if self._leading_comma:
|
if self._leading_comma:
|
||||||
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}")
|
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comments}")
|
||||||
else:
|
else:
|
||||||
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}")
|
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}")
|
||||||
else:
|
else:
|
||||||
result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}")
|
result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}")
|
||||||
|
|
||||||
result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
|
result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
|
||||||
return self.indent(result_sqls, skip_first=False) if indent else result_sqls
|
return self.indent(result_sqls, skip_first=False) if indent else result_sqls
|
||||||
|
@ -1354,7 +1369,10 @@ class Generator:
|
||||||
return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}"
|
return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}"
|
||||||
|
|
||||||
def naked_property(self, expression):
|
def naked_property(self, expression):
|
||||||
return f"{expression.name} {self.sql(expression, 'value')}"
|
property_name = exp.Properties.PROPERTY_TO_NAME.get(expression.__class__)
|
||||||
|
if not property_name:
|
||||||
|
self.unsupported(f"Unsupported property {expression.__class__.__name__}")
|
||||||
|
return f"{property_name} {self.sql(expression, 'this')}"
|
||||||
|
|
||||||
def set_operation(self, expression, op):
|
def set_operation(self, expression, op):
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
|
|
|
@ -68,6 +68,9 @@ def eliminate_subqueries(expression):
|
||||||
for cte_scope in root.cte_scopes:
|
for cte_scope in root.cte_scopes:
|
||||||
# Append all the new CTEs from this existing CTE
|
# Append all the new CTEs from this existing CTE
|
||||||
for scope in cte_scope.traverse():
|
for scope in cte_scope.traverse():
|
||||||
|
if scope is cte_scope:
|
||||||
|
# Don't try to eliminate this CTE itself
|
||||||
|
continue
|
||||||
new_cte = _eliminate(scope, existing_ctes, taken)
|
new_cte = _eliminate(scope, existing_ctes, taken)
|
||||||
if new_cte:
|
if new_cte:
|
||||||
new_ctes.append(new_cte)
|
new_ctes.append(new_cte)
|
||||||
|
@ -97,6 +100,9 @@ def _eliminate(scope, existing_ctes, taken):
|
||||||
if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
|
if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
|
||||||
return _eliminate_derived_table(scope, existing_ctes, taken)
|
return _eliminate_derived_table(scope, existing_ctes, taken)
|
||||||
|
|
||||||
|
if scope.is_cte:
|
||||||
|
return _eliminate_cte(scope, existing_ctes, taken)
|
||||||
|
|
||||||
|
|
||||||
def _eliminate_union(scope, existing_ctes, taken):
|
def _eliminate_union(scope, existing_ctes, taken):
|
||||||
duplicate_cte_alias = existing_ctes.get(scope.expression)
|
duplicate_cte_alias = existing_ctes.get(scope.expression)
|
||||||
|
@ -127,26 +133,61 @@ def _eliminate_union(scope, existing_ctes, taken):
|
||||||
|
|
||||||
|
|
||||||
def _eliminate_derived_table(scope, existing_ctes, taken):
|
def _eliminate_derived_table(scope, existing_ctes, taken):
|
||||||
|
parent = scope.expression.parent
|
||||||
|
name, cte = _new_cte(scope, existing_ctes, taken)
|
||||||
|
|
||||||
|
table = exp.alias_(exp.table_(name), alias=parent.alias or name)
|
||||||
|
parent.replace(table)
|
||||||
|
|
||||||
|
return cte
|
||||||
|
|
||||||
|
|
||||||
|
def _eliminate_cte(scope, existing_ctes, taken):
|
||||||
|
parent = scope.expression.parent
|
||||||
|
name, cte = _new_cte(scope, existing_ctes, taken)
|
||||||
|
|
||||||
|
with_ = parent.parent
|
||||||
|
parent.pop()
|
||||||
|
if not with_.expressions:
|
||||||
|
with_.pop()
|
||||||
|
|
||||||
|
# Rename references to this CTE
|
||||||
|
for child_scope in scope.parent.traverse():
|
||||||
|
for table, source in child_scope.selected_sources.values():
|
||||||
|
if source is scope:
|
||||||
|
new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name)
|
||||||
|
table.replace(new_table)
|
||||||
|
|
||||||
|
return cte
|
||||||
|
|
||||||
|
|
||||||
|
def _new_cte(scope, existing_ctes, taken):
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
tuple of (name, cte)
|
||||||
|
where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance.
|
||||||
|
If this CTE duplicates an existing CTE, `cte` will be None.
|
||||||
|
"""
|
||||||
duplicate_cte_alias = existing_ctes.get(scope.expression)
|
duplicate_cte_alias = existing_ctes.get(scope.expression)
|
||||||
parent = scope.expression.parent
|
parent = scope.expression.parent
|
||||||
name = alias = parent.alias
|
name = parent.alias
|
||||||
|
|
||||||
if not alias:
|
if not name:
|
||||||
name = alias = find_new_name(taken=taken, base="cte")
|
name = find_new_name(taken=taken, base="cte")
|
||||||
|
|
||||||
if duplicate_cte_alias:
|
if duplicate_cte_alias:
|
||||||
name = duplicate_cte_alias
|
name = duplicate_cte_alias
|
||||||
elif taken.get(alias):
|
elif taken.get(name):
|
||||||
name = find_new_name(taken=taken, base=alias)
|
name = find_new_name(taken=taken, base=name)
|
||||||
|
|
||||||
taken[name] = scope
|
taken[name] = scope
|
||||||
|
|
||||||
table = exp.alias_(exp.table_(name), alias=alias)
|
|
||||||
parent.replace(table)
|
|
||||||
|
|
||||||
if not duplicate_cte_alias:
|
if not duplicate_cte_alias:
|
||||||
existing_ctes[scope.expression] = name
|
existing_ctes[scope.expression] = name
|
||||||
return exp.CTE(
|
cte = exp.CTE(
|
||||||
this=scope.expression,
|
this=scope.expression,
|
||||||
alias=exp.TableAlias(this=exp.to_identifier(name)),
|
alias=exp.TableAlias(this=exp.to_identifier(name)),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
cte = None
|
||||||
|
return name, cte
|
||||||
|
|
92
sqlglot/optimizer/lower_identities.py
Normal file
92
sqlglot/optimizer/lower_identities.py
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
from sqlglot import exp
|
||||||
|
from sqlglot.helper import ensure_collection
|
||||||
|
|
||||||
|
|
||||||
|
def lower_identities(expression):
|
||||||
|
"""
|
||||||
|
Convert all unquoted identifiers to lower case.
|
||||||
|
|
||||||
|
Assuming the schema is all lower case, this essentially makes identifiers case-insensitive.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> import sqlglot
|
||||||
|
>>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
|
||||||
|
>>> lower_identities(expression).sql()
|
||||||
|
'SELECT bar.a AS A FROM "Foo".bar'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression (sqlglot.Expression): expression to quote
|
||||||
|
Returns:
|
||||||
|
sqlglot.Expression: quoted expression
|
||||||
|
"""
|
||||||
|
# We need to leave the output aliases unchanged, so the selects need special handling
|
||||||
|
_lower_selects(expression)
|
||||||
|
|
||||||
|
# These clauses can reference output aliases and also need special handling
|
||||||
|
_lower_order(expression)
|
||||||
|
_lower_having(expression)
|
||||||
|
|
||||||
|
# We've already handled these args, so don't traverse into them
|
||||||
|
traversed = {"expressions", "order", "having"}
|
||||||
|
|
||||||
|
if isinstance(expression, exp.Subquery):
|
||||||
|
# Root subquery, e.g. (SELECT A AS A FROM X) LIMIT 1
|
||||||
|
lower_identities(expression.this)
|
||||||
|
traversed |= {"this"}
|
||||||
|
|
||||||
|
if isinstance(expression, exp.Union):
|
||||||
|
# Union, e.g. SELECT A AS A FROM X UNION SELECT A AS A FROM X
|
||||||
|
lower_identities(expression.left)
|
||||||
|
lower_identities(expression.right)
|
||||||
|
traversed |= {"this", "expression"}
|
||||||
|
|
||||||
|
for k, v in expression.args.items():
|
||||||
|
if k in traversed:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for child in ensure_collection(v):
|
||||||
|
if isinstance(child, exp.Expression):
|
||||||
|
child.transform(_lower, copy=False)
|
||||||
|
|
||||||
|
return expression
|
||||||
|
|
||||||
|
|
||||||
|
def _lower_selects(expression):
|
||||||
|
for e in expression.expressions:
|
||||||
|
# Leave output aliases as-is
|
||||||
|
e.unalias().transform(_lower, copy=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _lower_order(expression):
|
||||||
|
order = expression.args.get("order")
|
||||||
|
|
||||||
|
if not order:
|
||||||
|
return
|
||||||
|
|
||||||
|
output_aliases = {e.alias for e in expression.expressions if isinstance(e, exp.Alias)}
|
||||||
|
|
||||||
|
for ordered in order.expressions:
|
||||||
|
# Don't lower references to output aliases
|
||||||
|
if not (
|
||||||
|
isinstance(ordered.this, exp.Column)
|
||||||
|
and not ordered.this.table
|
||||||
|
and ordered.this.name in output_aliases
|
||||||
|
):
|
||||||
|
ordered.transform(_lower, copy=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _lower_having(expression):
|
||||||
|
having = expression.args.get("having")
|
||||||
|
|
||||||
|
if not having:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Don't lower references to output aliases
|
||||||
|
for agg in having.find_all(exp.AggFunc):
|
||||||
|
agg.transform(_lower, copy=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _lower(node):
|
||||||
|
if isinstance(node, exp.Identifier) and not node.quoted:
|
||||||
|
node.set("this", node.this.lower())
|
||||||
|
return node
|
|
@ -6,6 +6,7 @@ from sqlglot.optimizer.eliminate_joins import eliminate_joins
|
||||||
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
||||||
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
|
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
|
||||||
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
|
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
|
||||||
|
from sqlglot.optimizer.lower_identities import lower_identities
|
||||||
from sqlglot.optimizer.merge_subqueries import merge_subqueries
|
from sqlglot.optimizer.merge_subqueries import merge_subqueries
|
||||||
from sqlglot.optimizer.normalize import normalize
|
from sqlglot.optimizer.normalize import normalize
|
||||||
from sqlglot.optimizer.optimize_joins import optimize_joins
|
from sqlglot.optimizer.optimize_joins import optimize_joins
|
||||||
|
@ -17,6 +18,7 @@ from sqlglot.optimizer.quote_identities import quote_identities
|
||||||
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
|
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
|
||||||
|
|
||||||
RULES = (
|
RULES = (
|
||||||
|
lower_identities,
|
||||||
qualify_tables,
|
qualify_tables,
|
||||||
isolate_table_selects,
|
isolate_table_selects,
|
||||||
qualify_columns,
|
qualify_columns,
|
||||||
|
|
|
@ -1,16 +1,15 @@
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
from sqlglot import exp
|
from sqlglot import exp
|
||||||
from sqlglot.optimizer.scope import traverse_scope
|
from sqlglot.optimizer.scope import ScopeType, traverse_scope
|
||||||
|
|
||||||
|
|
||||||
def unnest_subqueries(expression):
|
def unnest_subqueries(expression):
|
||||||
"""
|
"""
|
||||||
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
|
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
|
||||||
|
|
||||||
Convert the subquery into a group by so it is not a many to many left join.
|
Convert scalar subqueries into cross joins.
|
||||||
Unnesting can only occur if the subquery does not have LIMIT or OFFSET.
|
Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
|
||||||
Unnesting non correlated subqueries only happens on IN statements or = ANY statements.
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> import sqlglot
|
>>> import sqlglot
|
||||||
|
@ -29,21 +28,43 @@ def unnest_subqueries(expression):
|
||||||
for scope in traverse_scope(expression):
|
for scope in traverse_scope(expression):
|
||||||
select = scope.expression
|
select = scope.expression
|
||||||
parent = select.parent_select
|
parent = select.parent_select
|
||||||
|
if not parent:
|
||||||
|
continue
|
||||||
if scope.external_columns:
|
if scope.external_columns:
|
||||||
decorrelate(select, parent, scope.external_columns, sequence)
|
decorrelate(select, parent, scope.external_columns, sequence)
|
||||||
else:
|
elif scope.scope_type == ScopeType.SUBQUERY:
|
||||||
unnest(select, parent, sequence)
|
unnest(select, parent, sequence)
|
||||||
|
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
|
|
||||||
def unnest(select, parent_select, sequence):
|
def unnest(select, parent_select, sequence):
|
||||||
predicate = select.find_ancestor(exp.In, exp.Any)
|
if len(select.selects) > 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
predicate = select.find_ancestor(exp.Condition)
|
||||||
|
alias = _alias(sequence)
|
||||||
|
|
||||||
if not predicate or parent_select is not predicate.parent_select:
|
if not predicate or parent_select is not predicate.parent_select:
|
||||||
return
|
return
|
||||||
|
|
||||||
if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset):
|
# this subquery returns a scalar and can just be converted to a cross join
|
||||||
|
if not isinstance(predicate, (exp.In, exp.Any)):
|
||||||
|
having = predicate.find_ancestor(exp.Having)
|
||||||
|
column = exp.column(select.selects[0].alias_or_name, alias)
|
||||||
|
if having and having.parent_select is parent_select:
|
||||||
|
column = exp.Max(this=column)
|
||||||
|
_replace(select.parent, column)
|
||||||
|
|
||||||
|
parent_select.join(
|
||||||
|
select,
|
||||||
|
join_type="CROSS",
|
||||||
|
join_alias=alias,
|
||||||
|
copy=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if select.find(exp.Limit, exp.Offset):
|
||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(predicate, exp.Any):
|
if isinstance(predicate, exp.Any):
|
||||||
|
@ -54,7 +75,6 @@ def unnest(select, parent_select, sequence):
|
||||||
|
|
||||||
column = _other_operand(predicate)
|
column = _other_operand(predicate)
|
||||||
value = select.selects[0]
|
value = select.selects[0]
|
||||||
alias = _alias(sequence)
|
|
||||||
|
|
||||||
on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
|
on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
|
||||||
_replace(predicate, f"NOT {on.right} IS NULL")
|
_replace(predicate, f"NOT {on.right} IS NULL")
|
||||||
|
|
|
@ -4,7 +4,7 @@ import logging
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
from sqlglot import exp
|
from sqlglot import exp
|
||||||
from sqlglot.errors import ErrorLevel, ParseError, concat_errors
|
from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
|
||||||
from sqlglot.helper import apply_index_offset, ensure_collection, seq_get
|
from sqlglot.helper import apply_index_offset, ensure_collection, seq_get
|
||||||
from sqlglot.tokens import Token, Tokenizer, TokenType
|
from sqlglot.tokens import Token, Tokenizer, TokenType
|
||||||
from sqlglot.trie import in_trie, new_trie
|
from sqlglot.trie import in_trie, new_trie
|
||||||
|
@ -104,6 +104,7 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.BINARY,
|
TokenType.BINARY,
|
||||||
TokenType.VARBINARY,
|
TokenType.VARBINARY,
|
||||||
TokenType.JSON,
|
TokenType.JSON,
|
||||||
|
TokenType.JSONB,
|
||||||
TokenType.INTERVAL,
|
TokenType.INTERVAL,
|
||||||
TokenType.TIMESTAMP,
|
TokenType.TIMESTAMP,
|
||||||
TokenType.TIMESTAMPTZ,
|
TokenType.TIMESTAMPTZ,
|
||||||
|
@ -115,6 +116,7 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.GEOGRAPHY,
|
TokenType.GEOGRAPHY,
|
||||||
TokenType.GEOMETRY,
|
TokenType.GEOMETRY,
|
||||||
TokenType.HLLSKETCH,
|
TokenType.HLLSKETCH,
|
||||||
|
TokenType.HSTORE,
|
||||||
TokenType.SUPER,
|
TokenType.SUPER,
|
||||||
TokenType.SERIAL,
|
TokenType.SERIAL,
|
||||||
TokenType.SMALLSERIAL,
|
TokenType.SMALLSERIAL,
|
||||||
|
@ -153,6 +155,7 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.COLLATE,
|
TokenType.COLLATE,
|
||||||
TokenType.COMMAND,
|
TokenType.COMMAND,
|
||||||
TokenType.COMMIT,
|
TokenType.COMMIT,
|
||||||
|
TokenType.COMPOUND,
|
||||||
TokenType.CONSTRAINT,
|
TokenType.CONSTRAINT,
|
||||||
TokenType.CURRENT_TIME,
|
TokenType.CURRENT_TIME,
|
||||||
TokenType.DEFAULT,
|
TokenType.DEFAULT,
|
||||||
|
@ -194,6 +197,7 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.RANGE,
|
TokenType.RANGE,
|
||||||
TokenType.REFERENCES,
|
TokenType.REFERENCES,
|
||||||
TokenType.RETURNS,
|
TokenType.RETURNS,
|
||||||
|
TokenType.ROW,
|
||||||
TokenType.ROWS,
|
TokenType.ROWS,
|
||||||
TokenType.SCHEMA,
|
TokenType.SCHEMA,
|
||||||
TokenType.SCHEMA_COMMENT,
|
TokenType.SCHEMA_COMMENT,
|
||||||
|
@ -213,6 +217,7 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.TRUE,
|
TokenType.TRUE,
|
||||||
TokenType.UNBOUNDED,
|
TokenType.UNBOUNDED,
|
||||||
TokenType.UNIQUE,
|
TokenType.UNIQUE,
|
||||||
|
TokenType.UNLOGGED,
|
||||||
TokenType.UNPIVOT,
|
TokenType.UNPIVOT,
|
||||||
TokenType.PROPERTIES,
|
TokenType.PROPERTIES,
|
||||||
TokenType.PROCEDURE,
|
TokenType.PROCEDURE,
|
||||||
|
@ -400,9 +405,17 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
|
TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
|
||||||
TokenType.BEGIN: lambda self: self._parse_transaction(),
|
TokenType.BEGIN: lambda self: self._parse_transaction(),
|
||||||
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
|
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
|
||||||
|
TokenType.END: lambda self: self._parse_commit_or_rollback(),
|
||||||
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
|
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
UNARY_PARSERS = {
|
||||||
|
TokenType.PLUS: lambda self: self._parse_unary(), # Unary + is handled as a no-op
|
||||||
|
TokenType.NOT: lambda self: self.expression(exp.Not, this=self._parse_equality()),
|
||||||
|
TokenType.TILDA: lambda self: self.expression(exp.BitwiseNot, this=self._parse_unary()),
|
||||||
|
TokenType.DASH: lambda self: self.expression(exp.Neg, this=self._parse_unary()),
|
||||||
|
}
|
||||||
|
|
||||||
PRIMARY_PARSERS = {
|
PRIMARY_PARSERS = {
|
||||||
TokenType.STRING: lambda self, token: self.expression(
|
TokenType.STRING: lambda self, token: self.expression(
|
||||||
exp.Literal, this=token.text, is_string=True
|
exp.Literal, this=token.text, is_string=True
|
||||||
|
@ -446,19 +459,20 @@ class Parser(metaclass=_Parser):
|
||||||
}
|
}
|
||||||
|
|
||||||
PROPERTY_PARSERS = {
|
PROPERTY_PARSERS = {
|
||||||
TokenType.AUTO_INCREMENT: lambda self: self._parse_auto_increment(),
|
TokenType.AUTO_INCREMENT: lambda self: self._parse_property_assignment(
|
||||||
TokenType.CHARACTER_SET: lambda self: self._parse_character_set(),
|
exp.AutoIncrementProperty
|
||||||
TokenType.LOCATION: lambda self: self.expression(
|
|
||||||
exp.LocationProperty,
|
|
||||||
this=exp.Literal.string("LOCATION"),
|
|
||||||
value=self._parse_string(),
|
|
||||||
),
|
),
|
||||||
|
TokenType.CHARACTER_SET: lambda self: self._parse_character_set(),
|
||||||
|
TokenType.LOCATION: lambda self: self._parse_property_assignment(exp.LocationProperty),
|
||||||
TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(),
|
TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(),
|
||||||
TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(),
|
TokenType.SCHEMA_COMMENT: lambda self: self._parse_property_assignment(
|
||||||
TokenType.STORED: lambda self: self._parse_stored(),
|
exp.SchemaCommentProperty
|
||||||
|
),
|
||||||
|
TokenType.STORED: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
|
||||||
TokenType.DISTKEY: lambda self: self._parse_distkey(),
|
TokenType.DISTKEY: lambda self: self._parse_distkey(),
|
||||||
TokenType.DISTSTYLE: lambda self: self._parse_diststyle(),
|
TokenType.DISTSTYLE: lambda self: self._parse_property_assignment(exp.DistStyleProperty),
|
||||||
TokenType.SORTKEY: lambda self: self._parse_sortkey(),
|
TokenType.SORTKEY: lambda self: self._parse_sortkey(),
|
||||||
|
TokenType.LIKE: lambda self: self._parse_create_like(),
|
||||||
TokenType.RETURNS: lambda self: self._parse_returns(),
|
TokenType.RETURNS: lambda self: self._parse_returns(),
|
||||||
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
|
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
|
||||||
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
|
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
|
||||||
|
@ -468,7 +482,7 @@ class Parser(metaclass=_Parser):
|
||||||
),
|
),
|
||||||
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
|
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
|
||||||
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
|
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
|
||||||
TokenType.EXECUTE: lambda self: self._parse_execute_as(),
|
TokenType.EXECUTE: lambda self: self._parse_property_assignment(exp.ExecuteAsProperty),
|
||||||
TokenType.DETERMINISTIC: lambda self: self.expression(
|
TokenType.DETERMINISTIC: lambda self: self.expression(
|
||||||
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
|
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
|
||||||
),
|
),
|
||||||
|
@ -489,6 +503,7 @@ class Parser(metaclass=_Parser):
|
||||||
),
|
),
|
||||||
TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(),
|
TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(),
|
||||||
TokenType.UNIQUE: lambda self: self._parse_unique(),
|
TokenType.UNIQUE: lambda self: self._parse_unique(),
|
||||||
|
TokenType.LIKE: lambda self: self._parse_create_like(),
|
||||||
}
|
}
|
||||||
|
|
||||||
NO_PAREN_FUNCTION_PARSERS = {
|
NO_PAREN_FUNCTION_PARSERS = {
|
||||||
|
@ -505,6 +520,7 @@ class Parser(metaclass=_Parser):
|
||||||
"TRIM": lambda self: self._parse_trim(),
|
"TRIM": lambda self: self._parse_trim(),
|
||||||
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
|
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
|
||||||
"TRY_CAST": lambda self: self._parse_cast(False),
|
"TRY_CAST": lambda self: self._parse_cast(False),
|
||||||
|
"STRING_AGG": lambda self: self._parse_string_agg(),
|
||||||
}
|
}
|
||||||
|
|
||||||
QUERY_MODIFIER_PARSERS = {
|
QUERY_MODIFIER_PARSERS = {
|
||||||
|
@ -556,7 +572,7 @@ class Parser(metaclass=_Parser):
|
||||||
"_curr",
|
"_curr",
|
||||||
"_next",
|
"_next",
|
||||||
"_prev",
|
"_prev",
|
||||||
"_prev_comment",
|
"_prev_comments",
|
||||||
"_show_trie",
|
"_show_trie",
|
||||||
"_set_trie",
|
"_set_trie",
|
||||||
)
|
)
|
||||||
|
@ -589,7 +605,7 @@ class Parser(metaclass=_Parser):
|
||||||
self._curr = None
|
self._curr = None
|
||||||
self._next = None
|
self._next = None
|
||||||
self._prev = None
|
self._prev = None
|
||||||
self._prev_comment = None
|
self._prev_comments = None
|
||||||
|
|
||||||
def parse(self, raw_tokens, sql=None):
|
def parse(self, raw_tokens, sql=None):
|
||||||
"""
|
"""
|
||||||
|
@ -608,6 +624,7 @@ class Parser(metaclass=_Parser):
|
||||||
)
|
)
|
||||||
|
|
||||||
def parse_into(self, expression_types, raw_tokens, sql=None):
|
def parse_into(self, expression_types, raw_tokens, sql=None):
|
||||||
|
errors = []
|
||||||
for expression_type in ensure_collection(expression_types):
|
for expression_type in ensure_collection(expression_types):
|
||||||
parser = self.EXPRESSION_PARSERS.get(expression_type)
|
parser = self.EXPRESSION_PARSERS.get(expression_type)
|
||||||
if not parser:
|
if not parser:
|
||||||
|
@ -615,8 +632,12 @@ class Parser(metaclass=_Parser):
|
||||||
try:
|
try:
|
||||||
return self._parse(parser, raw_tokens, sql)
|
return self._parse(parser, raw_tokens, sql)
|
||||||
except ParseError as e:
|
except ParseError as e:
|
||||||
error = e
|
e.errors[0]["into_expression"] = expression_type
|
||||||
raise ParseError(f"Failed to parse into {expression_types}") from error
|
errors.append(e)
|
||||||
|
raise ParseError(
|
||||||
|
f"Failed to parse into {expression_types}",
|
||||||
|
errors=merge_errors(errors),
|
||||||
|
) from errors[-1]
|
||||||
|
|
||||||
def _parse(self, parse_method, raw_tokens, sql=None):
|
def _parse(self, parse_method, raw_tokens, sql=None):
|
||||||
self.reset()
|
self.reset()
|
||||||
|
@ -650,7 +671,10 @@ class Parser(metaclass=_Parser):
|
||||||
for error in self.errors:
|
for error in self.errors:
|
||||||
logger.error(str(error))
|
logger.error(str(error))
|
||||||
elif self.error_level == ErrorLevel.RAISE and self.errors:
|
elif self.error_level == ErrorLevel.RAISE and self.errors:
|
||||||
raise ParseError(concat_errors(self.errors, self.max_errors))
|
raise ParseError(
|
||||||
|
concat_messages(self.errors, self.max_errors),
|
||||||
|
errors=merge_errors(self.errors),
|
||||||
|
)
|
||||||
|
|
||||||
def raise_error(self, message, token=None):
|
def raise_error(self, message, token=None):
|
||||||
token = token or self._curr or self._prev or Token.string("")
|
token = token or self._curr or self._prev or Token.string("")
|
||||||
|
@ -659,19 +683,27 @@ class Parser(metaclass=_Parser):
|
||||||
start_context = self.sql[max(start - self.error_message_context, 0) : start]
|
start_context = self.sql[max(start - self.error_message_context, 0) : start]
|
||||||
highlight = self.sql[start:end]
|
highlight = self.sql[start:end]
|
||||||
end_context = self.sql[end : end + self.error_message_context]
|
end_context = self.sql[end : end + self.error_message_context]
|
||||||
error = ParseError(
|
error = ParseError.new(
|
||||||
f"{message}. Line {token.line}, Col: {token.col}.\n"
|
f"{message}. Line {token.line}, Col: {token.col}.\n"
|
||||||
f" {start_context}\033[4m{highlight}\033[0m{end_context}"
|
f" {start_context}\033[4m{highlight}\033[0m{end_context}",
|
||||||
|
description=message,
|
||||||
|
line=token.line,
|
||||||
|
col=token.col,
|
||||||
|
start_context=start_context,
|
||||||
|
highlight=highlight,
|
||||||
|
end_context=end_context,
|
||||||
)
|
)
|
||||||
if self.error_level == ErrorLevel.IMMEDIATE:
|
if self.error_level == ErrorLevel.IMMEDIATE:
|
||||||
raise error
|
raise error
|
||||||
self.errors.append(error)
|
self.errors.append(error)
|
||||||
|
|
||||||
def expression(self, exp_class, **kwargs):
|
def expression(self, exp_class, comments=None, **kwargs):
|
||||||
instance = exp_class(**kwargs)
|
instance = exp_class(**kwargs)
|
||||||
if self._prev_comment:
|
if self._prev_comments:
|
||||||
instance.comment = self._prev_comment
|
instance.comments = self._prev_comments
|
||||||
self._prev_comment = None
|
self._prev_comments = None
|
||||||
|
if comments:
|
||||||
|
instance.comments = comments
|
||||||
self.validate_expression(instance)
|
self.validate_expression(instance)
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
@ -714,10 +746,10 @@ class Parser(metaclass=_Parser):
|
||||||
self._next = seq_get(self._tokens, self._index + 1)
|
self._next = seq_get(self._tokens, self._index + 1)
|
||||||
if self._index > 0:
|
if self._index > 0:
|
||||||
self._prev = self._tokens[self._index - 1]
|
self._prev = self._tokens[self._index - 1]
|
||||||
self._prev_comment = self._prev.comment
|
self._prev_comments = self._prev.comments
|
||||||
else:
|
else:
|
||||||
self._prev = None
|
self._prev = None
|
||||||
self._prev_comment = None
|
self._prev_comments = None
|
||||||
|
|
||||||
def _retreat(self, index):
|
def _retreat(self, index):
|
||||||
self._advance(index - self._index)
|
self._advance(index - self._index)
|
||||||
|
@ -768,7 +800,7 @@ class Parser(metaclass=_Parser):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_create(self):
|
def _parse_create(self):
|
||||||
replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
|
replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
|
||||||
temporary = self._match(TokenType.TEMPORARY)
|
temporary = self._match(TokenType.TEMPORARY)
|
||||||
transient = self._match(TokenType.TRANSIENT)
|
transient = self._match(TokenType.TRANSIENT)
|
||||||
unique = self._match(TokenType.UNIQUE)
|
unique = self._match(TokenType.UNIQUE)
|
||||||
|
@ -822,97 +854,57 @@ class Parser(metaclass=_Parser):
|
||||||
def _parse_property(self):
|
def _parse_property(self):
|
||||||
if self._match_set(self.PROPERTY_PARSERS):
|
if self._match_set(self.PROPERTY_PARSERS):
|
||||||
return self.PROPERTY_PARSERS[self._prev.token_type](self)
|
return self.PROPERTY_PARSERS[self._prev.token_type](self)
|
||||||
|
|
||||||
if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET):
|
if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET):
|
||||||
return self._parse_character_set(True)
|
return self._parse_character_set(True)
|
||||||
|
|
||||||
if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False):
|
if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
|
||||||
key = self._parse_var().this
|
return self._parse_sortkey(compound=True)
|
||||||
self._match(TokenType.EQ)
|
|
||||||
|
|
||||||
return self.expression(
|
if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False):
|
||||||
exp.AnonymousProperty,
|
key = self._parse_var()
|
||||||
this=exp.Literal.string(key),
|
self._match(TokenType.EQ)
|
||||||
value=self._parse_column(),
|
return self.expression(exp.Property, this=key, value=self._parse_column())
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_property_assignment(self, exp_class):
|
def _parse_property_assignment(self, exp_class):
|
||||||
prop = self._prev.text
|
|
||||||
self._match(TokenType.EQ)
|
self._match(TokenType.EQ)
|
||||||
return self.expression(exp_class, this=prop, value=self._parse_var_or_string())
|
self._match(TokenType.ALIAS)
|
||||||
|
return self.expression(exp_class, this=self._parse_var_or_string() or self._parse_number())
|
||||||
|
|
||||||
def _parse_partitioned_by(self):
|
def _parse_partitioned_by(self):
|
||||||
self._match(TokenType.EQ)
|
self._match(TokenType.EQ)
|
||||||
return self.expression(
|
return self.expression(
|
||||||
exp.PartitionedByProperty,
|
exp.PartitionedByProperty,
|
||||||
this=exp.Literal.string("PARTITIONED_BY"),
|
this=self._parse_schema() or self._parse_bracket(self._parse_field()),
|
||||||
value=self._parse_schema() or self._parse_bracket(self._parse_field()),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_stored(self):
|
|
||||||
self._match(TokenType.ALIAS)
|
|
||||||
self._match(TokenType.EQ)
|
|
||||||
return self.expression(
|
|
||||||
exp.FileFormatProperty,
|
|
||||||
this=exp.Literal.string("FORMAT"),
|
|
||||||
value=exp.Literal.string(self._parse_var_or_string().name),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_distkey(self):
|
def _parse_distkey(self):
|
||||||
self._match_l_paren()
|
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_var))
|
||||||
this = exp.Literal.string("DISTKEY")
|
|
||||||
value = exp.Literal.string(self._parse_var().name)
|
|
||||||
self._match_r_paren()
|
|
||||||
return self.expression(
|
|
||||||
exp.DistKeyProperty,
|
|
||||||
this=this,
|
|
||||||
value=value,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_sortkey(self):
|
def _parse_create_like(self):
|
||||||
self._match_l_paren()
|
table = self._parse_table(schema=True)
|
||||||
this = exp.Literal.string("SORTKEY")
|
options = []
|
||||||
value = exp.Literal.string(self._parse_var().name)
|
while self._match_texts(("INCLUDING", "EXCLUDING")):
|
||||||
self._match_r_paren()
|
options.append(
|
||||||
return self.expression(
|
self.expression(
|
||||||
exp.SortKeyProperty,
|
exp.Property,
|
||||||
this=this,
|
this=self._prev.text.upper(),
|
||||||
value=value,
|
value=exp.Var(this=self._parse_id_var().this.upper()),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_diststyle(self):
|
|
||||||
this = exp.Literal.string("DISTSTYLE")
|
|
||||||
value = exp.Literal.string(self._parse_var().name)
|
|
||||||
return self.expression(
|
|
||||||
exp.DistStyleProperty,
|
|
||||||
this=this,
|
|
||||||
value=value,
|
|
||||||
)
|
)
|
||||||
|
return self.expression(exp.LikeProperty, this=table, expressions=options)
|
||||||
|
|
||||||
def _parse_auto_increment(self):
|
def _parse_sortkey(self, compound=False):
|
||||||
self._match(TokenType.EQ)
|
|
||||||
return self.expression(
|
return self.expression(
|
||||||
exp.AutoIncrementProperty,
|
exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_var), compound=compound
|
||||||
this=exp.Literal.string("AUTO_INCREMENT"),
|
|
||||||
value=self._parse_number(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_schema_comment(self):
|
|
||||||
self._match(TokenType.EQ)
|
|
||||||
return self.expression(
|
|
||||||
exp.SchemaCommentProperty,
|
|
||||||
this=exp.Literal.string("COMMENT"),
|
|
||||||
value=self._parse_string(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_character_set(self, default=False):
|
def _parse_character_set(self, default=False):
|
||||||
self._match(TokenType.EQ)
|
self._match(TokenType.EQ)
|
||||||
return self.expression(
|
return self.expression(
|
||||||
exp.CharacterSetProperty,
|
exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default
|
||||||
this=exp.Literal.string("CHARACTER_SET"),
|
|
||||||
value=self._parse_var_or_string(),
|
|
||||||
default=default,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_returns(self):
|
def _parse_returns(self):
|
||||||
|
@ -931,20 +923,7 @@ class Parser(metaclass=_Parser):
|
||||||
else:
|
else:
|
||||||
value = self._parse_types()
|
value = self._parse_types()
|
||||||
|
|
||||||
return self.expression(
|
return self.expression(exp.ReturnsProperty, this=value, is_table=is_table)
|
||||||
exp.ReturnsProperty,
|
|
||||||
this=exp.Literal.string("RETURNS"),
|
|
||||||
value=value,
|
|
||||||
is_table=is_table,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_execute_as(self):
|
|
||||||
self._match(TokenType.ALIAS)
|
|
||||||
return self.expression(
|
|
||||||
exp.ExecuteAsProperty,
|
|
||||||
this=exp.Literal.string("EXECUTE AS"),
|
|
||||||
value=self._parse_var(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_properties(self):
|
def _parse_properties(self):
|
||||||
properties = []
|
properties = []
|
||||||
|
@ -956,7 +935,7 @@ class Parser(metaclass=_Parser):
|
||||||
properties.extend(
|
properties.extend(
|
||||||
self._parse_wrapped_csv(
|
self._parse_wrapped_csv(
|
||||||
lambda: self.expression(
|
lambda: self.expression(
|
||||||
exp.AnonymousProperty,
|
exp.Property,
|
||||||
this=self._parse_string(),
|
this=self._parse_string(),
|
||||||
value=self._match(TokenType.EQ) and self._parse_string(),
|
value=self._match(TokenType.EQ) and self._parse_string(),
|
||||||
)
|
)
|
||||||
|
@ -1076,7 +1055,12 @@ class Parser(metaclass=_Parser):
|
||||||
options = []
|
options = []
|
||||||
|
|
||||||
if self._match(TokenType.OPTIONS):
|
if self._match(TokenType.OPTIONS):
|
||||||
options = self._parse_wrapped_csv(self._parse_string, sep=TokenType.EQ)
|
self._match_l_paren()
|
||||||
|
k = self._parse_string()
|
||||||
|
self._match(TokenType.EQ)
|
||||||
|
v = self._parse_string()
|
||||||
|
options = [k, v]
|
||||||
|
self._match_r_paren()
|
||||||
|
|
||||||
self._match(TokenType.ALIAS)
|
self._match(TokenType.ALIAS)
|
||||||
return self.expression(
|
return self.expression(
|
||||||
|
@ -1116,7 +1100,7 @@ class Parser(metaclass=_Parser):
|
||||||
self.raise_error(f"{this.key} does not support CTE")
|
self.raise_error(f"{this.key} does not support CTE")
|
||||||
this = cte
|
this = cte
|
||||||
elif self._match(TokenType.SELECT):
|
elif self._match(TokenType.SELECT):
|
||||||
comment = self._prev_comment
|
comments = self._prev_comments
|
||||||
|
|
||||||
hint = self._parse_hint()
|
hint = self._parse_hint()
|
||||||
all_ = self._match(TokenType.ALL)
|
all_ = self._match(TokenType.ALL)
|
||||||
|
@ -1141,10 +1125,16 @@ class Parser(metaclass=_Parser):
|
||||||
expressions=expressions,
|
expressions=expressions,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
this.comment = comment
|
this.comments = comments
|
||||||
|
|
||||||
|
into = self._parse_into()
|
||||||
|
if into:
|
||||||
|
this.set("into", into)
|
||||||
|
|
||||||
from_ = self._parse_from()
|
from_ = self._parse_from()
|
||||||
if from_:
|
if from_:
|
||||||
this.set("from", from_)
|
this.set("from", from_)
|
||||||
|
|
||||||
self._parse_query_modifiers(this)
|
self._parse_query_modifiers(this)
|
||||||
elif (table or nested) and self._match(TokenType.L_PAREN):
|
elif (table or nested) and self._match(TokenType.L_PAREN):
|
||||||
this = self._parse_table() if table else self._parse_select(nested=True)
|
this = self._parse_table() if table else self._parse_select(nested=True)
|
||||||
|
@ -1248,11 +1238,24 @@ class Parser(metaclass=_Parser):
|
||||||
return self.expression(exp.Hint, expressions=hints)
|
return self.expression(exp.Hint, expressions=hints)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _parse_into(self):
|
||||||
|
if not self._match(TokenType.INTO):
|
||||||
|
return None
|
||||||
|
|
||||||
|
temp = self._match(TokenType.TEMPORARY)
|
||||||
|
unlogged = self._match(TokenType.UNLOGGED)
|
||||||
|
self._match(TokenType.TABLE)
|
||||||
|
|
||||||
|
return self.expression(
|
||||||
|
exp.Into, this=self._parse_table(schema=True), temporary=temp, unlogged=unlogged
|
||||||
|
)
|
||||||
|
|
||||||
def _parse_from(self):
|
def _parse_from(self):
|
||||||
if not self._match(TokenType.FROM):
|
if not self._match(TokenType.FROM):
|
||||||
return None
|
return None
|
||||||
|
return self.expression(
|
||||||
return self.expression(exp.From, expressions=self._parse_csv(self._parse_table))
|
exp.From, comments=self._prev_comments, expressions=self._parse_csv(self._parse_table)
|
||||||
|
)
|
||||||
|
|
||||||
def _parse_lateral(self):
|
def _parse_lateral(self):
|
||||||
outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY)
|
outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY)
|
||||||
|
@ -1515,7 +1518,9 @@ class Parser(metaclass=_Parser):
|
||||||
def _parse_where(self, skip_where_token=False):
|
def _parse_where(self, skip_where_token=False):
|
||||||
if not skip_where_token and not self._match(TokenType.WHERE):
|
if not skip_where_token and not self._match(TokenType.WHERE):
|
||||||
return None
|
return None
|
||||||
return self.expression(exp.Where, this=self._parse_conjunction())
|
return self.expression(
|
||||||
|
exp.Where, comments=self._prev_comments, this=self._parse_conjunction()
|
||||||
|
)
|
||||||
|
|
||||||
def _parse_group(self, skip_group_by_token=False):
|
def _parse_group(self, skip_group_by_token=False):
|
||||||
if not skip_group_by_token and not self._match(TokenType.GROUP_BY):
|
if not skip_group_by_token and not self._match(TokenType.GROUP_BY):
|
||||||
|
@ -1737,12 +1742,8 @@ class Parser(metaclass=_Parser):
|
||||||
return self._parse_tokens(self._parse_unary, self.FACTOR)
|
return self._parse_tokens(self._parse_unary, self.FACTOR)
|
||||||
|
|
||||||
def _parse_unary(self):
|
def _parse_unary(self):
|
||||||
if self._match(TokenType.NOT):
|
if self._match_set(self.UNARY_PARSERS):
|
||||||
return self.expression(exp.Not, this=self._parse_equality())
|
return self.UNARY_PARSERS[self._prev.token_type](self)
|
||||||
if self._match(TokenType.TILDA):
|
|
||||||
return self.expression(exp.BitwiseNot, this=self._parse_unary())
|
|
||||||
if self._match(TokenType.DASH):
|
|
||||||
return self.expression(exp.Neg, this=self._parse_unary())
|
|
||||||
return self._parse_at_time_zone(self._parse_type())
|
return self._parse_at_time_zone(self._parse_type())
|
||||||
|
|
||||||
def _parse_type(self):
|
def _parse_type(self):
|
||||||
|
@ -1775,17 +1776,6 @@ class Parser(metaclass=_Parser):
|
||||||
expressions = None
|
expressions = None
|
||||||
maybe_func = False
|
maybe_func = False
|
||||||
|
|
||||||
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
|
|
||||||
return exp.DataType(
|
|
||||||
this=exp.DataType.Type.ARRAY,
|
|
||||||
expressions=[exp.DataType.build(type_token.value)],
|
|
||||||
nested=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._match(TokenType.L_BRACKET):
|
|
||||||
self._retreat(index)
|
|
||||||
return None
|
|
||||||
|
|
||||||
if self._match(TokenType.L_PAREN):
|
if self._match(TokenType.L_PAREN):
|
||||||
if is_struct:
|
if is_struct:
|
||||||
expressions = self._parse_csv(self._parse_struct_kwargs)
|
expressions = self._parse_csv(self._parse_struct_kwargs)
|
||||||
|
@ -1801,6 +1791,17 @@ class Parser(metaclass=_Parser):
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
maybe_func = True
|
maybe_func = True
|
||||||
|
|
||||||
|
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
|
||||||
|
return exp.DataType(
|
||||||
|
this=exp.DataType.Type.ARRAY,
|
||||||
|
expressions=[exp.DataType.build(type_token.value, expressions=expressions)],
|
||||||
|
nested=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._match(TokenType.L_BRACKET):
|
||||||
|
self._retreat(index)
|
||||||
|
return None
|
||||||
|
|
||||||
if nested and self._match(TokenType.LT):
|
if nested and self._match(TokenType.LT):
|
||||||
if is_struct:
|
if is_struct:
|
||||||
expressions = self._parse_csv(self._parse_struct_kwargs)
|
expressions = self._parse_csv(self._parse_struct_kwargs)
|
||||||
|
@ -1904,7 +1905,7 @@ class Parser(metaclass=_Parser):
|
||||||
return exp.Literal.number(f"0.{self._prev.text}")
|
return exp.Literal.number(f"0.{self._prev.text}")
|
||||||
|
|
||||||
if self._match(TokenType.L_PAREN):
|
if self._match(TokenType.L_PAREN):
|
||||||
comment = self._prev_comment
|
comments = self._prev_comments
|
||||||
query = self._parse_select()
|
query = self._parse_select()
|
||||||
|
|
||||||
if query:
|
if query:
|
||||||
|
@ -1924,8 +1925,8 @@ class Parser(metaclass=_Parser):
|
||||||
this = self.expression(exp.Tuple, expressions=expressions)
|
this = self.expression(exp.Tuple, expressions=expressions)
|
||||||
else:
|
else:
|
||||||
this = self.expression(exp.Paren, this=this)
|
this = self.expression(exp.Paren, this=this)
|
||||||
if comment:
|
if comments:
|
||||||
this.comment = comment
|
this.comments = comments
|
||||||
return this
|
return this
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
@ -2098,7 +2099,10 @@ class Parser(metaclass=_Parser):
|
||||||
elif self._match(TokenType.SCHEMA_COMMENT):
|
elif self._match(TokenType.SCHEMA_COMMENT):
|
||||||
kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string())
|
kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string())
|
||||||
elif self._match(TokenType.PRIMARY_KEY):
|
elif self._match(TokenType.PRIMARY_KEY):
|
||||||
kind = exp.PrimaryKeyColumnConstraint()
|
desc = None
|
||||||
|
if self._match(TokenType.ASC) or self._match(TokenType.DESC):
|
||||||
|
desc = self._prev.token_type == TokenType.DESC
|
||||||
|
kind = exp.PrimaryKeyColumnConstraint(desc=desc)
|
||||||
elif self._match(TokenType.UNIQUE):
|
elif self._match(TokenType.UNIQUE):
|
||||||
kind = exp.UniqueColumnConstraint()
|
kind = exp.UniqueColumnConstraint()
|
||||||
elif self._match(TokenType.GENERATED):
|
elif self._match(TokenType.GENERATED):
|
||||||
|
@ -2189,7 +2193,7 @@ class Parser(metaclass=_Parser):
|
||||||
if not self._match(TokenType.R_BRACKET):
|
if not self._match(TokenType.R_BRACKET):
|
||||||
self.raise_error("Expected ]")
|
self.raise_error("Expected ]")
|
||||||
|
|
||||||
this.comment = self._prev_comment
|
this.comments = self._prev_comments
|
||||||
return self._parse_bracket(this)
|
return self._parse_bracket(this)
|
||||||
|
|
||||||
def _parse_case(self):
|
def _parse_case(self):
|
||||||
|
@ -2256,6 +2260,33 @@ class Parser(metaclass=_Parser):
|
||||||
|
|
||||||
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
|
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
|
||||||
|
|
||||||
|
def _parse_string_agg(self):
|
||||||
|
if self._match(TokenType.DISTINCT):
|
||||||
|
args = self._parse_csv(self._parse_conjunction)
|
||||||
|
expression = self.expression(exp.Distinct, expressions=[seq_get(args, 0)])
|
||||||
|
else:
|
||||||
|
args = self._parse_csv(self._parse_conjunction)
|
||||||
|
expression = seq_get(args, 0)
|
||||||
|
|
||||||
|
index = self._index
|
||||||
|
if not self._match(TokenType.R_PAREN):
|
||||||
|
# postgres: STRING_AGG([DISTINCT] expression, separator [ORDER BY expression1 {ASC | DESC} [, ...]])
|
||||||
|
order = self._parse_order(this=expression)
|
||||||
|
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
|
||||||
|
|
||||||
|
# Checks if we can parse an order clause: WITHIN GROUP (ORDER BY <order_by_expression_list> [ASC | DESC]).
|
||||||
|
# This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that
|
||||||
|
# the STRING_AGG call is parsed like in MySQL / SQLite and can thus be transpiled more easily to them.
|
||||||
|
if not self._match(TokenType.WITHIN_GROUP):
|
||||||
|
self._retreat(index)
|
||||||
|
this = exp.GroupConcat.from_arg_list(args)
|
||||||
|
self.validate_expression(this, args)
|
||||||
|
return this
|
||||||
|
|
||||||
|
self._match_l_paren() # The corresponding match_r_paren will be called in parse_function (caller)
|
||||||
|
order = self._parse_order(this=expression)
|
||||||
|
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
|
||||||
|
|
||||||
def _parse_convert(self, strict):
|
def _parse_convert(self, strict):
|
||||||
this = self._parse_column()
|
this = self._parse_column()
|
||||||
if self._match(TokenType.USING):
|
if self._match(TokenType.USING):
|
||||||
|
@ -2511,8 +2542,8 @@ class Parser(metaclass=_Parser):
|
||||||
items = [parse_result] if parse_result is not None else []
|
items = [parse_result] if parse_result is not None else []
|
||||||
|
|
||||||
while self._match(sep):
|
while self._match(sep):
|
||||||
if parse_result and self._prev_comment is not None:
|
if parse_result and self._prev_comments:
|
||||||
parse_result.comment = self._prev_comment
|
parse_result.comments = self._prev_comments
|
||||||
|
|
||||||
parse_result = parse_method()
|
parse_result = parse_method()
|
||||||
if parse_result is not None:
|
if parse_result is not None:
|
||||||
|
@ -2525,7 +2556,10 @@ class Parser(metaclass=_Parser):
|
||||||
|
|
||||||
while self._match_set(expressions):
|
while self._match_set(expressions):
|
||||||
this = self.expression(
|
this = self.expression(
|
||||||
expressions[self._prev.token_type], this=this, expression=parse_method()
|
expressions[self._prev.token_type],
|
||||||
|
this=this,
|
||||||
|
comments=self._prev_comments,
|
||||||
|
expression=parse_method(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return this
|
return this
|
||||||
|
@ -2566,6 +2600,7 @@ class Parser(metaclass=_Parser):
|
||||||
return self.expression(exp.Transaction, this=this, modes=modes)
|
return self.expression(exp.Transaction, this=this, modes=modes)
|
||||||
|
|
||||||
def _parse_commit_or_rollback(self):
|
def _parse_commit_or_rollback(self):
|
||||||
|
chain = None
|
||||||
savepoint = None
|
savepoint = None
|
||||||
is_rollback = self._prev.token_type == TokenType.ROLLBACK
|
is_rollback = self._prev.token_type == TokenType.ROLLBACK
|
||||||
|
|
||||||
|
@ -2575,9 +2610,13 @@ class Parser(metaclass=_Parser):
|
||||||
self._match_text_seq("SAVEPOINT")
|
self._match_text_seq("SAVEPOINT")
|
||||||
savepoint = self._parse_id_var()
|
savepoint = self._parse_id_var()
|
||||||
|
|
||||||
|
if self._match(TokenType.AND):
|
||||||
|
chain = not self._match_text_seq("NO")
|
||||||
|
self._match_text_seq("CHAIN")
|
||||||
|
|
||||||
if is_rollback:
|
if is_rollback:
|
||||||
return self.expression(exp.Rollback, savepoint=savepoint)
|
return self.expression(exp.Rollback, savepoint=savepoint)
|
||||||
return self.expression(exp.Commit)
|
return self.expression(exp.Commit, chain=chain)
|
||||||
|
|
||||||
def _parse_show(self):
|
def _parse_show(self):
|
||||||
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
|
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
|
||||||
|
@ -2651,14 +2690,14 @@ class Parser(metaclass=_Parser):
|
||||||
def _match_l_paren(self, expression=None):
|
def _match_l_paren(self, expression=None):
|
||||||
if not self._match(TokenType.L_PAREN):
|
if not self._match(TokenType.L_PAREN):
|
||||||
self.raise_error("Expecting (")
|
self.raise_error("Expecting (")
|
||||||
if expression and self._prev_comment:
|
if expression and self._prev_comments:
|
||||||
expression.comment = self._prev_comment
|
expression.comments = self._prev_comments
|
||||||
|
|
||||||
def _match_r_paren(self, expression=None):
|
def _match_r_paren(self, expression=None):
|
||||||
if not self._match(TokenType.R_PAREN):
|
if not self._match(TokenType.R_PAREN):
|
||||||
self.raise_error("Expecting )")
|
self.raise_error("Expecting )")
|
||||||
if expression and self._prev_comment:
|
if expression and self._prev_comments:
|
||||||
expression.comment = self._prev_comment
|
expression.comments = self._prev_comments
|
||||||
|
|
||||||
def _match_texts(self, texts):
|
def _match_texts(self, texts):
|
||||||
if self._curr and self._curr.text.upper() in texts:
|
if self._curr and self._curr.text.upper() in texts:
|
||||||
|
|
|
@ -130,18 +130,20 @@ class Step:
|
||||||
aggregations = []
|
aggregations = []
|
||||||
sequence = itertools.count()
|
sequence = itertools.count()
|
||||||
|
|
||||||
for e in expression.expressions:
|
def extract_agg_operands(expression):
|
||||||
aggregation = e.find(exp.AggFunc)
|
for agg in expression.find_all(exp.AggFunc):
|
||||||
|
for operand in agg.unnest_operands():
|
||||||
if aggregation:
|
|
||||||
projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
|
|
||||||
aggregations.append(e)
|
|
||||||
for operand in aggregation.unnest_operands():
|
|
||||||
if isinstance(operand, exp.Column):
|
if isinstance(operand, exp.Column):
|
||||||
continue
|
continue
|
||||||
if operand not in operands:
|
if operand not in operands:
|
||||||
operands[operand] = f"_a_{next(sequence)}"
|
operands[operand] = f"_a_{next(sequence)}"
|
||||||
operand.replace(exp.column(operands[operand], quoted=True))
|
operand.replace(exp.column(operands[operand], quoted=True))
|
||||||
|
|
||||||
|
for e in expression.expressions:
|
||||||
|
if e.find(exp.AggFunc):
|
||||||
|
projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
|
||||||
|
aggregations.append(e)
|
||||||
|
extract_agg_operands(e)
|
||||||
else:
|
else:
|
||||||
projections.append(e)
|
projections.append(e)
|
||||||
|
|
||||||
|
@ -156,6 +158,13 @@ class Step:
|
||||||
aggregate = Aggregate()
|
aggregate = Aggregate()
|
||||||
aggregate.source = step.name
|
aggregate.source = step.name
|
||||||
aggregate.name = step.name
|
aggregate.name = step.name
|
||||||
|
|
||||||
|
having = expression.args.get("having")
|
||||||
|
|
||||||
|
if having:
|
||||||
|
extract_agg_operands(having)
|
||||||
|
aggregate.condition = having.this
|
||||||
|
|
||||||
aggregate.operands = tuple(
|
aggregate.operands = tuple(
|
||||||
alias(operand, alias_) for operand, alias_ in operands.items()
|
alias(operand, alias_) for operand, alias_ in operands.items()
|
||||||
)
|
)
|
||||||
|
@ -172,11 +181,6 @@ class Step:
|
||||||
aggregate.add_dependency(step)
|
aggregate.add_dependency(step)
|
||||||
step = aggregate
|
step = aggregate
|
||||||
|
|
||||||
having = expression.args.get("having")
|
|
||||||
|
|
||||||
if having:
|
|
||||||
step.condition = having.this
|
|
||||||
|
|
||||||
order = expression.args.get("order")
|
order = expression.args.get("order")
|
||||||
|
|
||||||
if order:
|
if order:
|
||||||
|
@ -188,6 +192,17 @@ class Step:
|
||||||
|
|
||||||
step.projections = projections
|
step.projections = projections
|
||||||
|
|
||||||
|
if isinstance(expression, exp.Select) and expression.args.get("distinct"):
|
||||||
|
distinct = Aggregate()
|
||||||
|
distinct.source = step.name
|
||||||
|
distinct.name = step.name
|
||||||
|
distinct.group = {
|
||||||
|
e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name)
|
||||||
|
for e in projections or expression.expressions
|
||||||
|
}
|
||||||
|
distinct.add_dependency(step)
|
||||||
|
step = distinct
|
||||||
|
|
||||||
limit = expression.args.get("limit")
|
limit = expression.args.get("limit")
|
||||||
|
|
||||||
if limit:
|
if limit:
|
||||||
|
@ -231,6 +246,9 @@ class Step:
|
||||||
if self.condition:
|
if self.condition:
|
||||||
lines.append(f"{nested}Condition: {self.condition.sql()}")
|
lines.append(f"{nested}Condition: {self.condition.sql()}")
|
||||||
|
|
||||||
|
if self.limit is not math.inf:
|
||||||
|
lines.append(f"{nested}Limit: {self.limit}")
|
||||||
|
|
||||||
if self.dependencies:
|
if self.dependencies:
|
||||||
lines.append(f"{nested}Dependencies:")
|
lines.append(f"{nested}Dependencies:")
|
||||||
for dependency in self.dependencies:
|
for dependency in self.dependencies:
|
||||||
|
@ -258,12 +276,7 @@ class Scan(Step):
|
||||||
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
|
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
|
||||||
) -> Step:
|
) -> Step:
|
||||||
table = expression
|
table = expression
|
||||||
alias_ = expression.alias
|
alias_ = expression.alias_or_name
|
||||||
|
|
||||||
if not alias_:
|
|
||||||
raise UnsupportedError(
|
|
||||||
"Tables/Subqueries must be aliased. Run it through the optimizer"
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(expression, exp.Subquery):
|
if isinstance(expression, exp.Subquery):
|
||||||
table = expression.this
|
table = expression.this
|
||||||
|
@ -338,6 +351,9 @@ class Aggregate(Step):
|
||||||
lines.append(f"{indent}Group:")
|
lines.append(f"{indent}Group:")
|
||||||
for expression in self.group.values():
|
for expression in self.group.values():
|
||||||
lines.append(f"{indent} - {expression.sql()}")
|
lines.append(f"{indent} - {expression.sql()}")
|
||||||
|
if self.condition:
|
||||||
|
lines.append(f"{indent}Having:")
|
||||||
|
lines.append(f"{indent} - {self.condition.sql()}")
|
||||||
if self.operands:
|
if self.operands:
|
||||||
lines.append(f"{indent}Operands:")
|
lines.append(f"{indent}Operands:")
|
||||||
for expression in self.operands:
|
for expression in self.operands:
|
||||||
|
|
|
@ -81,6 +81,7 @@ class TokenType(AutoName):
|
||||||
BINARY = auto()
|
BINARY = auto()
|
||||||
VARBINARY = auto()
|
VARBINARY = auto()
|
||||||
JSON = auto()
|
JSON = auto()
|
||||||
|
JSONB = auto()
|
||||||
TIMESTAMP = auto()
|
TIMESTAMP = auto()
|
||||||
TIMESTAMPTZ = auto()
|
TIMESTAMPTZ = auto()
|
||||||
TIMESTAMPLTZ = auto()
|
TIMESTAMPLTZ = auto()
|
||||||
|
@ -91,6 +92,7 @@ class TokenType(AutoName):
|
||||||
NULLABLE = auto()
|
NULLABLE = auto()
|
||||||
GEOMETRY = auto()
|
GEOMETRY = auto()
|
||||||
HLLSKETCH = auto()
|
HLLSKETCH = auto()
|
||||||
|
HSTORE = auto()
|
||||||
SUPER = auto()
|
SUPER = auto()
|
||||||
SERIAL = auto()
|
SERIAL = auto()
|
||||||
SMALLSERIAL = auto()
|
SMALLSERIAL = auto()
|
||||||
|
@ -113,6 +115,7 @@ class TokenType(AutoName):
|
||||||
APPLY = auto()
|
APPLY = auto()
|
||||||
ARRAY = auto()
|
ARRAY = auto()
|
||||||
ASC = auto()
|
ASC = auto()
|
||||||
|
ASOF = auto()
|
||||||
AT_TIME_ZONE = auto()
|
AT_TIME_ZONE = auto()
|
||||||
AUTO_INCREMENT = auto()
|
AUTO_INCREMENT = auto()
|
||||||
BEGIN = auto()
|
BEGIN = auto()
|
||||||
|
@ -130,6 +133,7 @@ class TokenType(AutoName):
|
||||||
COMMAND = auto()
|
COMMAND = auto()
|
||||||
COMMENT = auto()
|
COMMENT = auto()
|
||||||
COMMIT = auto()
|
COMMIT = auto()
|
||||||
|
COMPOUND = auto()
|
||||||
CONSTRAINT = auto()
|
CONSTRAINT = auto()
|
||||||
CREATE = auto()
|
CREATE = auto()
|
||||||
CROSS = auto()
|
CROSS = auto()
|
||||||
|
@ -271,6 +275,7 @@ class TokenType(AutoName):
|
||||||
UNBOUNDED = auto()
|
UNBOUNDED = auto()
|
||||||
UNCACHE = auto()
|
UNCACHE = auto()
|
||||||
UNION = auto()
|
UNION = auto()
|
||||||
|
UNLOGGED = auto()
|
||||||
UNNEST = auto()
|
UNNEST = auto()
|
||||||
UNPIVOT = auto()
|
UNPIVOT = auto()
|
||||||
UPDATE = auto()
|
UPDATE = auto()
|
||||||
|
@ -291,7 +296,7 @@ class TokenType(AutoName):
|
||||||
|
|
||||||
|
|
||||||
class Token:
|
class Token:
|
||||||
__slots__ = ("token_type", "text", "line", "col", "comment")
|
__slots__ = ("token_type", "text", "line", "col", "comments")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def number(cls, number: int) -> Token:
|
def number(cls, number: int) -> Token:
|
||||||
|
@ -319,13 +324,13 @@ class Token:
|
||||||
text: str,
|
text: str,
|
||||||
line: int = 1,
|
line: int = 1,
|
||||||
col: int = 1,
|
col: int = 1,
|
||||||
comment: t.Optional[str] = None,
|
comments: t.List[str] = [],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.token_type = token_type
|
self.token_type = token_type
|
||||||
self.text = text
|
self.text = text
|
||||||
self.line = line
|
self.line = line
|
||||||
self.col = max(col - len(text), 1)
|
self.col = max(col - len(text), 1)
|
||||||
self.comment = comment
|
self.comments = comments
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
|
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
|
||||||
|
@ -452,6 +457,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"COLLATE": TokenType.COLLATE,
|
"COLLATE": TokenType.COLLATE,
|
||||||
"COMMENT": TokenType.SCHEMA_COMMENT,
|
"COMMENT": TokenType.SCHEMA_COMMENT,
|
||||||
"COMMIT": TokenType.COMMIT,
|
"COMMIT": TokenType.COMMIT,
|
||||||
|
"COMPOUND": TokenType.COMPOUND,
|
||||||
"CONSTRAINT": TokenType.CONSTRAINT,
|
"CONSTRAINT": TokenType.CONSTRAINT,
|
||||||
"CREATE": TokenType.CREATE,
|
"CREATE": TokenType.CREATE,
|
||||||
"CROSS": TokenType.CROSS,
|
"CROSS": TokenType.CROSS,
|
||||||
|
@ -582,8 +588,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"TRAILING": TokenType.TRAILING,
|
"TRAILING": TokenType.TRAILING,
|
||||||
"UNBOUNDED": TokenType.UNBOUNDED,
|
"UNBOUNDED": TokenType.UNBOUNDED,
|
||||||
"UNION": TokenType.UNION,
|
"UNION": TokenType.UNION,
|
||||||
"UNPIVOT": TokenType.UNPIVOT,
|
"UNLOGGED": TokenType.UNLOGGED,
|
||||||
"UNNEST": TokenType.UNNEST,
|
"UNNEST": TokenType.UNNEST,
|
||||||
|
"UNPIVOT": TokenType.UNPIVOT,
|
||||||
"UPDATE": TokenType.UPDATE,
|
"UPDATE": TokenType.UPDATE,
|
||||||
"USE": TokenType.USE,
|
"USE": TokenType.USE,
|
||||||
"USING": TokenType.USING,
|
"USING": TokenType.USING,
|
||||||
|
@ -686,12 +693,12 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"_current",
|
"_current",
|
||||||
"_line",
|
"_line",
|
||||||
"_col",
|
"_col",
|
||||||
"_comment",
|
"_comments",
|
||||||
"_char",
|
"_char",
|
||||||
"_end",
|
"_end",
|
||||||
"_peek",
|
"_peek",
|
||||||
"_prev_token_line",
|
"_prev_token_line",
|
||||||
"_prev_token_comment",
|
"_prev_token_comments",
|
||||||
"_prev_token_type",
|
"_prev_token_type",
|
||||||
"_replace_backslash",
|
"_replace_backslash",
|
||||||
)
|
)
|
||||||
|
@ -708,13 +715,13 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
self._current = 0
|
self._current = 0
|
||||||
self._line = 1
|
self._line = 1
|
||||||
self._col = 1
|
self._col = 1
|
||||||
self._comment = None
|
self._comments: t.List[str] = []
|
||||||
|
|
||||||
self._char = None
|
self._char = None
|
||||||
self._end = None
|
self._end = None
|
||||||
self._peek = None
|
self._peek = None
|
||||||
self._prev_token_line = -1
|
self._prev_token_line = -1
|
||||||
self._prev_token_comment = None
|
self._prev_token_comments: t.List[str] = []
|
||||||
self._prev_token_type = None
|
self._prev_token_type = None
|
||||||
|
|
||||||
def tokenize(self, sql: str) -> t.List[Token]:
|
def tokenize(self, sql: str) -> t.List[Token]:
|
||||||
|
@ -767,7 +774,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
|
|
||||||
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
|
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
|
||||||
self._prev_token_line = self._line
|
self._prev_token_line = self._line
|
||||||
self._prev_token_comment = self._comment
|
self._prev_token_comments = self._comments
|
||||||
self._prev_token_type = token_type # type: ignore
|
self._prev_token_type = token_type # type: ignore
|
||||||
self.tokens.append(
|
self.tokens.append(
|
||||||
Token(
|
Token(
|
||||||
|
@ -775,10 +782,10 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
self._text if text is None else text,
|
self._text if text is None else text,
|
||||||
self._line,
|
self._line,
|
||||||
self._col,
|
self._col,
|
||||||
self._comment,
|
self._comments,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self._comment = None
|
self._comments = []
|
||||||
|
|
||||||
if token_type in self.COMMANDS and (
|
if token_type in self.COMMANDS and (
|
||||||
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
|
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
|
||||||
|
@ -857,22 +864,18 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
while not self._end and self._chars(comment_end_size) != comment_end:
|
while not self._end and self._chars(comment_end_size) != comment_end:
|
||||||
self._advance()
|
self._advance()
|
||||||
|
|
||||||
self._comment = self._text[comment_start_size : -comment_end_size + 1] # type: ignore
|
self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore
|
||||||
self._advance(comment_end_size - 1)
|
self._advance(comment_end_size - 1)
|
||||||
else:
|
else:
|
||||||
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore
|
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore
|
||||||
self._advance()
|
self._advance()
|
||||||
self._comment = self._text[comment_start_size:] # type: ignore
|
self._comments.append(self._text[comment_start_size:]) # type: ignore
|
||||||
|
|
||||||
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. If both
|
|
||||||
# types of comment can be attached to a token, the trailing one is discarded in favour of the leading one.
|
|
||||||
|
|
||||||
|
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding.
|
||||||
|
# Multiple consecutive comments are preserved by appending them to the current comments list.
|
||||||
if comment_start_line == self._prev_token_line:
|
if comment_start_line == self._prev_token_line:
|
||||||
if self._prev_token_comment is None:
|
self.tokens[-1].comments.extend(self._comments)
|
||||||
self.tokens[-1].comment = self._comment
|
self._comments = []
|
||||||
self._prev_token_comment = self._comment
|
|
||||||
|
|
||||||
self._comment = None
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,8 @@ from __future__ import annotations
|
||||||
|
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
|
from sqlglot.helper import find_new_name
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from sqlglot.generator import Generator
|
from sqlglot.generator import Generator
|
||||||
|
|
||||||
|
@ -43,6 +45,43 @@ def unalias_group(expression: exp.Expression) -> exp.Expression:
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
|
|
||||||
|
def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
|
||||||
|
"""
|
||||||
|
Convert SELECT DISTINCT ON statements to a subquery with a window function.
|
||||||
|
|
||||||
|
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression: the expression that will be transformed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The transformed expression.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
isinstance(expression, exp.Select)
|
||||||
|
and expression.args.get("distinct")
|
||||||
|
and expression.args["distinct"].args.get("on")
|
||||||
|
and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
|
||||||
|
):
|
||||||
|
distinct_cols = [e.copy() for e in expression.args["distinct"].args["on"].expressions]
|
||||||
|
outer_selects = [e.copy() for e in expression.expressions]
|
||||||
|
nested = expression.copy()
|
||||||
|
nested.args["distinct"].pop()
|
||||||
|
row_number = find_new_name(expression.named_selects, "_row_number")
|
||||||
|
window = exp.Window(
|
||||||
|
this=exp.RowNumber(),
|
||||||
|
partition_by=distinct_cols,
|
||||||
|
)
|
||||||
|
order = nested.args.get("order")
|
||||||
|
if order:
|
||||||
|
window.set("order", order.copy())
|
||||||
|
order.pop()
|
||||||
|
window = exp.alias_(window, row_number)
|
||||||
|
nested.select(window, copy=False)
|
||||||
|
return exp.select(*outer_selects).from_(nested.subquery()).where(f'"{row_number}" = 1')
|
||||||
|
return expression
|
||||||
|
|
||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
|
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
|
||||||
to_sql: t.Callable[[Generator, exp.Expression], str],
|
to_sql: t.Callable[[Generator, exp.Expression], str],
|
||||||
|
@ -81,3 +120,4 @@ def delegate(attr: str) -> t.Callable:
|
||||||
|
|
||||||
|
|
||||||
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
|
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
|
||||||
|
ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))}
|
||||||
|
|
|
@ -1276,7 +1276,7 @@ class TestFunctions(unittest.TestCase):
|
||||||
col = SF.concat(SF.col("cola"), SF.col("colb"))
|
col = SF.concat(SF.col("cola"), SF.col("colb"))
|
||||||
self.assertEqual("CONCAT(cola, colb)", col.sql())
|
self.assertEqual("CONCAT(cola, colb)", col.sql())
|
||||||
col_single = SF.concat("cola")
|
col_single = SF.concat("cola")
|
||||||
self.assertEqual("CONCAT(cola)", col_single.sql())
|
self.assertEqual("cola", col_single.sql())
|
||||||
|
|
||||||
def test_array_position(self):
|
def test_array_position(self):
|
||||||
col_str = SF.array_position("cola", SF.col("colb"))
|
col_str = SF.array_position("cola", SF.col("colb"))
|
||||||
|
|
|
@ -10,6 +10,10 @@ class TestClickhouse(Validator):
|
||||||
self.validate_identity("SELECT * FROM x AS y FINAL")
|
self.validate_identity("SELECT * FROM x AS y FINAL")
|
||||||
self.validate_identity("'a' IN mapKeys(map('a', 1, 'b', 2))")
|
self.validate_identity("'a' IN mapKeys(map('a', 1, 'b', 2))")
|
||||||
self.validate_identity("CAST((1, 2) AS Tuple(a Int8, b Int16))")
|
self.validate_identity("CAST((1, 2) AS Tuple(a Int8, b Int16))")
|
||||||
|
self.validate_identity("SELECT * FROM foo LEFT ANY JOIN bla")
|
||||||
|
self.validate_identity("SELECT * FROM foo LEFT ASOF JOIN bla")
|
||||||
|
self.validate_identity("SELECT * FROM foo ASOF JOIN bla")
|
||||||
|
self.validate_identity("SELECT * FROM foo ANY JOIN bla")
|
||||||
|
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
|
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
|
||||||
|
|
|
@ -997,6 +997,13 @@ class TestDialect(Validator):
|
||||||
"spark": "CONCAT_WS('-', x)",
|
"spark": "CONCAT_WS('-', x)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"CONCAT(a)",
|
||||||
|
write={
|
||||||
|
"mysql": "a",
|
||||||
|
"tsql": "a",
|
||||||
|
},
|
||||||
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"IF(x > 1, 1, 0)",
|
"IF(x > 1, 1, 0)",
|
||||||
write={
|
write={
|
||||||
|
@ -1263,8 +1270,8 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"""/* comment1 */
|
"""/* comment1 */
|
||||||
SELECT
|
SELECT
|
||||||
x, -- comment2
|
x, /* comment2 */
|
||||||
y -- comment3""",
|
y /* comment3 */""",
|
||||||
read={
|
read={
|
||||||
"mysql": """SELECT # comment1
|
"mysql": """SELECT # comment1
|
||||||
x, # comment2
|
x, # comment2
|
||||||
|
|
|
@ -89,6 +89,8 @@ class TestDuckDB(Validator):
|
||||||
"presto": "CAST(COL AS ARRAY(BIGINT))",
|
"presto": "CAST(COL AS ARRAY(BIGINT))",
|
||||||
"hive": "CAST(COL AS ARRAY<BIGINT>)",
|
"hive": "CAST(COL AS ARRAY<BIGINT>)",
|
||||||
"spark": "CAST(COL AS ARRAY<LONG>)",
|
"spark": "CAST(COL AS ARRAY<LONG>)",
|
||||||
|
"postgres": "CAST(COL AS BIGINT[])",
|
||||||
|
"snowflake": "CAST(COL AS ARRAY)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -104,6 +106,10 @@ class TestDuckDB(Validator):
|
||||||
"spark": "ARRAY(0, 1, 2)",
|
"spark": "ARRAY(0, 1, 2)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT ARRAY_LENGTH([0], 1) AS x",
|
||||||
|
write={"duckdb": "SELECT ARRAY_LENGTH(LIST_VALUE(0), 1) AS x"},
|
||||||
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"REGEXP_MATCHES(x, y)",
|
"REGEXP_MATCHES(x, y)",
|
||||||
write={
|
write={
|
||||||
|
|
|
@ -139,7 +139,7 @@ class TestHive(Validator):
|
||||||
"CREATE TABLE test STORED AS parquet TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1",
|
"CREATE TABLE test STORED AS parquet TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1",
|
||||||
write={
|
write={
|
||||||
"duckdb": "CREATE TABLE test AS SELECT 1",
|
"duckdb": "CREATE TABLE test AS SELECT 1",
|
||||||
"presto": "CREATE TABLE test WITH (FORMAT='parquet', x='1', Z='2') AS SELECT 1",
|
"presto": "CREATE TABLE test WITH (FORMAT='PARQUET', x='1', Z='2') AS SELECT 1",
|
||||||
"hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1",
|
"hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1",
|
||||||
"spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1",
|
"spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1",
|
||||||
},
|
},
|
||||||
|
@ -459,6 +459,7 @@ class TestHive(Validator):
|
||||||
"hive": "MAP(a, b, c, d)",
|
"hive": "MAP(a, b, c, d)",
|
||||||
"presto": "MAP(ARRAY[a, c], ARRAY[b, d])",
|
"presto": "MAP(ARRAY[a, c], ARRAY[b, d])",
|
||||||
"spark": "MAP(a, b, c, d)",
|
"spark": "MAP(a, b, c, d)",
|
||||||
|
"snowflake": "OBJECT_CONSTRUCT(a, b, c, d)",
|
||||||
},
|
},
|
||||||
write={
|
write={
|
||||||
"": "MAP(ARRAY(a, c), ARRAY(b, d))",
|
"": "MAP(ARRAY(a, c), ARRAY(b, d))",
|
||||||
|
@ -467,6 +468,7 @@ class TestHive(Validator):
|
||||||
"presto": "MAP(ARRAY[a, c], ARRAY[b, d])",
|
"presto": "MAP(ARRAY[a, c], ARRAY[b, d])",
|
||||||
"hive": "MAP(a, b, c, d)",
|
"hive": "MAP(a, b, c, d)",
|
||||||
"spark": "MAP(a, b, c, d)",
|
"spark": "MAP(a, b, c, d)",
|
||||||
|
"snowflake": "OBJECT_CONSTRUCT(a, b, c, d)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -476,6 +478,7 @@ class TestHive(Validator):
|
||||||
"presto": "MAP(ARRAY[a], ARRAY[b])",
|
"presto": "MAP(ARRAY[a], ARRAY[b])",
|
||||||
"hive": "MAP(a, b)",
|
"hive": "MAP(a, b)",
|
||||||
"spark": "MAP(a, b)",
|
"spark": "MAP(a, b)",
|
||||||
|
"snowflake": "OBJECT_CONSTRUCT(a, b)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
|
|
@ -23,6 +23,8 @@ class TestMySQL(Validator):
|
||||||
self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')")
|
self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')")
|
||||||
self.validate_identity("@@GLOBAL.max_connections")
|
self.validate_identity("@@GLOBAL.max_connections")
|
||||||
|
|
||||||
|
self.validate_identity("CREATE TABLE A LIKE B")
|
||||||
|
|
||||||
# SET Commands
|
# SET Commands
|
||||||
self.validate_identity("SET @var_name = expr")
|
self.validate_identity("SET @var_name = expr")
|
||||||
self.validate_identity("SET @name = 43")
|
self.validate_identity("SET @name = 43")
|
||||||
|
@ -177,14 +179,27 @@ class TestMySQL(Validator):
|
||||||
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
|
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
|
||||||
write={
|
write={
|
||||||
"mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR ',')",
|
"mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR ',')",
|
||||||
"sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
|
"sqlite": "GROUP_CONCAT(DISTINCT x)",
|
||||||
|
"tsql": "STRING_AGG(x, ',') WITHIN GROUP (ORDER BY y DESC)",
|
||||||
|
"postgres": "STRING_AGG(DISTINCT x, ',' ORDER BY y DESC NULLS LAST)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"GROUP_CONCAT(x ORDER BY y SEPARATOR z)",
|
||||||
|
write={
|
||||||
|
"mysql": "GROUP_CONCAT(x ORDER BY y SEPARATOR z)",
|
||||||
|
"sqlite": "GROUP_CONCAT(x, z)",
|
||||||
|
"tsql": "STRING_AGG(x, z) WITHIN GROUP (ORDER BY y)",
|
||||||
|
"postgres": "STRING_AGG(x, z ORDER BY y NULLS FIRST)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')",
|
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')",
|
||||||
write={
|
write={
|
||||||
"mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')",
|
"mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')",
|
||||||
"sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC, '')",
|
"sqlite": "GROUP_CONCAT(DISTINCT x, '')",
|
||||||
|
"tsql": "STRING_AGG(x, '') WITHIN GROUP (ORDER BY y DESC)",
|
||||||
|
"postgres": "STRING_AGG(DISTINCT x, '' ORDER BY y DESC NULLS LAST)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_identity(
|
self.validate_identity(
|
||||||
|
|
|
@ -6,6 +6,9 @@ class TestPostgres(Validator):
|
||||||
dialect = "postgres"
|
dialect = "postgres"
|
||||||
|
|
||||||
def test_ddl(self):
|
def test_ddl(self):
|
||||||
|
self.validate_identity("CREATE TABLE test (foo HSTORE)")
|
||||||
|
self.validate_identity("CREATE TABLE test (foo JSONB)")
|
||||||
|
self.validate_identity("CREATE TABLE test (foo VARCHAR(64)[])")
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)",
|
"CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)",
|
||||||
write={
|
write={
|
||||||
|
@ -60,6 +63,12 @@ class TestPostgres(Validator):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_postgres(self):
|
def test_postgres(self):
|
||||||
|
self.validate_identity("SELECT ARRAY[1, 2, 3]")
|
||||||
|
self.validate_identity("SELECT ARRAY_LENGTH(ARRAY[1, 2, 3], 1)")
|
||||||
|
self.validate_identity("STRING_AGG(x, y)")
|
||||||
|
self.validate_identity("STRING_AGG(x, ',' ORDER BY y)")
|
||||||
|
self.validate_identity("STRING_AGG(x, ',' ORDER BY y DESC)")
|
||||||
|
self.validate_identity("STRING_AGG(DISTINCT x, ',' ORDER BY y DESC)")
|
||||||
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END")
|
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END")
|
||||||
self.validate_identity(
|
self.validate_identity(
|
||||||
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END"
|
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END"
|
||||||
|
@ -86,6 +95,14 @@ class TestPostgres(Validator):
|
||||||
self.validate_identity("SELECT e'\\xDEADBEEF'")
|
self.validate_identity("SELECT e'\\xDEADBEEF'")
|
||||||
self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
|
self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"END WORK AND NO CHAIN",
|
||||||
|
write={"postgres": "COMMIT AND NO CHAIN"},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"END AND CHAIN",
|
||||||
|
write={"postgres": "COMMIT AND CHAIN"},
|
||||||
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CREATE TABLE x (a UUID, b BYTEA)",
|
"CREATE TABLE x (a UUID, b BYTEA)",
|
||||||
write={
|
write={
|
||||||
|
@ -95,6 +112,10 @@ class TestPostgres(Validator):
|
||||||
"spark": "CREATE TABLE x (a UUID, b BINARY)",
|
"spark": "CREATE TABLE x (a UUID, b BINARY)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.validate_identity(
|
||||||
|
"CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)"
|
||||||
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS 1 PRECEDING)",
|
"SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS 1 PRECEDING)",
|
||||||
write={
|
write={
|
||||||
|
|
|
@ -13,6 +13,7 @@ class TestPresto(Validator):
|
||||||
"duckdb": "CAST(a AS INT[])",
|
"duckdb": "CAST(a AS INT[])",
|
||||||
"presto": "CAST(a AS ARRAY(INTEGER))",
|
"presto": "CAST(a AS ARRAY(INTEGER))",
|
||||||
"spark": "CAST(a AS ARRAY<INT>)",
|
"spark": "CAST(a AS ARRAY<INT>)",
|
||||||
|
"snowflake": "CAST(a AS ARRAY)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -31,6 +32,7 @@ class TestPresto(Validator):
|
||||||
"duckdb": "CAST(LIST_VALUE(1, 2) AS BIGINT[])",
|
"duckdb": "CAST(LIST_VALUE(1, 2) AS BIGINT[])",
|
||||||
"presto": "CAST(ARRAY[1, 2] AS ARRAY(BIGINT))",
|
"presto": "CAST(ARRAY[1, 2] AS ARRAY(BIGINT))",
|
||||||
"spark": "CAST(ARRAY(1, 2) AS ARRAY<LONG>)",
|
"spark": "CAST(ARRAY(1, 2) AS ARRAY<LONG>)",
|
||||||
|
"snowflake": "CAST([1, 2] AS ARRAY)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -41,6 +43,7 @@ class TestPresto(Validator):
|
||||||
"presto": "CAST(MAP(ARRAY[1], ARRAY[1]) AS MAP(INTEGER, INTEGER))",
|
"presto": "CAST(MAP(ARRAY[1], ARRAY[1]) AS MAP(INTEGER, INTEGER))",
|
||||||
"hive": "CAST(MAP(1, 1) AS MAP<INT, INT>)",
|
"hive": "CAST(MAP(1, 1) AS MAP<INT, INT>)",
|
||||||
"spark": "CAST(MAP_FROM_ARRAYS(ARRAY(1), ARRAY(1)) AS MAP<INT, INT>)",
|
"spark": "CAST(MAP_FROM_ARRAYS(ARRAY(1), ARRAY(1)) AS MAP<INT, INT>)",
|
||||||
|
"snowflake": "CAST(OBJECT_CONSTRUCT(1, 1) AS OBJECT)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -51,6 +54,7 @@ class TestPresto(Validator):
|
||||||
"presto": "CAST(MAP(ARRAY['a', 'b', 'c'], ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]]) AS MAP(VARCHAR, ARRAY(INTEGER)))",
|
"presto": "CAST(MAP(ARRAY['a', 'b', 'c'], ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]]) AS MAP(VARCHAR, ARRAY(INTEGER)))",
|
||||||
"hive": "CAST(MAP('a', ARRAY(1), 'b', ARRAY(2), 'c', ARRAY(3)) AS MAP<STRING, ARRAY<INT>>)",
|
"hive": "CAST(MAP('a', ARRAY(1), 'b', ARRAY(2), 'c', ARRAY(3)) AS MAP<STRING, ARRAY<INT>>)",
|
||||||
"spark": "CAST(MAP_FROM_ARRAYS(ARRAY('a', 'b', 'c'), ARRAY(ARRAY(1), ARRAY(2), ARRAY(3))) AS MAP<STRING, ARRAY<INT>>)",
|
"spark": "CAST(MAP_FROM_ARRAYS(ARRAY('a', 'b', 'c'), ARRAY(ARRAY(1), ARRAY(2), ARRAY(3))) AS MAP<STRING, ARRAY<INT>>)",
|
||||||
|
"snowflake": "CAST(OBJECT_CONSTRUCT('a', [1], 'b', [2], 'c', [3]) AS OBJECT)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -393,6 +397,7 @@ class TestPresto(Validator):
|
||||||
write={
|
write={
|
||||||
"hive": UnsupportedError,
|
"hive": UnsupportedError,
|
||||||
"spark": "MAP_FROM_ARRAYS(a, b)",
|
"spark": "MAP_FROM_ARRAYS(a, b)",
|
||||||
|
"snowflake": UnsupportedError,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -401,6 +406,7 @@ class TestPresto(Validator):
|
||||||
"hive": "MAP(a, c, b, d)",
|
"hive": "MAP(a, c, b, d)",
|
||||||
"presto": "MAP(ARRAY[a, b], ARRAY[c, d])",
|
"presto": "MAP(ARRAY[a, b], ARRAY[c, d])",
|
||||||
"spark": "MAP_FROM_ARRAYS(ARRAY(a, b), ARRAY(c, d))",
|
"spark": "MAP_FROM_ARRAYS(ARRAY(a, b), ARRAY(c, d))",
|
||||||
|
"snowflake": "OBJECT_CONSTRUCT(a, c, b, d)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -409,6 +415,7 @@ class TestPresto(Validator):
|
||||||
"hive": "MAP('a', 'b')",
|
"hive": "MAP('a', 'b')",
|
||||||
"presto": "MAP(ARRAY['a'], ARRAY['b'])",
|
"presto": "MAP(ARRAY['a'], ARRAY['b'])",
|
||||||
"spark": "MAP_FROM_ARRAYS(ARRAY('a'), ARRAY('b'))",
|
"spark": "MAP_FROM_ARRAYS(ARRAY('a'), ARRAY('b'))",
|
||||||
|
"snowflake": "OBJECT_CONSTRUCT('a', 'b')",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
|
|
@ -50,6 +50,12 @@ class TestRedshift(Validator):
|
||||||
"redshift": 'SELECT tablename, "column" FROM pg_table_def WHERE "column" LIKE \'%start\\\\_%\' LIMIT 5'
|
"redshift": 'SELECT tablename, "column" FROM pg_table_def WHERE "column" LIKE \'%start\\\\_%\' LIMIT 5'
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
|
||||||
|
write={
|
||||||
|
"redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_identity(self):
|
def test_identity(self):
|
||||||
self.validate_identity("CAST('bla' AS SUPER)")
|
self.validate_identity("CAST('bla' AS SUPER)")
|
||||||
|
@ -64,3 +70,13 @@ class TestRedshift(Validator):
|
||||||
self.validate_identity(
|
self.validate_identity(
|
||||||
"SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'"
|
"SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'"
|
||||||
)
|
)
|
||||||
|
self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE AUTO")
|
||||||
|
self.validate_identity(
|
||||||
|
"CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid)"
|
||||||
|
)
|
||||||
|
self.validate_identity(
|
||||||
|
"COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'"
|
||||||
|
)
|
||||||
|
self.validate_identity(
|
||||||
|
"UNLOAD ('select * from venue') TO 's3://mybucket/unload/' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'"
|
||||||
|
)
|
||||||
|
|
|
@ -172,13 +172,28 @@ class TestSnowflake(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"trim(date_column, 'UTC')",
|
"trim(date_column, 'UTC')",
|
||||||
write={
|
write={
|
||||||
|
"bigquery": "TRIM(date_column, 'UTC')",
|
||||||
"snowflake": "TRIM(date_column, 'UTC')",
|
"snowflake": "TRIM(date_column, 'UTC')",
|
||||||
"postgres": "TRIM('UTC' FROM date_column)",
|
"postgres": "TRIM('UTC' FROM date_column)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"trim(date_column)",
|
"trim(date_column)",
|
||||||
write={"snowflake": "TRIM(date_column)"},
|
write={
|
||||||
|
"snowflake": "TRIM(date_column)",
|
||||||
|
"bigquery": "TRIM(date_column)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"DECODE(x, a, b, c, d)",
|
||||||
|
read={
|
||||||
|
"": "MATCHES(x, a, b, c, d)",
|
||||||
|
},
|
||||||
|
write={
|
||||||
|
"": "MATCHES(x, a, b, c, d)",
|
||||||
|
"oracle": "DECODE(x, a, b, c, d)",
|
||||||
|
"snowflake": "DECODE(x, a, b, c, d)",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_null_treatment(self):
|
def test_null_treatment(self):
|
||||||
|
@ -370,7 +385,8 @@ class TestSnowflake(Validator):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""}
|
r"""SELECT * FROM TABLE(?)""",
|
||||||
|
write={"snowflake": r"""SELECT * FROM TABLE(?)"""},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
|
|
@ -32,13 +32,14 @@ class TestSpark(Validator):
|
||||||
"presto": "CREATE TABLE db.example_table (col_a ARRAY(INTEGER), col_b ARRAY(ARRAY(INTEGER)))",
|
"presto": "CREATE TABLE db.example_table (col_a ARRAY(INTEGER), col_b ARRAY(ARRAY(INTEGER)))",
|
||||||
"hive": "CREATE TABLE db.example_table (col_a ARRAY<INT>, col_b ARRAY<ARRAY<INT>>)",
|
"hive": "CREATE TABLE db.example_table (col_a ARRAY<INT>, col_b ARRAY<ARRAY<INT>>)",
|
||||||
"spark": "CREATE TABLE db.example_table (col_a ARRAY<INT>, col_b ARRAY<ARRAY<INT>>)",
|
"spark": "CREATE TABLE db.example_table (col_a ARRAY<INT>, col_b ARRAY<ARRAY<INT>>)",
|
||||||
|
"snowflake": "CREATE TABLE db.example_table (col_a ARRAY, col_b ARRAY)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'",
|
"CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'",
|
||||||
write={
|
write={
|
||||||
"duckdb": "CREATE TABLE x",
|
"duckdb": "CREATE TABLE x",
|
||||||
"presto": "CREATE TABLE x WITH (TABLE_FORMAT = 'ICEBERG', PARTITIONED_BY=ARRAY['MONTHS'])",
|
"presto": "CREATE TABLE x WITH (TABLE_FORMAT='ICEBERG', PARTITIONED_BY=ARRAY['MONTHS'])",
|
||||||
"hive": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'",
|
"hive": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'",
|
||||||
"spark": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'",
|
"spark": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'",
|
||||||
},
|
},
|
||||||
|
@ -94,6 +95,13 @@ TBLPROPERTIES (
|
||||||
pretty=True,
|
pretty=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"CACHE TABLE testCache OPTIONS ('storageLevel' 'DISK_ONLY') SELECT * FROM testData",
|
||||||
|
write={
|
||||||
|
"spark": "CACHE TABLE testCache OPTIONS('storageLevel' = 'DISK_ONLY') AS SELECT * FROM testData"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_to_date(self):
|
def test_to_date(self):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"TO_DATE(x, 'yyyy-MM-dd')",
|
"TO_DATE(x, 'yyyy-MM-dd')",
|
||||||
|
@ -271,6 +279,7 @@ TBLPROPERTIES (
|
||||||
"presto": "MAP(ARRAY[1], c)",
|
"presto": "MAP(ARRAY[1], c)",
|
||||||
"hive": "MAP(ARRAY(1), c)",
|
"hive": "MAP(ARRAY(1), c)",
|
||||||
"spark": "MAP_FROM_ARRAYS(ARRAY(1), c)",
|
"spark": "MAP_FROM_ARRAYS(ARRAY(1), c)",
|
||||||
|
"snowflake": "OBJECT_CONSTRUCT([1], c)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
|
|
@ -5,6 +5,10 @@ class TestSQLite(Validator):
|
||||||
dialect = "sqlite"
|
dialect = "sqlite"
|
||||||
|
|
||||||
def test_ddl(self):
|
def test_ddl(self):
|
||||||
|
self.validate_all(
|
||||||
|
"CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)",
|
||||||
|
write={"sqlite": "CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)"},
|
||||||
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"""
|
"""
|
||||||
CREATE TABLE "Track"
|
CREATE TABLE "Track"
|
||||||
|
|
|
@ -17,7 +17,6 @@ class TestTSQL(Validator):
|
||||||
"spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo",
|
"spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CONVERT(INT, CONVERT(NUMERIC, '444.75'))",
|
"CONVERT(INT, CONVERT(NUMERIC, '444.75'))",
|
||||||
write={
|
write={
|
||||||
|
@ -25,6 +24,33 @@ class TestTSQL(Validator):
|
||||||
"tsql": "CAST(CAST('444.75' AS NUMERIC) AS INTEGER)",
|
"tsql": "CAST(CAST('444.75' AS NUMERIC) AS INTEGER)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"STRING_AGG(x, y) WITHIN GROUP (ORDER BY z DESC)",
|
||||||
|
write={
|
||||||
|
"tsql": "STRING_AGG(x, y) WITHIN GROUP (ORDER BY z DESC)",
|
||||||
|
"mysql": "GROUP_CONCAT(x ORDER BY z DESC SEPARATOR y)",
|
||||||
|
"sqlite": "GROUP_CONCAT(x, y)",
|
||||||
|
"postgres": "STRING_AGG(x, y ORDER BY z DESC NULLS LAST)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z ASC)",
|
||||||
|
write={
|
||||||
|
"tsql": "STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z)",
|
||||||
|
"mysql": "GROUP_CONCAT(x ORDER BY z SEPARATOR '|')",
|
||||||
|
"sqlite": "GROUP_CONCAT(x, '|')",
|
||||||
|
"postgres": "STRING_AGG(x, '|' ORDER BY z NULLS FIRST)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"STRING_AGG(x, '|')",
|
||||||
|
write={
|
||||||
|
"tsql": "STRING_AGG(x, '|')",
|
||||||
|
"mysql": "GROUP_CONCAT(x SEPARATOR '|')",
|
||||||
|
"sqlite": "GROUP_CONCAT(x, '|')",
|
||||||
|
"postgres": "STRING_AGG(x, '|')",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_types(self):
|
def test_types(self):
|
||||||
self.validate_identity("CAST(x AS XML)")
|
self.validate_identity("CAST(x AS XML)")
|
||||||
|
|
7
tests/fixtures/identity.sql
vendored
7
tests/fixtures/identity.sql
vendored
|
@ -34,6 +34,7 @@ x >> 1
|
||||||
x >> 1 | 1 & 1 ^ 1
|
x >> 1 | 1 & 1 ^ 1
|
||||||
x || y
|
x || y
|
||||||
1 - -1
|
1 - -1
|
||||||
|
- -5
|
||||||
dec.x + y
|
dec.x + y
|
||||||
a.filter
|
a.filter
|
||||||
a.b.c
|
a.b.c
|
||||||
|
@ -438,6 +439,7 @@ SELECT student, score FROM tests CROSS JOIN UNNEST(scores) AS t(a, b)
|
||||||
SELECT student, score FROM tests CROSS JOIN UNNEST(scores) WITH ORDINALITY AS t(a, b)
|
SELECT student, score FROM tests CROSS JOIN UNNEST(scores) WITH ORDINALITY AS t(a, b)
|
||||||
SELECT student, score FROM tests CROSS JOIN UNNEST(x.scores) AS t(score)
|
SELECT student, score FROM tests CROSS JOIN UNNEST(x.scores) AS t(score)
|
||||||
SELECT student, score FROM tests CROSS JOIN UNNEST(ARRAY(x.scores)) AS t(score)
|
SELECT student, score FROM tests CROSS JOIN UNNEST(ARRAY(x.scores)) AS t(score)
|
||||||
|
CREATE TABLE foo (id INT PRIMARY KEY ASC)
|
||||||
CREATE TABLE a.b AS SELECT 1
|
CREATE TABLE a.b AS SELECT 1
|
||||||
CREATE TABLE a.b AS SELECT a FROM a.c
|
CREATE TABLE a.b AS SELECT a FROM a.c
|
||||||
CREATE TABLE IF NOT EXISTS x AS SELECT a FROM d
|
CREATE TABLE IF NOT EXISTS x AS SELECT a FROM d
|
||||||
|
@ -579,6 +581,7 @@ SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)
|
||||||
SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)
|
SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)
|
||||||
SELECT CAST(x AS INT) /* comment */ FROM foo
|
SELECT CAST(x AS INT) /* comment */ FROM foo
|
||||||
SELECT a /* x */, b /* x */
|
SELECT a /* x */, b /* x */
|
||||||
|
SELECT a /* x */ /* y */ /* z */, b /* k */ /* m */
|
||||||
SELECT * FROM foo /* x */, bla /* x */
|
SELECT * FROM foo /* x */, bla /* x */
|
||||||
SELECT 1 /* comment */ + 1
|
SELECT 1 /* comment */ + 1
|
||||||
SELECT 1 /* c1 */ + 2 /* c2 */
|
SELECT 1 /* c1 */ + 2 /* c2 */
|
||||||
|
@ -588,3 +591,7 @@ SELECT x FROM a.b.c /* x */, e.f.g /* x */
|
||||||
SELECT FOO(x /* c */) /* FOO */, b /* b */
|
SELECT FOO(x /* c */) /* FOO */, b /* b */
|
||||||
SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */
|
SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */
|
||||||
SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b'
|
SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b'
|
||||||
|
SELECT x AS INTO FROM bla
|
||||||
|
SELECT * INTO newevent FROM event
|
||||||
|
SELECT * INTO TEMPORARY newevent FROM event
|
||||||
|
SELECT * INTO UNLOGGED newevent FROM event
|
||||||
|
|
|
@ -77,3 +77,15 @@ WITH x_2 AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT x.id FROM x
|
||||||
-- Existing duplicate CTE
|
-- Existing duplicate CTE
|
||||||
WITH y AS (SELECT a FROM x) SELECT a FROM (SELECT a FROM x) AS y JOIN y AS z;
|
WITH y AS (SELECT a FROM x) SELECT a FROM (SELECT a FROM x) AS y JOIN y AS z;
|
||||||
WITH y AS (SELECT a FROM x) SELECT a FROM y AS y JOIN y AS z;
|
WITH y AS (SELECT a FROM x) SELECT a FROM y AS y JOIN y AS z;
|
||||||
|
|
||||||
|
-- Nested CTE
|
||||||
|
WITH cte1 AS (SELECT a FROM x) SELECT a FROM (WITH cte2 AS (SELECT a FROM cte1) SELECT a FROM cte2);
|
||||||
|
WITH cte1 AS (SELECT a FROM x), cte2 AS (SELECT a FROM cte1), cte AS (SELECT a FROM cte2 AS cte2) SELECT a FROM cte AS cte;
|
||||||
|
|
||||||
|
-- Nested CTE inside CTE
|
||||||
|
WITH cte1 AS (WITH cte2 AS (SELECT a FROM x) SELECT t.a FROM cte2 AS t) SELECT a FROM cte1;
|
||||||
|
WITH cte2 AS (SELECT a FROM x), cte1 AS (SELECT t.a FROM cte2 AS t) SELECT a FROM cte1;
|
||||||
|
|
||||||
|
-- Duplicate CTE nested in CTE
|
||||||
|
WITH cte1 AS (SELECT a FROM x), cte2 AS (WITH cte3 AS (SELECT a FROM x) SELECT a FROM cte3) SELECT a FROM cte2;
|
||||||
|
WITH cte1 AS (SELECT a FROM x), cte2 AS (SELECT a FROM cte1 AS cte3) SELECT a FROM cte2;
|
||||||
|
|
41
tests/fixtures/optimizer/lower_identities.sql
vendored
Normal file
41
tests/fixtures/optimizer/lower_identities.sql
vendored
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
SELECT a FROM x;
|
||||||
|
SELECT a FROM x;
|
||||||
|
|
||||||
|
SELECT "A" FROM "X";
|
||||||
|
SELECT "A" FROM "X";
|
||||||
|
|
||||||
|
SELECT a AS A FROM x;
|
||||||
|
SELECT a AS A FROM x;
|
||||||
|
|
||||||
|
SELECT * FROM x;
|
||||||
|
SELECT * FROM x;
|
||||||
|
|
||||||
|
SELECT A FROM x;
|
||||||
|
SELECT a FROM x;
|
||||||
|
|
||||||
|
SELECT a FROM X;
|
||||||
|
SELECT a FROM x;
|
||||||
|
|
||||||
|
SELECT A AS A FROM (SELECT a AS A FROM x);
|
||||||
|
SELECT a AS A FROM (SELECT a AS a FROM x);
|
||||||
|
|
||||||
|
SELECT a AS B FROM x ORDER BY B;
|
||||||
|
SELECT a AS B FROM x ORDER BY B;
|
||||||
|
|
||||||
|
SELECT A FROM x ORDER BY A;
|
||||||
|
SELECT a FROM x ORDER BY a;
|
||||||
|
|
||||||
|
SELECT A AS B FROM X GROUP BY A HAVING SUM(B) > 0;
|
||||||
|
SELECT a AS B FROM x GROUP BY a HAVING SUM(b) > 0;
|
||||||
|
|
||||||
|
SELECT A AS B, SUM(B) AS C FROM X GROUP BY A HAVING C > 0;
|
||||||
|
SELECT a AS B, SUM(b) AS C FROM x GROUP BY a HAVING C > 0;
|
||||||
|
|
||||||
|
SELECT A FROM X UNION SELECT A FROM X;
|
||||||
|
SELECT a FROM x UNION SELECT a FROM x;
|
||||||
|
|
||||||
|
SELECT A AS A FROM X UNION SELECT A AS A FROM X;
|
||||||
|
SELECT a AS A FROM x UNION SELECT a AS A FROM x;
|
||||||
|
|
||||||
|
(SELECT A AS A FROM X);
|
||||||
|
(SELECT a AS A FROM x);
|
15
tests/fixtures/optimizer/optimizer.sql
vendored
15
tests/fixtures/optimizer/optimizer.sql
vendored
|
@ -276,3 +276,18 @@ SELECT /*+ COALESCE(3),
|
||||||
FROM `x` AS `x`
|
FROM `x` AS `x`
|
||||||
JOIN `y` AS `y`
|
JOIN `y` AS `y`
|
||||||
ON `x`.`b` = `y`.`b`;
|
ON `x`.`b` = `y`.`b`;
|
||||||
|
|
||||||
|
WITH cte1 AS (
|
||||||
|
WITH cte2 AS (
|
||||||
|
SELECT a, b FROM x
|
||||||
|
)
|
||||||
|
SELECT a1
|
||||||
|
FROM (
|
||||||
|
WITH cte3 AS (SELECT 1)
|
||||||
|
SELECT a AS a1, b AS b1 FROM cte2
|
||||||
|
)
|
||||||
|
)
|
||||||
|
SELECT a1 FROM cte1;
|
||||||
|
SELECT
|
||||||
|
"x"."a" AS "a1"
|
||||||
|
FROM "x" AS "x";
|
||||||
|
|
9
tests/fixtures/optimizer/simplify.sql
vendored
9
tests/fixtures/optimizer/simplify.sql
vendored
|
@ -274,6 +274,15 @@ TRUE;
|
||||||
-(-1);
|
-(-1);
|
||||||
1;
|
1;
|
||||||
|
|
||||||
|
- -+1;
|
||||||
|
1;
|
||||||
|
|
||||||
|
+-1;
|
||||||
|
-1;
|
||||||
|
|
||||||
|
++1;
|
||||||
|
1;
|
||||||
|
|
||||||
0.06 - 0.01;
|
0.06 - 0.01;
|
||||||
0.05;
|
0.05;
|
||||||
|
|
||||||
|
|
67
tests/fixtures/optimizer/tpc-h/tpc-h.sql
vendored
67
tests/fixtures/optimizer/tpc-h/tpc-h.sql
vendored
|
@ -666,19 +666,7 @@ WITH "supplier_2" AS (
|
||||||
FROM "nation" AS "nation"
|
FROM "nation" AS "nation"
|
||||||
WHERE
|
WHERE
|
||||||
"nation"."n_name" = 'GERMANY'
|
"nation"."n_name" = 'GERMANY'
|
||||||
)
|
), "_u_0" AS (
|
||||||
SELECT
|
|
||||||
"partsupp"."ps_partkey" AS "ps_partkey",
|
|
||||||
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value"
|
|
||||||
FROM "partsupp" AS "partsupp"
|
|
||||||
JOIN "supplier_2" AS "supplier"
|
|
||||||
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
|
|
||||||
JOIN "nation_2" AS "nation"
|
|
||||||
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
|
|
||||||
GROUP BY
|
|
||||||
"partsupp"."ps_partkey"
|
|
||||||
HAVING
|
|
||||||
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > (
|
|
||||||
SELECT
|
SELECT
|
||||||
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0"
|
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0"
|
||||||
FROM "partsupp" AS "partsupp"
|
FROM "partsupp" AS "partsupp"
|
||||||
|
@ -686,7 +674,20 @@ HAVING
|
||||||
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
|
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
|
||||||
JOIN "nation_2" AS "nation"
|
JOIN "nation_2" AS "nation"
|
||||||
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
|
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
|
||||||
)
|
)
|
||||||
|
SELECT
|
||||||
|
"partsupp"."ps_partkey" AS "ps_partkey",
|
||||||
|
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value"
|
||||||
|
FROM "partsupp" AS "partsupp"
|
||||||
|
CROSS JOIN "_u_0" AS "_u_0"
|
||||||
|
JOIN "supplier_2" AS "supplier"
|
||||||
|
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
|
||||||
|
JOIN "nation_2" AS "nation"
|
||||||
|
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
|
||||||
|
GROUP BY
|
||||||
|
"partsupp"."ps_partkey"
|
||||||
|
HAVING
|
||||||
|
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > MAX("_u_0"."_col_0")
|
||||||
ORDER BY
|
ORDER BY
|
||||||
"value" DESC;
|
"value" DESC;
|
||||||
|
|
||||||
|
@ -880,6 +881,10 @@ WITH "revenue" AS (
|
||||||
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1996-01-01' AS DATE)
|
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1996-01-01' AS DATE)
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"lineitem"."l_suppkey"
|
"lineitem"."l_suppkey"
|
||||||
|
), "_u_0" AS (
|
||||||
|
SELECT
|
||||||
|
MAX("revenue"."total_revenue") AS "_col_0"
|
||||||
|
FROM "revenue"
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
"supplier"."s_suppkey" AS "s_suppkey",
|
"supplier"."s_suppkey" AS "s_suppkey",
|
||||||
|
@ -889,12 +894,9 @@ SELECT
|
||||||
"revenue"."total_revenue" AS "total_revenue"
|
"revenue"."total_revenue" AS "total_revenue"
|
||||||
FROM "supplier" AS "supplier"
|
FROM "supplier" AS "supplier"
|
||||||
JOIN "revenue"
|
JOIN "revenue"
|
||||||
ON "revenue"."total_revenue" = (
|
ON "supplier"."s_suppkey" = "revenue"."supplier_no"
|
||||||
SELECT
|
JOIN "_u_0" AS "_u_0"
|
||||||
MAX("revenue"."total_revenue") AS "_col_0"
|
ON "revenue"."total_revenue" = "_u_0"."_col_0"
|
||||||
FROM "revenue"
|
|
||||||
)
|
|
||||||
AND "supplier"."s_suppkey" = "revenue"."supplier_no"
|
|
||||||
ORDER BY
|
ORDER BY
|
||||||
"s_suppkey";
|
"s_suppkey";
|
||||||
|
|
||||||
|
@ -1395,7 +1397,14 @@ order by
|
||||||
cntrycode;
|
cntrycode;
|
||||||
WITH "_u_0" AS (
|
WITH "_u_0" AS (
|
||||||
SELECT
|
SELECT
|
||||||
"orders"."o_custkey" AS "_u_1"
|
AVG("customer"."c_acctbal") AS "_col_0"
|
||||||
|
FROM "customer" AS "customer"
|
||||||
|
WHERE
|
||||||
|
"customer"."c_acctbal" > 0.00
|
||||||
|
AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
|
||||||
|
), "_u_1" AS (
|
||||||
|
SELECT
|
||||||
|
"orders"."o_custkey" AS "_u_2"
|
||||||
FROM "orders" AS "orders"
|
FROM "orders" AS "orders"
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"orders"."o_custkey"
|
"orders"."o_custkey"
|
||||||
|
@ -1405,18 +1414,12 @@ SELECT
|
||||||
COUNT(*) AS "numcust",
|
COUNT(*) AS "numcust",
|
||||||
SUM("customer"."c_acctbal") AS "totacctbal"
|
SUM("customer"."c_acctbal") AS "totacctbal"
|
||||||
FROM "customer" AS "customer"
|
FROM "customer" AS "customer"
|
||||||
LEFT JOIN "_u_0" AS "_u_0"
|
JOIN "_u_0" AS "_u_0"
|
||||||
ON "_u_0"."_u_1" = "customer"."c_custkey"
|
ON "customer"."c_acctbal" > "_u_0"."_col_0"
|
||||||
|
LEFT JOIN "_u_1" AS "_u_1"
|
||||||
|
ON "_u_1"."_u_2" = "customer"."c_custkey"
|
||||||
WHERE
|
WHERE
|
||||||
"_u_0"."_u_1" IS NULL
|
"_u_1"."_u_2" IS NULL
|
||||||
AND "customer"."c_acctbal" > (
|
|
||||||
SELECT
|
|
||||||
AVG("customer"."c_acctbal") AS "_col_0"
|
|
||||||
FROM "customer" AS "customer"
|
|
||||||
WHERE
|
|
||||||
"customer"."c_acctbal" > 0.00
|
|
||||||
AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
|
|
||||||
)
|
|
||||||
AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
|
AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
|
||||||
GROUP BY
|
GROUP BY
|
||||||
SUBSTRING("customer"."c_phone", 1, 2)
|
SUBSTRING("customer"."c_phone", 1, 2)
|
||||||
|
|
96
tests/fixtures/optimizer/unnest_subqueries.sql
vendored
96
tests/fixtures/optimizer/unnest_subqueries.sql
vendored
|
@ -1,10 +1,12 @@
|
||||||
|
--SELECT x.a > (SELECT SUM(y.a) AS b FROM y) FROM x;
|
||||||
--------------------------------------
|
--------------------------------------
|
||||||
-- Unnest Subqueries
|
-- Unnest Subqueries
|
||||||
--------------------------------------
|
--------------------------------------
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM x AS x
|
FROM x AS x
|
||||||
WHERE
|
WHERE
|
||||||
x.a IN (SELECT y.a AS a FROM y)
|
x.a = (SELECT SUM(y.a) AS a FROM y)
|
||||||
|
AND x.a IN (SELECT y.a AS a FROM y)
|
||||||
AND x.a IN (SELECT y.b AS b FROM y)
|
AND x.a IN (SELECT y.b AS b FROM y)
|
||||||
AND x.a = ANY (SELECT y.a AS a FROM y)
|
AND x.a = ANY (SELECT y.a AS a FROM y)
|
||||||
AND x.a = (SELECT SUM(y.b) AS b FROM y WHERE x.a = y.a)
|
AND x.a = (SELECT SUM(y.b) AS b FROM y WHERE x.a = y.a)
|
||||||
|
@ -24,62 +26,57 @@ WHERE
|
||||||
SELECT
|
SELECT
|
||||||
*
|
*
|
||||||
FROM x AS x
|
FROM x AS x
|
||||||
|
CROSS JOIN (
|
||||||
|
SELECT
|
||||||
|
SUM(y.a) AS a
|
||||||
|
FROM y
|
||||||
|
) AS "_u_0"
|
||||||
LEFT JOIN (
|
LEFT JOIN (
|
||||||
SELECT
|
SELECT
|
||||||
y.a AS a
|
y.a AS a
|
||||||
FROM y
|
FROM y
|
||||||
GROUP BY
|
GROUP BY
|
||||||
y.a
|
y.a
|
||||||
) AS "_u_0"
|
) AS "_u_1"
|
||||||
ON x.a = "_u_0"."a"
|
ON x.a = "_u_1"."a"
|
||||||
LEFT JOIN (
|
LEFT JOIN (
|
||||||
SELECT
|
SELECT
|
||||||
y.b AS b
|
y.b AS b
|
||||||
FROM y
|
FROM y
|
||||||
GROUP BY
|
GROUP BY
|
||||||
y.b
|
y.b
|
||||||
) AS "_u_1"
|
) AS "_u_2"
|
||||||
ON x.a = "_u_1"."b"
|
ON x.a = "_u_2"."b"
|
||||||
LEFT JOIN (
|
LEFT JOIN (
|
||||||
SELECT
|
SELECT
|
||||||
y.a AS a
|
y.a AS a
|
||||||
FROM y
|
FROM y
|
||||||
GROUP BY
|
GROUP BY
|
||||||
y.a
|
y.a
|
||||||
) AS "_u_2"
|
|
||||||
ON x.a = "_u_2"."a"
|
|
||||||
LEFT JOIN (
|
|
||||||
SELECT
|
|
||||||
SUM(y.b) AS b,
|
|
||||||
y.a AS _u_4
|
|
||||||
FROM y
|
|
||||||
WHERE
|
|
||||||
TRUE
|
|
||||||
GROUP BY
|
|
||||||
y.a
|
|
||||||
) AS "_u_3"
|
) AS "_u_3"
|
||||||
ON x.a = "_u_3"."_u_4"
|
ON x.a = "_u_3"."a"
|
||||||
LEFT JOIN (
|
LEFT JOIN (
|
||||||
SELECT
|
SELECT
|
||||||
SUM(y.b) AS b,
|
SUM(y.b) AS b,
|
||||||
y.a AS _u_6
|
y.a AS _u_5
|
||||||
FROM y
|
FROM y
|
||||||
WHERE
|
WHERE
|
||||||
TRUE
|
TRUE
|
||||||
GROUP BY
|
GROUP BY
|
||||||
y.a
|
y.a
|
||||||
) AS "_u_5"
|
) AS "_u_4"
|
||||||
ON x.a = "_u_5"."_u_6"
|
ON x.a = "_u_4"."_u_5"
|
||||||
LEFT JOIN (
|
LEFT JOIN (
|
||||||
SELECT
|
SELECT
|
||||||
y.a AS a
|
SUM(y.b) AS b,
|
||||||
|
y.a AS _u_7
|
||||||
FROM y
|
FROM y
|
||||||
WHERE
|
WHERE
|
||||||
TRUE
|
TRUE
|
||||||
GROUP BY
|
GROUP BY
|
||||||
y.a
|
y.a
|
||||||
) AS "_u_7"
|
) AS "_u_6"
|
||||||
ON "_u_7".a = x.a
|
ON x.a = "_u_6"."_u_7"
|
||||||
LEFT JOIN (
|
LEFT JOIN (
|
||||||
SELECT
|
SELECT
|
||||||
y.a AS a
|
y.a AS a
|
||||||
|
@ -90,29 +87,39 @@ LEFT JOIN (
|
||||||
y.a
|
y.a
|
||||||
) AS "_u_8"
|
) AS "_u_8"
|
||||||
ON "_u_8".a = x.a
|
ON "_u_8".a = x.a
|
||||||
|
LEFT JOIN (
|
||||||
|
SELECT
|
||||||
|
y.a AS a
|
||||||
|
FROM y
|
||||||
|
WHERE
|
||||||
|
TRUE
|
||||||
|
GROUP BY
|
||||||
|
y.a
|
||||||
|
) AS "_u_9"
|
||||||
|
ON "_u_9".a = x.a
|
||||||
LEFT JOIN (
|
LEFT JOIN (
|
||||||
SELECT
|
SELECT
|
||||||
ARRAY_AGG(y.a) AS a,
|
ARRAY_AGG(y.a) AS a,
|
||||||
y.b AS _u_10
|
y.b AS _u_11
|
||||||
FROM y
|
FROM y
|
||||||
WHERE
|
WHERE
|
||||||
TRUE
|
TRUE
|
||||||
GROUP BY
|
GROUP BY
|
||||||
y.b
|
y.b
|
||||||
) AS "_u_9"
|
) AS "_u_10"
|
||||||
ON "_u_9"."_u_10" = x.a
|
ON "_u_10"."_u_11" = x.a
|
||||||
LEFT JOIN (
|
LEFT JOIN (
|
||||||
SELECT
|
SELECT
|
||||||
SUM(y.a) AS a,
|
SUM(y.a) AS a,
|
||||||
y.a AS _u_12,
|
y.a AS _u_13,
|
||||||
ARRAY_AGG(y.b) AS _u_13
|
ARRAY_AGG(y.b) AS _u_14
|
||||||
FROM y
|
FROM y
|
||||||
WHERE
|
WHERE
|
||||||
TRUE AND TRUE AND TRUE
|
TRUE AND TRUE AND TRUE
|
||||||
GROUP BY
|
GROUP BY
|
||||||
y.a
|
y.a
|
||||||
) AS "_u_11"
|
) AS "_u_12"
|
||||||
ON "_u_11"."_u_12" = x.a AND "_u_11"."_u_12" = x.b
|
ON "_u_12"."_u_13" = x.a AND "_u_12"."_u_13" = x.b
|
||||||
LEFT JOIN (
|
LEFT JOIN (
|
||||||
SELECT
|
SELECT
|
||||||
y.a AS a
|
y.a AS a
|
||||||
|
@ -121,37 +128,38 @@ LEFT JOIN (
|
||||||
TRUE
|
TRUE
|
||||||
GROUP BY
|
GROUP BY
|
||||||
y.a
|
y.a
|
||||||
) AS "_u_14"
|
) AS "_u_15"
|
||||||
ON x.a = "_u_14".a
|
ON x.a = "_u_15".a
|
||||||
WHERE
|
WHERE
|
||||||
NOT "_u_0"."a" IS NULL
|
x.a = "_u_0".a
|
||||||
AND NOT "_u_1"."b" IS NULL
|
AND NOT "_u_1"."a" IS NULL
|
||||||
AND NOT "_u_2"."a" IS NULL
|
AND NOT "_u_2"."b" IS NULL
|
||||||
|
AND NOT "_u_3"."a" IS NULL
|
||||||
AND (
|
AND (
|
||||||
x.a = "_u_3".b AND NOT "_u_3"."_u_4" IS NULL
|
x.a = "_u_4".b AND NOT "_u_4"."_u_5" IS NULL
|
||||||
)
|
)
|
||||||
AND (
|
AND (
|
||||||
x.a > "_u_5".b AND NOT "_u_5"."_u_6" IS NULL
|
x.a > "_u_6".b AND NOT "_u_6"."_u_7" IS NULL
|
||||||
)
|
)
|
||||||
AND (
|
AND (
|
||||||
None = "_u_7".a AND NOT "_u_7".a IS NULL
|
None = "_u_8".a AND NOT "_u_8".a IS NULL
|
||||||
)
|
)
|
||||||
AND NOT (
|
AND NOT (
|
||||||
x.a = "_u_8".a AND NOT "_u_8".a IS NULL
|
x.a = "_u_9".a AND NOT "_u_9".a IS NULL
|
||||||
)
|
)
|
||||||
AND (
|
AND (
|
||||||
ARRAY_ANY("_u_9".a, _x -> _x = x.a) AND NOT "_u_9"."_u_10" IS NULL
|
ARRAY_ANY("_u_10".a, _x -> _x = x.a) AND NOT "_u_10"."_u_11" IS NULL
|
||||||
)
|
)
|
||||||
AND (
|
AND (
|
||||||
(
|
(
|
||||||
(
|
(
|
||||||
x.a < "_u_11".a AND NOT "_u_11"."_u_12" IS NULL
|
x.a < "_u_12".a AND NOT "_u_12"."_u_13" IS NULL
|
||||||
) AND NOT "_u_11"."_u_12" IS NULL
|
) AND NOT "_u_12"."_u_13" IS NULL
|
||||||
)
|
)
|
||||||
AND ARRAY_ANY("_u_11"."_u_13", "_x" -> "_x" <> x.d)
|
AND ARRAY_ANY("_u_12"."_u_14", "_x" -> "_x" <> x.d)
|
||||||
)
|
)
|
||||||
AND (
|
AND (
|
||||||
NOT "_u_14".a IS NULL AND NOT "_u_14".a IS NULL
|
NOT "_u_15".a IS NULL AND NOT "_u_15".a IS NULL
|
||||||
)
|
)
|
||||||
AND x.a IN (
|
AND x.a IN (
|
||||||
SELECT
|
SELECT
|
||||||
|
|
|
@ -68,13 +68,13 @@ class TestExecutor(unittest.TestCase):
|
||||||
|
|
||||||
def test_execute_tpch(self):
|
def test_execute_tpch(self):
|
||||||
def to_csv(expression):
|
def to_csv(expression):
|
||||||
if isinstance(expression, exp.Table):
|
if isinstance(expression, exp.Table) and expression.name not in ("revenue"):
|
||||||
return parse_one(
|
return parse_one(
|
||||||
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}"
|
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}"
|
||||||
)
|
)
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
for i, (sql, _) in enumerate(self.sqls[0:7]):
|
for i, (sql, _) in enumerate(self.sqls[0:16]):
|
||||||
with self.subTest(f"tpch-h {i + 1}"):
|
with self.subTest(f"tpch-h {i + 1}"):
|
||||||
a = self.cached_execute(sql)
|
a = self.cached_execute(sql)
|
||||||
sql = parse_one(sql).transform(to_csv).sql(pretty=True)
|
sql = parse_one(sql).transform(to_csv).sql(pretty=True)
|
||||||
|
@ -165,6 +165,39 @@ class TestExecutor(unittest.TestCase):
|
||||||
["a"],
|
["a"],
|
||||||
[("a",)],
|
[("a",)],
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
"SELECT DISTINCT a FROM (SELECT 1 AS a UNION ALL SELECT 1 AS a)",
|
||||||
|
["a"],
|
||||||
|
[(1,)],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT DISTINCT a, SUM(b) AS b "
|
||||||
|
"FROM (SELECT 'a' AS a, 1 AS b UNION ALL SELECT 'a' AS a, 2 AS b UNION ALL SELECT 'b' AS a, 1 AS b) "
|
||||||
|
"GROUP BY a "
|
||||||
|
"LIMIT 1",
|
||||||
|
["a", "b"],
|
||||||
|
[("a", 3)],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT COUNT(1) AS a FROM (SELECT 1)",
|
||||||
|
["a"],
|
||||||
|
[(1,)],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT COUNT(1) AS a FROM (SELECT 1) LIMIT 0",
|
||||||
|
["a"],
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT a FROM x GROUP BY a LIMIT 0",
|
||||||
|
["a"],
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT a FROM x LIMIT 0",
|
||||||
|
["a"],
|
||||||
|
[],
|
||||||
|
),
|
||||||
]:
|
]:
|
||||||
with self.subTest(sql):
|
with self.subTest(sql):
|
||||||
result = execute(sql, schema=schema, tables=tables)
|
result = execute(sql, schema=schema, tables=tables)
|
||||||
|
@ -346,6 +379,28 @@ class TestExecutor(unittest.TestCase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_execute_subqueries(self):
|
||||||
|
tables = {
|
||||||
|
"table": [
|
||||||
|
{"a": 1, "b": 1},
|
||||||
|
{"a": 2, "b": 2},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
execute(
|
||||||
|
"""
|
||||||
|
SELECT *
|
||||||
|
FROM table
|
||||||
|
WHERE a = (SELECT MAX(a) FROM table)
|
||||||
|
""",
|
||||||
|
tables=tables,
|
||||||
|
).rows,
|
||||||
|
[
|
||||||
|
(2, 2),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
def test_table_depth_mismatch(self):
|
def test_table_depth_mismatch(self):
|
||||||
tables = {"table": []}
|
tables = {"table": []}
|
||||||
schema = {"db": {"table": {"col": "VARCHAR"}}}
|
schema = {"db": {"table": {"col": "VARCHAR"}}}
|
||||||
|
@ -401,6 +456,7 @@ class TestExecutor(unittest.TestCase):
|
||||||
("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
|
("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
|
||||||
("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
|
("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
|
||||||
("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]),
|
("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]),
|
||||||
|
("SELECT SUM(x) FROM (SELECT 1 AS x WHERE FALSE)", ["_col_0"], [(0,)]),
|
||||||
]:
|
]:
|
||||||
result = execute(sql)
|
result = execute(sql)
|
||||||
self.assertEqual(result.columns, tuple(cols))
|
self.assertEqual(result.columns, tuple(cols))
|
||||||
|
@ -462,7 +518,18 @@ class TestExecutor(unittest.TestCase):
|
||||||
("IF(false, 1, 0)", 0),
|
("IF(false, 1, 0)", 0),
|
||||||
("CASE WHEN 0 = 1 THEN 'foo' ELSE 'bar' END", "bar"),
|
("CASE WHEN 0 = 1 THEN 'foo' ELSE 'bar' END", "bar"),
|
||||||
("CAST('2022-01-01' AS DATE) + INTERVAL '1' DAY", date(2022, 1, 2)),
|
("CAST('2022-01-01' AS DATE) + INTERVAL '1' DAY", date(2022, 1, 2)),
|
||||||
|
("1 IN (1, 2, 3)", True),
|
||||||
|
("1 IN (2, 3)", False),
|
||||||
|
("NULL IS NULL", True),
|
||||||
|
("NULL IS NOT NULL", False),
|
||||||
|
("NULL = NULL", None),
|
||||||
|
("NULL <> NULL", None),
|
||||||
]:
|
]:
|
||||||
with self.subTest(sql):
|
with self.subTest(sql):
|
||||||
result = execute(f"SELECT {sql}")
|
result = execute(f"SELECT {sql}")
|
||||||
self.assertEqual(result.rows, [(expected,)])
|
self.assertEqual(result.rows, [(expected,)])
|
||||||
|
|
||||||
|
def test_case_sensitivity(self):
|
||||||
|
result = execute("SELECT A AS A FROM X", tables={"x": [{"a": 1}]})
|
||||||
|
self.assertEqual(result.columns, ("A",))
|
||||||
|
self.assertEqual(result.rows, [(1,)])
|
||||||
|
|
|
@ -525,24 +525,14 @@ class TestExpressions(unittest.TestCase):
|
||||||
),
|
),
|
||||||
exp.Properties(
|
exp.Properties(
|
||||||
expressions=[
|
expressions=[
|
||||||
exp.FileFormatProperty(
|
exp.FileFormatProperty(this=exp.Literal.string("parquet")),
|
||||||
this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet")
|
|
||||||
),
|
|
||||||
exp.PartitionedByProperty(
|
exp.PartitionedByProperty(
|
||||||
this=exp.Literal.string("PARTITIONED_BY"),
|
this=exp.Tuple(expressions=[exp.to_identifier("a"), exp.to_identifier("b")])
|
||||||
value=exp.Tuple(
|
|
||||||
expressions=[exp.to_identifier("a"), exp.to_identifier("b")]
|
|
||||||
),
|
),
|
||||||
),
|
exp.Property(this=exp.Literal.string("custom"), value=exp.Literal.number(1)),
|
||||||
exp.AnonymousProperty(
|
exp.TableFormatProperty(this=exp.to_identifier("test_format")),
|
||||||
this=exp.Literal.string("custom"), value=exp.Literal.number(1)
|
exp.EngineProperty(this=exp.null()),
|
||||||
),
|
exp.CollateProperty(this=exp.true()),
|
||||||
exp.TableFormatProperty(
|
|
||||||
this=exp.Literal.string("TABLE_FORMAT"),
|
|
||||||
value=exp.to_identifier("test_format"),
|
|
||||||
),
|
|
||||||
exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.null()),
|
|
||||||
exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.true()),
|
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -609,9 +599,9 @@ FROM foo""",
|
||||||
"""SELECT
|
"""SELECT
|
||||||
a,
|
a,
|
||||||
b AS B,
|
b AS B,
|
||||||
c, -- comment
|
c, /* comment */
|
||||||
d AS D, -- another comment
|
d AS D, /* another comment */
|
||||||
CAST(x AS INT) -- final comment
|
CAST(x AS INT) /* final comment */
|
||||||
FROM foo""",
|
FROM foo""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -85,9 +85,8 @@ class TestOptimizer(unittest.TestCase):
|
||||||
if leave_tables_isolated is not None:
|
if leave_tables_isolated is not None:
|
||||||
func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated)
|
func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated)
|
||||||
|
|
||||||
optimized = func(parse_one(sql, read=dialect), **func_kwargs)
|
|
||||||
|
|
||||||
with self.subTest(title):
|
with self.subTest(title):
|
||||||
|
optimized = func(parse_one(sql, read=dialect), **func_kwargs)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
expected,
|
expected,
|
||||||
optimized.sql(pretty=pretty, dialect=dialect),
|
optimized.sql(pretty=pretty, dialect=dialect),
|
||||||
|
@ -168,6 +167,9 @@ class TestOptimizer(unittest.TestCase):
|
||||||
def test_quote_identities(self):
|
def test_quote_identities(self):
|
||||||
self.check_file("quote_identities", optimizer.quote_identities.quote_identities)
|
self.check_file("quote_identities", optimizer.quote_identities.quote_identities)
|
||||||
|
|
||||||
|
def test_lower_identities(self):
|
||||||
|
self.check_file("lower_identities", optimizer.lower_identities.lower_identities)
|
||||||
|
|
||||||
def test_pushdown_projection(self):
|
def test_pushdown_projection(self):
|
||||||
def pushdown_projections(expression, **kwargs):
|
def pushdown_projections(expression, **kwargs):
|
||||||
expression = optimizer.qualify_tables.qualify_tables(expression)
|
expression = optimizer.qualify_tables.qualify_tables(expression)
|
||||||
|
|
|
@ -15,6 +15,51 @@ class TestParser(unittest.TestCase):
|
||||||
self.assertIsInstance(parse_one("int", into=exp.DataType), exp.DataType)
|
self.assertIsInstance(parse_one("int", into=exp.DataType), exp.DataType)
|
||||||
self.assertIsInstance(parse_one("array<int>", into=exp.DataType), exp.DataType)
|
self.assertIsInstance(parse_one("array<int>", into=exp.DataType), exp.DataType)
|
||||||
|
|
||||||
|
def test_parse_into_error(self):
|
||||||
|
expected_message = "Failed to parse into [<class 'sqlglot.expressions.From'>]"
|
||||||
|
expected_errors = [
|
||||||
|
{
|
||||||
|
"description": "Invalid expression / Unexpected token",
|
||||||
|
"line": 1,
|
||||||
|
"col": 1,
|
||||||
|
"start_context": "",
|
||||||
|
"highlight": "SELECT",
|
||||||
|
"end_context": " 1;",
|
||||||
|
"into_expression": exp.From,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
with self.assertRaises(ParseError) as ctx:
|
||||||
|
parse_one("SELECT 1;", "sqlite", [exp.From])
|
||||||
|
self.assertEqual(str(ctx.exception), expected_message)
|
||||||
|
self.assertEqual(ctx.exception.errors, expected_errors)
|
||||||
|
|
||||||
|
def test_parse_into_errors(self):
|
||||||
|
expected_message = "Failed to parse into [<class 'sqlglot.expressions.From'>, <class 'sqlglot.expressions.Join'>]"
|
||||||
|
expected_errors = [
|
||||||
|
{
|
||||||
|
"description": "Invalid expression / Unexpected token",
|
||||||
|
"line": 1,
|
||||||
|
"col": 1,
|
||||||
|
"start_context": "",
|
||||||
|
"highlight": "SELECT",
|
||||||
|
"end_context": " 1;",
|
||||||
|
"into_expression": exp.From,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "Invalid expression / Unexpected token",
|
||||||
|
"line": 1,
|
||||||
|
"col": 1,
|
||||||
|
"start_context": "",
|
||||||
|
"highlight": "SELECT",
|
||||||
|
"end_context": " 1;",
|
||||||
|
"into_expression": exp.Join,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
with self.assertRaises(ParseError) as ctx:
|
||||||
|
parse_one("SELECT 1;", "sqlite", [exp.From, exp.Join])
|
||||||
|
self.assertEqual(str(ctx.exception), expected_message)
|
||||||
|
self.assertEqual(ctx.exception.errors, expected_errors)
|
||||||
|
|
||||||
def test_column(self):
|
def test_column(self):
|
||||||
columns = parse_one("select a, ARRAY[1] b, case when 1 then 1 end").find_all(exp.Column)
|
columns = parse_one("select a, ARRAY[1] b, case when 1 then 1 end").find_all(exp.Column)
|
||||||
assert len(list(columns)) == 1
|
assert len(list(columns)) == 1
|
||||||
|
@ -24,6 +69,9 @@ class TestParser(unittest.TestCase):
|
||||||
def test_float(self):
|
def test_float(self):
|
||||||
self.assertEqual(parse_one(".2"), parse_one("0.2"))
|
self.assertEqual(parse_one(".2"), parse_one("0.2"))
|
||||||
|
|
||||||
|
def test_unary_plus(self):
|
||||||
|
self.assertEqual(parse_one("+15"), exp.Literal.number(15))
|
||||||
|
|
||||||
def test_table(self):
|
def test_table(self):
|
||||||
tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)]
|
tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)]
|
||||||
self.assertEqual(tables, ["a", "b.c", "d"])
|
self.assertEqual(tables, ["a", "b.c", "d"])
|
||||||
|
@ -157,8 +205,9 @@ class TestParser(unittest.TestCase):
|
||||||
def test_comments(self):
|
def test_comments(self):
|
||||||
expression = parse_one(
|
expression = parse_one(
|
||||||
"""
|
"""
|
||||||
--comment1
|
--comment1.1
|
||||||
SELECT /* this won't be used */
|
--comment1.2
|
||||||
|
SELECT /*comment1.3*/
|
||||||
a, --comment2
|
a, --comment2
|
||||||
b as B, --comment3:testing
|
b as B, --comment3:testing
|
||||||
"test--annotation",
|
"test--annotation",
|
||||||
|
@ -169,13 +218,13 @@ class TestParser(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(expression.comment, "comment1")
|
self.assertEqual(expression.comments, ["comment1.1", "comment1.2", "comment1.3"])
|
||||||
self.assertEqual(expression.expressions[0].comment, "comment2")
|
self.assertEqual(expression.expressions[0].comments, ["comment2"])
|
||||||
self.assertEqual(expression.expressions[1].comment, "comment3:testing")
|
self.assertEqual(expression.expressions[1].comments, ["comment3:testing"])
|
||||||
self.assertEqual(expression.expressions[2].comment, None)
|
self.assertEqual(expression.expressions[2].comments, None)
|
||||||
self.assertEqual(expression.expressions[3].comment, "comment4 --foo")
|
self.assertEqual(expression.expressions[3].comments, ["comment4 --foo"])
|
||||||
self.assertEqual(expression.expressions[4].comment, "")
|
self.assertEqual(expression.expressions[4].comments, [""])
|
||||||
self.assertEqual(expression.expressions[5].comment, " space")
|
self.assertEqual(expression.expressions[5].comments, [" space"])
|
||||||
|
|
||||||
def test_type_literals(self):
|
def test_type_literals(self):
|
||||||
self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)"))
|
self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)"))
|
||||||
|
|
|
@ -7,13 +7,13 @@ class TestTokens(unittest.TestCase):
|
||||||
def test_comment_attachment(self):
|
def test_comment_attachment(self):
|
||||||
tokenizer = Tokenizer()
|
tokenizer = Tokenizer()
|
||||||
sql_comment = [
|
sql_comment = [
|
||||||
("/*comment*/ foo", "comment"),
|
("/*comment*/ foo", ["comment"]),
|
||||||
("/*comment*/ foo --test", "comment"),
|
("/*comment*/ foo --test", ["comment", "test"]),
|
||||||
("--comment\nfoo --test", "comment"),
|
("--comment\nfoo --test", ["comment", "test"]),
|
||||||
("foo --comment", "comment"),
|
("foo --comment", ["comment"]),
|
||||||
("foo", None),
|
("foo", []),
|
||||||
("foo /*comment 1*/ /*comment 2*/", "comment 1"),
|
("foo /*comment 1*/ /*comment 2*/", ["comment 1", "comment 2"]),
|
||||||
]
|
]
|
||||||
|
|
||||||
for sql, comment in sql_comment:
|
for sql, comment in sql_comment:
|
||||||
self.assertEqual(tokenizer.tokenize(sql)[0].comment, comment)
|
self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from sqlglot import parse_one
|
from sqlglot import parse_one
|
||||||
from sqlglot.transforms import unalias_group
|
from sqlglot.transforms import eliminate_distinct_on, unalias_group
|
||||||
|
|
||||||
|
|
||||||
class TestTime(unittest.TestCase):
|
class TestTime(unittest.TestCase):
|
||||||
|
@ -35,3 +35,30 @@ class TestTime(unittest.TestCase):
|
||||||
"SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date",
|
"SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date",
|
||||||
"SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date",
|
"SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_eliminate_distinct_on(self):
|
||||||
|
self.validate(
|
||||||
|
eliminate_distinct_on,
|
||||||
|
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
|
||||||
|
'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1',
|
||||||
|
)
|
||||||
|
self.validate(
|
||||||
|
eliminate_distinct_on,
|
||||||
|
"SELECT DISTINCT ON (a) a, b FROM x",
|
||||||
|
'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a) AS "_row_number" FROM x) WHERE "_row_number" = 1',
|
||||||
|
)
|
||||||
|
self.validate(
|
||||||
|
eliminate_distinct_on,
|
||||||
|
"SELECT DISTINCT ON (a, b) a, b FROM x ORDER BY c DESC",
|
||||||
|
'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1',
|
||||||
|
)
|
||||||
|
self.validate(
|
||||||
|
eliminate_distinct_on,
|
||||||
|
"SELECT DISTINCT a, b FROM x ORDER BY c DESC",
|
||||||
|
"SELECT DISTINCT a, b FROM x ORDER BY c DESC",
|
||||||
|
)
|
||||||
|
self.validate(
|
||||||
|
eliminate_distinct_on,
|
||||||
|
"SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC",
|
||||||
|
'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS "_row_number_2" FROM x) WHERE "_row_number_2" = 1',
|
||||||
|
)
|
||||||
|
|
|
@ -26,6 +26,7 @@ class TestTranspile(unittest.TestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(transpile("SELECT 1 current_date")[0], "SELECT 1 AS current_date")
|
self.assertEqual(transpile("SELECT 1 current_date")[0], "SELECT 1 AS current_date")
|
||||||
self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime")
|
self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime")
|
||||||
|
self.assertEqual(transpile("SELECT 1 row")[0], "SELECT 1 AS row")
|
||||||
|
|
||||||
for key in ("union", "filter", "over", "from", "join"):
|
for key in ("union", "filter", "over", "from", "join"):
|
||||||
with self.subTest(f"alias {key}"):
|
with self.subTest(f"alias {key}"):
|
||||||
|
@ -38,6 +39,11 @@ class TestTranspile(unittest.TestCase):
|
||||||
def test_asc(self):
|
def test_asc(self):
|
||||||
self.validate("SELECT x FROM y ORDER BY x ASC", "SELECT x FROM y ORDER BY x")
|
self.validate("SELECT x FROM y ORDER BY x ASC", "SELECT x FROM y ORDER BY x")
|
||||||
|
|
||||||
|
def test_unary(self):
|
||||||
|
self.validate("+++1", "1")
|
||||||
|
self.validate("+-1", "-1")
|
||||||
|
self.validate("+- - -1", "- - -1")
|
||||||
|
|
||||||
def test_paren(self):
|
def test_paren(self):
|
||||||
with self.assertRaises(ParseError):
|
with self.assertRaises(ParseError):
|
||||||
transpile("1 + (2 + 3")
|
transpile("1 + (2 + 3")
|
||||||
|
@ -58,7 +64,7 @@ class TestTranspile(unittest.TestCase):
|
||||||
)
|
)
|
||||||
self.validate(
|
self.validate(
|
||||||
"SELECT FOO, /*x*/\nBAR, /*y*/\nBAZ",
|
"SELECT FOO, /*x*/\nBAR, /*y*/\nBAZ",
|
||||||
"SELECT\n FOO -- x\n , BAR -- y\n , BAZ",
|
"SELECT\n FOO /* x */\n , BAR /* y */\n , BAZ",
|
||||||
leading_comma=True,
|
leading_comma=True,
|
||||||
pretty=True,
|
pretty=True,
|
||||||
)
|
)
|
||||||
|
@ -78,7 +84,8 @@ class TestTranspile(unittest.TestCase):
|
||||||
def test_comments(self):
|
def test_comments(self):
|
||||||
self.validate("SELECT */*comment*/", "SELECT * /* comment */")
|
self.validate("SELECT */*comment*/", "SELECT * /* comment */")
|
||||||
self.validate(
|
self.validate(
|
||||||
"SELECT * FROM table /*comment 1*/ /*comment 2*/", "SELECT * FROM table /* comment 1 */"
|
"SELECT * FROM table /*comment 1*/ /*comment 2*/",
|
||||||
|
"SELECT * FROM table /* comment 1 */ /* comment 2 */",
|
||||||
)
|
)
|
||||||
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */")
|
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */")
|
||||||
self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo")
|
self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo")
|
||||||
|
@ -112,6 +119,53 @@ class TestTranspile(unittest.TestCase):
|
||||||
)
|
)
|
||||||
self.validate(
|
self.validate(
|
||||||
"""
|
"""
|
||||||
|
-- comment 1
|
||||||
|
-- comment 2
|
||||||
|
-- comment 3
|
||||||
|
SELECT * FROM foo
|
||||||
|
""",
|
||||||
|
"/* comment 1 */ /* comment 2 */ /* comment 3 */ SELECT * FROM foo",
|
||||||
|
)
|
||||||
|
self.validate(
|
||||||
|
"""
|
||||||
|
-- comment 1
|
||||||
|
-- comment 2
|
||||||
|
-- comment 3
|
||||||
|
SELECT * FROM foo""",
|
||||||
|
"""/* comment 1 */
|
||||||
|
/* comment 2 */
|
||||||
|
/* comment 3 */
|
||||||
|
SELECT
|
||||||
|
*
|
||||||
|
FROM foo""",
|
||||||
|
pretty=True,
|
||||||
|
)
|
||||||
|
self.validate(
|
||||||
|
"""
|
||||||
|
SELECT * FROM tbl /*line1
|
||||||
|
line2
|
||||||
|
line3*/ /*another comment*/ where 1=1 -- comment at the end""",
|
||||||
|
"""SELECT * FROM tbl /* line1
|
||||||
|
line2
|
||||||
|
line3 */ /* another comment */ WHERE 1 = 1 /* comment at the end */""",
|
||||||
|
)
|
||||||
|
self.validate(
|
||||||
|
"""
|
||||||
|
SELECT * FROM tbl /*line1
|
||||||
|
line2
|
||||||
|
line3*/ /*another comment*/ where 1=1 -- comment at the end""",
|
||||||
|
"""SELECT
|
||||||
|
*
|
||||||
|
FROM tbl /* line1
|
||||||
|
line2
|
||||||
|
line3 */
|
||||||
|
/* another comment */
|
||||||
|
WHERE
|
||||||
|
1 = 1 /* comment at the end */""",
|
||||||
|
pretty=True,
|
||||||
|
)
|
||||||
|
self.validate(
|
||||||
|
"""
|
||||||
/* multi
|
/* multi
|
||||||
line
|
line
|
||||||
comment
|
comment
|
||||||
|
@ -130,8 +184,8 @@ class TestTranspile(unittest.TestCase):
|
||||||
*/
|
*/
|
||||||
SELECT
|
SELECT
|
||||||
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
|
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
|
||||||
CAST(x AS INT), -- comment 3
|
CAST(x AS INT), /* comment 3 */
|
||||||
y -- comment 4
|
y /* comment 4 */
|
||||||
FROM bar /* comment 5 */, tbl /* comment 6 */""",
|
FROM bar /* comment 5 */, tbl /* comment 6 */""",
|
||||||
read="mysql",
|
read="mysql",
|
||||||
pretty=True,
|
pretty=True,
|
||||||
|
@ -364,33 +418,79 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""",
|
||||||
@mock.patch("sqlglot.parser.logger")
|
@mock.patch("sqlglot.parser.logger")
|
||||||
def test_error_level(self, logger):
|
def test_error_level(self, logger):
|
||||||
invalid = "x + 1. ("
|
invalid = "x + 1. ("
|
||||||
errors = [
|
expected_messages = [
|
||||||
"Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Aliases'>. Line 1, Col: 8.\n x + 1. \033[4m(\033[0m",
|
"Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Aliases'>. Line 1, Col: 8.\n x + 1. \033[4m(\033[0m",
|
||||||
"Expecting ). Line 1, Col: 8.\n x + 1. \033[4m(\033[0m",
|
"Expecting ). Line 1, Col: 8.\n x + 1. \033[4m(\033[0m",
|
||||||
]
|
]
|
||||||
|
expected_errors = [
|
||||||
|
{
|
||||||
|
"description": "Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Aliases'>",
|
||||||
|
"line": 1,
|
||||||
|
"col": 8,
|
||||||
|
"start_context": "x + 1. ",
|
||||||
|
"highlight": "(",
|
||||||
|
"end_context": "",
|
||||||
|
"into_expression": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "Expecting )",
|
||||||
|
"line": 1,
|
||||||
|
"col": 8,
|
||||||
|
"start_context": "x + 1. ",
|
||||||
|
"highlight": "(",
|
||||||
|
"end_context": "",
|
||||||
|
"into_expression": None,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
transpile(invalid, error_level=ErrorLevel.WARN)
|
transpile(invalid, error_level=ErrorLevel.WARN)
|
||||||
for error in errors:
|
for error in expected_messages:
|
||||||
assert_logger_contains(error, logger)
|
assert_logger_contains(error, logger)
|
||||||
|
|
||||||
with self.assertRaises(ParseError) as ctx:
|
with self.assertRaises(ParseError) as ctx:
|
||||||
transpile(invalid, error_level=ErrorLevel.IMMEDIATE)
|
transpile(invalid, error_level=ErrorLevel.IMMEDIATE)
|
||||||
self.assertEqual(str(ctx.exception), errors[0])
|
self.assertEqual(str(ctx.exception), expected_messages[0])
|
||||||
|
self.assertEqual(ctx.exception.errors[0], expected_errors[0])
|
||||||
|
|
||||||
with self.assertRaises(ParseError) as ctx:
|
with self.assertRaises(ParseError) as ctx:
|
||||||
transpile(invalid, error_level=ErrorLevel.RAISE)
|
transpile(invalid, error_level=ErrorLevel.RAISE)
|
||||||
self.assertEqual(str(ctx.exception), "\n\n".join(errors))
|
self.assertEqual(str(ctx.exception), "\n\n".join(expected_messages))
|
||||||
|
self.assertEqual(ctx.exception.errors, expected_errors)
|
||||||
|
|
||||||
more_than_max_errors = "(((("
|
more_than_max_errors = "(((("
|
||||||
expected = (
|
expected_messages = (
|
||||||
"Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
|
"Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
|
||||||
"Required keyword: 'this' missing for <class 'sqlglot.expressions.Paren'>. Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
|
"Required keyword: 'this' missing for <class 'sqlglot.expressions.Paren'>. Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
|
||||||
"Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
|
"Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
|
||||||
"... and 2 more"
|
"... and 2 more"
|
||||||
)
|
)
|
||||||
|
expected_errors = [
|
||||||
|
{
|
||||||
|
"description": "Expecting )",
|
||||||
|
"line": 1,
|
||||||
|
"col": 4,
|
||||||
|
"start_context": "(((",
|
||||||
|
"highlight": "(",
|
||||||
|
"end_context": "",
|
||||||
|
"into_expression": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "Required keyword: 'this' missing for <class 'sqlglot.expressions.Paren'>",
|
||||||
|
"line": 1,
|
||||||
|
"col": 4,
|
||||||
|
"start_context": "(((",
|
||||||
|
"highlight": "(",
|
||||||
|
"end_context": "",
|
||||||
|
"into_expression": None,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
# Also expect three trailing structured errors that match the first
|
||||||
|
expected_errors += [expected_errors[0]] * 3
|
||||||
|
|
||||||
with self.assertRaises(ParseError) as ctx:
|
with self.assertRaises(ParseError) as ctx:
|
||||||
transpile(more_than_max_errors, error_level=ErrorLevel.RAISE)
|
transpile(more_than_max_errors, error_level=ErrorLevel.RAISE)
|
||||||
self.assertEqual(str(ctx.exception), expected)
|
self.assertEqual(str(ctx.exception), expected_messages)
|
||||||
|
self.assertEqual(ctx.exception.errors, expected_errors)
|
||||||
|
|
||||||
@mock.patch("sqlglot.generator.logger")
|
@mock.patch("sqlglot.generator.logger")
|
||||||
def test_unsupported_level(self, logger):
|
def test_unsupported_level(self, logger):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue