Merging upstream version 10.0.8.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
407314e8d2
commit
efc1e37108
67 changed files with 2461 additions and 840 deletions
22
CHANGELOG.md
22
CHANGELOG.md
|
@ -7,21 +7,37 @@ v10.0.0
|
||||||
Changes:
|
Changes:
|
||||||
|
|
||||||
- Breaking: replaced SQLGlot annotations with comments. Now comments can be preserved after transpilation, and they can appear in other places besides SELECT's expressions.
|
- Breaking: replaced SQLGlot annotations with comments. Now comments can be preserved after transpilation, and they can appear in other places besides SELECT's expressions.
|
||||||
|
|
||||||
- Breaking: renamed list_get to seq_get.
|
- Breaking: renamed list_get to seq_get.
|
||||||
|
|
||||||
- Breaking: activated mypy type checking for SQLGlot.
|
- Breaking: activated mypy type checking for SQLGlot.
|
||||||
|
|
||||||
- New: Azure Databricks support.
|
- New: Azure Databricks support.
|
||||||
|
|
||||||
- New: placeholders can now be replaced in an expression.
|
- New: placeholders can now be replaced in an expression.
|
||||||
|
|
||||||
- New: null safe equal operator (<=>).
|
- New: null safe equal operator (<=>).
|
||||||
|
|
||||||
- New: [SET statements](https://github.com/tobymao/sqlglot/pull/673) for MySQL.
|
- New: [SET statements](https://github.com/tobymao/sqlglot/pull/673) for MySQL.
|
||||||
|
|
||||||
- New: [SHOW commands](https://dev.mysql.com/doc/refman/8.0/en/show.html) for MySQL.
|
- New: [SHOW commands](https://dev.mysql.com/doc/refman/8.0/en/show.html) for MySQL.
|
||||||
|
|
||||||
- New: [FORMAT function](https://www.w3schools.com/sql/func_sqlserver_format.asp) for TSQL.
|
- New: [FORMAT function](https://www.w3schools.com/sql/func_sqlserver_format.asp) for TSQL.
|
||||||
|
|
||||||
- New: CROSS APPLY / OUTER APPLY [support](https://github.com/tobymao/sqlglot/pull/641) for TSQL.
|
- New: CROSS APPLY / OUTER APPLY [support](https://github.com/tobymao/sqlglot/pull/641) for TSQL.
|
||||||
- New: added formats for TSQL's [DATENAME/DATEPART functions](https://learn.microsoft.com/en-us/sql/t-sql/functions/datename-transact-sql?view=sql-server-ver16)
|
|
||||||
|
- New: added formats for TSQL's [DATENAME/DATEPART functions](https://learn.microsoft.com/en-us/sql/t-sql/functions/datename-transact-sql?view=sql-server-ver16).
|
||||||
|
|
||||||
- New: added styles for TSQL's [CONVERT function](https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16).
|
- New: added styles for TSQL's [CONVERT function](https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16).
|
||||||
|
|
||||||
- Improvement: [refactored the schema](https://github.com/tobymao/sqlglot/pull/668) to be more lenient; before it needed to do an exact match of db.table, now it finds table if there are no ambiguities.
|
- Improvement: [refactored the schema](https://github.com/tobymao/sqlglot/pull/668) to be more lenient; before it needed to do an exact match of db.table, now it finds table if there are no ambiguities.
|
||||||
|
|
||||||
- Improvement: allow functions to [inherit](https://github.com/tobymao/sqlglot/pull/674) their arguments' types, so that annotating CASE, IF etc. is possible.
|
- Improvement: allow functions to [inherit](https://github.com/tobymao/sqlglot/pull/674) their arguments' types, so that annotating CASE, IF etc. is possible.
|
||||||
|
|
||||||
- Improvement: allow [joining with same names](https://github.com/tobymao/sqlglot/pull/660) in the python executor.
|
- Improvement: allow [joining with same names](https://github.com/tobymao/sqlglot/pull/660) in the python executor.
|
||||||
|
|
||||||
- Improvement: the "using" field can now be set for the [join expression builders](https://github.com/tobymao/sqlglot/pull/636).
|
- Improvement: the "using" field can now be set for the [join expression builders](https://github.com/tobymao/sqlglot/pull/636).
|
||||||
|
|
||||||
- Improvement: qualify_columns [now qualifies](https://github.com/tobymao/sqlglot/pull/635) only non-alias columns in the having clause.
|
- Improvement: qualify_columns [now qualifies](https://github.com/tobymao/sqlglot/pull/635) only non-alias columns in the having clause.
|
||||||
|
|
||||||
v9.0.0
|
v9.0.0
|
||||||
|
@ -37,6 +53,7 @@ v8.0.0
|
||||||
Changes:
|
Changes:
|
||||||
|
|
||||||
- Breaking : New add\_table method in Schema ABC.
|
- Breaking : New add\_table method in Schema ABC.
|
||||||
|
|
||||||
- New: SQLGlot now supports the [PySpark](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dataframe) dataframe API. This is still relatively experimental.
|
- New: SQLGlot now supports the [PySpark](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dataframe) dataframe API. This is still relatively experimental.
|
||||||
|
|
||||||
v7.1.0
|
v7.1.0
|
||||||
|
@ -45,8 +62,11 @@ v7.1.0
|
||||||
Changes:
|
Changes:
|
||||||
|
|
||||||
- Improvement: Pretty generator now takes max\_text\_width which breaks segments into new lines
|
- Improvement: Pretty generator now takes max\_text\_width which breaks segments into new lines
|
||||||
|
|
||||||
- New: exp.to\_table helper to turn table names into table expression objects
|
- New: exp.to\_table helper to turn table names into table expression objects
|
||||||
|
|
||||||
- New: int[] type parsers
|
- New: int[] type parsers
|
||||||
|
|
||||||
- New: annotations are now generated in sql
|
- New: annotations are now generated in sql
|
||||||
|
|
||||||
v7.0.0
|
v7.0.0
|
||||||
|
|
|
@ -21,7 +21,7 @@ Pull requests are the best way to propose changes to the codebase. We actively w
|
||||||
5. Issue that pull request and wait for it to be reviewed by a maintainer or contributor!
|
5. Issue that pull request and wait for it to be reviewed by a maintainer or contributor!
|
||||||
|
|
||||||
## Report bugs using Github's [issues](https://github.com/tobymao/sqlglot/issues)
|
## Report bugs using Github's [issues](https://github.com/tobymao/sqlglot/issues)
|
||||||
We use GitHub issues to track public bugs. Report a bug by [opening a new issue]().
|
We use GitHub issues to track public bugs. Report a bug by opening a new issue.
|
||||||
|
|
||||||
**Great Bug Reports** tend to have:
|
**Great Bug Reports** tend to have:
|
||||||
|
|
||||||
|
|
14
README.md
14
README.md
|
@ -90,7 +90,7 @@ sqlglot.transpile("SELECT STRFTIME(x, '%y-%-m-%S')", read="duckdb", write="hive"
|
||||||
"SELECT DATE_FORMAT(x, 'yy-M-ss')"
|
"SELECT DATE_FORMAT(x, 'yy-M-ss')"
|
||||||
```
|
```
|
||||||
|
|
||||||
As another example, let's suppose that we want to read in a SQL query that contains a CTE and a cast to `REAL`, and then transpile it to Spark, which uses backticks as identifiers and `FLOAT` instead of `REAL`:
|
As another example, let's suppose that we want to read in a SQL query that contains a CTE and a cast to `REAL`, and then transpile it to Spark, which uses backticks for identifiers and `FLOAT` instead of `REAL`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import sqlglot
|
import sqlglot
|
||||||
|
@ -376,12 +376,12 @@ print(Dialect["custom"])
|
||||||
|
|
||||||
[Benchmarks](benchmarks) run on Python 3.10.5 in seconds.
|
[Benchmarks](benchmarks) run on Python 3.10.5 in seconds.
|
||||||
|
|
||||||
| Query | sqlglot | sqltree | sqlparse | moz_sql_parser | sqloxide |
|
| Query | sqlglot | sqlfluff | sqltree | sqlparse | moz_sql_parser | sqloxide |
|
||||||
| --------------- | --------------- | --------------- | --------------- | --------------- | --------------- |
|
| --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- |
|
||||||
| tpch | 0.01178 (1.0) | 0.01173 (0.995) | 0.04676 (3.966) | 0.06800 (5.768) | 0.00094 (0.080) |
|
| tpch | 0.01308 (1.0) | 1.60626 (122.7) | 0.01168 (0.893) | 0.04958 (3.791) | 0.08543 (6.531) | 0.00136 (0.104) |
|
||||||
| short | 0.00084 (1.0) | 0.00079 (0.948) | 0.00296 (3.524) | 0.00443 (5.266) | 0.00006 (0.072) |
|
| short | 0.00109 (1.0) | 0.14134 (129.2) | 0.00099 (0.906) | 0.00342 (3.131) | 0.00652 (5.970) | 8.76621 (0.080) |
|
||||||
| long | 0.01102 (1.0) | 0.01044 (0.947) | 0.04349 (3.945) | 0.05998 (5.440) | 0.00084 (0.077) |
|
| long | 0.01399 (1.0) | 2.12632 (151.9) | 0.01126 (0.805) | 0.04410 (3.151) | 0.06671 (4.767) | 0.00107 (0.076) |
|
||||||
| crazy | 0.03751 (1.0) | 0.03471 (0.925) | 11.0796 (295.3) | 1.03355 (27.55) | 0.00529 (0.141) |
|
| crazy | 0.03969 (1.0) | 24.3777 (614.1) | 0.03917 (0.987) | 11.7043 (294.8) | 1.03280 (26.02) | 0.00625 (0.157) |
|
||||||
|
|
||||||
|
|
||||||
## Optional Dependencies
|
## Optional Dependencies
|
||||||
|
|
|
@ -5,8 +5,10 @@ collections.Iterable = collections.abc.Iterable
|
||||||
import gc
|
import gc
|
||||||
import timeit
|
import timeit
|
||||||
|
|
||||||
import moz_sql_parser
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import sqlfluff
|
||||||
|
import moz_sql_parser
|
||||||
import sqloxide
|
import sqloxide
|
||||||
import sqlparse
|
import sqlparse
|
||||||
import sqltree
|
import sqltree
|
||||||
|
@ -177,6 +179,10 @@ def sqloxide_parse(sql):
|
||||||
sqloxide.parse_sql(sql, dialect="ansi")
|
sqloxide.parse_sql(sql, dialect="ansi")
|
||||||
|
|
||||||
|
|
||||||
|
def sqlfluff_parse(sql):
|
||||||
|
sqlfluff.parse(sql)
|
||||||
|
|
||||||
|
|
||||||
def border(columns):
|
def border(columns):
|
||||||
columns = " | ".join(columns)
|
columns = " | ".join(columns)
|
||||||
return f"| {columns} |"
|
return f"| {columns} |"
|
||||||
|
@ -193,6 +199,7 @@ def diff(row, column):
|
||||||
|
|
||||||
libs = [
|
libs = [
|
||||||
"sqlglot",
|
"sqlglot",
|
||||||
|
"sqlfluff",
|
||||||
"sqltree",
|
"sqltree",
|
||||||
"sqlparse",
|
"sqlparse",
|
||||||
"moz_sql_parser",
|
"moz_sql_parser",
|
||||||
|
@ -206,7 +213,8 @@ for name, sql in {"tpch": tpch, "short": short, "long": long, "crazy": crazy}.it
|
||||||
for lib in libs:
|
for lib in libs:
|
||||||
try:
|
try:
|
||||||
row[lib] = np.mean(timeit.repeat(lambda: globals()[lib + "_parse"](sql), number=3))
|
row[lib] = np.mean(timeit.repeat(lambda: globals()[lib + "_parse"](sql), number=3))
|
||||||
except:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
row[lib] = "error"
|
row[lib] = "error"
|
||||||
|
|
||||||
columns = ["Query"] + libs
|
columns = ["Query"] + libs
|
||||||
|
|
|
@ -30,7 +30,7 @@ from sqlglot.parser import Parser
|
||||||
from sqlglot.schema import MappingSchema
|
from sqlglot.schema import MappingSchema
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
from sqlglot.tokens import Tokenizer, TokenType
|
||||||
|
|
||||||
__version__ = "10.0.1"
|
__version__ = "10.0.8"
|
||||||
|
|
||||||
pretty = False
|
pretty = False
|
||||||
|
|
||||||
|
|
|
@ -260,7 +260,10 @@ class Column:
|
||||||
"""
|
"""
|
||||||
if isinstance(dataType, DataType):
|
if isinstance(dataType, DataType):
|
||||||
dataType = dataType.simpleString()
|
dataType = dataType.simpleString()
|
||||||
new_expression = exp.Cast(this=self.column_expression, to=dataType)
|
new_expression = exp.Cast(
|
||||||
|
this=self.column_expression,
|
||||||
|
to=sqlglot.parse_one(dataType, into=exp.DataType, read="spark"), # type: ignore
|
||||||
|
)
|
||||||
return Column(new_expression)
|
return Column(new_expression)
|
||||||
|
|
||||||
def startswith(self, value: t.Union[str, Column]) -> Column:
|
def startswith(self, value: t.Union[str, Column]) -> Column:
|
||||||
|
|
|
@ -314,7 +314,13 @@ class DataFrame:
|
||||||
replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore
|
replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore
|
||||||
cache_table_name
|
cache_table_name
|
||||||
)
|
)
|
||||||
sqlglot.schema.add_table(cache_table_name, select_expression.named_selects)
|
sqlglot.schema.add_table(
|
||||||
|
cache_table_name,
|
||||||
|
{
|
||||||
|
expression.alias_or_name: expression.type.name
|
||||||
|
for expression in select_expression.expressions
|
||||||
|
},
|
||||||
|
)
|
||||||
cache_storage_level = select_expression.args["cache_storage_level"]
|
cache_storage_level = select_expression.args["cache_storage_level"]
|
||||||
options = [
|
options = [
|
||||||
exp.Literal.string("storageLevel"),
|
exp.Literal.string("storageLevel"),
|
||||||
|
|
|
@ -757,11 +757,15 @@ def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
|
||||||
|
|
||||||
|
|
||||||
def decode(col: ColumnOrName, charset: str) -> Column:
|
def decode(col: ColumnOrName, charset: str) -> Column:
|
||||||
return Column.invoke_anonymous_function(col, "DECODE", lit(charset))
|
return Column.invoke_expression_over_column(
|
||||||
|
col, glotexp.Decode, charset=glotexp.Literal.string(charset)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def encode(col: ColumnOrName, charset: str) -> Column:
|
def encode(col: ColumnOrName, charset: str) -> Column:
|
||||||
return Column.invoke_anonymous_function(col, "ENCODE", lit(charset))
|
return Column.invoke_expression_over_column(
|
||||||
|
col, glotexp.Encode, charset=glotexp.Literal.string(charset)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def format_number(col: ColumnOrName, d: int) -> Column:
|
def format_number(col: ColumnOrName, d: int) -> Column:
|
||||||
|
@ -867,11 +871,11 @@ def bin(col: ColumnOrName) -> Column:
|
||||||
|
|
||||||
|
|
||||||
def hex(col: ColumnOrName) -> Column:
|
def hex(col: ColumnOrName) -> Column:
|
||||||
return Column.invoke_anonymous_function(col, "HEX")
|
return Column.invoke_expression_over_column(col, glotexp.Hex)
|
||||||
|
|
||||||
|
|
||||||
def unhex(col: ColumnOrName) -> Column:
|
def unhex(col: ColumnOrName) -> Column:
|
||||||
return Column.invoke_anonymous_function(col, "UNHEX")
|
return Column.invoke_expression_over_column(col, glotexp.Unhex)
|
||||||
|
|
||||||
|
|
||||||
def length(col: ColumnOrName) -> Column:
|
def length(col: ColumnOrName) -> Column:
|
||||||
|
@ -939,11 +943,7 @@ def array_join(
|
||||||
|
|
||||||
|
|
||||||
def concat(*cols: ColumnOrName) -> Column:
|
def concat(*cols: ColumnOrName) -> Column:
|
||||||
if len(cols) == 1:
|
return Column.invoke_expression_over_column(None, glotexp.Concat, expressions=cols)
|
||||||
return Column.invoke_anonymous_function(cols[0], "CONCAT")
|
|
||||||
return Column.invoke_anonymous_function(
|
|
||||||
cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
|
def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
|
||||||
|
|
|
@ -88,14 +88,14 @@ class SparkSession:
|
||||||
"expressions": sel_columns,
|
"expressions": sel_columns,
|
||||||
"from": exp.From(
|
"from": exp.From(
|
||||||
expressions=[
|
expressions=[
|
||||||
exp.Subquery(
|
exp.Values(
|
||||||
this=exp.Values(expressions=data_expressions),
|
expressions=data_expressions,
|
||||||
alias=exp.TableAlias(
|
alias=exp.TableAlias(
|
||||||
this=exp.to_identifier(self._auto_incrementing_name),
|
this=exp.to_identifier(self._auto_incrementing_name),
|
||||||
columns=[exp.to_identifier(col_name) for col_name in column_mapping],
|
columns=[exp.to_identifier(col_name) for col_name in column_mapping],
|
||||||
),
|
),
|
||||||
)
|
),
|
||||||
]
|
],
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ from sqlglot.dialects.bigquery import BigQuery
|
||||||
from sqlglot.dialects.clickhouse import ClickHouse
|
from sqlglot.dialects.clickhouse import ClickHouse
|
||||||
from sqlglot.dialects.databricks import Databricks
|
from sqlglot.dialects.databricks import Databricks
|
||||||
from sqlglot.dialects.dialect import Dialect, Dialects
|
from sqlglot.dialects.dialect import Dialect, Dialects
|
||||||
|
from sqlglot.dialects.drill import Drill
|
||||||
from sqlglot.dialects.duckdb import DuckDB
|
from sqlglot.dialects.duckdb import DuckDB
|
||||||
from sqlglot.dialects.hive import Hive
|
from sqlglot.dialects.hive import Hive
|
||||||
from sqlglot.dialects.mysql import MySQL
|
from sqlglot.dialects.mysql import MySQL
|
||||||
|
|
|
@ -119,6 +119,8 @@ class BigQuery(Dialect):
|
||||||
"UNKNOWN": TokenType.NULL,
|
"UNKNOWN": TokenType.NULL,
|
||||||
"WINDOW": TokenType.WINDOW,
|
"WINDOW": TokenType.WINDOW,
|
||||||
"NOT DETERMINISTIC": TokenType.VOLATILE,
|
"NOT DETERMINISTIC": TokenType.VOLATILE,
|
||||||
|
"BEGIN": TokenType.COMMAND,
|
||||||
|
"BEGIN TRANSACTION": TokenType.BEGIN,
|
||||||
}
|
}
|
||||||
KEYWORDS.pop("DIV")
|
KEYWORDS.pop("DIV")
|
||||||
|
|
||||||
|
@ -204,6 +206,15 @@ class BigQuery(Dialect):
|
||||||
|
|
||||||
EXPLICIT_UNION = True
|
EXPLICIT_UNION = True
|
||||||
|
|
||||||
|
def transaction_sql(self, *_):
|
||||||
|
return "BEGIN TRANSACTION"
|
||||||
|
|
||||||
|
def commit_sql(self, *_):
|
||||||
|
return "COMMIT TRANSACTION"
|
||||||
|
|
||||||
|
def rollback_sql(self, *_):
|
||||||
|
return "ROLLBACK TRANSACTION"
|
||||||
|
|
||||||
def in_unnest_op(self, unnest):
|
def in_unnest_op(self, unnest):
|
||||||
return self.sql(unnest)
|
return self.sql(unnest)
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,7 @@ class Dialects(str, Enum):
|
||||||
TRINO = "trino"
|
TRINO = "trino"
|
||||||
TSQL = "tsql"
|
TSQL = "tsql"
|
||||||
DATABRICKS = "databricks"
|
DATABRICKS = "databricks"
|
||||||
|
DRILL = "drill"
|
||||||
|
|
||||||
|
|
||||||
class _Dialect(type):
|
class _Dialect(type):
|
||||||
|
@ -362,3 +363,18 @@ def parse_date_delta(exp_class, unit_mapping=None):
|
||||||
return exp_class(this=this, expression=expression, unit=unit)
|
return exp_class(this=this, expression=expression, unit=unit)
|
||||||
|
|
||||||
return inner_func
|
return inner_func
|
||||||
|
|
||||||
|
|
||||||
|
def locate_to_strposition(args):
|
||||||
|
return exp.StrPosition(
|
||||||
|
this=seq_get(args, 1),
|
||||||
|
substr=seq_get(args, 0),
|
||||||
|
position=seq_get(args, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def strposition_to_local_sql(self, expression):
|
||||||
|
args = self.format_args(
|
||||||
|
expression.args.get("substr"), expression.this, expression.args.get("position")
|
||||||
|
)
|
||||||
|
return f"LOCATE({args})"
|
||||||
|
|
174
sqlglot/dialects/drill.py
Normal file
174
sqlglot/dialects/drill.py
Normal file
|
@ -0,0 +1,174 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from sqlglot import exp, generator, parser, tokens
|
||||||
|
from sqlglot.dialects.dialect import (
|
||||||
|
Dialect,
|
||||||
|
create_with_partitions_sql,
|
||||||
|
format_time_lambda,
|
||||||
|
no_pivot_sql,
|
||||||
|
no_trycast_sql,
|
||||||
|
rename_func,
|
||||||
|
str_position_sql,
|
||||||
|
)
|
||||||
|
from sqlglot.dialects.postgres import _lateral_sql
|
||||||
|
|
||||||
|
|
||||||
|
def _to_timestamp(args):
|
||||||
|
# TO_TIMESTAMP accepts either a single double argument or (text, text)
|
||||||
|
if len(args) == 1 and args[0].is_number:
|
||||||
|
return exp.UnixToTime.from_arg_list(args)
|
||||||
|
return format_time_lambda(exp.StrToTime, "drill")(args)
|
||||||
|
|
||||||
|
|
||||||
|
def _str_to_time_sql(self, expression):
|
||||||
|
return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
|
||||||
|
|
||||||
|
|
||||||
|
def _ts_or_ds_to_date_sql(self, expression):
|
||||||
|
time_format = self.format_time(expression)
|
||||||
|
if time_format and time_format not in (Drill.time_format, Drill.date_format):
|
||||||
|
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
|
||||||
|
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
||||||
|
|
||||||
|
|
||||||
|
def _date_add_sql(kind):
|
||||||
|
def func(self, expression):
|
||||||
|
this = self.sql(expression, "this")
|
||||||
|
unit = expression.text("unit").upper() or "DAY"
|
||||||
|
expression = self.sql(expression, "expression")
|
||||||
|
return f"DATE_{kind}({this}, INTERVAL '{expression}' {unit})"
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def if_sql(self, expression):
|
||||||
|
"""
|
||||||
|
Drill requires backticks around certain SQL reserved words, IF being one of them, This function
|
||||||
|
adds the backticks around the keyword IF.
|
||||||
|
Args:
|
||||||
|
self: The Drill dialect
|
||||||
|
expression: The input IF expression
|
||||||
|
|
||||||
|
Returns: The expression with IF in backticks.
|
||||||
|
|
||||||
|
"""
|
||||||
|
expressions = self.format_args(
|
||||||
|
expression.this, expression.args.get("true"), expression.args.get("false")
|
||||||
|
)
|
||||||
|
return f"`IF`({expressions})"
|
||||||
|
|
||||||
|
|
||||||
|
def _str_to_date(self, expression):
|
||||||
|
this = self.sql(expression, "this")
|
||||||
|
time_format = self.format_time(expression)
|
||||||
|
if time_format == Drill.date_format:
|
||||||
|
return f"CAST({this} AS DATE)"
|
||||||
|
return f"TO_DATE({this}, {time_format})"
|
||||||
|
|
||||||
|
|
||||||
|
class Drill(Dialect):
|
||||||
|
normalize_functions = None
|
||||||
|
null_ordering = "nulls_are_last"
|
||||||
|
date_format = "'yyyy-MM-dd'"
|
||||||
|
dateint_format = "'yyyyMMdd'"
|
||||||
|
time_format = "'yyyy-MM-dd HH:mm:ss'"
|
||||||
|
|
||||||
|
time_mapping = {
|
||||||
|
"y": "%Y",
|
||||||
|
"Y": "%Y",
|
||||||
|
"YYYY": "%Y",
|
||||||
|
"yyyy": "%Y",
|
||||||
|
"YY": "%y",
|
||||||
|
"yy": "%y",
|
||||||
|
"MMMM": "%B",
|
||||||
|
"MMM": "%b",
|
||||||
|
"MM": "%m",
|
||||||
|
"M": "%-m",
|
||||||
|
"dd": "%d",
|
||||||
|
"d": "%-d",
|
||||||
|
"HH": "%H",
|
||||||
|
"H": "%-H",
|
||||||
|
"hh": "%I",
|
||||||
|
"h": "%-I",
|
||||||
|
"mm": "%M",
|
||||||
|
"m": "%-M",
|
||||||
|
"ss": "%S",
|
||||||
|
"s": "%-S",
|
||||||
|
"SSSSSS": "%f",
|
||||||
|
"a": "%p",
|
||||||
|
"DD": "%j",
|
||||||
|
"D": "%-j",
|
||||||
|
"E": "%a",
|
||||||
|
"EE": "%a",
|
||||||
|
"EEE": "%a",
|
||||||
|
"EEEE": "%A",
|
||||||
|
"''T''": "T",
|
||||||
|
}
|
||||||
|
|
||||||
|
class Tokenizer(tokens.Tokenizer):
|
||||||
|
QUOTES = ["'"]
|
||||||
|
IDENTIFIERS = ["`"]
|
||||||
|
ESCAPES = ["\\"]
|
||||||
|
ENCODE = "utf-8"
|
||||||
|
|
||||||
|
class Parser(parser.Parser):
|
||||||
|
STRICT_CAST = False
|
||||||
|
|
||||||
|
FUNCTIONS = {
|
||||||
|
**parser.Parser.FUNCTIONS,
|
||||||
|
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
|
||||||
|
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
|
||||||
|
}
|
||||||
|
|
||||||
|
class Generator(generator.Generator):
|
||||||
|
TYPE_MAPPING = {
|
||||||
|
**generator.Generator.TYPE_MAPPING,
|
||||||
|
exp.DataType.Type.INT: "INTEGER",
|
||||||
|
exp.DataType.Type.SMALLINT: "INTEGER",
|
||||||
|
exp.DataType.Type.TINYINT: "INTEGER",
|
||||||
|
exp.DataType.Type.BINARY: "VARBINARY",
|
||||||
|
exp.DataType.Type.TEXT: "VARCHAR",
|
||||||
|
exp.DataType.Type.NCHAR: "VARCHAR",
|
||||||
|
exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP",
|
||||||
|
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||||
|
exp.DataType.Type.DATETIME: "TIMESTAMP",
|
||||||
|
}
|
||||||
|
|
||||||
|
ROOT_PROPERTIES = {exp.PartitionedByProperty}
|
||||||
|
|
||||||
|
TRANSFORMS = {
|
||||||
|
**generator.Generator.TRANSFORMS,
|
||||||
|
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||||
|
exp.Lateral: _lateral_sql,
|
||||||
|
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
|
||||||
|
exp.ArraySize: rename_func("REPEATED_COUNT"),
|
||||||
|
exp.Create: create_with_partitions_sql,
|
||||||
|
exp.DateAdd: _date_add_sql("ADD"),
|
||||||
|
exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
|
||||||
|
exp.DateSub: _date_add_sql("SUB"),
|
||||||
|
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)",
|
||||||
|
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})",
|
||||||
|
exp.If: if_sql,
|
||||||
|
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
|
||||||
|
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
|
||||||
|
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
|
||||||
|
exp.Pivot: no_pivot_sql,
|
||||||
|
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
|
||||||
|
exp.StrPosition: str_position_sql,
|
||||||
|
exp.StrToDate: _str_to_date,
|
||||||
|
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||||
|
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
|
||||||
|
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
|
||||||
|
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||||
|
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||||
|
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||||
|
exp.TryCast: no_trycast_sql,
|
||||||
|
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), INTERVAL '{self.sql(e, 'expression')}' DAY)",
|
||||||
|
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
|
||||||
|
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
|
||||||
|
}
|
||||||
|
|
||||||
|
def normalize_func(self, name):
|
||||||
|
return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`"
|
|
@ -55,13 +55,13 @@ def _array_sort_sql(self, expression):
|
||||||
|
|
||||||
def _sort_array_sql(self, expression):
|
def _sort_array_sql(self, expression):
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
if expression.args.get("asc") == exp.FALSE:
|
if expression.args.get("asc") == exp.false():
|
||||||
return f"ARRAY_REVERSE_SORT({this})"
|
return f"ARRAY_REVERSE_SORT({this})"
|
||||||
return f"ARRAY_SORT({this})"
|
return f"ARRAY_SORT({this})"
|
||||||
|
|
||||||
|
|
||||||
def _sort_array_reverse(args):
|
def _sort_array_reverse(args):
|
||||||
return exp.SortArray(this=seq_get(args, 0), asc=exp.FALSE)
|
return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
|
||||||
|
|
||||||
|
|
||||||
def _struct_pack_sql(self, expression):
|
def _struct_pack_sql(self, expression):
|
||||||
|
|
|
@ -7,16 +7,19 @@ from sqlglot.dialects.dialect import (
|
||||||
create_with_partitions_sql,
|
create_with_partitions_sql,
|
||||||
format_time_lambda,
|
format_time_lambda,
|
||||||
if_sql,
|
if_sql,
|
||||||
|
locate_to_strposition,
|
||||||
no_ilike_sql,
|
no_ilike_sql,
|
||||||
no_recursive_cte_sql,
|
no_recursive_cte_sql,
|
||||||
no_safe_divide_sql,
|
no_safe_divide_sql,
|
||||||
no_trycast_sql,
|
no_trycast_sql,
|
||||||
rename_func,
|
rename_func,
|
||||||
|
strposition_to_local_sql,
|
||||||
struct_extract_sql,
|
struct_extract_sql,
|
||||||
var_map_sql,
|
var_map_sql,
|
||||||
)
|
)
|
||||||
from sqlglot.helper import seq_get
|
from sqlglot.helper import seq_get
|
||||||
from sqlglot.parser import parse_var_map
|
from sqlglot.parser import parse_var_map
|
||||||
|
from sqlglot.tokens import TokenType
|
||||||
|
|
||||||
# (FuncType, Multiplier)
|
# (FuncType, Multiplier)
|
||||||
DATE_DELTA_INTERVAL = {
|
DATE_DELTA_INTERVAL = {
|
||||||
|
@ -181,6 +184,15 @@ class Hive(Dialect):
|
||||||
"F": "FLOAT",
|
"F": "FLOAT",
|
||||||
"BD": "DECIMAL",
|
"BD": "DECIMAL",
|
||||||
}
|
}
|
||||||
|
KEYWORDS = {
|
||||||
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
|
"ADD ARCHIVE": TokenType.COMMAND,
|
||||||
|
"ADD ARCHIVES": TokenType.COMMAND,
|
||||||
|
"ADD FILE": TokenType.COMMAND,
|
||||||
|
"ADD FILES": TokenType.COMMAND,
|
||||||
|
"ADD JAR": TokenType.COMMAND,
|
||||||
|
"ADD JARS": TokenType.COMMAND,
|
||||||
|
}
|
||||||
|
|
||||||
class Parser(parser.Parser):
|
class Parser(parser.Parser):
|
||||||
STRICT_CAST = False
|
STRICT_CAST = False
|
||||||
|
@ -210,11 +222,7 @@ class Hive(Dialect):
|
||||||
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||||
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
|
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
|
||||||
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
|
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
|
||||||
"LOCATE": lambda args: exp.StrPosition(
|
"LOCATE": locate_to_strposition,
|
||||||
this=seq_get(args, 1),
|
|
||||||
substr=seq_get(args, 0),
|
|
||||||
position=seq_get(args, 2),
|
|
||||||
),
|
|
||||||
"LOG": (
|
"LOG": (
|
||||||
lambda args: exp.Log.from_arg_list(args)
|
lambda args: exp.Log.from_arg_list(args)
|
||||||
if len(args) > 1
|
if len(args) > 1
|
||||||
|
@ -272,7 +280,7 @@ class Hive(Dialect):
|
||||||
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
|
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
|
||||||
exp.SetAgg: rename_func("COLLECT_SET"),
|
exp.SetAgg: rename_func("COLLECT_SET"),
|
||||||
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
|
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
|
||||||
exp.StrPosition: lambda self, e: f"LOCATE({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})",
|
exp.StrPosition: strposition_to_local_sql,
|
||||||
exp.StrToDate: _str_to_date,
|
exp.StrToDate: _str_to_date,
|
||||||
exp.StrToTime: _str_to_time,
|
exp.StrToTime: _str_to_time,
|
||||||
exp.StrToUnix: _str_to_unix,
|
exp.StrToUnix: _str_to_unix,
|
||||||
|
|
|
@ -5,10 +5,12 @@ import typing as t
|
||||||
from sqlglot import exp, generator, parser, tokens
|
from sqlglot import exp, generator, parser, tokens
|
||||||
from sqlglot.dialects.dialect import (
|
from sqlglot.dialects.dialect import (
|
||||||
Dialect,
|
Dialect,
|
||||||
|
locate_to_strposition,
|
||||||
no_ilike_sql,
|
no_ilike_sql,
|
||||||
no_paren_current_date_sql,
|
no_paren_current_date_sql,
|
||||||
no_tablesample_sql,
|
no_tablesample_sql,
|
||||||
no_trycast_sql,
|
no_trycast_sql,
|
||||||
|
strposition_to_local_sql,
|
||||||
)
|
)
|
||||||
from sqlglot.helper import seq_get
|
from sqlglot.helper import seq_get
|
||||||
from sqlglot.tokens import TokenType
|
from sqlglot.tokens import TokenType
|
||||||
|
@ -120,6 +122,7 @@ class MySQL(Dialect):
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**tokens.Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
|
"START": TokenType.BEGIN,
|
||||||
"SEPARATOR": TokenType.SEPARATOR,
|
"SEPARATOR": TokenType.SEPARATOR,
|
||||||
"_ARMSCII8": TokenType.INTRODUCER,
|
"_ARMSCII8": TokenType.INTRODUCER,
|
||||||
"_ASCII": TokenType.INTRODUCER,
|
"_ASCII": TokenType.INTRODUCER,
|
||||||
|
@ -172,13 +175,18 @@ class MySQL(Dialect):
|
||||||
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
|
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
|
||||||
|
|
||||||
class Parser(parser.Parser):
|
class Parser(parser.Parser):
|
||||||
STRICT_CAST = False
|
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA}
|
||||||
|
|
||||||
FUNCTIONS = {
|
FUNCTIONS = {
|
||||||
**parser.Parser.FUNCTIONS,
|
**parser.Parser.FUNCTIONS,
|
||||||
"DATE_ADD": _date_add(exp.DateAdd),
|
"DATE_ADD": _date_add(exp.DateAdd),
|
||||||
"DATE_SUB": _date_add(exp.DateSub),
|
"DATE_SUB": _date_add(exp.DateSub),
|
||||||
"STR_TO_DATE": _str_to_date,
|
"STR_TO_DATE": _str_to_date,
|
||||||
|
"LOCATE": locate_to_strposition,
|
||||||
|
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
|
||||||
|
"LEFT": lambda args: exp.Substring(
|
||||||
|
this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1)
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
FUNCTION_PARSERS = {
|
FUNCTION_PARSERS = {
|
||||||
|
@ -264,6 +272,7 @@ class MySQL(Dialect):
|
||||||
"CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
|
"CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
|
||||||
"CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
|
"CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
|
||||||
"NAMES": lambda self: self._parse_set_item_names(),
|
"NAMES": lambda self: self._parse_set_item_names(),
|
||||||
|
"TRANSACTION": lambda self: self._parse_set_transaction(),
|
||||||
}
|
}
|
||||||
|
|
||||||
PROFILE_TYPES = {
|
PROFILE_TYPES = {
|
||||||
|
@ -278,39 +287,48 @@ class MySQL(Dialect):
|
||||||
"SWAPS",
|
"SWAPS",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TRANSACTION_CHARACTERISTICS = {
|
||||||
|
"ISOLATION LEVEL REPEATABLE READ",
|
||||||
|
"ISOLATION LEVEL READ COMMITTED",
|
||||||
|
"ISOLATION LEVEL READ UNCOMMITTED",
|
||||||
|
"ISOLATION LEVEL SERIALIZABLE",
|
||||||
|
"READ WRITE",
|
||||||
|
"READ ONLY",
|
||||||
|
}
|
||||||
|
|
||||||
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
|
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
|
||||||
if target:
|
if target:
|
||||||
if isinstance(target, str):
|
if isinstance(target, str):
|
||||||
self._match_text(target)
|
self._match_text_seq(target)
|
||||||
target_id = self._parse_id_var()
|
target_id = self._parse_id_var()
|
||||||
else:
|
else:
|
||||||
target_id = None
|
target_id = None
|
||||||
|
|
||||||
log = self._parse_string() if self._match_text("IN") else None
|
log = self._parse_string() if self._match_text_seq("IN") else None
|
||||||
|
|
||||||
if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}:
|
if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}:
|
||||||
position = self._parse_number() if self._match_text("FROM") else None
|
position = self._parse_number() if self._match_text_seq("FROM") else None
|
||||||
db = None
|
db = None
|
||||||
else:
|
else:
|
||||||
position = None
|
position = None
|
||||||
db = self._parse_id_var() if self._match_text("FROM") else None
|
db = self._parse_id_var() if self._match_text_seq("FROM") else None
|
||||||
|
|
||||||
channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None
|
channel = self._parse_id_var() if self._match_text_seq("FOR", "CHANNEL") else None
|
||||||
|
|
||||||
like = self._parse_string() if self._match_text("LIKE") else None
|
like = self._parse_string() if self._match_text_seq("LIKE") else None
|
||||||
where = self._parse_where()
|
where = self._parse_where()
|
||||||
|
|
||||||
if this == "PROFILE":
|
if this == "PROFILE":
|
||||||
types = self._parse_csv(self._parse_show_profile_type)
|
types = self._parse_csv(lambda: self._parse_var_from_options(self.PROFILE_TYPES))
|
||||||
query = self._parse_number() if self._match_text("FOR", "QUERY") else None
|
query = self._parse_number() if self._match_text_seq("FOR", "QUERY") else None
|
||||||
offset = self._parse_number() if self._match_text("OFFSET") else None
|
offset = self._parse_number() if self._match_text_seq("OFFSET") else None
|
||||||
limit = self._parse_number() if self._match_text("LIMIT") else None
|
limit = self._parse_number() if self._match_text_seq("LIMIT") else None
|
||||||
else:
|
else:
|
||||||
types, query = None, None
|
types, query = None, None
|
||||||
offset, limit = self._parse_oldstyle_limit()
|
offset, limit = self._parse_oldstyle_limit()
|
||||||
|
|
||||||
mutex = True if self._match_text("MUTEX") else None
|
mutex = True if self._match_text_seq("MUTEX") else None
|
||||||
mutex = False if self._match_text("STATUS") else mutex
|
mutex = False if self._match_text_seq("STATUS") else mutex
|
||||||
|
|
||||||
return self.expression(
|
return self.expression(
|
||||||
exp.Show,
|
exp.Show,
|
||||||
|
@ -331,16 +349,16 @@ class MySQL(Dialect):
|
||||||
**{"global": global_},
|
**{"global": global_},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_show_profile_type(self):
|
def _parse_var_from_options(self, options):
|
||||||
for type_ in self.PROFILE_TYPES:
|
for option in options:
|
||||||
if self._match_text(*type_.split(" ")):
|
if self._match_text_seq(*option.split(" ")):
|
||||||
return exp.Var(this=type_)
|
return exp.Var(this=option)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_oldstyle_limit(self):
|
def _parse_oldstyle_limit(self):
|
||||||
limit = None
|
limit = None
|
||||||
offset = None
|
offset = None
|
||||||
if self._match_text("LIMIT"):
|
if self._match_text_seq("LIMIT"):
|
||||||
parts = self._parse_csv(self._parse_number)
|
parts = self._parse_csv(self._parse_number)
|
||||||
if len(parts) == 1:
|
if len(parts) == 1:
|
||||||
limit = parts[0]
|
limit = parts[0]
|
||||||
|
@ -353,6 +371,9 @@ class MySQL(Dialect):
|
||||||
return self._parse_set_item_assignment(kind=None)
|
return self._parse_set_item_assignment(kind=None)
|
||||||
|
|
||||||
def _parse_set_item_assignment(self, kind):
|
def _parse_set_item_assignment(self, kind):
|
||||||
|
if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"):
|
||||||
|
return self._parse_set_transaction(global_=kind == "GLOBAL")
|
||||||
|
|
||||||
left = self._parse_primary() or self._parse_id_var()
|
left = self._parse_primary() or self._parse_id_var()
|
||||||
if not self._match(TokenType.EQ):
|
if not self._match(TokenType.EQ):
|
||||||
self.raise_error("Expected =")
|
self.raise_error("Expected =")
|
||||||
|
@ -381,7 +402,7 @@ class MySQL(Dialect):
|
||||||
|
|
||||||
def _parse_set_item_names(self):
|
def _parse_set_item_names(self):
|
||||||
charset = self._parse_string() or self._parse_id_var()
|
charset = self._parse_string() or self._parse_id_var()
|
||||||
if self._match_text("COLLATE"):
|
if self._match_text_seq("COLLATE"):
|
||||||
collate = self._parse_string() or self._parse_id_var()
|
collate = self._parse_string() or self._parse_id_var()
|
||||||
else:
|
else:
|
||||||
collate = None
|
collate = None
|
||||||
|
@ -392,6 +413,18 @@ class MySQL(Dialect):
|
||||||
kind="NAMES",
|
kind="NAMES",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _parse_set_transaction(self, global_=False):
|
||||||
|
self._match_text_seq("TRANSACTION")
|
||||||
|
characteristics = self._parse_csv(
|
||||||
|
lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS)
|
||||||
|
)
|
||||||
|
return self.expression(
|
||||||
|
exp.SetItem,
|
||||||
|
expressions=characteristics,
|
||||||
|
kind="TRANSACTION",
|
||||||
|
**{"global": global_},
|
||||||
|
)
|
||||||
|
|
||||||
class Generator(generator.Generator):
|
class Generator(generator.Generator):
|
||||||
NULL_ORDERING_SUPPORTED = False
|
NULL_ORDERING_SUPPORTED = False
|
||||||
|
|
||||||
|
@ -411,6 +444,7 @@ class MySQL(Dialect):
|
||||||
exp.Trim: _trim_sql,
|
exp.Trim: _trim_sql,
|
||||||
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
|
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
|
||||||
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
|
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
|
||||||
|
exp.StrPosition: strposition_to_local_sql,
|
||||||
}
|
}
|
||||||
|
|
||||||
ROOT_PROPERTIES = {
|
ROOT_PROPERTIES = {
|
||||||
|
@ -481,9 +515,11 @@ class MySQL(Dialect):
|
||||||
kind = self.sql(expression, "kind")
|
kind = self.sql(expression, "kind")
|
||||||
kind = f"{kind} " if kind else ""
|
kind = f"{kind} " if kind else ""
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
|
expressions = self.expressions(expression)
|
||||||
collate = self.sql(expression, "collate")
|
collate = self.sql(expression, "collate")
|
||||||
collate = f" COLLATE {collate}" if collate else ""
|
collate = f" COLLATE {collate}" if collate else ""
|
||||||
return f"{kind}{this}{collate}"
|
global_ = "GLOBAL " if expression.args.get("global") else ""
|
||||||
|
return f"{global_}{kind}{this}{expressions}{collate}"
|
||||||
|
|
||||||
def set_sql(self, expression):
|
def set_sql(self, expression):
|
||||||
return f"SET {self.expressions(expression)}"
|
return f"SET {self.expressions(expression)}"
|
||||||
|
|
|
@ -91,6 +91,7 @@ class Oracle(Dialect):
|
||||||
class Tokenizer(tokens.Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**tokens.Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
|
"START": TokenType.BEGIN,
|
||||||
"TOP": TokenType.TOP,
|
"TOP": TokenType.TOP,
|
||||||
"VARCHAR2": TokenType.VARCHAR,
|
"VARCHAR2": TokenType.VARCHAR,
|
||||||
"NVARCHAR2": TokenType.NVARCHAR,
|
"NVARCHAR2": TokenType.NVARCHAR,
|
||||||
|
|
|
@ -164,11 +164,34 @@ class Postgres(Dialect):
|
||||||
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
|
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
|
||||||
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
|
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
|
||||||
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
|
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
|
||||||
|
|
||||||
|
CREATABLES = (
|
||||||
|
"AGGREGATE",
|
||||||
|
"CAST",
|
||||||
|
"CONVERSION",
|
||||||
|
"COLLATION",
|
||||||
|
"DEFAULT CONVERSION",
|
||||||
|
"CONSTRAINT",
|
||||||
|
"DOMAIN",
|
||||||
|
"EXTENSION",
|
||||||
|
"FOREIGN",
|
||||||
|
"FUNCTION",
|
||||||
|
"OPERATOR",
|
||||||
|
"POLICY",
|
||||||
|
"ROLE",
|
||||||
|
"RULE",
|
||||||
|
"SEQUENCE",
|
||||||
|
"TEXT",
|
||||||
|
"TRIGGER",
|
||||||
|
"TYPE",
|
||||||
|
"UNLOGGED",
|
||||||
|
"USER",
|
||||||
|
)
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**tokens.Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"ALWAYS": TokenType.ALWAYS,
|
"ALWAYS": TokenType.ALWAYS,
|
||||||
"BY DEFAULT": TokenType.BY_DEFAULT,
|
"BY DEFAULT": TokenType.BY_DEFAULT,
|
||||||
"COMMENT ON": TokenType.COMMENT_ON,
|
|
||||||
"IDENTITY": TokenType.IDENTITY,
|
"IDENTITY": TokenType.IDENTITY,
|
||||||
"GENERATED": TokenType.GENERATED,
|
"GENERATED": TokenType.GENERATED,
|
||||||
"DOUBLE PRECISION": TokenType.DOUBLE,
|
"DOUBLE PRECISION": TokenType.DOUBLE,
|
||||||
|
@ -176,6 +199,19 @@ class Postgres(Dialect):
|
||||||
"SERIAL": TokenType.SERIAL,
|
"SERIAL": TokenType.SERIAL,
|
||||||
"SMALLSERIAL": TokenType.SMALLSERIAL,
|
"SMALLSERIAL": TokenType.SMALLSERIAL,
|
||||||
"UUID": TokenType.UUID,
|
"UUID": TokenType.UUID,
|
||||||
|
"TEMP": TokenType.TEMPORARY,
|
||||||
|
"BEGIN TRANSACTION": TokenType.BEGIN,
|
||||||
|
"BEGIN": TokenType.COMMAND,
|
||||||
|
"COMMENT ON": TokenType.COMMAND,
|
||||||
|
"DECLARE": TokenType.COMMAND,
|
||||||
|
"DO": TokenType.COMMAND,
|
||||||
|
"REFRESH": TokenType.COMMAND,
|
||||||
|
"REINDEX": TokenType.COMMAND,
|
||||||
|
"RESET": TokenType.COMMAND,
|
||||||
|
"REVOKE": TokenType.COMMAND,
|
||||||
|
"GRANT": TokenType.COMMAND,
|
||||||
|
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
|
||||||
|
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
|
||||||
}
|
}
|
||||||
QUOTES = ["'", "$$"]
|
QUOTES = ["'", "$$"]
|
||||||
SINGLE_TOKENS = {
|
SINGLE_TOKENS = {
|
||||||
|
|
|
@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
|
||||||
struct_extract_sql,
|
struct_extract_sql,
|
||||||
)
|
)
|
||||||
from sqlglot.dialects.mysql import MySQL
|
from sqlglot.dialects.mysql import MySQL
|
||||||
|
from sqlglot.errors import UnsupportedError
|
||||||
from sqlglot.helper import seq_get
|
from sqlglot.helper import seq_get
|
||||||
from sqlglot.tokens import TokenType
|
from sqlglot.tokens import TokenType
|
||||||
|
|
||||||
|
@ -61,8 +62,18 @@ def _initcap_sql(self, expression):
|
||||||
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
|
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_sql(self, expression):
|
||||||
|
_ensure_utf8(expression.args.get("charset"))
|
||||||
|
return f"FROM_UTF8({self.sql(expression, 'this')})"
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_sql(self, expression):
|
||||||
|
_ensure_utf8(expression.args.get("charset"))
|
||||||
|
return f"TO_UTF8({self.sql(expression, 'this')})"
|
||||||
|
|
||||||
|
|
||||||
def _no_sort_array(self, expression):
|
def _no_sort_array(self, expression):
|
||||||
if expression.args.get("asc") == exp.FALSE:
|
if expression.args.get("asc") == exp.false():
|
||||||
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
|
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
|
||||||
else:
|
else:
|
||||||
comparator = None
|
comparator = None
|
||||||
|
@ -72,7 +83,7 @@ def _no_sort_array(self, expression):
|
||||||
|
|
||||||
def _schema_sql(self, expression):
|
def _schema_sql(self, expression):
|
||||||
if isinstance(expression.parent, exp.Property):
|
if isinstance(expression.parent, exp.Property):
|
||||||
columns = ", ".join(f"'{c.text('this')}'" for c in expression.expressions)
|
columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
|
||||||
return f"ARRAY[{columns}]"
|
return f"ARRAY[{columns}]"
|
||||||
|
|
||||||
for schema in expression.parent.find_all(exp.Schema):
|
for schema in expression.parent.find_all(exp.Schema):
|
||||||
|
@ -106,6 +117,11 @@ def _ts_or_ds_add_sql(self, expression):
|
||||||
return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
|
return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_utf8(charset):
|
||||||
|
if charset.name.lower() != "utf-8":
|
||||||
|
raise UnsupportedError(f"Unsupported charset {charset}")
|
||||||
|
|
||||||
|
|
||||||
class Presto(Dialect):
|
class Presto(Dialect):
|
||||||
index_offset = 1
|
index_offset = 1
|
||||||
null_ordering = "nulls_are_last"
|
null_ordering = "nulls_are_last"
|
||||||
|
@ -115,6 +131,7 @@ class Presto(Dialect):
|
||||||
class Tokenizer(tokens.Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**tokens.Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
|
"START": TokenType.BEGIN,
|
||||||
"ROW": TokenType.STRUCT,
|
"ROW": TokenType.STRUCT,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -140,6 +157,14 @@ class Presto(Dialect):
|
||||||
"STRPOS": exp.StrPosition.from_arg_list,
|
"STRPOS": exp.StrPosition.from_arg_list,
|
||||||
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
|
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
|
||||||
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||||
|
"FROM_HEX": exp.Unhex.from_arg_list,
|
||||||
|
"TO_HEX": exp.Hex.from_arg_list,
|
||||||
|
"TO_UTF8": lambda args: exp.Encode(
|
||||||
|
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
|
||||||
|
),
|
||||||
|
"FROM_UTF8": lambda args: exp.Decode(
|
||||||
|
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
class Generator(generator.Generator):
|
class Generator(generator.Generator):
|
||||||
|
@ -187,7 +212,10 @@ class Presto(Dialect):
|
||||||
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
|
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
|
||||||
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
|
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
|
||||||
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
|
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
|
||||||
|
exp.Decode: _decode_sql,
|
||||||
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
|
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
|
||||||
|
exp.Encode: _encode_sql,
|
||||||
|
exp.Hex: rename_func("TO_HEX"),
|
||||||
exp.If: if_sql,
|
exp.If: if_sql,
|
||||||
exp.ILike: no_ilike_sql,
|
exp.ILike: no_ilike_sql,
|
||||||
exp.Initcap: _initcap_sql,
|
exp.Initcap: _initcap_sql,
|
||||||
|
@ -212,7 +240,13 @@ class Presto(Dialect):
|
||||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
|
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
|
||||||
exp.TsOrDsAdd: _ts_or_ds_add_sql,
|
exp.TsOrDsAdd: _ts_or_ds_add_sql,
|
||||||
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
|
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
|
||||||
|
exp.Unhex: rename_func("FROM_HEX"),
|
||||||
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
|
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
|
||||||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||||
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
|
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def transaction_sql(self, expression):
|
||||||
|
modes = expression.args.get("modes")
|
||||||
|
modes = f" {', '.join(modes)}" if modes else ""
|
||||||
|
return f"START TRANSACTION{modes}"
|
||||||
|
|
|
@ -148,6 +148,7 @@ class Snowflake(Dialect):
|
||||||
**parser.Parser.FUNCTION_PARSERS,
|
**parser.Parser.FUNCTION_PARSERS,
|
||||||
"DATE_PART": _parse_date_part,
|
"DATE_PART": _parse_date_part,
|
||||||
}
|
}
|
||||||
|
FUNCTION_PARSERS.pop("TRIM")
|
||||||
|
|
||||||
FUNC_TOKENS = {
|
FUNC_TOKENS = {
|
||||||
*parser.Parser.FUNC_TOKENS,
|
*parser.Parser.FUNC_TOKENS,
|
||||||
|
@ -203,6 +204,7 @@ class Snowflake(Dialect):
|
||||||
exp.StrPosition: rename_func("POSITION"),
|
exp.StrPosition: rename_func("POSITION"),
|
||||||
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
|
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
|
||||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
|
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
|
||||||
|
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
|
|
|
@ -63,3 +63,8 @@ class SQLite(Dialect):
|
||||||
exp.TableSample: no_tablesample_sql,
|
exp.TableSample: no_tablesample_sql,
|
||||||
exp.TryCast: no_trycast_sql,
|
exp.TryCast: no_trycast_sql,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def transaction_sql(self, expression):
|
||||||
|
this = expression.this
|
||||||
|
this = f" {this}" if this else ""
|
||||||
|
return f"BEGIN{this} TRANSACTION"
|
||||||
|
|
|
@ -248,7 +248,7 @@ class TSQL(Dialect):
|
||||||
def _parse_convert(self, strict):
|
def _parse_convert(self, strict):
|
||||||
to = self._parse_types()
|
to = self._parse_types()
|
||||||
self._match(TokenType.COMMA)
|
self._match(TokenType.COMMA)
|
||||||
this = self._parse_column()
|
this = self._parse_conjunction()
|
||||||
|
|
||||||
# Retrieve length of datatype and override to default if not specified
|
# Retrieve length of datatype and override to default if not specified
|
||||||
if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
|
if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing as t
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from heapq import heappop, heappush
|
from heapq import heappop, heappush
|
||||||
|
@ -6,6 +9,10 @@ from sqlglot import Dialect
|
||||||
from sqlglot import expressions as exp
|
from sqlglot import expressions as exp
|
||||||
from sqlglot.helper import ensure_collection
|
from sqlglot.helper import ensure_collection
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
T = t.TypeVar("T")
|
||||||
|
Edit = t.Union[Insert, Remove, Move, Update, Keep]
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Insert:
|
class Insert:
|
||||||
|
@ -44,7 +51,7 @@ class Keep:
|
||||||
target: exp.Expression
|
target: exp.Expression
|
||||||
|
|
||||||
|
|
||||||
def diff(source, target):
|
def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
|
||||||
"""
|
"""
|
||||||
Returns the list of changes between the source and the target expressions.
|
Returns the list of changes between the source and the target expressions.
|
||||||
|
|
||||||
|
@ -89,25 +96,25 @@ class ChangeDistiller:
|
||||||
Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf.
|
Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, f=0.6, t=0.6):
|
def __init__(self, f: float = 0.6, t: float = 0.6) -> None:
|
||||||
self.f = f
|
self.f = f
|
||||||
self.t = t
|
self.t = t
|
||||||
self._sql_generator = Dialect().generator()
|
self._sql_generator = Dialect().generator()
|
||||||
|
|
||||||
def diff(self, source, target):
|
def diff(self, source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
|
||||||
self._source = source
|
self._source = source
|
||||||
self._target = target
|
self._target = target
|
||||||
self._source_index = {id(n[0]): n[0] for n in source.bfs()}
|
self._source_index = {id(n[0]): n[0] for n in source.bfs()}
|
||||||
self._target_index = {id(n[0]): n[0] for n in target.bfs()}
|
self._target_index = {id(n[0]): n[0] for n in target.bfs()}
|
||||||
self._unmatched_source_nodes = set(self._source_index)
|
self._unmatched_source_nodes = set(self._source_index)
|
||||||
self._unmatched_target_nodes = set(self._target_index)
|
self._unmatched_target_nodes = set(self._target_index)
|
||||||
self._bigram_histo_cache = {}
|
self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
|
||||||
|
|
||||||
matching_set = self._compute_matching_set()
|
matching_set = self._compute_matching_set()
|
||||||
return self._generate_edit_script(matching_set)
|
return self._generate_edit_script(matching_set)
|
||||||
|
|
||||||
def _generate_edit_script(self, matching_set):
|
def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.List[Edit]:
|
||||||
edit_script = []
|
edit_script: t.List[Edit] = []
|
||||||
for removed_node_id in self._unmatched_source_nodes:
|
for removed_node_id in self._unmatched_source_nodes:
|
||||||
edit_script.append(Remove(self._source_index[removed_node_id]))
|
edit_script.append(Remove(self._source_index[removed_node_id]))
|
||||||
for inserted_node_id in self._unmatched_target_nodes:
|
for inserted_node_id in self._unmatched_target_nodes:
|
||||||
|
@ -125,7 +132,9 @@ class ChangeDistiller:
|
||||||
|
|
||||||
return edit_script
|
return edit_script
|
||||||
|
|
||||||
def _generate_move_edits(self, source, target, matching_set):
|
def _generate_move_edits(
|
||||||
|
self, source: exp.Expression, target: exp.Expression, matching_set: t.Set[t.Tuple[int, int]]
|
||||||
|
) -> t.List[Move]:
|
||||||
source_args = [id(e) for e in _expression_only_args(source)]
|
source_args = [id(e) for e in _expression_only_args(source)]
|
||||||
target_args = [id(e) for e in _expression_only_args(target)]
|
target_args = [id(e) for e in _expression_only_args(target)]
|
||||||
|
|
||||||
|
@ -138,7 +147,7 @@ class ChangeDistiller:
|
||||||
|
|
||||||
return move_edits
|
return move_edits
|
||||||
|
|
||||||
def _compute_matching_set(self):
|
def _compute_matching_set(self) -> t.Set[t.Tuple[int, int]]:
|
||||||
leaves_matching_set = self._compute_leaf_matching_set()
|
leaves_matching_set = self._compute_leaf_matching_set()
|
||||||
matching_set = leaves_matching_set.copy()
|
matching_set = leaves_matching_set.copy()
|
||||||
|
|
||||||
|
@ -183,8 +192,8 @@ class ChangeDistiller:
|
||||||
|
|
||||||
return matching_set
|
return matching_set
|
||||||
|
|
||||||
def _compute_leaf_matching_set(self):
|
def _compute_leaf_matching_set(self) -> t.Set[t.Tuple[int, int]]:
|
||||||
candidate_matchings = []
|
candidate_matchings: t.List[t.Tuple[float, int, exp.Expression, exp.Expression]] = []
|
||||||
source_leaves = list(_get_leaves(self._source))
|
source_leaves = list(_get_leaves(self._source))
|
||||||
target_leaves = list(_get_leaves(self._target))
|
target_leaves = list(_get_leaves(self._target))
|
||||||
for source_leaf in source_leaves:
|
for source_leaf in source_leaves:
|
||||||
|
@ -216,7 +225,7 @@ class ChangeDistiller:
|
||||||
|
|
||||||
return matching_set
|
return matching_set
|
||||||
|
|
||||||
def _dice_coefficient(self, source, target):
|
def _dice_coefficient(self, source: exp.Expression, target: exp.Expression) -> float:
|
||||||
source_histo = self._bigram_histo(source)
|
source_histo = self._bigram_histo(source)
|
||||||
target_histo = self._bigram_histo(target)
|
target_histo = self._bigram_histo(target)
|
||||||
|
|
||||||
|
@ -231,13 +240,13 @@ class ChangeDistiller:
|
||||||
|
|
||||||
return 2 * overlap_len / total_grams
|
return 2 * overlap_len / total_grams
|
||||||
|
|
||||||
def _bigram_histo(self, expression):
|
def _bigram_histo(self, expression: exp.Expression) -> t.DefaultDict[str, int]:
|
||||||
if id(expression) in self._bigram_histo_cache:
|
if id(expression) in self._bigram_histo_cache:
|
||||||
return self._bigram_histo_cache[id(expression)]
|
return self._bigram_histo_cache[id(expression)]
|
||||||
|
|
||||||
expression_str = self._sql_generator.generate(expression)
|
expression_str = self._sql_generator.generate(expression)
|
||||||
count = max(0, len(expression_str) - 1)
|
count = max(0, len(expression_str) - 1)
|
||||||
bigram_histo = defaultdict(int)
|
bigram_histo: t.DefaultDict[str, int] = defaultdict(int)
|
||||||
for i in range(count):
|
for i in range(count):
|
||||||
bigram_histo[expression_str[i : i + 2]] += 1
|
bigram_histo[expression_str[i : i + 2]] += 1
|
||||||
|
|
||||||
|
@ -245,7 +254,7 @@ class ChangeDistiller:
|
||||||
return bigram_histo
|
return bigram_histo
|
||||||
|
|
||||||
|
|
||||||
def _get_leaves(expression):
|
def _get_leaves(expression: exp.Expression) -> t.Generator[exp.Expression, None, None]:
|
||||||
has_child_exprs = False
|
has_child_exprs = False
|
||||||
|
|
||||||
for a in expression.args.values():
|
for a in expression.args.values():
|
||||||
|
@ -258,7 +267,7 @@ def _get_leaves(expression):
|
||||||
yield expression
|
yield expression
|
||||||
|
|
||||||
|
|
||||||
def _is_same_type(source, target):
|
def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool:
|
||||||
if type(source) is type(target):
|
if type(source) is type(target):
|
||||||
if isinstance(source, exp.Join):
|
if isinstance(source, exp.Join):
|
||||||
return source.args.get("side") == target.args.get("side")
|
return source.args.get("side") == target.args.get("side")
|
||||||
|
@ -271,15 +280,17 @@ def _is_same_type(source, target):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _expression_only_args(expression):
|
def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]:
|
||||||
args = []
|
args: t.List[t.Union[exp.Expression, t.List]] = []
|
||||||
if expression:
|
if expression:
|
||||||
for a in expression.args.values():
|
for a in expression.args.values():
|
||||||
args.extend(ensure_collection(a))
|
args.extend(ensure_collection(a))
|
||||||
return [a for a in args if isinstance(a, exp.Expression)]
|
return [a for a in args if isinstance(a, exp.Expression)]
|
||||||
|
|
||||||
|
|
||||||
def _lcs(seq_a, seq_b, equal):
|
def _lcs(
|
||||||
|
seq_a: t.Sequence[T], seq_b: t.Sequence[T], equal: t.Callable[[T, T], bool]
|
||||||
|
) -> t.Sequence[t.Optional[T]]:
|
||||||
"""Calculates the longest common subsequence"""
|
"""Calculates the longest common subsequence"""
|
||||||
|
|
||||||
len_a = len(seq_a)
|
len_a = len(seq_a)
|
||||||
|
@ -289,14 +300,14 @@ def _lcs(seq_a, seq_b, equal):
|
||||||
for i in range(len_a + 1):
|
for i in range(len_a + 1):
|
||||||
for j in range(len_b + 1):
|
for j in range(len_b + 1):
|
||||||
if i == 0 or j == 0:
|
if i == 0 or j == 0:
|
||||||
lcs_result[i][j] = []
|
lcs_result[i][j] = [] # type: ignore
|
||||||
elif equal(seq_a[i - 1], seq_b[j - 1]):
|
elif equal(seq_a[i - 1], seq_b[j - 1]):
|
||||||
lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]]
|
lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]] # type: ignore
|
||||||
else:
|
else:
|
||||||
lcs_result[i][j] = (
|
lcs_result[i][j] = (
|
||||||
lcs_result[i - 1][j]
|
lcs_result[i - 1][j]
|
||||||
if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1])
|
if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1]) # type: ignore
|
||||||
else lcs_result[i][j - 1]
|
else lcs_result[i][j - 1]
|
||||||
)
|
)
|
||||||
|
|
||||||
return lcs_result[len_a][len_b]
|
return lcs_result[len_a][len_b] # type: ignore
|
||||||
|
|
|
@ -37,6 +37,10 @@ class SchemaError(SqlglotError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ExecuteError(SqlglotError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str:
|
def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str:
|
||||||
msg = [str(e) for e in errors[:maximum]]
|
msg = [str(e) for e in errors[:maximum]]
|
||||||
remaining = len(errors) - maximum
|
remaining = len(errors) - maximum
|
||||||
|
|
|
@ -1,20 +1,23 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from sqlglot import parse_one
|
from sqlglot import maybe_parse
|
||||||
|
from sqlglot.errors import ExecuteError
|
||||||
from sqlglot.executor.python import PythonExecutor
|
from sqlglot.executor.python import PythonExecutor
|
||||||
|
from sqlglot.executor.table import Table, ensure_tables
|
||||||
from sqlglot.optimizer import optimize
|
from sqlglot.optimizer import optimize
|
||||||
from sqlglot.planner import Plan
|
from sqlglot.planner import Plan
|
||||||
|
from sqlglot.schema import ensure_schema
|
||||||
|
|
||||||
logger = logging.getLogger("sqlglot")
|
logger = logging.getLogger("sqlglot")
|
||||||
|
|
||||||
|
|
||||||
def execute(sql, schema, read=None):
|
def execute(sql, schema=None, read=None, tables=None):
|
||||||
"""
|
"""
|
||||||
Run a sql query against data.
|
Run a sql query against data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sql (str): a sql statement
|
sql (str|sqlglot.Expression): a sql statement
|
||||||
schema (dict|sqlglot.optimizer.Schema): database schema.
|
schema (dict|sqlglot.optimizer.Schema): database schema.
|
||||||
This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
|
This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
|
||||||
the following forms:
|
the following forms:
|
||||||
|
@ -23,10 +26,20 @@ def execute(sql, schema, read=None):
|
||||||
3. {catalog: {db: {table: {col: type}}}}
|
3. {catalog: {db: {table: {col: type}}}}
|
||||||
read (str): the SQL dialect to apply during parsing
|
read (str): the SQL dialect to apply during parsing
|
||||||
(eg. "spark", "hive", "presto", "mysql").
|
(eg. "spark", "hive", "presto", "mysql").
|
||||||
|
tables (dict): additional tables to register.
|
||||||
Returns:
|
Returns:
|
||||||
sqlglot.executor.Table: Simple columnar data structure.
|
sqlglot.executor.Table: Simple columnar data structure.
|
||||||
"""
|
"""
|
||||||
expression = parse_one(sql, read=read)
|
tables = ensure_tables(tables)
|
||||||
|
if not schema:
|
||||||
|
schema = {
|
||||||
|
name: {column: type(table[0][column]).__name__ for column in table.columns}
|
||||||
|
for name, table in tables.mapping.items()
|
||||||
|
}
|
||||||
|
schema = ensure_schema(schema)
|
||||||
|
if tables.supported_table_args and tables.supported_table_args != schema.supported_table_args:
|
||||||
|
raise ExecuteError("Tables must support the same table args as schema")
|
||||||
|
expression = maybe_parse(sql, dialect=read)
|
||||||
now = time.time()
|
now = time.time()
|
||||||
expression = optimize(expression, schema, leave_tables_isolated=True)
|
expression = optimize(expression, schema, leave_tables_isolated=True)
|
||||||
logger.debug("Optimization finished: %f", time.time() - now)
|
logger.debug("Optimization finished: %f", time.time() - now)
|
||||||
|
@ -34,6 +47,6 @@ def execute(sql, schema, read=None):
|
||||||
plan = Plan(expression)
|
plan = Plan(expression)
|
||||||
logger.debug("Logical Plan: %s", plan)
|
logger.debug("Logical Plan: %s", plan)
|
||||||
now = time.time()
|
now = time.time()
|
||||||
result = PythonExecutor().execute(plan)
|
result = PythonExecutor(tables=tables).execute(plan)
|
||||||
logger.debug("Query finished: %f", time.time() - now)
|
logger.debug("Query finished: %f", time.time() - now)
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -1,5 +1,12 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing as t
|
||||||
|
|
||||||
from sqlglot.executor.env import ENV
|
from sqlglot.executor.env import ENV
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
from sqlglot.executor.table import Table, TableIter
|
||||||
|
|
||||||
|
|
||||||
class Context:
|
class Context:
|
||||||
"""
|
"""
|
||||||
|
@ -12,14 +19,14 @@ class Context:
|
||||||
evaluation of aggregation functions.
|
evaluation of aggregation functions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tables, env=None):
|
def __init__(self, tables: t.Dict[str, Table], env: t.Optional[t.Dict] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Args
|
Args
|
||||||
tables (dict): table_name -> Table, representing the scope of the current execution context
|
tables: representing the scope of the current execution context.
|
||||||
env (Optional[dict]): dictionary of functions within the execution context
|
env: dictionary of functions within the execution context.
|
||||||
"""
|
"""
|
||||||
self.tables = tables
|
self.tables = tables
|
||||||
self._table = None
|
self._table: t.Optional[Table] = None
|
||||||
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
|
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
|
||||||
self.row_readers = {name: table.reader for name, table in tables.items()}
|
self.row_readers = {name: table.reader for name, table in tables.items()}
|
||||||
self.env = {**(env or {}), "scope": self.row_readers}
|
self.env = {**(env or {}), "scope": self.row_readers}
|
||||||
|
@ -31,7 +38,7 @@ class Context:
|
||||||
return tuple(self.eval(code) for code in codes)
|
return tuple(self.eval(code) for code in codes)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def table(self):
|
def table(self) -> Table:
|
||||||
if self._table is None:
|
if self._table is None:
|
||||||
self._table = list(self.tables.values())[0]
|
self._table = list(self.tables.values())[0]
|
||||||
for other in self.tables.values():
|
for other in self.tables.values():
|
||||||
|
@ -41,8 +48,12 @@ class Context:
|
||||||
raise Exception(f"Rows are different.")
|
raise Exception(f"Rows are different.")
|
||||||
return self._table
|
return self._table
|
||||||
|
|
||||||
|
def add_columns(self, *columns: str) -> None:
|
||||||
|
for table in self.tables.values():
|
||||||
|
table.add_columns(*columns)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def columns(self):
|
def columns(self) -> t.Tuple:
|
||||||
return self.table.columns
|
return self.table.columns
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
@ -52,35 +63,39 @@ class Context:
|
||||||
reader = table[i]
|
reader = table[i]
|
||||||
yield reader, self
|
yield reader, self
|
||||||
|
|
||||||
def table_iter(self, table):
|
def table_iter(self, table: str) -> t.Generator[t.Tuple[TableIter, Context], None, None]:
|
||||||
self.env["scope"] = self.row_readers
|
self.env["scope"] = self.row_readers
|
||||||
|
|
||||||
for reader in self.tables[table]:
|
for reader in self.tables[table]:
|
||||||
yield reader, self
|
yield reader, self
|
||||||
|
|
||||||
def sort(self, key):
|
def filter(self, condition) -> None:
|
||||||
table = self.table
|
rows = [reader.row for reader, _ in self if self.eval(condition)]
|
||||||
|
|
||||||
def sort_key(row):
|
for table in self.tables.values():
|
||||||
table.reader.row = row
|
table.rows = rows
|
||||||
|
|
||||||
|
def sort(self, key) -> None:
|
||||||
|
def sort_key(row: t.Tuple) -> t.Tuple:
|
||||||
|
self.set_row(row)
|
||||||
return self.eval_tuple(key)
|
return self.eval_tuple(key)
|
||||||
|
|
||||||
table.rows.sort(key=sort_key)
|
self.table.rows.sort(key=sort_key)
|
||||||
|
|
||||||
def set_row(self, row):
|
def set_row(self, row: t.Tuple) -> None:
|
||||||
for table in self.tables.values():
|
for table in self.tables.values():
|
||||||
table.reader.row = row
|
table.reader.row = row
|
||||||
self.env["scope"] = self.row_readers
|
self.env["scope"] = self.row_readers
|
||||||
|
|
||||||
def set_index(self, index):
|
def set_index(self, index: int) -> None:
|
||||||
for table in self.tables.values():
|
for table in self.tables.values():
|
||||||
table[index]
|
table[index]
|
||||||
self.env["scope"] = self.row_readers
|
self.env["scope"] = self.row_readers
|
||||||
|
|
||||||
def set_range(self, start, end):
|
def set_range(self, start: int, end: int) -> None:
|
||||||
for name in self.tables:
|
for name in self.tables:
|
||||||
self.range_readers[name].range = range(start, end)
|
self.range_readers[name].range = range(start, end)
|
||||||
self.env["scope"] = self.range_readers
|
self.env["scope"] = self.range_readers
|
||||||
|
|
||||||
def __contains__(self, table):
|
def __contains__(self, table: str) -> bool:
|
||||||
return table in self.tables
|
return table in self.tables
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
import inspect
|
||||||
import re
|
import re
|
||||||
import statistics
|
import statistics
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from sqlglot import exp
|
||||||
from sqlglot.helper import PYTHON_VERSION
|
from sqlglot.helper import PYTHON_VERSION
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,20 +19,153 @@ class reverse_key:
|
||||||
return other.obj < self.obj
|
return other.obj < self.obj
|
||||||
|
|
||||||
|
|
||||||
|
def filter_nulls(func):
|
||||||
|
@wraps(func)
|
||||||
|
def _func(values):
|
||||||
|
return func(v for v in values if v is not None)
|
||||||
|
|
||||||
|
return _func
|
||||||
|
|
||||||
|
|
||||||
|
def null_if_any(*required):
|
||||||
|
"""
|
||||||
|
Decorator that makes a function return `None` if any of the `required` arguments are `None`.
|
||||||
|
|
||||||
|
This also supports decoration with no arguments, e.g.:
|
||||||
|
|
||||||
|
@null_if_any
|
||||||
|
def foo(a, b): ...
|
||||||
|
|
||||||
|
In which case all arguments are required.
|
||||||
|
"""
|
||||||
|
f = None
|
||||||
|
if len(required) == 1 and callable(required[0]):
|
||||||
|
f = required[0]
|
||||||
|
required = ()
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
if required:
|
||||||
|
required_indices = [
|
||||||
|
i for i, param in enumerate(inspect.signature(func).parameters) if param in required
|
||||||
|
]
|
||||||
|
|
||||||
|
def predicate(*args):
|
||||||
|
return any(args[i] is None for i in required_indices)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def predicate(*args):
|
||||||
|
return any(a is None for a in args)
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def _func(*args):
|
||||||
|
if predicate(*args):
|
||||||
|
return None
|
||||||
|
return func(*args)
|
||||||
|
|
||||||
|
return _func
|
||||||
|
|
||||||
|
if f:
|
||||||
|
return decorator(f)
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
@null_if_any("substr", "this")
|
||||||
|
def str_position(substr, this, position=None):
|
||||||
|
position = position - 1 if position is not None else position
|
||||||
|
return this.find(substr, position) + 1
|
||||||
|
|
||||||
|
|
||||||
|
@null_if_any("this")
|
||||||
|
def substring(this, start=None, length=None):
|
||||||
|
if start is None:
|
||||||
|
return this
|
||||||
|
elif start == 0:
|
||||||
|
return ""
|
||||||
|
elif start < 0:
|
||||||
|
start = len(this) + start
|
||||||
|
else:
|
||||||
|
start -= 1
|
||||||
|
|
||||||
|
end = None if length is None else start + length
|
||||||
|
|
||||||
|
return this[start:end]
|
||||||
|
|
||||||
|
|
||||||
|
@null_if_any
|
||||||
|
def cast(this, to):
|
||||||
|
if to == exp.DataType.Type.DATE:
|
||||||
|
return datetime.date.fromisoformat(this)
|
||||||
|
if to == exp.DataType.Type.DATETIME:
|
||||||
|
return datetime.datetime.fromisoformat(this)
|
||||||
|
if to in exp.DataType.TEXT_TYPES:
|
||||||
|
return str(this)
|
||||||
|
if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
|
||||||
|
return float(this)
|
||||||
|
if to in exp.DataType.NUMERIC_TYPES:
|
||||||
|
return int(this)
|
||||||
|
raise NotImplementedError(f"Casting to '{to}' not implemented.")
|
||||||
|
|
||||||
|
|
||||||
|
def ordered(this, desc, nulls_first):
|
||||||
|
if desc:
|
||||||
|
return reverse_key(this)
|
||||||
|
return this
|
||||||
|
|
||||||
|
|
||||||
|
@null_if_any
|
||||||
|
def interval(this, unit):
|
||||||
|
if unit == "DAY":
|
||||||
|
return datetime.timedelta(days=float(this))
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
ENV = {
|
ENV = {
|
||||||
"__builtins__": {},
|
"__builtins__": {},
|
||||||
"datetime": datetime,
|
"exp": exp,
|
||||||
"locals": locals,
|
# aggs
|
||||||
"re": re,
|
"SUM": filter_nulls(sum),
|
||||||
"bool": bool,
|
"AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
|
||||||
"float": float,
|
"COUNT": filter_nulls(lambda acc: sum(1 for _ in acc)),
|
||||||
"int": int,
|
"MAX": filter_nulls(max),
|
||||||
"str": str,
|
"MIN": filter_nulls(min),
|
||||||
"desc": reverse_key,
|
# scalar functions
|
||||||
"SUM": sum,
|
"ABS": null_if_any(lambda this: abs(this)),
|
||||||
"AVG": statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean, # type: ignore
|
"ADD": null_if_any(lambda e, this: e + this),
|
||||||
"COUNT": lambda acc: sum(1 for e in acc if e is not None),
|
"BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
|
||||||
"MAX": max,
|
"BITWISEAND": null_if_any(lambda this, e: this & e),
|
||||||
"MIN": min,
|
"BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
|
||||||
|
"BITWISEOR": null_if_any(lambda this, e: this | e),
|
||||||
|
"BITWISERIGHTSHIFT": null_if_any(lambda this, e: this >> e),
|
||||||
|
"BITWISEXOR": null_if_any(lambda this, e: this ^ e),
|
||||||
|
"CAST": cast,
|
||||||
|
"COALESCE": lambda *args: next((a for a in args if a is not None), None),
|
||||||
|
"CONCAT": null_if_any(lambda *args: "".join(args)),
|
||||||
|
"CONCATWS": null_if_any(lambda this, *args: this.join(args)),
|
||||||
|
"DIV": null_if_any(lambda e, this: e / this),
|
||||||
|
"EQ": null_if_any(lambda this, e: this == e),
|
||||||
|
"EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
|
||||||
|
"GT": null_if_any(lambda this, e: this > e),
|
||||||
|
"GTE": null_if_any(lambda this, e: this >= e),
|
||||||
|
"IFNULL": lambda e, alt: alt if e is None else e,
|
||||||
|
"IF": lambda predicate, true, false: true if predicate else false,
|
||||||
|
"INTDIV": null_if_any(lambda e, this: e // this),
|
||||||
|
"INTERVAL": interval,
|
||||||
|
"LIKE": null_if_any(
|
||||||
|
lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this))
|
||||||
|
),
|
||||||
|
"LOWER": null_if_any(lambda arg: arg.lower()),
|
||||||
|
"LT": null_if_any(lambda this, e: this < e),
|
||||||
|
"LTE": null_if_any(lambda this, e: this <= e),
|
||||||
|
"MOD": null_if_any(lambda e, this: e % this),
|
||||||
|
"MUL": null_if_any(lambda e, this: e * this),
|
||||||
|
"NEQ": null_if_any(lambda this, e: this != e),
|
||||||
|
"ORD": null_if_any(ord),
|
||||||
|
"ORDERED": ordered,
|
||||||
"POW": pow,
|
"POW": pow,
|
||||||
|
"STRPOSITION": str_position,
|
||||||
|
"SUB": null_if_any(lambda e, this: e - this),
|
||||||
|
"SUBSTRING": substring,
|
||||||
|
"UPPER": null_if_any(lambda arg: arg.upper()),
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,16 +5,18 @@ import math
|
||||||
|
|
||||||
from sqlglot import exp, generator, planner, tokens
|
from sqlglot import exp, generator, planner, tokens
|
||||||
from sqlglot.dialects.dialect import Dialect, inline_array_sql
|
from sqlglot.dialects.dialect import Dialect, inline_array_sql
|
||||||
|
from sqlglot.errors import ExecuteError
|
||||||
from sqlglot.executor.context import Context
|
from sqlglot.executor.context import Context
|
||||||
from sqlglot.executor.env import ENV
|
from sqlglot.executor.env import ENV
|
||||||
from sqlglot.executor.table import Table
|
from sqlglot.executor.table import RowReader, Table
|
||||||
from sqlglot.helper import csv_reader
|
from sqlglot.helper import csv_reader, subclasses
|
||||||
|
|
||||||
|
|
||||||
class PythonExecutor:
|
class PythonExecutor:
|
||||||
def __init__(self, env=None):
|
def __init__(self, env=None, tables=None):
|
||||||
self.generator = Python().generator(identify=True)
|
self.generator = Python().generator(identify=True, comments=False)
|
||||||
self.env = {**ENV, **(env or {})}
|
self.env = {**ENV, **(env or {})}
|
||||||
|
self.tables = tables or {}
|
||||||
|
|
||||||
def execute(self, plan):
|
def execute(self, plan):
|
||||||
running = set()
|
running = set()
|
||||||
|
@ -24,6 +26,7 @@ class PythonExecutor:
|
||||||
|
|
||||||
while queue:
|
while queue:
|
||||||
node = queue.pop()
|
node = queue.pop()
|
||||||
|
try:
|
||||||
context = self.context(
|
context = self.context(
|
||||||
{
|
{
|
||||||
name: table
|
name: table
|
||||||
|
@ -41,6 +44,8 @@ class PythonExecutor:
|
||||||
contexts[node] = self.join(node, context)
|
contexts[node] = self.join(node, context)
|
||||||
elif isinstance(node, planner.Sort):
|
elif isinstance(node, planner.Sort):
|
||||||
contexts[node] = self.sort(node, context)
|
contexts[node] = self.sort(node, context)
|
||||||
|
elif isinstance(node, planner.SetOperation):
|
||||||
|
contexts[node] = self.set_operation(node, context)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -54,6 +59,8 @@ class PythonExecutor:
|
||||||
for dep in node.dependencies:
|
for dep in node.dependencies:
|
||||||
if all(d in finished for d in dep.dependents):
|
if all(d in finished for d in dep.dependents):
|
||||||
contexts.pop(dep)
|
contexts.pop(dep)
|
||||||
|
except Exception as e:
|
||||||
|
raise ExecuteError(f"Step '{node.id}' failed: {e}") from e
|
||||||
|
|
||||||
root = plan.root
|
root = plan.root
|
||||||
return contexts[root].tables[root.name]
|
return contexts[root].tables[root.name]
|
||||||
|
@ -76,38 +83,43 @@ class PythonExecutor:
|
||||||
return Context(tables, env=self.env)
|
return Context(tables, env=self.env)
|
||||||
|
|
||||||
def table(self, expressions):
|
def table(self, expressions):
|
||||||
return Table(expression.alias_or_name for expression in expressions)
|
return Table(
|
||||||
|
expression.alias_or_name if isinstance(expression, exp.Expression) else expression
|
||||||
|
for expression in expressions
|
||||||
|
)
|
||||||
|
|
||||||
def scan(self, step, context):
|
def scan(self, step, context):
|
||||||
source = step.source
|
source = step.source
|
||||||
|
|
||||||
if isinstance(source, exp.Expression):
|
if source and isinstance(source, exp.Expression):
|
||||||
source = source.name or source.alias
|
source = source.name or source.alias
|
||||||
|
|
||||||
condition = self.generate(step.condition)
|
condition = self.generate(step.condition)
|
||||||
projections = self.generate_tuple(step.projections)
|
projections = self.generate_tuple(step.projections)
|
||||||
|
|
||||||
if source in context:
|
if source is None:
|
||||||
|
context, table_iter = self.static()
|
||||||
|
elif source in context:
|
||||||
if not projections and not condition:
|
if not projections and not condition:
|
||||||
return self.context({step.name: context.tables[source]})
|
return self.context({step.name: context.tables[source]})
|
||||||
table_iter = context.table_iter(source)
|
table_iter = context.table_iter(source)
|
||||||
else:
|
elif isinstance(step.source, exp.Table) and isinstance(step.source.this, exp.ReadCSV):
|
||||||
table_iter = self.scan_csv(step)
|
table_iter = self.scan_csv(step)
|
||||||
|
context = next(table_iter)
|
||||||
|
else:
|
||||||
|
context, table_iter = self.scan_table(step)
|
||||||
|
|
||||||
if projections:
|
if projections:
|
||||||
sink = self.table(step.projections)
|
sink = self.table(step.projections)
|
||||||
else:
|
else:
|
||||||
sink = None
|
sink = self.table(context.columns)
|
||||||
|
|
||||||
for reader, ctx in table_iter:
|
for reader in table_iter:
|
||||||
if sink is None:
|
if condition and not context.eval(condition):
|
||||||
sink = Table(reader.columns)
|
|
||||||
|
|
||||||
if condition and not ctx.eval(condition):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if projections:
|
if projections:
|
||||||
sink.append(ctx.eval_tuple(projections))
|
sink.append(context.eval_tuple(projections))
|
||||||
else:
|
else:
|
||||||
sink.append(reader.row)
|
sink.append(reader.row)
|
||||||
|
|
||||||
|
@ -116,14 +128,23 @@ class PythonExecutor:
|
||||||
|
|
||||||
return self.context({step.name: sink})
|
return self.context({step.name: sink})
|
||||||
|
|
||||||
|
def static(self):
|
||||||
|
return self.context({}), [RowReader(())]
|
||||||
|
|
||||||
|
def scan_table(self, step):
|
||||||
|
table = self.tables.find(step.source)
|
||||||
|
context = self.context({step.source.alias_or_name: table})
|
||||||
|
return context, iter(table)
|
||||||
|
|
||||||
def scan_csv(self, step):
|
def scan_csv(self, step):
|
||||||
source = step.source
|
alias = step.source.alias
|
||||||
alias = source.alias
|
source = step.source.this
|
||||||
|
|
||||||
with csv_reader(source) as reader:
|
with csv_reader(source) as reader:
|
||||||
columns = next(reader)
|
columns = next(reader)
|
||||||
table = Table(columns)
|
table = Table(columns)
|
||||||
context = self.context({alias: table})
|
context = self.context({alias: table})
|
||||||
|
yield context
|
||||||
types = []
|
types = []
|
||||||
|
|
||||||
for row in reader:
|
for row in reader:
|
||||||
|
@ -134,7 +155,7 @@ class PythonExecutor:
|
||||||
except (ValueError, SyntaxError):
|
except (ValueError, SyntaxError):
|
||||||
types.append(str)
|
types.append(str)
|
||||||
context.set_row(tuple(t(v) for t, v in zip(types, row)))
|
context.set_row(tuple(t(v) for t, v in zip(types, row)))
|
||||||
yield context.table.reader, context
|
yield context.table.reader
|
||||||
|
|
||||||
def join(self, step, context):
|
def join(self, step, context):
|
||||||
source = step.name
|
source = step.name
|
||||||
|
@ -160,16 +181,19 @@ class PythonExecutor:
|
||||||
for name, column_range in column_ranges.items()
|
for name, column_range in column_ranges.items()
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
condition = self.generate(join["condition"])
|
||||||
|
if condition:
|
||||||
|
source_context.filter(condition)
|
||||||
|
|
||||||
condition = self.generate(step.condition)
|
condition = self.generate(step.condition)
|
||||||
projections = self.generate_tuple(step.projections)
|
projections = self.generate_tuple(step.projections)
|
||||||
|
|
||||||
if not condition or not projections:
|
if not condition and not projections:
|
||||||
return source_context
|
return source_context
|
||||||
|
|
||||||
sink = self.table(step.projections if projections else source_context.columns)
|
sink = self.table(step.projections if projections else source_context.columns)
|
||||||
|
|
||||||
for reader, ctx in join_context:
|
for reader, ctx in source_context:
|
||||||
if condition and not ctx.eval(condition):
|
if condition and not ctx.eval(condition):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -181,7 +205,15 @@ class PythonExecutor:
|
||||||
if len(sink) >= step.limit:
|
if len(sink) >= step.limit:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if projections:
|
||||||
return self.context({step.name: sink})
|
return self.context({step.name: sink})
|
||||||
|
else:
|
||||||
|
return self.context(
|
||||||
|
{
|
||||||
|
name: Table(table.columns, sink.rows, table.column_range)
|
||||||
|
for name, table in source_context.tables.items()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def nested_loop_join(self, _join, source_context, join_context):
|
def nested_loop_join(self, _join, source_context, join_context):
|
||||||
table = Table(source_context.columns + join_context.columns)
|
table = Table(source_context.columns + join_context.columns)
|
||||||
|
@ -195,6 +227,8 @@ class PythonExecutor:
|
||||||
def hash_join(self, join, source_context, join_context):
|
def hash_join(self, join, source_context, join_context):
|
||||||
source_key = self.generate_tuple(join["source_key"])
|
source_key = self.generate_tuple(join["source_key"])
|
||||||
join_key = self.generate_tuple(join["join_key"])
|
join_key = self.generate_tuple(join["join_key"])
|
||||||
|
left = join.get("side") == "LEFT"
|
||||||
|
right = join.get("side") == "RIGHT"
|
||||||
|
|
||||||
results = collections.defaultdict(lambda: ([], []))
|
results = collections.defaultdict(lambda: ([], []))
|
||||||
|
|
||||||
|
@ -204,28 +238,47 @@ class PythonExecutor:
|
||||||
results[ctx.eval_tuple(join_key)][1].append(reader.row)
|
results[ctx.eval_tuple(join_key)][1].append(reader.row)
|
||||||
|
|
||||||
table = Table(source_context.columns + join_context.columns)
|
table = Table(source_context.columns + join_context.columns)
|
||||||
|
nulls = [(None,) * len(join_context.columns if left else source_context.columns)]
|
||||||
|
|
||||||
for a_group, b_group in results.values():
|
for a_group, b_group in results.values():
|
||||||
|
if left:
|
||||||
|
b_group = b_group or nulls
|
||||||
|
elif right:
|
||||||
|
a_group = a_group or nulls
|
||||||
|
|
||||||
for a_row, b_row in itertools.product(a_group, b_group):
|
for a_row, b_row in itertools.product(a_group, b_group):
|
||||||
table.append(a_row + b_row)
|
table.append(a_row + b_row)
|
||||||
|
|
||||||
return table
|
return table
|
||||||
|
|
||||||
def aggregate(self, step, context):
|
def aggregate(self, step, context):
|
||||||
source = step.source
|
group_by = self.generate_tuple(step.group.values())
|
||||||
group_by = self.generate_tuple(step.group)
|
|
||||||
aggregations = self.generate_tuple(step.aggregations)
|
aggregations = self.generate_tuple(step.aggregations)
|
||||||
operands = self.generate_tuple(step.operands)
|
operands = self.generate_tuple(step.operands)
|
||||||
|
|
||||||
if operands:
|
if operands:
|
||||||
source_table = context.tables[source]
|
operand_table = Table(self.table(step.operands).columns)
|
||||||
operand_table = Table(source_table.columns + self.table(step.operands).columns)
|
|
||||||
|
|
||||||
for reader, ctx in context:
|
for reader, ctx in context:
|
||||||
operand_table.append(reader.row + ctx.eval_tuple(operands))
|
operand_table.append(ctx.eval_tuple(operands))
|
||||||
|
|
||||||
|
for i, (a, b) in enumerate(zip(context.table.rows, operand_table.rows)):
|
||||||
|
context.table.rows[i] = a + b
|
||||||
|
|
||||||
|
width = len(context.columns)
|
||||||
|
context.add_columns(*operand_table.columns)
|
||||||
|
|
||||||
|
operand_table = Table(
|
||||||
|
context.columns,
|
||||||
|
context.table.rows,
|
||||||
|
range(width, width + len(operand_table.columns)),
|
||||||
|
)
|
||||||
|
|
||||||
context = self.context(
|
context = self.context(
|
||||||
{None: operand_table, **{table: operand_table for table in context.tables}}
|
{
|
||||||
|
None: operand_table,
|
||||||
|
**context.tables,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
context.sort(group_by)
|
context.sort(group_by)
|
||||||
|
@ -233,25 +286,22 @@ class PythonExecutor:
|
||||||
group = None
|
group = None
|
||||||
start = 0
|
start = 0
|
||||||
end = 1
|
end = 1
|
||||||
length = len(context.tables[source])
|
length = len(context.table)
|
||||||
table = self.table(step.group + step.aggregations)
|
table = self.table(list(step.group) + step.aggregations)
|
||||||
|
|
||||||
for i in range(length):
|
for i in range(length):
|
||||||
context.set_index(i)
|
context.set_index(i)
|
||||||
key = context.eval_tuple(group_by)
|
key = context.eval_tuple(group_by)
|
||||||
group = key if group is None else group
|
group = key if group is None else group
|
||||||
end += 1
|
end += 1
|
||||||
|
if key != group:
|
||||||
if i == length - 1:
|
|
||||||
context.set_range(start, end - 1)
|
|
||||||
elif key != group:
|
|
||||||
context.set_range(start, end - 2)
|
context.set_range(start, end - 2)
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
table.append(group + context.eval_tuple(aggregations))
|
table.append(group + context.eval_tuple(aggregations))
|
||||||
group = key
|
group = key
|
||||||
start = end - 2
|
start = end - 2
|
||||||
|
if i == length - 1:
|
||||||
|
context.set_range(start, end - 1)
|
||||||
|
table.append(group + context.eval_tuple(aggregations))
|
||||||
|
|
||||||
context = self.context({step.name: table, **{name: table for name in context.tables}})
|
context = self.context({step.name: table, **{name: table for name in context.tables}})
|
||||||
|
|
||||||
|
@ -262,87 +312,67 @@ class PythonExecutor:
|
||||||
def sort(self, step, context):
|
def sort(self, step, context):
|
||||||
projections = self.generate_tuple(step.projections)
|
projections = self.generate_tuple(step.projections)
|
||||||
|
|
||||||
sink = self.table(step.projections)
|
projection_columns = [p.alias_or_name for p in step.projections]
|
||||||
|
all_columns = list(context.columns) + projection_columns
|
||||||
|
sink = self.table(all_columns)
|
||||||
|
|
||||||
for reader, ctx in context:
|
for reader, ctx in context:
|
||||||
sink.append(ctx.eval_tuple(projections))
|
sink.append(reader.row + ctx.eval_tuple(projections))
|
||||||
|
|
||||||
context = self.context(
|
sort_ctx = self.context(
|
||||||
{
|
{
|
||||||
None: sink,
|
None: sink,
|
||||||
**{table: sink for table in context.tables},
|
**{table: sink for table in context.tables},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
context.sort(self.generate_tuple(step.key))
|
sort_ctx.sort(self.generate_tuple(step.key))
|
||||||
|
|
||||||
if not math.isinf(step.limit):
|
if not math.isinf(step.limit):
|
||||||
context.table.rows = context.table.rows[0 : step.limit]
|
sort_ctx.table.rows = sort_ctx.table.rows[0 : step.limit]
|
||||||
|
|
||||||
return self.context({step.name: context.table})
|
output = Table(
|
||||||
|
projection_columns,
|
||||||
|
rows=[r[len(context.columns) : len(all_columns)] for r in sort_ctx.table.rows],
|
||||||
|
)
|
||||||
|
return self.context({step.name: output})
|
||||||
|
|
||||||
|
def set_operation(self, step, context):
|
||||||
|
left = context.tables[step.left]
|
||||||
|
right = context.tables[step.right]
|
||||||
|
|
||||||
def _cast_py(self, expression):
|
sink = self.table(left.columns)
|
||||||
to = expression.args["to"].this
|
|
||||||
this = self.sql(expression, "this")
|
|
||||||
|
|
||||||
if to == exp.DataType.Type.DATE:
|
if issubclass(step.op, exp.Intersect):
|
||||||
return f"datetime.date.fromisoformat({this})"
|
sink.rows = list(set(left.rows).intersection(set(right.rows)))
|
||||||
if to == exp.DataType.Type.TEXT:
|
elif issubclass(step.op, exp.Except):
|
||||||
return f"str({this})"
|
sink.rows = list(set(left.rows).difference(set(right.rows)))
|
||||||
raise NotImplementedError
|
elif issubclass(step.op, exp.Union) and step.distinct:
|
||||||
|
sink.rows = list(set(left.rows).union(set(right.rows)))
|
||||||
|
else:
|
||||||
|
sink.rows = left.rows + right.rows
|
||||||
|
|
||||||
|
return self.context({step.name: sink})
|
||||||
def _column_py(self, expression):
|
|
||||||
table = self.sql(expression, "table") or None
|
|
||||||
this = self.sql(expression, "this")
|
|
||||||
return f"scope[{table}][{this}]"
|
|
||||||
|
|
||||||
|
|
||||||
def _interval_py(self, expression):
|
|
||||||
this = self.sql(expression, "this")
|
|
||||||
unit = expression.text("unit").upper()
|
|
||||||
if unit == "DAY":
|
|
||||||
return f"datetime.timedelta(days=float({this}))"
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
def _like_py(self, expression):
|
|
||||||
this = self.sql(expression, "this")
|
|
||||||
expression = self.sql(expression, "expression")
|
|
||||||
return f"""bool(re.match({expression}.replace("_", ".").replace("%", ".*"), {this}))"""
|
|
||||||
|
|
||||||
|
|
||||||
def _ordered_py(self, expression):
|
def _ordered_py(self, expression):
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
desc = expression.args.get("desc")
|
desc = "True" if expression.args.get("desc") else "False"
|
||||||
return f"desc({this})" if desc else this
|
nulls_first = "True" if expression.args.get("nulls_first") else "False"
|
||||||
|
return f"ORDERED({this}, {desc}, {nulls_first})"
|
||||||
|
|
||||||
|
|
||||||
class Python(Dialect):
|
def _rename(self, e):
|
||||||
class Tokenizer(tokens.Tokenizer):
|
try:
|
||||||
ESCAPES = ["\\"]
|
if "expressions" in e.args:
|
||||||
|
this = self.sql(e, "this")
|
||||||
|
this = f"{this}, " if this else ""
|
||||||
|
return f"{e.key.upper()}({this}{self.expressions(e)})"
|
||||||
|
return f"{e.key.upper()}({self.format_args(*e.args.values())})"
|
||||||
|
except Exception as ex:
|
||||||
|
raise Exception(f"Could not rename {repr(e)}") from ex
|
||||||
|
|
||||||
class Generator(generator.Generator):
|
|
||||||
TRANSFORMS = {
|
|
||||||
exp.Alias: lambda self, e: self.sql(e.this),
|
|
||||||
exp.Array: inline_array_sql,
|
|
||||||
exp.And: lambda self, e: self.binary(e, "and"),
|
|
||||||
exp.Boolean: lambda self, e: "True" if e.this else "False",
|
|
||||||
exp.Cast: _cast_py,
|
|
||||||
exp.Column: _column_py,
|
|
||||||
exp.EQ: lambda self, e: self.binary(e, "=="),
|
|
||||||
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}",
|
|
||||||
exp.Interval: _interval_py,
|
|
||||||
exp.Is: lambda self, e: self.binary(e, "is"),
|
|
||||||
exp.Like: _like_py,
|
|
||||||
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
|
|
||||||
exp.Null: lambda *_: "None",
|
|
||||||
exp.Or: lambda self, e: self.binary(e, "or"),
|
|
||||||
exp.Ordered: _ordered_py,
|
|
||||||
exp.Star: lambda *_: "1",
|
|
||||||
}
|
|
||||||
|
|
||||||
def case_sql(self, expression):
|
def _case_sql(self, expression):
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
chain = self.sql(expression, "default") or "None"
|
chain = self.sql(expression, "default") or "None"
|
||||||
|
|
||||||
|
@ -353,3 +383,30 @@ class Python(Dialect):
|
||||||
chain = f"{true} if {condition} else ({chain})"
|
chain = f"{true} if {condition} else ({chain})"
|
||||||
|
|
||||||
return chain
|
return chain
|
||||||
|
|
||||||
|
|
||||||
|
class Python(Dialect):
|
||||||
|
class Tokenizer(tokens.Tokenizer):
|
||||||
|
ESCAPES = ["\\"]
|
||||||
|
|
||||||
|
class Generator(generator.Generator):
|
||||||
|
TRANSFORMS = {
|
||||||
|
**{klass: _rename for klass in subclasses(exp.__name__, exp.Binary)},
|
||||||
|
**{klass: _rename for klass in exp.ALL_FUNCTIONS},
|
||||||
|
exp.Case: _case_sql,
|
||||||
|
exp.Alias: lambda self, e: self.sql(e.this),
|
||||||
|
exp.Array: inline_array_sql,
|
||||||
|
exp.And: lambda self, e: self.binary(e, "and"),
|
||||||
|
exp.Between: _rename,
|
||||||
|
exp.Boolean: lambda self, e: "True" if e.this else "False",
|
||||||
|
exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
|
||||||
|
exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
|
||||||
|
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
|
||||||
|
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}",
|
||||||
|
exp.Is: lambda self, e: self.binary(e, "is"),
|
||||||
|
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
|
||||||
|
exp.Null: lambda *_: "None",
|
||||||
|
exp.Or: lambda self, e: self.binary(e, "or"),
|
||||||
|
exp.Ordered: _ordered_py,
|
||||||
|
exp.Star: lambda *_: "1",
|
||||||
|
}
|
||||||
|
|
|
@ -1,14 +1,27 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlglot.helper import dict_depth
|
||||||
|
from sqlglot.schema import AbstractMappingSchema
|
||||||
|
|
||||||
|
|
||||||
class Table:
|
class Table:
|
||||||
def __init__(self, columns, rows=None, column_range=None):
|
def __init__(self, columns, rows=None, column_range=None):
|
||||||
self.columns = tuple(columns)
|
self.columns = tuple(columns)
|
||||||
self.column_range = column_range
|
self.column_range = column_range
|
||||||
self.reader = RowReader(self.columns, self.column_range)
|
self.reader = RowReader(self.columns, self.column_range)
|
||||||
|
|
||||||
self.rows = rows or []
|
self.rows = rows or []
|
||||||
if rows:
|
if rows:
|
||||||
assert len(rows[0]) == len(self.columns)
|
assert len(rows[0]) == len(self.columns)
|
||||||
self.range_reader = RangeReader(self)
|
self.range_reader = RangeReader(self)
|
||||||
|
|
||||||
|
def add_columns(self, *columns: str) -> None:
|
||||||
|
self.columns += columns
|
||||||
|
if self.column_range:
|
||||||
|
self.column_range = range(
|
||||||
|
self.column_range.start, self.column_range.stop + len(columns)
|
||||||
|
)
|
||||||
|
self.reader = RowReader(self.columns, self.column_range)
|
||||||
|
|
||||||
def append(self, row):
|
def append(self, row):
|
||||||
assert len(row) == len(self.columns)
|
assert len(row) == len(self.columns)
|
||||||
self.rows.append(row)
|
self.rows.append(row)
|
||||||
|
@ -87,3 +100,31 @@ class RowReader:
|
||||||
|
|
||||||
def __getitem__(self, column):
|
def __getitem__(self, column):
|
||||||
return self.row[self.columns[column]]
|
return self.row[self.columns[column]]
|
||||||
|
|
||||||
|
|
||||||
|
class Tables(AbstractMappingSchema[Table]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_tables(d: dict | None) -> Tables:
|
||||||
|
return Tables(_ensure_tables(d))
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_tables(d: dict | None) -> dict:
|
||||||
|
if not d:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
depth = dict_depth(d)
|
||||||
|
|
||||||
|
if depth > 1:
|
||||||
|
return {k: _ensure_tables(v) for k, v in d.items()}
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for name, table in d.items():
|
||||||
|
if isinstance(table, Table):
|
||||||
|
result[name] = table
|
||||||
|
else:
|
||||||
|
columns = tuple(table[0]) if table else ()
|
||||||
|
rows = [tuple(row[c] for c in columns) for row in table]
|
||||||
|
result[name] = Table(columns=columns, rows=rows)
|
||||||
|
return result
|
||||||
|
|
|
@ -641,9 +641,11 @@ class Set(Expression):
|
||||||
|
|
||||||
class SetItem(Expression):
|
class SetItem(Expression):
|
||||||
arg_types = {
|
arg_types = {
|
||||||
"this": True,
|
"this": False,
|
||||||
|
"expressions": False,
|
||||||
"kind": False,
|
"kind": False,
|
||||||
"collate": False, # MySQL SET NAMES statement
|
"collate": False, # MySQL SET NAMES statement
|
||||||
|
"global": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -787,6 +789,7 @@ class Drop(Expression):
|
||||||
"exists": False,
|
"exists": False,
|
||||||
"temporary": False,
|
"temporary": False,
|
||||||
"materialized": False,
|
"materialized": False,
|
||||||
|
"cascade": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1073,6 +1076,18 @@ class FileFormatProperty(Property):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DistKeyProperty(Property):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SortKeyProperty(Property):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DistStyleProperty(Property):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LocationProperty(Property):
|
class LocationProperty(Property):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -1130,6 +1145,9 @@ class Properties(Expression):
|
||||||
"LOCATION": LocationProperty,
|
"LOCATION": LocationProperty,
|
||||||
"PARTITIONED_BY": PartitionedByProperty,
|
"PARTITIONED_BY": PartitionedByProperty,
|
||||||
"TABLE_FORMAT": TableFormatProperty,
|
"TABLE_FORMAT": TableFormatProperty,
|
||||||
|
"DISTKEY": DistKeyProperty,
|
||||||
|
"DISTSTYLE": DistStyleProperty,
|
||||||
|
"SORTKEY": SortKeyProperty,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1356,7 +1374,7 @@ class Var(Expression):
|
||||||
|
|
||||||
|
|
||||||
class Schema(Expression):
|
class Schema(Expression):
|
||||||
arg_types = {"this": False, "expressions": True}
|
arg_types = {"this": False, "expressions": False}
|
||||||
|
|
||||||
|
|
||||||
class Select(Subqueryable):
|
class Select(Subqueryable):
|
||||||
|
@ -1741,7 +1759,7 @@ class Select(Subqueryable):
|
||||||
)
|
)
|
||||||
|
|
||||||
if join_alias:
|
if join_alias:
|
||||||
join.set("this", alias_(join.args["this"], join_alias, table=True))
|
join.set("this", alias_(join.this, join_alias, table=True))
|
||||||
return _apply_list_builder(
|
return _apply_list_builder(
|
||||||
join,
|
join,
|
||||||
instance=self,
|
instance=self,
|
||||||
|
@ -1884,6 +1902,7 @@ class Subquery(DerivedTable, Unionable):
|
||||||
arg_types = {
|
arg_types = {
|
||||||
"this": True,
|
"this": True,
|
||||||
"alias": False,
|
"alias": False,
|
||||||
|
"with": False,
|
||||||
**QUERY_MODIFIERS,
|
**QUERY_MODIFIERS,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2025,6 +2044,31 @@ class DataType(Expression):
|
||||||
NULL = auto()
|
NULL = auto()
|
||||||
UNKNOWN = auto() # Sentinel value, useful for type annotation
|
UNKNOWN = auto() # Sentinel value, useful for type annotation
|
||||||
|
|
||||||
|
TEXT_TYPES = {
|
||||||
|
Type.CHAR,
|
||||||
|
Type.NCHAR,
|
||||||
|
Type.VARCHAR,
|
||||||
|
Type.NVARCHAR,
|
||||||
|
Type.TEXT,
|
||||||
|
}
|
||||||
|
|
||||||
|
NUMERIC_TYPES = {
|
||||||
|
Type.INT,
|
||||||
|
Type.TINYINT,
|
||||||
|
Type.SMALLINT,
|
||||||
|
Type.BIGINT,
|
||||||
|
Type.FLOAT,
|
||||||
|
Type.DOUBLE,
|
||||||
|
}
|
||||||
|
|
||||||
|
TEMPORAL_TYPES = {
|
||||||
|
Type.TIMESTAMP,
|
||||||
|
Type.TIMESTAMPTZ,
|
||||||
|
Type.TIMESTAMPLTZ,
|
||||||
|
Type.DATE,
|
||||||
|
Type.DATETIME,
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(cls, dtype, **kwargs) -> DataType:
|
def build(cls, dtype, **kwargs) -> DataType:
|
||||||
return DataType(
|
return DataType(
|
||||||
|
@ -2054,16 +2098,25 @@ class Exists(SubqueryPredicate):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Commands to interact with the databases or engines
|
# Commands to interact with the databases or engines. For most of the command
|
||||||
# These expressions don't truly parse the expression and consume
|
# expressions we parse whatever comes after the command's name as a string.
|
||||||
# whatever exists as a string until the end or a semicolon
|
|
||||||
class Command(Expression):
|
class Command(Expression):
|
||||||
arg_types = {"this": True, "expression": False}
|
arg_types = {"this": True, "expression": False}
|
||||||
|
|
||||||
|
|
||||||
# Binary Expressions
|
class Transaction(Command):
|
||||||
# (ADD a b)
|
arg_types = {"this": False, "modes": False}
|
||||||
# (FROM table selects)
|
|
||||||
|
|
||||||
|
class Commit(Command):
|
||||||
|
arg_types = {} # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class Rollback(Command):
|
||||||
|
arg_types = {"savepoint": False}
|
||||||
|
|
||||||
|
|
||||||
|
# Binary expressions like (ADD a b)
|
||||||
class Binary(Expression):
|
class Binary(Expression):
|
||||||
arg_types = {"this": True, "expression": True}
|
arg_types = {"this": True, "expression": True}
|
||||||
|
|
||||||
|
@ -2215,7 +2268,7 @@ class Not(Unary, Condition):
|
||||||
|
|
||||||
|
|
||||||
class Paren(Unary, Condition):
|
class Paren(Unary, Condition):
|
||||||
pass
|
arg_types = {"this": True, "with": False}
|
||||||
|
|
||||||
|
|
||||||
class Neg(Unary):
|
class Neg(Unary):
|
||||||
|
@ -2428,6 +2481,10 @@ class Cast(Func):
|
||||||
return self.args["to"]
|
return self.args["to"]
|
||||||
|
|
||||||
|
|
||||||
|
class Collate(Binary):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TryCast(Cast):
|
class TryCast(Cast):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -2442,13 +2499,17 @@ class Coalesce(Func):
|
||||||
is_var_len_args = True
|
is_var_len_args = True
|
||||||
|
|
||||||
|
|
||||||
class ConcatWs(Func):
|
class Concat(Func):
|
||||||
arg_types = {"expressions": False}
|
arg_types = {"expressions": True}
|
||||||
is_var_len_args = True
|
is_var_len_args = True
|
||||||
|
|
||||||
|
|
||||||
|
class ConcatWs(Concat):
|
||||||
|
_sql_names = ["CONCAT_WS"]
|
||||||
|
|
||||||
|
|
||||||
class Count(AggFunc):
|
class Count(AggFunc):
|
||||||
pass
|
arg_types = {"this": False}
|
||||||
|
|
||||||
|
|
||||||
class CurrentDate(Func):
|
class CurrentDate(Func):
|
||||||
|
@ -2556,10 +2617,18 @@ class Day(Func):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Decode(Func):
|
||||||
|
arg_types = {"this": True, "charset": True}
|
||||||
|
|
||||||
|
|
||||||
class DiToDate(Func):
|
class DiToDate(Func):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Encode(Func):
|
||||||
|
arg_types = {"this": True, "charset": True}
|
||||||
|
|
||||||
|
|
||||||
class Exp(Func):
|
class Exp(Func):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -2581,6 +2650,10 @@ class GroupConcat(Func):
|
||||||
arg_types = {"this": True, "separator": False}
|
arg_types = {"this": True, "separator": False}
|
||||||
|
|
||||||
|
|
||||||
|
class Hex(Func):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class If(Func):
|
class If(Func):
|
||||||
arg_types = {"this": True, "true": True, "false": False}
|
arg_types = {"this": True, "true": True, "false": False}
|
||||||
|
|
||||||
|
@ -2641,7 +2714,7 @@ class Log10(Func):
|
||||||
|
|
||||||
|
|
||||||
class Lower(Func):
|
class Lower(Func):
|
||||||
pass
|
_sql_names = ["LOWER", "LCASE"]
|
||||||
|
|
||||||
|
|
||||||
class Map(Func):
|
class Map(Func):
|
||||||
|
@ -2686,6 +2759,12 @@ class ApproxQuantile(Quantile):
|
||||||
arg_types = {"this": True, "quantile": True, "accuracy": False}
|
arg_types = {"this": True, "quantile": True, "accuracy": False}
|
||||||
|
|
||||||
|
|
||||||
|
class ReadCSV(Func):
|
||||||
|
_sql_names = ["READ_CSV"]
|
||||||
|
is_var_len_args = True
|
||||||
|
arg_types = {"this": True, "expressions": False}
|
||||||
|
|
||||||
|
|
||||||
class Reduce(Func):
|
class Reduce(Func):
|
||||||
arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
|
arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
|
||||||
|
|
||||||
|
@ -2804,8 +2883,8 @@ class TimeStrToUnix(Func):
|
||||||
class Trim(Func):
|
class Trim(Func):
|
||||||
arg_types = {
|
arg_types = {
|
||||||
"this": True,
|
"this": True,
|
||||||
"position": False,
|
|
||||||
"expression": False,
|
"expression": False,
|
||||||
|
"position": False,
|
||||||
"collation": False,
|
"collation": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2826,6 +2905,10 @@ class TsOrDiToDi(Func):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Unhex(Func):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UnixToStr(Func):
|
class UnixToStr(Func):
|
||||||
arg_types = {"this": True, "format": False}
|
arg_types = {"this": True, "format": False}
|
||||||
|
|
||||||
|
@ -2843,7 +2926,7 @@ class UnixToTimeStr(Func):
|
||||||
|
|
||||||
|
|
||||||
class Upper(Func):
|
class Upper(Func):
|
||||||
pass
|
_sql_names = ["UPPER", "UCASE"]
|
||||||
|
|
||||||
|
|
||||||
class Variance(AggFunc):
|
class Variance(AggFunc):
|
||||||
|
@ -3701,6 +3784,19 @@ def replace_placeholders(expression, *args, **kwargs):
|
||||||
return expression.transform(_replace_placeholders, iter(args), **kwargs)
|
return expression.transform(_replace_placeholders, iter(args), **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def true():
|
||||||
|
return Boolean(this=True)
|
||||||
|
|
||||||
|
|
||||||
|
def false():
|
||||||
|
return Boolean(this=False)
|
||||||
|
|
||||||
|
|
||||||
|
def null():
|
||||||
|
return Null()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: deprecate this
|
||||||
TRUE = Boolean(this=True)
|
TRUE = Boolean(this=True)
|
||||||
FALSE = Boolean(this=False)
|
FALSE = Boolean(this=False)
|
||||||
NULL = Null()
|
NULL = Null()
|
||||||
|
|
|
@ -67,7 +67,7 @@ class Generator:
|
||||||
exp.LocationProperty: lambda self, e: self.naked_property(e),
|
exp.LocationProperty: lambda self, e: self.naked_property(e),
|
||||||
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
|
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
|
||||||
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
|
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
|
||||||
exp.VolatilityProperty: lambda self, e: self.sql(e.name),
|
exp.VolatilityProperty: lambda self, e: e.name,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
|
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
|
||||||
|
@ -94,6 +94,9 @@ class Generator:
|
||||||
ROOT_PROPERTIES = {
|
ROOT_PROPERTIES = {
|
||||||
exp.ReturnsProperty,
|
exp.ReturnsProperty,
|
||||||
exp.LanguageProperty,
|
exp.LanguageProperty,
|
||||||
|
exp.DistStyleProperty,
|
||||||
|
exp.DistKeyProperty,
|
||||||
|
exp.SortKeyProperty,
|
||||||
}
|
}
|
||||||
|
|
||||||
WITH_PROPERTIES = {
|
WITH_PROPERTIES = {
|
||||||
|
@ -241,7 +244,7 @@ class Generator:
|
||||||
if not NEWLINE_RE.search(comment):
|
if not NEWLINE_RE.search(comment):
|
||||||
return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
|
return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
|
||||||
|
|
||||||
return f"/*{comment}*/\n{sql}"
|
return f"/*{comment}*/\n{sql}" if sql else f" /*{comment}*/"
|
||||||
|
|
||||||
def wrap(self, expression):
|
def wrap(self, expression):
|
||||||
this_sql = self.indent(
|
this_sql = self.indent(
|
||||||
|
@ -475,7 +478,8 @@ class Generator:
|
||||||
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
|
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
|
||||||
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
|
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
|
||||||
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
|
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
|
||||||
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}"
|
cascade = " CASCADE" if expression.args.get("cascade") else ""
|
||||||
|
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}"
|
||||||
|
|
||||||
def except_sql(self, expression):
|
def except_sql(self, expression):
|
||||||
return self.prepend_ctes(
|
return self.prepend_ctes(
|
||||||
|
@ -915,13 +919,15 @@ class Generator:
|
||||||
def subquery_sql(self, expression):
|
def subquery_sql(self, expression):
|
||||||
alias = self.sql(expression, "alias")
|
alias = self.sql(expression, "alias")
|
||||||
|
|
||||||
return self.query_modifiers(
|
sql = self.query_modifiers(
|
||||||
expression,
|
expression,
|
||||||
self.wrap(expression),
|
self.wrap(expression),
|
||||||
self.expressions(expression, key="pivots", sep=" "),
|
self.expressions(expression, key="pivots", sep=" "),
|
||||||
f" AS {alias}" if alias else "",
|
f" AS {alias}" if alias else "",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return self.prepend_ctes(expression, sql)
|
||||||
|
|
||||||
def qualify_sql(self, expression):
|
def qualify_sql(self, expression):
|
||||||
this = self.indent(self.sql(expression, "this"))
|
this = self.indent(self.sql(expression, "this"))
|
||||||
return f"{self.seg('QUALIFY')}{self.sep()}{this}"
|
return f"{self.seg('QUALIFY')}{self.sep()}{this}"
|
||||||
|
@ -1111,9 +1117,12 @@ class Generator:
|
||||||
|
|
||||||
def paren_sql(self, expression):
|
def paren_sql(self, expression):
|
||||||
if isinstance(expression.unnest(), exp.Select):
|
if isinstance(expression.unnest(), exp.Select):
|
||||||
return self.wrap(expression)
|
sql = self.wrap(expression)
|
||||||
|
else:
|
||||||
sql = self.seg(self.indent(self.sql(expression, "this")), sep="")
|
sql = self.seg(self.indent(self.sql(expression, "this")), sep="")
|
||||||
return f"({sql}{self.seg(')', sep='')}"
|
sql = f"({sql}{self.seg(')', sep='')}"
|
||||||
|
|
||||||
|
return self.prepend_ctes(expression, sql)
|
||||||
|
|
||||||
def neg_sql(self, expression):
|
def neg_sql(self, expression):
|
||||||
return f"-{self.sql(expression, 'this')}"
|
return f"-{self.sql(expression, 'this')}"
|
||||||
|
@ -1173,9 +1182,23 @@ class Generator:
|
||||||
zone = self.sql(expression, "this")
|
zone = self.sql(expression, "this")
|
||||||
return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE"
|
return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE"
|
||||||
|
|
||||||
|
def collate_sql(self, expression):
|
||||||
|
return self.binary(expression, "COLLATE")
|
||||||
|
|
||||||
def command_sql(self, expression):
|
def command_sql(self, expression):
|
||||||
return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}"
|
return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}"
|
||||||
|
|
||||||
|
def transaction_sql(self, *_):
|
||||||
|
return "BEGIN"
|
||||||
|
|
||||||
|
def commit_sql(self, *_):
|
||||||
|
return "COMMIT"
|
||||||
|
|
||||||
|
def rollback_sql(self, expression):
|
||||||
|
savepoint = expression.args.get("savepoint")
|
||||||
|
savepoint = f" TO {savepoint}" if savepoint else ""
|
||||||
|
return f"ROLLBACK{savepoint}"
|
||||||
|
|
||||||
def distinct_sql(self, expression):
|
def distinct_sql(self, expression):
|
||||||
this = self.expressions(expression, flat=True)
|
this = self.expressions(expression, flat=True)
|
||||||
this = f" {this}" if this else ""
|
this = f" {this}" if this else ""
|
||||||
|
@ -1193,10 +1216,7 @@ class Generator:
|
||||||
def intdiv_sql(self, expression):
|
def intdiv_sql(self, expression):
|
||||||
return self.sql(
|
return self.sql(
|
||||||
exp.Cast(
|
exp.Cast(
|
||||||
this=exp.Div(
|
this=exp.Div(this=expression.this, expression=expression.expression),
|
||||||
this=expression.args["this"],
|
|
||||||
expression=expression.args["expression"],
|
|
||||||
),
|
|
||||||
to=exp.DataType(this=exp.DataType.Type.INT),
|
to=exp.DataType(this=exp.DataType.Type.INT),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,7 +11,8 @@ from copy import copy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from sqlglot.expressions import Expression, Table
|
from sqlglot import exp
|
||||||
|
from sqlglot.expressions import Expression
|
||||||
|
|
||||||
T = t.TypeVar("T")
|
T = t.TypeVar("T")
|
||||||
E = t.TypeVar("E", bound=Expression)
|
E = t.TypeVar("E", bound=Expression)
|
||||||
|
@ -150,7 +151,7 @@ def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
|
||||||
if expression.is_int:
|
if expression.is_int:
|
||||||
expression = expression.copy()
|
expression = expression.copy()
|
||||||
logger.warning("Applying array index offset (%s)", offset)
|
logger.warning("Applying array index offset (%s)", offset)
|
||||||
expression.args["this"] = str(int(expression.args["this"]) + offset)
|
expression.args["this"] = str(int(expression.this) + offset)
|
||||||
return [expression]
|
return [expression]
|
||||||
|
|
||||||
return expressions
|
return expressions
|
||||||
|
@ -228,19 +229,18 @@ def open_file(file_name: str) -> t.TextIO:
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def csv_reader(table: Table) -> t.Any:
|
def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
|
||||||
"""
|
"""
|
||||||
Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
|
Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
table: a `Table` expression with an anonymous function `READ_CSV` in it.
|
read_csv: a `ReadCSV` function call
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
A python csv reader.
|
A python csv reader.
|
||||||
"""
|
"""
|
||||||
file, *args = table.this.expressions
|
args = read_csv.expressions
|
||||||
file = file.name
|
file = open_file(read_csv.name)
|
||||||
file = open_file(file)
|
|
||||||
|
|
||||||
delimiter = ","
|
delimiter = ","
|
||||||
args = iter(arg.name for arg in args)
|
args = iter(arg.name for arg in args)
|
||||||
|
@ -354,3 +354,34 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any,
|
||||||
yield from flatten(value)
|
yield from flatten(value)
|
||||||
else:
|
else:
|
||||||
yield value
|
yield value
|
||||||
|
|
||||||
|
|
||||||
|
def dict_depth(d: t.Dict) -> int:
|
||||||
|
"""
|
||||||
|
Get the nesting depth of a dictionary.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
>>> dict_depth(None)
|
||||||
|
0
|
||||||
|
>>> dict_depth({})
|
||||||
|
1
|
||||||
|
>>> dict_depth({"a": "b"})
|
||||||
|
1
|
||||||
|
>>> dict_depth({"a": {}})
|
||||||
|
2
|
||||||
|
>>> dict_depth({"a": {"b": {}}})
|
||||||
|
3
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d (dict): dictionary
|
||||||
|
Returns:
|
||||||
|
int: depth
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return 1 + dict_depth(next(iter(d.values())))
|
||||||
|
except AttributeError:
|
||||||
|
# d doesn't have attribute "values"
|
||||||
|
return 0
|
||||||
|
except StopIteration:
|
||||||
|
# d.values() returns an empty sequence
|
||||||
|
return 1
|
||||||
|
|
|
@ -245,23 +245,31 @@ class TypeAnnotator:
|
||||||
def annotate(self, expression):
|
def annotate(self, expression):
|
||||||
if isinstance(expression, self.TRAVERSABLES):
|
if isinstance(expression, self.TRAVERSABLES):
|
||||||
for scope in traverse_scope(expression):
|
for scope in traverse_scope(expression):
|
||||||
subscope_selects = {
|
selects = {}
|
||||||
name: {select.alias_or_name: select for select in source.selects}
|
for name, source in scope.sources.items():
|
||||||
for name, source in scope.sources.items()
|
if not isinstance(source, Scope):
|
||||||
if isinstance(source, Scope)
|
continue
|
||||||
|
if isinstance(source.expression, exp.Values):
|
||||||
|
selects[name] = {
|
||||||
|
alias: column
|
||||||
|
for alias, column in zip(
|
||||||
|
source.expression.alias_column_names,
|
||||||
|
source.expression.expressions[0].expressions,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
selects[name] = {
|
||||||
|
select.alias_or_name: select for select in source.expression.selects
|
||||||
}
|
}
|
||||||
|
|
||||||
# First annotate the current scope's column references
|
# First annotate the current scope's column references
|
||||||
for col in scope.columns:
|
for col in scope.columns:
|
||||||
source = scope.sources[col.table]
|
source = scope.sources[col.table]
|
||||||
if isinstance(source, exp.Table):
|
if isinstance(source, exp.Table):
|
||||||
col.type = self.schema.get_column_type(source, col)
|
col.type = self.schema.get_column_type(source, col)
|
||||||
else:
|
else:
|
||||||
col.type = subscope_selects[col.table][col.name].type
|
col.type = selects[col.table][col.name].type
|
||||||
|
|
||||||
# Then (possibly) annotate the remaining expressions in the scope
|
# Then (possibly) annotate the remaining expressions in the scope
|
||||||
self._maybe_annotate(scope.expression)
|
self._maybe_annotate(scope.expression)
|
||||||
|
|
||||||
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
|
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
|
||||||
|
|
||||||
def _maybe_annotate(self, expression):
|
def _maybe_annotate(self, expression):
|
||||||
|
|
48
sqlglot/optimizer/canonicalize.py
Normal file
48
sqlglot/optimizer/canonicalize.py
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
from sqlglot import exp
|
||||||
|
|
||||||
|
|
||||||
|
def canonicalize(expression: exp.Expression) -> exp.Expression:
|
||||||
|
"""Converts a sql expression into a standard form.
|
||||||
|
|
||||||
|
This method relies on annotate_types because many of the
|
||||||
|
conversions rely on type inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression: The expression to canonicalize.
|
||||||
|
"""
|
||||||
|
exp.replace_children(expression, canonicalize)
|
||||||
|
expression = add_text_to_concat(expression)
|
||||||
|
expression = coerce_type(expression)
|
||||||
|
return expression
|
||||||
|
|
||||||
|
|
||||||
|
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
|
||||||
|
if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES:
|
||||||
|
node = exp.Concat(this=node.this, expression=node.expression)
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
def coerce_type(node: exp.Expression) -> exp.Expression:
|
||||||
|
if isinstance(node, exp.Binary):
|
||||||
|
_coerce_date(node.left, node.right)
|
||||||
|
elif isinstance(node, exp.Between):
|
||||||
|
_coerce_date(node.this, node.args["low"])
|
||||||
|
elif isinstance(node, exp.Extract):
|
||||||
|
if node.expression.type not in exp.DataType.TEMPORAL_TYPES:
|
||||||
|
_replace_cast(node.expression, "datetime")
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
|
||||||
|
for a, b in itertools.permutations([a, b]):
|
||||||
|
if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE:
|
||||||
|
_replace_cast(b, "date")
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_cast(node: exp.Expression, to: str) -> None:
|
||||||
|
data_type = exp.DataType.build(to)
|
||||||
|
cast = exp.Cast(this=node.copy(), to=data_type)
|
||||||
|
cast.type = data_type
|
||||||
|
node.replace(cast)
|
|
@ -128,8 +128,8 @@ def join_condition(join):
|
||||||
Tuple of (source key, join key, remaining predicate)
|
Tuple of (source key, join key, remaining predicate)
|
||||||
"""
|
"""
|
||||||
name = join.this.alias_or_name
|
name = join.this.alias_or_name
|
||||||
on = join.args.get("on") or exp.TRUE
|
on = (join.args.get("on") or exp.true()).copy()
|
||||||
on = on.copy()
|
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
|
||||||
source_key = []
|
source_key = []
|
||||||
join_key = []
|
join_key = []
|
||||||
|
|
||||||
|
@ -141,7 +141,7 @@ def join_condition(join):
|
||||||
#
|
#
|
||||||
# should pull y.b as the join key and x.a as the source key
|
# should pull y.b as the join key and x.a as the source key
|
||||||
if normalized(on):
|
if normalized(on):
|
||||||
for condition in on.flatten() if isinstance(on, exp.And) else [on]:
|
for condition in on.flatten():
|
||||||
if isinstance(condition, exp.EQ):
|
if isinstance(condition, exp.EQ):
|
||||||
left, right = condition.unnest_operands()
|
left, right = condition.unnest_operands()
|
||||||
left_tables = exp.column_table_names(left)
|
left_tables = exp.column_table_names(left)
|
||||||
|
@ -150,13 +150,12 @@ def join_condition(join):
|
||||||
if name in left_tables and name not in right_tables:
|
if name in left_tables and name not in right_tables:
|
||||||
join_key.append(left)
|
join_key.append(left)
|
||||||
source_key.append(right)
|
source_key.append(right)
|
||||||
condition.replace(exp.TRUE)
|
condition.replace(exp.true())
|
||||||
elif name in right_tables and name not in left_tables:
|
elif name in right_tables and name not in left_tables:
|
||||||
join_key.append(right)
|
join_key.append(right)
|
||||||
source_key.append(left)
|
source_key.append(left)
|
||||||
condition.replace(exp.TRUE)
|
condition.replace(exp.true())
|
||||||
|
|
||||||
on = simplify(on)
|
on = simplify(on)
|
||||||
remaining_condition = None if on == exp.TRUE else on
|
remaining_condition = None if on == exp.true() else on
|
||||||
|
|
||||||
return source_key, join_key, remaining_condition
|
return source_key, join_key, remaining_condition
|
||||||
|
|
|
@ -29,7 +29,7 @@ def optimize_joins(expression):
|
||||||
if isinstance(on, exp.Connector):
|
if isinstance(on, exp.Connector):
|
||||||
for predicate in on.flatten():
|
for predicate in on.flatten():
|
||||||
if name in exp.column_table_names(predicate):
|
if name in exp.column_table_names(predicate):
|
||||||
predicate.replace(exp.TRUE)
|
predicate.replace(exp.true())
|
||||||
join.on(predicate, copy=False)
|
join.on(predicate, copy=False)
|
||||||
|
|
||||||
expression = reorder_joins(expression)
|
expression = reorder_joins(expression)
|
||||||
|
@ -70,6 +70,6 @@ def normalize(expression):
|
||||||
def other_table_names(join, exclude):
|
def other_table_names(join, exclude):
|
||||||
return [
|
return [
|
||||||
name
|
name
|
||||||
for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
|
for name in (exp.column_table_names(join.args.get("on") or exp.true()))
|
||||||
if name != exclude
|
if name != exclude
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
import sqlglot
|
import sqlglot
|
||||||
|
from sqlglot.optimizer.annotate_types import annotate_types
|
||||||
|
from sqlglot.optimizer.canonicalize import canonicalize
|
||||||
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
|
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
|
||||||
from sqlglot.optimizer.eliminate_joins import eliminate_joins
|
from sqlglot.optimizer.eliminate_joins import eliminate_joins
|
||||||
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
||||||
|
@ -28,6 +30,8 @@ RULES = (
|
||||||
merge_subqueries,
|
merge_subqueries,
|
||||||
eliminate_joins,
|
eliminate_joins,
|
||||||
eliminate_ctes,
|
eliminate_ctes,
|
||||||
|
annotate_types,
|
||||||
|
canonicalize,
|
||||||
quote_identities,
|
quote_identities,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -64,11 +64,11 @@ def pushdown_cnf(predicates, scope, scope_ref_count):
|
||||||
for predicate in predicates:
|
for predicate in predicates:
|
||||||
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
|
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
|
||||||
if isinstance(node, exp.Join):
|
if isinstance(node, exp.Join):
|
||||||
predicate.replace(exp.TRUE)
|
predicate.replace(exp.true())
|
||||||
node.on(predicate, copy=False)
|
node.on(predicate, copy=False)
|
||||||
break
|
break
|
||||||
if isinstance(node, exp.Select):
|
if isinstance(node, exp.Select):
|
||||||
predicate.replace(exp.TRUE)
|
predicate.replace(exp.true())
|
||||||
node.where(replace_aliases(node, predicate), copy=False)
|
node.where(replace_aliases(node, predicate), copy=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -382,9 +382,7 @@ class _Resolver:
|
||||||
raise OptimizeError(str(e)) from e
|
raise OptimizeError(str(e)) from e
|
||||||
|
|
||||||
if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
|
if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
|
||||||
values_alias = source.expression.parent
|
return source.expression.alias_column_names
|
||||||
if hasattr(values_alias, "alias_column_names"):
|
|
||||||
return values_alias.alias_column_names
|
|
||||||
|
|
||||||
# Otherwise, if referencing another scope, return that scope's named selects
|
# Otherwise, if referencing another scope, return that scope's named selects
|
||||||
return source.expression.named_selects
|
return source.expression.named_selects
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
from sqlglot import alias, exp
|
from sqlglot import alias, exp
|
||||||
|
from sqlglot.helper import csv_reader
|
||||||
from sqlglot.optimizer.scope import traverse_scope
|
from sqlglot.optimizer.scope import traverse_scope
|
||||||
|
|
||||||
|
|
||||||
def qualify_tables(expression, db=None, catalog=None):
|
def qualify_tables(expression, db=None, catalog=None, schema=None):
|
||||||
"""
|
"""
|
||||||
Rewrite sqlglot AST to have fully qualified tables.
|
Rewrite sqlglot AST to have fully qualified tables.
|
||||||
|
|
||||||
|
@ -18,6 +19,7 @@ def qualify_tables(expression, db=None, catalog=None):
|
||||||
expression (sqlglot.Expression): expression to qualify
|
expression (sqlglot.Expression): expression to qualify
|
||||||
db (str): Database name
|
db (str): Database name
|
||||||
catalog (str): Catalog name
|
catalog (str): Catalog name
|
||||||
|
schema: A schema to populate
|
||||||
Returns:
|
Returns:
|
||||||
sqlglot.Expression: qualified expression
|
sqlglot.Expression: qualified expression
|
||||||
"""
|
"""
|
||||||
|
@ -41,7 +43,7 @@ def qualify_tables(expression, db=None, catalog=None):
|
||||||
source.set("catalog", exp.to_identifier(catalog))
|
source.set("catalog", exp.to_identifier(catalog))
|
||||||
|
|
||||||
if not source.alias:
|
if not source.alias:
|
||||||
source.replace(
|
source = source.replace(
|
||||||
alias(
|
alias(
|
||||||
source.copy(),
|
source.copy(),
|
||||||
source.this if identifier else f"_q_{next(sequence)}",
|
source.this if identifier else f"_q_{next(sequence)}",
|
||||||
|
@ -49,4 +51,12 @@ def qualify_tables(expression, db=None, catalog=None):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if schema and isinstance(source.this, exp.ReadCSV):
|
||||||
|
with csv_reader(source.this) as reader:
|
||||||
|
header = next(reader)
|
||||||
|
columns = next(reader)
|
||||||
|
schema.add_table(
|
||||||
|
source, {k: type(v).__name__ for k, v in zip(header, columns)}
|
||||||
|
)
|
||||||
|
|
||||||
return expression
|
return expression
|
||||||
|
|
|
@ -189,11 +189,11 @@ def absorb_and_eliminate(expression):
|
||||||
|
|
||||||
# absorb
|
# absorb
|
||||||
if is_complement(b, aa):
|
if is_complement(b, aa):
|
||||||
aa.replace(exp.TRUE if kind == exp.And else exp.FALSE)
|
aa.replace(exp.true() if kind == exp.And else exp.false())
|
||||||
elif is_complement(b, ab):
|
elif is_complement(b, ab):
|
||||||
ab.replace(exp.TRUE if kind == exp.And else exp.FALSE)
|
ab.replace(exp.true() if kind == exp.And else exp.false())
|
||||||
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
|
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
|
||||||
a.replace(exp.FALSE if kind == exp.And else exp.TRUE)
|
a.replace(exp.false() if kind == exp.And else exp.true())
|
||||||
elif isinstance(b, kind):
|
elif isinstance(b, kind):
|
||||||
# eliminate
|
# eliminate
|
||||||
rhs = b.unnest_operands()
|
rhs = b.unnest_operands()
|
||||||
|
|
|
@ -169,7 +169,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
||||||
select.parent.replace(alias)
|
select.parent.replace(alias)
|
||||||
|
|
||||||
for key, column, predicate in keys:
|
for key, column, predicate in keys:
|
||||||
predicate.replace(exp.TRUE)
|
predicate.replace(exp.true())
|
||||||
nested = exp.column(key_aliases[key], table_alias)
|
nested = exp.column(key_aliases[key], table_alias)
|
||||||
|
|
||||||
if key in group_by:
|
if key in group_by:
|
||||||
|
|
|
@ -141,26 +141,29 @@ class Parser(metaclass=_Parser):
|
||||||
|
|
||||||
ID_VAR_TOKENS = {
|
ID_VAR_TOKENS = {
|
||||||
TokenType.VAR,
|
TokenType.VAR,
|
||||||
TokenType.ALTER,
|
|
||||||
TokenType.ALWAYS,
|
TokenType.ALWAYS,
|
||||||
TokenType.ANTI,
|
TokenType.ANTI,
|
||||||
TokenType.APPLY,
|
TokenType.APPLY,
|
||||||
|
TokenType.AUTO_INCREMENT,
|
||||||
TokenType.BEGIN,
|
TokenType.BEGIN,
|
||||||
TokenType.BOTH,
|
TokenType.BOTH,
|
||||||
TokenType.BUCKET,
|
TokenType.BUCKET,
|
||||||
TokenType.CACHE,
|
TokenType.CACHE,
|
||||||
TokenType.CALL,
|
TokenType.CASCADE,
|
||||||
TokenType.COLLATE,
|
TokenType.COLLATE,
|
||||||
|
TokenType.COMMAND,
|
||||||
TokenType.COMMIT,
|
TokenType.COMMIT,
|
||||||
TokenType.CONSTRAINT,
|
TokenType.CONSTRAINT,
|
||||||
|
TokenType.CURRENT_TIME,
|
||||||
TokenType.DEFAULT,
|
TokenType.DEFAULT,
|
||||||
TokenType.DELETE,
|
TokenType.DELETE,
|
||||||
TokenType.DESCRIBE,
|
TokenType.DESCRIBE,
|
||||||
TokenType.DETERMINISTIC,
|
TokenType.DETERMINISTIC,
|
||||||
|
TokenType.DISTKEY,
|
||||||
|
TokenType.DISTSTYLE,
|
||||||
TokenType.EXECUTE,
|
TokenType.EXECUTE,
|
||||||
TokenType.ENGINE,
|
TokenType.ENGINE,
|
||||||
TokenType.ESCAPE,
|
TokenType.ESCAPE,
|
||||||
TokenType.EXPLAIN,
|
|
||||||
TokenType.FALSE,
|
TokenType.FALSE,
|
||||||
TokenType.FIRST,
|
TokenType.FIRST,
|
||||||
TokenType.FOLLOWING,
|
TokenType.FOLLOWING,
|
||||||
|
@ -182,7 +185,6 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.NATURAL,
|
TokenType.NATURAL,
|
||||||
TokenType.NEXT,
|
TokenType.NEXT,
|
||||||
TokenType.ONLY,
|
TokenType.ONLY,
|
||||||
TokenType.OPTIMIZE,
|
|
||||||
TokenType.OPTIONS,
|
TokenType.OPTIONS,
|
||||||
TokenType.ORDINALITY,
|
TokenType.ORDINALITY,
|
||||||
TokenType.PARTITIONED_BY,
|
TokenType.PARTITIONED_BY,
|
||||||
|
@ -199,6 +201,7 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.SEMI,
|
TokenType.SEMI,
|
||||||
TokenType.SET,
|
TokenType.SET,
|
||||||
TokenType.SHOW,
|
TokenType.SHOW,
|
||||||
|
TokenType.SORTKEY,
|
||||||
TokenType.STABLE,
|
TokenType.STABLE,
|
||||||
TokenType.STORED,
|
TokenType.STORED,
|
||||||
TokenType.TABLE,
|
TokenType.TABLE,
|
||||||
|
@ -207,7 +210,6 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.TRANSIENT,
|
TokenType.TRANSIENT,
|
||||||
TokenType.TOP,
|
TokenType.TOP,
|
||||||
TokenType.TRAILING,
|
TokenType.TRAILING,
|
||||||
TokenType.TRUNCATE,
|
|
||||||
TokenType.TRUE,
|
TokenType.TRUE,
|
||||||
TokenType.UNBOUNDED,
|
TokenType.UNBOUNDED,
|
||||||
TokenType.UNIQUE,
|
TokenType.UNIQUE,
|
||||||
|
@ -217,6 +219,7 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.VOLATILE,
|
TokenType.VOLATILE,
|
||||||
*SUBQUERY_PREDICATES,
|
*SUBQUERY_PREDICATES,
|
||||||
*TYPE_TOKENS,
|
*TYPE_TOKENS,
|
||||||
|
*NO_PAREN_FUNCTIONS,
|
||||||
}
|
}
|
||||||
|
|
||||||
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY}
|
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY}
|
||||||
|
@ -231,6 +234,7 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.FILTER,
|
TokenType.FILTER,
|
||||||
TokenType.FIRST,
|
TokenType.FIRST,
|
||||||
TokenType.FORMAT,
|
TokenType.FORMAT,
|
||||||
|
TokenType.IDENTIFIER,
|
||||||
TokenType.ISNULL,
|
TokenType.ISNULL,
|
||||||
TokenType.OFFSET,
|
TokenType.OFFSET,
|
||||||
TokenType.PRIMARY_KEY,
|
TokenType.PRIMARY_KEY,
|
||||||
|
@ -242,6 +246,7 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.RIGHT,
|
TokenType.RIGHT,
|
||||||
TokenType.DATE,
|
TokenType.DATE,
|
||||||
TokenType.DATETIME,
|
TokenType.DATETIME,
|
||||||
|
TokenType.TABLE,
|
||||||
TokenType.TIMESTAMP,
|
TokenType.TIMESTAMP,
|
||||||
TokenType.TIMESTAMPTZ,
|
TokenType.TIMESTAMPTZ,
|
||||||
*TYPE_TOKENS,
|
*TYPE_TOKENS,
|
||||||
|
@ -277,6 +282,7 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.DASH: exp.Sub,
|
TokenType.DASH: exp.Sub,
|
||||||
TokenType.PLUS: exp.Add,
|
TokenType.PLUS: exp.Add,
|
||||||
TokenType.MOD: exp.Mod,
|
TokenType.MOD: exp.Mod,
|
||||||
|
TokenType.COLLATE: exp.Collate,
|
||||||
}
|
}
|
||||||
|
|
||||||
FACTOR = {
|
FACTOR = {
|
||||||
|
@ -391,7 +397,10 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.DELETE: lambda self: self._parse_delete(),
|
TokenType.DELETE: lambda self: self._parse_delete(),
|
||||||
TokenType.CACHE: lambda self: self._parse_cache(),
|
TokenType.CACHE: lambda self: self._parse_cache(),
|
||||||
TokenType.UNCACHE: lambda self: self._parse_uncache(),
|
TokenType.UNCACHE: lambda self: self._parse_uncache(),
|
||||||
TokenType.USE: lambda self: self._parse_use(),
|
TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
|
||||||
|
TokenType.BEGIN: lambda self: self._parse_transaction(),
|
||||||
|
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
|
||||||
|
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
|
||||||
}
|
}
|
||||||
|
|
||||||
PRIMARY_PARSERS = {
|
PRIMARY_PARSERS = {
|
||||||
|
@ -402,7 +411,8 @@ class Parser(metaclass=_Parser):
|
||||||
exp.Literal, this=token.text, is_string=False
|
exp.Literal, this=token.text, is_string=False
|
||||||
),
|
),
|
||||||
TokenType.STAR: lambda self, _: self.expression(
|
TokenType.STAR: lambda self, _: self.expression(
|
||||||
exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()}
|
exp.Star,
|
||||||
|
**{"except": self._parse_except(), "replace": self._parse_replace()},
|
||||||
),
|
),
|
||||||
TokenType.NULL: lambda self, _: self.expression(exp.Null),
|
TokenType.NULL: lambda self, _: self.expression(exp.Null),
|
||||||
TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
|
TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
|
||||||
|
@ -446,6 +456,9 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(),
|
TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(),
|
||||||
TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(),
|
TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(),
|
||||||
TokenType.STORED: lambda self: self._parse_stored(),
|
TokenType.STORED: lambda self: self._parse_stored(),
|
||||||
|
TokenType.DISTKEY: lambda self: self._parse_distkey(),
|
||||||
|
TokenType.DISTSTYLE: lambda self: self._parse_diststyle(),
|
||||||
|
TokenType.SORTKEY: lambda self: self._parse_sortkey(),
|
||||||
TokenType.RETURNS: lambda self: self._parse_returns(),
|
TokenType.RETURNS: lambda self: self._parse_returns(),
|
||||||
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
|
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
|
||||||
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
|
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
|
||||||
|
@ -471,7 +484,9 @@ class Parser(metaclass=_Parser):
|
||||||
}
|
}
|
||||||
|
|
||||||
CONSTRAINT_PARSERS = {
|
CONSTRAINT_PARSERS = {
|
||||||
TokenType.CHECK: lambda self: self._parse_check(),
|
TokenType.CHECK: lambda self: self.expression(
|
||||||
|
exp.Check, this=self._parse_wrapped(self._parse_conjunction)
|
||||||
|
),
|
||||||
TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(),
|
TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(),
|
||||||
TokenType.UNIQUE: lambda self: self._parse_unique(),
|
TokenType.UNIQUE: lambda self: self._parse_unique(),
|
||||||
}
|
}
|
||||||
|
@ -521,6 +536,8 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.SCHEMA,
|
TokenType.SCHEMA,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
|
||||||
|
|
||||||
STRICT_CAST = True
|
STRICT_CAST = True
|
||||||
|
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
|
@ -740,6 +757,7 @@ class Parser(metaclass=_Parser):
|
||||||
kind=kind,
|
kind=kind,
|
||||||
temporary=temporary,
|
temporary=temporary,
|
||||||
materialized=materialized,
|
materialized=materialized,
|
||||||
|
cascade=self._match(TokenType.CASCADE),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_exists(self, not_=False):
|
def _parse_exists(self, not_=False):
|
||||||
|
@ -777,7 +795,11 @@ class Parser(metaclass=_Parser):
|
||||||
expression = self._parse_select_or_expression()
|
expression = self._parse_select_or_expression()
|
||||||
elif create_token.token_type == TokenType.INDEX:
|
elif create_token.token_type == TokenType.INDEX:
|
||||||
this = self._parse_index()
|
this = self._parse_index()
|
||||||
elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW, TokenType.SCHEMA):
|
elif create_token.token_type in (
|
||||||
|
TokenType.TABLE,
|
||||||
|
TokenType.VIEW,
|
||||||
|
TokenType.SCHEMA,
|
||||||
|
):
|
||||||
this = self._parse_table(schema=True)
|
this = self._parse_table(schema=True)
|
||||||
properties = self._parse_properties()
|
properties = self._parse_properties()
|
||||||
if self._match(TokenType.ALIAS):
|
if self._match(TokenType.ALIAS):
|
||||||
|
@ -834,7 +856,38 @@ class Parser(metaclass=_Parser):
|
||||||
return self.expression(
|
return self.expression(
|
||||||
exp.FileFormatProperty,
|
exp.FileFormatProperty,
|
||||||
this=exp.Literal.string("FORMAT"),
|
this=exp.Literal.string("FORMAT"),
|
||||||
value=exp.Literal.string(self._parse_var().name),
|
value=exp.Literal.string(self._parse_var_or_string().name),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_distkey(self):
|
||||||
|
self._match_l_paren()
|
||||||
|
this = exp.Literal.string("DISTKEY")
|
||||||
|
value = exp.Literal.string(self._parse_var().name)
|
||||||
|
self._match_r_paren()
|
||||||
|
return self.expression(
|
||||||
|
exp.DistKeyProperty,
|
||||||
|
this=this,
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_sortkey(self):
|
||||||
|
self._match_l_paren()
|
||||||
|
this = exp.Literal.string("SORTKEY")
|
||||||
|
value = exp.Literal.string(self._parse_var().name)
|
||||||
|
self._match_r_paren()
|
||||||
|
return self.expression(
|
||||||
|
exp.SortKeyProperty,
|
||||||
|
this=this,
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_diststyle(self):
|
||||||
|
this = exp.Literal.string("DISTSTYLE")
|
||||||
|
value = exp.Literal.string(self._parse_var().name)
|
||||||
|
return self.expression(
|
||||||
|
exp.DistStyleProperty,
|
||||||
|
this=this,
|
||||||
|
value=value,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_auto_increment(self):
|
def _parse_auto_increment(self):
|
||||||
|
@ -842,7 +895,7 @@ class Parser(metaclass=_Parser):
|
||||||
return self.expression(
|
return self.expression(
|
||||||
exp.AutoIncrementProperty,
|
exp.AutoIncrementProperty,
|
||||||
this=exp.Literal.string("AUTO_INCREMENT"),
|
this=exp.Literal.string("AUTO_INCREMENT"),
|
||||||
value=self._parse_var() or self._parse_number(),
|
value=self._parse_number(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_schema_comment(self):
|
def _parse_schema_comment(self):
|
||||||
|
@ -898,13 +951,10 @@ class Parser(metaclass=_Parser):
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if self._match(TokenType.WITH):
|
if self._match(TokenType.WITH):
|
||||||
self._match_l_paren()
|
properties.extend(self._parse_wrapped_csv(self._parse_property))
|
||||||
properties.extend(self._parse_csv(lambda: self._parse_property()))
|
|
||||||
self._match_r_paren()
|
|
||||||
elif self._match(TokenType.PROPERTIES):
|
elif self._match(TokenType.PROPERTIES):
|
||||||
self._match_l_paren()
|
|
||||||
properties.extend(
|
properties.extend(
|
||||||
self._parse_csv(
|
self._parse_wrapped_csv(
|
||||||
lambda: self.expression(
|
lambda: self.expression(
|
||||||
exp.AnonymousProperty,
|
exp.AnonymousProperty,
|
||||||
this=self._parse_string(),
|
this=self._parse_string(),
|
||||||
|
@ -912,25 +962,24 @@ class Parser(metaclass=_Parser):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self._match_r_paren()
|
|
||||||
else:
|
else:
|
||||||
identified_property = self._parse_property()
|
identified_property = self._parse_property()
|
||||||
if not identified_property:
|
if not identified_property:
|
||||||
break
|
break
|
||||||
properties.append(identified_property)
|
properties.append(identified_property)
|
||||||
|
|
||||||
if properties:
|
if properties:
|
||||||
return self.expression(exp.Properties, expressions=properties)
|
return self.expression(exp.Properties, expressions=properties)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_describe(self):
|
def _parse_describe(self):
|
||||||
self._match(TokenType.TABLE)
|
self._match(TokenType.TABLE)
|
||||||
|
|
||||||
return self.expression(exp.Describe, this=self._parse_id_var())
|
return self.expression(exp.Describe, this=self._parse_id_var())
|
||||||
|
|
||||||
def _parse_insert(self):
|
def _parse_insert(self):
|
||||||
overwrite = self._match(TokenType.OVERWRITE)
|
overwrite = self._match(TokenType.OVERWRITE)
|
||||||
local = self._match(TokenType.LOCAL)
|
local = self._match(TokenType.LOCAL)
|
||||||
if self._match_text("DIRECTORY"):
|
if self._match_text_seq("DIRECTORY"):
|
||||||
this = self.expression(
|
this = self.expression(
|
||||||
exp.Directory,
|
exp.Directory,
|
||||||
this=self._parse_var_or_string(),
|
this=self._parse_var_or_string(),
|
||||||
|
@ -954,27 +1003,27 @@ class Parser(metaclass=_Parser):
|
||||||
if not self._match_pair(TokenType.ROW, TokenType.FORMAT):
|
if not self._match_pair(TokenType.ROW, TokenType.FORMAT):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
self._match_text("DELIMITED")
|
self._match_text_seq("DELIMITED")
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
if self._match_text("FIELDS", "TERMINATED", "BY"):
|
if self._match_text_seq("FIELDS", "TERMINATED", "BY"):
|
||||||
kwargs["fields"] = self._parse_string()
|
kwargs["fields"] = self._parse_string()
|
||||||
if self._match_text("ESCAPED", "BY"):
|
if self._match_text_seq("ESCAPED", "BY"):
|
||||||
kwargs["escaped"] = self._parse_string()
|
kwargs["escaped"] = self._parse_string()
|
||||||
if self._match_text("COLLECTION", "ITEMS", "TERMINATED", "BY"):
|
if self._match_text_seq("COLLECTION", "ITEMS", "TERMINATED", "BY"):
|
||||||
kwargs["collection_items"] = self._parse_string()
|
kwargs["collection_items"] = self._parse_string()
|
||||||
if self._match_text("MAP", "KEYS", "TERMINATED", "BY"):
|
if self._match_text_seq("MAP", "KEYS", "TERMINATED", "BY"):
|
||||||
kwargs["map_keys"] = self._parse_string()
|
kwargs["map_keys"] = self._parse_string()
|
||||||
if self._match_text("LINES", "TERMINATED", "BY"):
|
if self._match_text_seq("LINES", "TERMINATED", "BY"):
|
||||||
kwargs["lines"] = self._parse_string()
|
kwargs["lines"] = self._parse_string()
|
||||||
if self._match_text("NULL", "DEFINED", "AS"):
|
if self._match_text_seq("NULL", "DEFINED", "AS"):
|
||||||
kwargs["null"] = self._parse_string()
|
kwargs["null"] = self._parse_string()
|
||||||
return self.expression(exp.RowFormat, **kwargs)
|
return self.expression(exp.RowFormat, **kwargs)
|
||||||
|
|
||||||
def _parse_load_data(self):
|
def _parse_load_data(self):
|
||||||
local = self._match(TokenType.LOCAL)
|
local = self._match(TokenType.LOCAL)
|
||||||
self._match_text("INPATH")
|
self._match_text_seq("INPATH")
|
||||||
inpath = self._parse_string()
|
inpath = self._parse_string()
|
||||||
overwrite = self._match(TokenType.OVERWRITE)
|
overwrite = self._match(TokenType.OVERWRITE)
|
||||||
self._match_pair(TokenType.INTO, TokenType.TABLE)
|
self._match_pair(TokenType.INTO, TokenType.TABLE)
|
||||||
|
@ -986,8 +1035,8 @@ class Parser(metaclass=_Parser):
|
||||||
overwrite=overwrite,
|
overwrite=overwrite,
|
||||||
inpath=inpath,
|
inpath=inpath,
|
||||||
partition=self._parse_partition(),
|
partition=self._parse_partition(),
|
||||||
input_format=self._match_text("INPUTFORMAT") and self._parse_string(),
|
input_format=self._match_text_seq("INPUTFORMAT") and self._parse_string(),
|
||||||
serde=self._match_text("SERDE") and self._parse_string(),
|
serde=self._match_text_seq("SERDE") and self._parse_string(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_delete(self):
|
def _parse_delete(self):
|
||||||
|
@ -996,9 +1045,7 @@ class Parser(metaclass=_Parser):
|
||||||
return self.expression(
|
return self.expression(
|
||||||
exp.Delete,
|
exp.Delete,
|
||||||
this=self._parse_table(schema=True),
|
this=self._parse_table(schema=True),
|
||||||
using=self._parse_csv(
|
using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table()),
|
||||||
lambda: self._match(TokenType.USING) and self._parse_table(schema=True)
|
|
||||||
),
|
|
||||||
where=self._parse_where(),
|
where=self._parse_where(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1029,12 +1076,7 @@ class Parser(metaclass=_Parser):
|
||||||
options = []
|
options = []
|
||||||
|
|
||||||
if self._match(TokenType.OPTIONS):
|
if self._match(TokenType.OPTIONS):
|
||||||
self._match_l_paren()
|
options = self._parse_wrapped_csv(self._parse_string, sep=TokenType.EQ)
|
||||||
k = self._parse_string()
|
|
||||||
self._match(TokenType.EQ)
|
|
||||||
v = self._parse_string()
|
|
||||||
options = [k, v]
|
|
||||||
self._match_r_paren()
|
|
||||||
|
|
||||||
self._match(TokenType.ALIAS)
|
self._match(TokenType.ALIAS)
|
||||||
return self.expression(
|
return self.expression(
|
||||||
|
@ -1050,27 +1092,13 @@ class Parser(metaclass=_Parser):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def parse_values():
|
def parse_values():
|
||||||
key = self._parse_var()
|
props = self._parse_csv(self._parse_var_or_string, sep=TokenType.EQ)
|
||||||
value = None
|
return exp.Property(this=seq_get(props, 0), value=seq_get(props, 1))
|
||||||
|
|
||||||
if self._match(TokenType.EQ):
|
return self.expression(exp.Partition, this=self._parse_wrapped_csv(parse_values))
|
||||||
value = self._parse_string()
|
|
||||||
|
|
||||||
return exp.Property(this=key, value=value)
|
|
||||||
|
|
||||||
self._match_l_paren()
|
|
||||||
values = self._parse_csv(parse_values)
|
|
||||||
self._match_r_paren()
|
|
||||||
|
|
||||||
return self.expression(
|
|
||||||
exp.Partition,
|
|
||||||
this=values,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_value(self):
|
def _parse_value(self):
|
||||||
self._match_l_paren()
|
expressions = self._parse_wrapped_csv(self._parse_conjunction)
|
||||||
expressions = self._parse_csv(self._parse_conjunction)
|
|
||||||
self._match_r_paren()
|
|
||||||
return self.expression(exp.Tuple, expressions=expressions)
|
return self.expression(exp.Tuple, expressions=expressions)
|
||||||
|
|
||||||
def _parse_select(self, nested=False, table=False):
|
def _parse_select(self, nested=False, table=False):
|
||||||
|
@ -1124,10 +1152,11 @@ class Parser(metaclass=_Parser):
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
this = self._parse_subquery(this)
|
this = self._parse_subquery(this)
|
||||||
elif self._match(TokenType.VALUES):
|
elif self._match(TokenType.VALUES):
|
||||||
this = self.expression(exp.Values, expressions=self._parse_csv(self._parse_value))
|
this = self.expression(
|
||||||
alias = self._parse_table_alias()
|
exp.Values,
|
||||||
if alias:
|
expressions=self._parse_csv(self._parse_value),
|
||||||
this = self.expression(exp.Subquery, this=this, alias=alias)
|
alias=self._parse_table_alias(),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
this = None
|
this = None
|
||||||
|
|
||||||
|
@ -1140,7 +1169,6 @@ class Parser(metaclass=_Parser):
|
||||||
recursive = self._match(TokenType.RECURSIVE)
|
recursive = self._match(TokenType.RECURSIVE)
|
||||||
|
|
||||||
expressions = []
|
expressions = []
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
expressions.append(self._parse_cte())
|
expressions.append(self._parse_cte())
|
||||||
|
|
||||||
|
@ -1149,11 +1177,7 @@ class Parser(metaclass=_Parser):
|
||||||
else:
|
else:
|
||||||
self._match(TokenType.WITH)
|
self._match(TokenType.WITH)
|
||||||
|
|
||||||
return self.expression(
|
return self.expression(exp.With, expressions=expressions, recursive=recursive)
|
||||||
exp.With,
|
|
||||||
expressions=expressions,
|
|
||||||
recursive=recursive,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_cte(self):
|
def _parse_cte(self):
|
||||||
alias = self._parse_table_alias()
|
alias = self._parse_table_alias()
|
||||||
|
@ -1163,13 +1187,9 @@ class Parser(metaclass=_Parser):
|
||||||
if not self._match(TokenType.ALIAS):
|
if not self._match(TokenType.ALIAS):
|
||||||
self.raise_error("Expected AS in CTE")
|
self.raise_error("Expected AS in CTE")
|
||||||
|
|
||||||
self._match_l_paren()
|
|
||||||
expression = self._parse_statement()
|
|
||||||
self._match_r_paren()
|
|
||||||
|
|
||||||
return self.expression(
|
return self.expression(
|
||||||
exp.CTE,
|
exp.CTE,
|
||||||
this=expression,
|
this=self._parse_wrapped(self._parse_statement),
|
||||||
alias=alias,
|
alias=alias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1223,7 +1243,7 @@ class Parser(metaclass=_Parser):
|
||||||
def _parse_hint(self):
|
def _parse_hint(self):
|
||||||
if self._match(TokenType.HINT):
|
if self._match(TokenType.HINT):
|
||||||
hints = self._parse_csv(self._parse_function)
|
hints = self._parse_csv(self._parse_function)
|
||||||
if not self._match(TokenType.HINT):
|
if not self._match_pair(TokenType.STAR, TokenType.SLASH):
|
||||||
self.raise_error("Expected */ after HINT")
|
self.raise_error("Expected */ after HINT")
|
||||||
return self.expression(exp.Hint, expressions=hints)
|
return self.expression(exp.Hint, expressions=hints)
|
||||||
return None
|
return None
|
||||||
|
@ -1259,26 +1279,18 @@ class Parser(metaclass=_Parser):
|
||||||
columns = self._parse_csv(self._parse_id_var)
|
columns = self._parse_csv(self._parse_id_var)
|
||||||
elif self._match(TokenType.L_PAREN):
|
elif self._match(TokenType.L_PAREN):
|
||||||
columns = self._parse_csv(self._parse_id_var)
|
columns = self._parse_csv(self._parse_id_var)
|
||||||
self._match(TokenType.R_PAREN)
|
self._match_r_paren()
|
||||||
|
|
||||||
expression = self.expression(
|
expression = self.expression(
|
||||||
exp.Lateral,
|
exp.Lateral,
|
||||||
this=this,
|
this=this,
|
||||||
view=view,
|
view=view,
|
||||||
outer=outer,
|
outer=outer,
|
||||||
alias=self.expression(
|
alias=self.expression(exp.TableAlias, this=table_alias, columns=columns),
|
||||||
exp.TableAlias,
|
|
||||||
this=table_alias,
|
|
||||||
columns=columns,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if outer_apply or cross_apply:
|
if outer_apply or cross_apply:
|
||||||
return self.expression(
|
return self.expression(exp.Join, this=expression, side=None if cross_apply else "LEFT")
|
||||||
exp.Join,
|
|
||||||
this=expression,
|
|
||||||
side=None if cross_apply else "LEFT",
|
|
||||||
)
|
|
||||||
|
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
|
@ -1387,12 +1399,8 @@ class Parser(metaclass=_Parser):
|
||||||
if not self._match(TokenType.UNNEST):
|
if not self._match(TokenType.UNNEST):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
self._match_l_paren()
|
expressions = self._parse_wrapped_csv(self._parse_column)
|
||||||
expressions = self._parse_csv(self._parse_column)
|
|
||||||
self._match_r_paren()
|
|
||||||
|
|
||||||
ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY))
|
ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY))
|
||||||
|
|
||||||
alias = self._parse_table_alias()
|
alias = self._parse_table_alias()
|
||||||
|
|
||||||
if alias and self.unnest_column_only:
|
if alias and self.unnest_column_only:
|
||||||
|
@ -1402,10 +1410,7 @@ class Parser(metaclass=_Parser):
|
||||||
alias.set("this", None)
|
alias.set("this", None)
|
||||||
|
|
||||||
return self.expression(
|
return self.expression(
|
||||||
exp.Unnest,
|
exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias
|
||||||
expressions=expressions,
|
|
||||||
ordinality=ordinality,
|
|
||||||
alias=alias,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_derived_table_values(self):
|
def _parse_derived_table_values(self):
|
||||||
|
@ -1418,13 +1423,7 @@ class Parser(metaclass=_Parser):
|
||||||
if is_derived:
|
if is_derived:
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
|
|
||||||
alias = self._parse_table_alias()
|
return self.expression(exp.Values, expressions=expressions, alias=self._parse_table_alias())
|
||||||
|
|
||||||
return self.expression(
|
|
||||||
exp.Values,
|
|
||||||
expressions=expressions,
|
|
||||||
alias=alias,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_table_sample(self):
|
def _parse_table_sample(self):
|
||||||
if not self._match(TokenType.TABLE_SAMPLE):
|
if not self._match(TokenType.TABLE_SAMPLE):
|
||||||
|
@ -1460,9 +1459,7 @@ class Parser(metaclass=_Parser):
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
|
|
||||||
if self._match(TokenType.SEED):
|
if self._match(TokenType.SEED):
|
||||||
self._match_l_paren()
|
seed = self._parse_wrapped(self._parse_number)
|
||||||
seed = self._parse_number()
|
|
||||||
self._match_r_paren()
|
|
||||||
|
|
||||||
return self.expression(
|
return self.expression(
|
||||||
exp.TableSample,
|
exp.TableSample,
|
||||||
|
@ -1513,12 +1510,7 @@ class Parser(metaclass=_Parser):
|
||||||
|
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
|
|
||||||
return self.expression(
|
return self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot)
|
||||||
exp.Pivot,
|
|
||||||
expressions=expressions,
|
|
||||||
field=field,
|
|
||||||
unpivot=unpivot,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_where(self, skip_where_token=False):
|
def _parse_where(self, skip_where_token=False):
|
||||||
if not skip_where_token and not self._match(TokenType.WHERE):
|
if not skip_where_token and not self._match(TokenType.WHERE):
|
||||||
|
@ -1539,11 +1531,7 @@ class Parser(metaclass=_Parser):
|
||||||
def _parse_grouping_sets(self):
|
def _parse_grouping_sets(self):
|
||||||
if not self._match(TokenType.GROUPING_SETS):
|
if not self._match(TokenType.GROUPING_SETS):
|
||||||
return None
|
return None
|
||||||
|
return self._parse_wrapped_csv(self._parse_grouping_set)
|
||||||
self._match_l_paren()
|
|
||||||
grouping_sets = self._parse_csv(self._parse_grouping_set)
|
|
||||||
self._match_r_paren()
|
|
||||||
return grouping_sets
|
|
||||||
|
|
||||||
def _parse_grouping_set(self):
|
def _parse_grouping_set(self):
|
||||||
if self._match(TokenType.L_PAREN):
|
if self._match(TokenType.L_PAREN):
|
||||||
|
@ -1573,7 +1561,6 @@ class Parser(metaclass=_Parser):
|
||||||
def _parse_sort(self, token_type, exp_class):
|
def _parse_sort(self, token_type, exp_class):
|
||||||
if not self._match(token_type):
|
if not self._match(token_type):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
|
return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
|
||||||
|
|
||||||
def _parse_ordered(self):
|
def _parse_ordered(self):
|
||||||
|
@ -1602,9 +1589,12 @@ class Parser(metaclass=_Parser):
|
||||||
if self._match(TokenType.TOP if top else TokenType.LIMIT):
|
if self._match(TokenType.TOP if top else TokenType.LIMIT):
|
||||||
limit_paren = self._match(TokenType.L_PAREN)
|
limit_paren = self._match(TokenType.L_PAREN)
|
||||||
limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number())
|
limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number())
|
||||||
|
|
||||||
if limit_paren:
|
if limit_paren:
|
||||||
self._match(TokenType.R_PAREN)
|
self._match_r_paren()
|
||||||
|
|
||||||
return limit_exp
|
return limit_exp
|
||||||
|
|
||||||
if self._match(TokenType.FETCH):
|
if self._match(TokenType.FETCH):
|
||||||
direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
|
direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
|
||||||
direction = self._prev.text if direction else "FIRST"
|
direction = self._prev.text if direction else "FIRST"
|
||||||
|
@ -1612,11 +1602,13 @@ class Parser(metaclass=_Parser):
|
||||||
self._match_set((TokenType.ROW, TokenType.ROWS))
|
self._match_set((TokenType.ROW, TokenType.ROWS))
|
||||||
self._match(TokenType.ONLY)
|
self._match(TokenType.ONLY)
|
||||||
return self.expression(exp.Fetch, direction=direction, count=count)
|
return self.expression(exp.Fetch, direction=direction, count=count)
|
||||||
|
|
||||||
return this
|
return this
|
||||||
|
|
||||||
def _parse_offset(self, this=None):
|
def _parse_offset(self, this=None):
|
||||||
if not self._match_set((TokenType.OFFSET, TokenType.COMMA)):
|
if not self._match_set((TokenType.OFFSET, TokenType.COMMA)):
|
||||||
return this
|
return this
|
||||||
|
|
||||||
count = self._parse_number()
|
count = self._parse_number()
|
||||||
self._match_set((TokenType.ROW, TokenType.ROWS))
|
self._match_set((TokenType.ROW, TokenType.ROWS))
|
||||||
return self.expression(exp.Offset, this=this, expression=count)
|
return self.expression(exp.Offset, this=this, expression=count)
|
||||||
|
@ -1678,6 +1670,7 @@ class Parser(metaclass=_Parser):
|
||||||
if self._match(TokenType.DISTINCT_FROM):
|
if self._match(TokenType.DISTINCT_FROM):
|
||||||
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
|
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
|
||||||
return self.expression(klass, this=this, expression=self._parse_expression())
|
return self.expression(klass, this=this, expression=self._parse_expression())
|
||||||
|
|
||||||
this = self.expression(
|
this = self.expression(
|
||||||
exp.Is,
|
exp.Is,
|
||||||
this=this,
|
this=this,
|
||||||
|
@ -1754,11 +1747,7 @@ class Parser(metaclass=_Parser):
|
||||||
|
|
||||||
def _parse_type(self):
|
def _parse_type(self):
|
||||||
if self._match(TokenType.INTERVAL):
|
if self._match(TokenType.INTERVAL):
|
||||||
return self.expression(
|
return self.expression(exp.Interval, this=self._parse_term(), unit=self._parse_var())
|
||||||
exp.Interval,
|
|
||||||
this=self._parse_term(),
|
|
||||||
unit=self._parse_var(),
|
|
||||||
)
|
|
||||||
|
|
||||||
index = self._index
|
index = self._index
|
||||||
type_token = self._parse_types(check_func=True)
|
type_token = self._parse_types(check_func=True)
|
||||||
|
@ -1824,30 +1813,18 @@ class Parser(metaclass=_Parser):
|
||||||
value = None
|
value = None
|
||||||
if type_token in self.TIMESTAMPS:
|
if type_token in self.TIMESTAMPS:
|
||||||
if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ:
|
if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ:
|
||||||
value = exp.DataType(
|
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions)
|
||||||
this=exp.DataType.Type.TIMESTAMPTZ,
|
|
||||||
expressions=expressions,
|
|
||||||
)
|
|
||||||
elif (
|
elif (
|
||||||
self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
|
self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
|
||||||
):
|
):
|
||||||
value = exp.DataType(
|
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions)
|
||||||
this=exp.DataType.Type.TIMESTAMPLTZ,
|
|
||||||
expressions=expressions,
|
|
||||||
)
|
|
||||||
elif self._match(TokenType.WITHOUT_TIME_ZONE):
|
elif self._match(TokenType.WITHOUT_TIME_ZONE):
|
||||||
value = exp.DataType(
|
value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
|
||||||
this=exp.DataType.Type.TIMESTAMP,
|
|
||||||
expressions=expressions,
|
|
||||||
)
|
|
||||||
|
|
||||||
maybe_func = maybe_func and value is None
|
maybe_func = maybe_func and value is None
|
||||||
|
|
||||||
if value is None:
|
if value is None:
|
||||||
value = exp.DataType(
|
value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
|
||||||
this=exp.DataType.Type.TIMESTAMP,
|
|
||||||
expressions=expressions,
|
|
||||||
)
|
|
||||||
|
|
||||||
if maybe_func and check_func:
|
if maybe_func and check_func:
|
||||||
index2 = self._index
|
index2 = self._index
|
||||||
|
@ -1872,6 +1849,7 @@ class Parser(metaclass=_Parser):
|
||||||
this = self._parse_id_var()
|
this = self._parse_id_var()
|
||||||
self._match(TokenType.COLON)
|
self._match(TokenType.COLON)
|
||||||
data_type = self._parse_types()
|
data_type = self._parse_types()
|
||||||
|
|
||||||
if not data_type:
|
if not data_type:
|
||||||
return None
|
return None
|
||||||
return self.expression(exp.StructKwarg, this=this, expression=data_type)
|
return self.expression(exp.StructKwarg, this=this, expression=data_type)
|
||||||
|
@ -1879,7 +1857,6 @@ class Parser(metaclass=_Parser):
|
||||||
def _parse_at_time_zone(self, this):
|
def _parse_at_time_zone(self, this):
|
||||||
if not self._match(TokenType.AT_TIME_ZONE):
|
if not self._match(TokenType.AT_TIME_ZONE):
|
||||||
return this
|
return this
|
||||||
|
|
||||||
return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary())
|
return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary())
|
||||||
|
|
||||||
def _parse_column(self):
|
def _parse_column(self):
|
||||||
|
@ -1984,16 +1961,14 @@ class Parser(metaclass=_Parser):
|
||||||
else:
|
else:
|
||||||
subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type)
|
subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type)
|
||||||
|
|
||||||
if subquery_predicate and self._curr.token_type in (
|
if subquery_predicate and self._curr.token_type in (TokenType.SELECT, TokenType.WITH):
|
||||||
TokenType.SELECT,
|
|
||||||
TokenType.WITH,
|
|
||||||
):
|
|
||||||
this = self.expression(subquery_predicate, this=self._parse_select())
|
this = self.expression(subquery_predicate, this=self._parse_select())
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
return this
|
return this
|
||||||
|
|
||||||
if functions is None:
|
if functions is None:
|
||||||
functions = self.FUNCTIONS
|
functions = self.FUNCTIONS
|
||||||
|
|
||||||
function = functions.get(upper)
|
function = functions.get(upper)
|
||||||
args = self._parse_csv(self._parse_lambda)
|
args = self._parse_csv(self._parse_lambda)
|
||||||
|
|
||||||
|
@ -2014,6 +1989,7 @@ class Parser(metaclass=_Parser):
|
||||||
|
|
||||||
if not self._match(TokenType.L_PAREN):
|
if not self._match(TokenType.L_PAREN):
|
||||||
return this
|
return this
|
||||||
|
|
||||||
expressions = self._parse_csv(self._parse_udf_kwarg)
|
expressions = self._parse_csv(self._parse_udf_kwarg)
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
|
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
|
||||||
|
@ -2021,25 +1997,19 @@ class Parser(metaclass=_Parser):
|
||||||
def _parse_introducer(self, token):
|
def _parse_introducer(self, token):
|
||||||
literal = self._parse_primary()
|
literal = self._parse_primary()
|
||||||
if literal:
|
if literal:
|
||||||
return self.expression(
|
return self.expression(exp.Introducer, this=token.text, expression=literal)
|
||||||
exp.Introducer,
|
|
||||||
this=token.text,
|
|
||||||
expression=literal,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.expression(exp.Identifier, this=token.text)
|
return self.expression(exp.Identifier, this=token.text)
|
||||||
|
|
||||||
def _parse_session_parameter(self):
|
def _parse_session_parameter(self):
|
||||||
kind = None
|
kind = None
|
||||||
this = self._parse_id_var() or self._parse_primary()
|
this = self._parse_id_var() or self._parse_primary()
|
||||||
|
|
||||||
if self._match(TokenType.DOT):
|
if self._match(TokenType.DOT):
|
||||||
kind = this.name
|
kind = this.name
|
||||||
this = self._parse_var() or self._parse_primary()
|
this = self._parse_var() or self._parse_primary()
|
||||||
return self.expression(
|
|
||||||
exp.SessionParameter,
|
return self.expression(exp.SessionParameter, this=this, kind=kind)
|
||||||
this=this,
|
|
||||||
kind=kind,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_udf_kwarg(self):
|
def _parse_udf_kwarg(self):
|
||||||
this = self._parse_id_var()
|
this = self._parse_id_var()
|
||||||
|
@ -2106,7 +2076,10 @@ class Parser(metaclass=_Parser):
|
||||||
return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints)
|
return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints)
|
||||||
|
|
||||||
def _parse_column_constraint(self):
|
def _parse_column_constraint(self):
|
||||||
this = None
|
this = self._parse_references()
|
||||||
|
|
||||||
|
if this:
|
||||||
|
return this
|
||||||
|
|
||||||
if self._match(TokenType.CONSTRAINT):
|
if self._match(TokenType.CONSTRAINT):
|
||||||
this = self._parse_id_var()
|
this = self._parse_id_var()
|
||||||
|
@ -2114,13 +2087,12 @@ class Parser(metaclass=_Parser):
|
||||||
if self._match(TokenType.AUTO_INCREMENT):
|
if self._match(TokenType.AUTO_INCREMENT):
|
||||||
kind = exp.AutoIncrementColumnConstraint()
|
kind = exp.AutoIncrementColumnConstraint()
|
||||||
elif self._match(TokenType.CHECK):
|
elif self._match(TokenType.CHECK):
|
||||||
self._match_l_paren()
|
constraint = self._parse_wrapped(self._parse_conjunction)
|
||||||
kind = self.expression(exp.CheckColumnConstraint, this=self._parse_conjunction())
|
kind = self.expression(exp.CheckColumnConstraint, this=constraint)
|
||||||
self._match_r_paren()
|
|
||||||
elif self._match(TokenType.COLLATE):
|
elif self._match(TokenType.COLLATE):
|
||||||
kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var())
|
kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var())
|
||||||
elif self._match(TokenType.DEFAULT):
|
elif self._match(TokenType.DEFAULT):
|
||||||
kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_field())
|
kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction())
|
||||||
elif self._match_pair(TokenType.NOT, TokenType.NULL):
|
elif self._match_pair(TokenType.NOT, TokenType.NULL):
|
||||||
kind = exp.NotNullColumnConstraint()
|
kind = exp.NotNullColumnConstraint()
|
||||||
elif self._match(TokenType.SCHEMA_COMMENT):
|
elif self._match(TokenType.SCHEMA_COMMENT):
|
||||||
|
@ -2137,7 +2109,7 @@ class Parser(metaclass=_Parser):
|
||||||
kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
|
kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
|
||||||
self._match_pair(TokenType.ALIAS, TokenType.IDENTITY)
|
self._match_pair(TokenType.ALIAS, TokenType.IDENTITY)
|
||||||
else:
|
else:
|
||||||
return None
|
return this
|
||||||
|
|
||||||
return self.expression(exp.ColumnConstraint, this=this, kind=kind)
|
return self.expression(exp.ColumnConstraint, this=this, kind=kind)
|
||||||
|
|
||||||
|
@ -2159,37 +2131,29 @@ class Parser(metaclass=_Parser):
|
||||||
def _parse_unnamed_constraint(self):
|
def _parse_unnamed_constraint(self):
|
||||||
if not self._match_set(self.CONSTRAINT_PARSERS):
|
if not self._match_set(self.CONSTRAINT_PARSERS):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self.CONSTRAINT_PARSERS[self._prev.token_type](self)
|
return self.CONSTRAINT_PARSERS[self._prev.token_type](self)
|
||||||
|
|
||||||
def _parse_check(self):
|
|
||||||
self._match(TokenType.CHECK)
|
|
||||||
self._match_l_paren()
|
|
||||||
expression = self._parse_conjunction()
|
|
||||||
self._match_r_paren()
|
|
||||||
|
|
||||||
return self.expression(exp.Check, this=expression)
|
|
||||||
|
|
||||||
def _parse_unique(self):
|
def _parse_unique(self):
|
||||||
self._match(TokenType.UNIQUE)
|
return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars())
|
||||||
columns = self._parse_wrapped_id_vars()
|
|
||||||
|
|
||||||
return self.expression(exp.Unique, expressions=columns)
|
def _parse_references(self):
|
||||||
|
if not self._match(TokenType.REFERENCES):
|
||||||
def _parse_foreign_key(self):
|
return None
|
||||||
self._match(TokenType.FOREIGN_KEY)
|
return self.expression(
|
||||||
|
|
||||||
expressions = self._parse_wrapped_id_vars()
|
|
||||||
reference = self._match(TokenType.REFERENCES) and self.expression(
|
|
||||||
exp.Reference,
|
exp.Reference,
|
||||||
this=self._parse_id_var(),
|
this=self._parse_id_var(),
|
||||||
expressions=self._parse_wrapped_id_vars(),
|
expressions=self._parse_wrapped_id_vars(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _parse_foreign_key(self):
|
||||||
|
expressions = self._parse_wrapped_id_vars()
|
||||||
|
reference = self._parse_references()
|
||||||
options = {}
|
options = {}
|
||||||
|
|
||||||
while self._match(TokenType.ON):
|
while self._match(TokenType.ON):
|
||||||
if not self._match_set((TokenType.DELETE, TokenType.UPDATE)):
|
if not self._match_set((TokenType.DELETE, TokenType.UPDATE)):
|
||||||
self.raise_error("Expected DELETE or UPDATE")
|
self.raise_error("Expected DELETE or UPDATE")
|
||||||
|
|
||||||
kind = self._prev.text.lower()
|
kind = self._prev.text.lower()
|
||||||
|
|
||||||
if self._match(TokenType.NO_ACTION):
|
if self._match(TokenType.NO_ACTION):
|
||||||
|
@ -2200,6 +2164,7 @@ class Parser(metaclass=_Parser):
|
||||||
else:
|
else:
|
||||||
self._advance()
|
self._advance()
|
||||||
action = self._prev.text.upper()
|
action = self._prev.text.upper()
|
||||||
|
|
||||||
options[kind] = action
|
options[kind] = action
|
||||||
|
|
||||||
return self.expression(
|
return self.expression(
|
||||||
|
@ -2363,20 +2328,14 @@ class Parser(metaclass=_Parser):
|
||||||
|
|
||||||
def _parse_window(self, this, alias=False):
|
def _parse_window(self, this, alias=False):
|
||||||
if self._match(TokenType.FILTER):
|
if self._match(TokenType.FILTER):
|
||||||
self._match_l_paren()
|
where = self._parse_wrapped(self._parse_where)
|
||||||
this = self.expression(exp.Filter, this=this, expression=self._parse_where())
|
this = self.expression(exp.Filter, this=this, expression=where)
|
||||||
self._match_r_paren()
|
|
||||||
|
|
||||||
# T-SQL allows the OVER (...) syntax after WITHIN GROUP.
|
# T-SQL allows the OVER (...) syntax after WITHIN GROUP.
|
||||||
# https://learn.microsoft.com/en-us/sql/t-sql/functions/percentile-disc-transact-sql?view=sql-server-ver16
|
# https://learn.microsoft.com/en-us/sql/t-sql/functions/percentile-disc-transact-sql?view=sql-server-ver16
|
||||||
if self._match(TokenType.WITHIN_GROUP):
|
if self._match(TokenType.WITHIN_GROUP):
|
||||||
self._match_l_paren()
|
order = self._parse_wrapped(self._parse_order)
|
||||||
this = self.expression(
|
this = self.expression(exp.WithinGroup, this=this, expression=order)
|
||||||
exp.WithinGroup,
|
|
||||||
this=this,
|
|
||||||
expression=self._parse_order(),
|
|
||||||
)
|
|
||||||
self._match_r_paren()
|
|
||||||
|
|
||||||
# SQL spec defines an optional [ { IGNORE | RESPECT } NULLS ] OVER
|
# SQL spec defines an optional [ { IGNORE | RESPECT } NULLS ] OVER
|
||||||
# Some dialects choose to implement and some do not.
|
# Some dialects choose to implement and some do not.
|
||||||
|
@ -2404,18 +2363,11 @@ class Parser(metaclass=_Parser):
|
||||||
return this
|
return this
|
||||||
|
|
||||||
if not self._match(TokenType.L_PAREN):
|
if not self._match(TokenType.L_PAREN):
|
||||||
alias = self._parse_id_var(False)
|
return self.expression(exp.Window, this=this, alias=self._parse_id_var(False))
|
||||||
|
|
||||||
return self.expression(
|
alias = self._parse_id_var(False)
|
||||||
exp.Window,
|
|
||||||
this=this,
|
|
||||||
alias=alias,
|
|
||||||
)
|
|
||||||
|
|
||||||
partition = None
|
partition = None
|
||||||
|
|
||||||
alias = self._parse_id_var(False)
|
|
||||||
|
|
||||||
if self._match(TokenType.PARTITION_BY):
|
if self._match(TokenType.PARTITION_BY):
|
||||||
partition = self._parse_csv(self._parse_conjunction)
|
partition = self._parse_csv(self._parse_conjunction)
|
||||||
|
|
||||||
|
@ -2552,17 +2504,13 @@ class Parser(metaclass=_Parser):
|
||||||
def _parse_replace(self):
|
def _parse_replace(self):
|
||||||
if not self._match(TokenType.REPLACE):
|
if not self._match(TokenType.REPLACE):
|
||||||
return None
|
return None
|
||||||
|
return self._parse_wrapped_csv(lambda: self._parse_alias(self._parse_expression()))
|
||||||
|
|
||||||
self._match_l_paren()
|
def _parse_csv(self, parse_method, sep=TokenType.COMMA):
|
||||||
columns = self._parse_csv(lambda: self._parse_alias(self._parse_expression()))
|
|
||||||
self._match_r_paren()
|
|
||||||
return columns
|
|
||||||
|
|
||||||
def _parse_csv(self, parse_method):
|
|
||||||
parse_result = parse_method()
|
parse_result = parse_method()
|
||||||
items = [parse_result] if parse_result is not None else []
|
items = [parse_result] if parse_result is not None else []
|
||||||
|
|
||||||
while self._match(TokenType.COMMA):
|
while self._match(sep):
|
||||||
if parse_result and self._prev_comment is not None:
|
if parse_result and self._prev_comment is not None:
|
||||||
parse_result.comment = self._prev_comment
|
parse_result.comment = self._prev_comment
|
||||||
|
|
||||||
|
@ -2583,16 +2531,53 @@ class Parser(metaclass=_Parser):
|
||||||
return this
|
return this
|
||||||
|
|
||||||
def _parse_wrapped_id_vars(self):
|
def _parse_wrapped_id_vars(self):
|
||||||
|
return self._parse_wrapped_csv(self._parse_id_var)
|
||||||
|
|
||||||
|
def _parse_wrapped_csv(self, parse_method, sep=TokenType.COMMA):
|
||||||
|
return self._parse_wrapped(lambda: self._parse_csv(parse_method, sep=sep))
|
||||||
|
|
||||||
|
def _parse_wrapped(self, parse_method):
|
||||||
self._match_l_paren()
|
self._match_l_paren()
|
||||||
expressions = self._parse_csv(self._parse_id_var)
|
parse_result = parse_method()
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
return expressions
|
return parse_result
|
||||||
|
|
||||||
def _parse_select_or_expression(self):
|
def _parse_select_or_expression(self):
|
||||||
return self._parse_select() or self._parse_expression()
|
return self._parse_select() or self._parse_expression()
|
||||||
|
|
||||||
def _parse_use(self):
|
def _parse_transaction(self):
|
||||||
return self.expression(exp.Use, this=self._parse_id_var())
|
this = None
|
||||||
|
if self._match_texts(self.TRANSACTION_KIND):
|
||||||
|
this = self._prev.text
|
||||||
|
|
||||||
|
self._match_texts({"TRANSACTION", "WORK"})
|
||||||
|
|
||||||
|
modes = []
|
||||||
|
while True:
|
||||||
|
mode = []
|
||||||
|
while self._match(TokenType.VAR):
|
||||||
|
mode.append(self._prev.text)
|
||||||
|
|
||||||
|
if mode:
|
||||||
|
modes.append(" ".join(mode))
|
||||||
|
if not self._match(TokenType.COMMA):
|
||||||
|
break
|
||||||
|
|
||||||
|
return self.expression(exp.Transaction, this=this, modes=modes)
|
||||||
|
|
||||||
|
def _parse_commit_or_rollback(self):
|
||||||
|
savepoint = None
|
||||||
|
is_rollback = self._prev.token_type == TokenType.ROLLBACK
|
||||||
|
|
||||||
|
self._match_texts({"TRANSACTION", "WORK"})
|
||||||
|
|
||||||
|
if self._match_text_seq("TO"):
|
||||||
|
self._match_text_seq("SAVEPOINT")
|
||||||
|
savepoint = self._parse_id_var()
|
||||||
|
|
||||||
|
if is_rollback:
|
||||||
|
return self.expression(exp.Rollback, savepoint=savepoint)
|
||||||
|
return self.expression(exp.Commit)
|
||||||
|
|
||||||
def _parse_show(self):
|
def _parse_show(self):
|
||||||
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
|
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
|
||||||
|
@ -2675,7 +2660,13 @@ class Parser(metaclass=_Parser):
|
||||||
if expression and self._prev_comment:
|
if expression and self._prev_comment:
|
||||||
expression.comment = self._prev_comment
|
expression.comment = self._prev_comment
|
||||||
|
|
||||||
def _match_text(self, *texts):
|
def _match_texts(self, texts):
|
||||||
|
if self._curr and self._curr.text.upper() in texts:
|
||||||
|
self._advance()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _match_text_seq(self, *texts):
|
||||||
index = self._index
|
index = self._index
|
||||||
for text in texts:
|
for text in texts:
|
||||||
if self._curr and self._curr.text.upper() == text:
|
if self._curr and self._curr.text.upper() == text:
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import math
|
import math
|
||||||
|
import typing as t
|
||||||
|
|
||||||
from sqlglot import alias, exp
|
from sqlglot import alias, exp
|
||||||
from sqlglot.errors import UnsupportedError
|
from sqlglot.errors import UnsupportedError
|
||||||
|
@ -7,15 +10,15 @@ from sqlglot.optimizer.eliminate_joins import join_condition
|
||||||
|
|
||||||
|
|
||||||
class Plan:
|
class Plan:
|
||||||
def __init__(self, expression):
|
def __init__(self, expression: exp.Expression) -> None:
|
||||||
self.expression = expression
|
self.expression = expression.copy()
|
||||||
self.root = Step.from_expression(self.expression)
|
self.root = Step.from_expression(self.expression)
|
||||||
self._dag = {}
|
self._dag: t.Dict[Step, t.Set[Step]] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dag(self):
|
def dag(self) -> t.Dict[Step, t.Set[Step]]:
|
||||||
if not self._dag:
|
if not self._dag:
|
||||||
dag = {}
|
dag: t.Dict[Step, t.Set[Step]] = {}
|
||||||
nodes = {self.root}
|
nodes = {self.root}
|
||||||
|
|
||||||
while nodes:
|
while nodes:
|
||||||
|
@ -29,32 +32,64 @@ class Plan:
|
||||||
return self._dag
|
return self._dag
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def leaves(self):
|
def leaves(self) -> t.Generator[Step, None, None]:
|
||||||
return (node for node, deps in self.dag.items() if not deps)
|
return (node for node, deps in self.dag.items() if not deps)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"Plan\n----\n{repr(self.root)}"
|
||||||
|
|
||||||
|
|
||||||
class Step:
|
class Step:
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_expression(cls, expression, ctes=None):
|
def from_expression(
|
||||||
|
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
|
||||||
|
) -> Step:
|
||||||
"""
|
"""
|
||||||
Build a DAG of Steps from a SQL expression.
|
Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine.
|
||||||
|
Note: the expression's tables and subqueries must be aliased for this method to work. For
|
||||||
|
example, given the following expression:
|
||||||
|
|
||||||
Giving an expression like:
|
SELECT
|
||||||
|
x.a,
|
||||||
SELECT x.a, SUM(x.b)
|
SUM(x.b)
|
||||||
FROM x
|
FROM x AS x
|
||||||
JOIN y
|
JOIN y AS y
|
||||||
ON x.a = y.a
|
ON x.a = y.a
|
||||||
GROUP BY x.a
|
GROUP BY x.a
|
||||||
|
|
||||||
Transform it into a DAG of the form:
|
the following DAG is produced (the expression IDs might differ per execution):
|
||||||
|
|
||||||
Aggregate(x.a, SUM(x.b))
|
- Aggregate: x (4347984624)
|
||||||
Join(y)
|
Context:
|
||||||
Scan(x)
|
Aggregations:
|
||||||
Scan(y)
|
- SUM(x.b)
|
||||||
|
Group:
|
||||||
|
- x.a
|
||||||
|
Projections:
|
||||||
|
- x.a
|
||||||
|
- "x".""
|
||||||
|
Dependencies:
|
||||||
|
- Join: x (4347985296)
|
||||||
|
Context:
|
||||||
|
y:
|
||||||
|
On: x.a = y.a
|
||||||
|
Projections:
|
||||||
|
Dependencies:
|
||||||
|
- Scan: x (4347983136)
|
||||||
|
Context:
|
||||||
|
Source: x AS x
|
||||||
|
Projections:
|
||||||
|
- Scan: y (4343416624)
|
||||||
|
Context:
|
||||||
|
Source: y AS y
|
||||||
|
Projections:
|
||||||
|
|
||||||
This can then more easily be executed on by an engine.
|
Args:
|
||||||
|
expression: the expression to build the DAG from.
|
||||||
|
ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Step DAG corresponding to `expression`.
|
||||||
"""
|
"""
|
||||||
ctes = ctes or {}
|
ctes = ctes or {}
|
||||||
with_ = expression.args.get("with")
|
with_ = expression.args.get("with")
|
||||||
|
@ -65,11 +100,11 @@ class Step:
|
||||||
for cte in with_.expressions:
|
for cte in with_.expressions:
|
||||||
step = Step.from_expression(cte.this, ctes)
|
step = Step.from_expression(cte.this, ctes)
|
||||||
step.name = cte.alias
|
step.name = cte.alias
|
||||||
ctes[step.name] = step
|
ctes[step.name] = step # type: ignore
|
||||||
|
|
||||||
from_ = expression.args.get("from")
|
from_ = expression.args.get("from")
|
||||||
|
|
||||||
if from_:
|
if isinstance(expression, exp.Select) and from_:
|
||||||
from_ = from_.expressions
|
from_ = from_.expressions
|
||||||
if len(from_) > 1:
|
if len(from_) > 1:
|
||||||
raise UnsupportedError(
|
raise UnsupportedError(
|
||||||
|
@ -77,8 +112,10 @@ class Step:
|
||||||
)
|
)
|
||||||
|
|
||||||
step = Scan.from_expression(from_[0], ctes)
|
step = Scan.from_expression(from_[0], ctes)
|
||||||
|
elif isinstance(expression, exp.Union):
|
||||||
|
step = SetOperation.from_expression(expression, ctes)
|
||||||
else:
|
else:
|
||||||
raise UnsupportedError("Static selects are unsupported.")
|
step = Scan()
|
||||||
|
|
||||||
joins = expression.args.get("joins")
|
joins = expression.args.get("joins")
|
||||||
|
|
||||||
|
@ -115,7 +152,7 @@ class Step:
|
||||||
|
|
||||||
group = expression.args.get("group")
|
group = expression.args.get("group")
|
||||||
|
|
||||||
if group:
|
if group or aggregations:
|
||||||
aggregate = Aggregate()
|
aggregate = Aggregate()
|
||||||
aggregate.source = step.name
|
aggregate.source = step.name
|
||||||
aggregate.name = step.name
|
aggregate.name = step.name
|
||||||
|
@ -123,7 +160,15 @@ class Step:
|
||||||
alias(operand, alias_) for operand, alias_ in operands.items()
|
alias(operand, alias_) for operand, alias_ in operands.items()
|
||||||
)
|
)
|
||||||
aggregate.aggregations = aggregations
|
aggregate.aggregations = aggregations
|
||||||
aggregate.group = group.expressions
|
# give aggregates names and replace projections with references to them
|
||||||
|
aggregate.group = {
|
||||||
|
f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
|
||||||
|
}
|
||||||
|
for projection in projections:
|
||||||
|
for i, e in aggregate.group.items():
|
||||||
|
for child, _, _ in projection.walk():
|
||||||
|
if child == e:
|
||||||
|
child.replace(exp.column(i, step.name))
|
||||||
aggregate.add_dependency(step)
|
aggregate.add_dependency(step)
|
||||||
step = aggregate
|
step = aggregate
|
||||||
|
|
||||||
|
@ -150,22 +195,22 @@ class Step:
|
||||||
|
|
||||||
return step
|
return step
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.name = None
|
self.name: t.Optional[str] = None
|
||||||
self.dependencies = set()
|
self.dependencies: t.Set[Step] = set()
|
||||||
self.dependents = set()
|
self.dependents: t.Set[Step] = set()
|
||||||
self.projections = []
|
self.projections: t.Sequence[exp.Expression] = []
|
||||||
self.limit = math.inf
|
self.limit: float = math.inf
|
||||||
self.condition = None
|
self.condition: t.Optional[exp.Expression] = None
|
||||||
|
|
||||||
def add_dependency(self, dependency):
|
def add_dependency(self, dependency: Step) -> None:
|
||||||
self.dependencies.add(dependency)
|
self.dependencies.add(dependency)
|
||||||
dependency.dependents.add(self)
|
dependency.dependents.add(self)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return self.to_s()
|
return self.to_s()
|
||||||
|
|
||||||
def to_s(self, level=0):
|
def to_s(self, level: int = 0) -> str:
|
||||||
indent = " " * level
|
indent = " " * level
|
||||||
nested = f"{indent} "
|
nested = f"{indent} "
|
||||||
|
|
||||||
|
@ -175,7 +220,7 @@ class Step:
|
||||||
context = [f"{nested}Context:"] + context
|
context = [f"{nested}Context:"] + context
|
||||||
|
|
||||||
lines = [
|
lines = [
|
||||||
f"{indent}- {self.__class__.__name__}: {self.name}",
|
f"{indent}- {self.id}",
|
||||||
*context,
|
*context,
|
||||||
f"{nested}Projections:",
|
f"{nested}Projections:",
|
||||||
]
|
]
|
||||||
|
@ -193,13 +238,25 @@ class Step:
|
||||||
|
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
def _to_s(self, _indent):
|
@property
|
||||||
|
def type_name(self) -> str:
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> str:
|
||||||
|
name = self.name
|
||||||
|
name = f" {name}" if name else ""
|
||||||
|
return f"{self.type_name}:{name} ({id(self)})"
|
||||||
|
|
||||||
|
def _to_s(self, _indent: str) -> t.List[str]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
class Scan(Step):
|
class Scan(Step):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_expression(cls, expression, ctes=None):
|
def from_expression(
|
||||||
|
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
|
||||||
|
) -> Step:
|
||||||
table = expression
|
table = expression
|
||||||
alias_ = expression.alias
|
alias_ = expression.alias
|
||||||
|
|
||||||
|
@ -217,26 +274,24 @@ class Scan(Step):
|
||||||
step = Scan()
|
step = Scan()
|
||||||
step.name = alias_
|
step.name = alias_
|
||||||
step.source = expression
|
step.source = expression
|
||||||
if table.name in ctes:
|
if ctes and table.name in ctes:
|
||||||
step.add_dependency(ctes[table.name])
|
step.add_dependency(ctes[table.name])
|
||||||
|
|
||||||
return step
|
return step
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.source = None
|
self.source: t.Optional[exp.Expression] = None
|
||||||
|
|
||||||
def _to_s(self, indent):
|
def _to_s(self, indent: str) -> t.List[str]:
|
||||||
return [f"{indent}Source: {self.source.sql()}"]
|
return [f"{indent}Source: {self.source.sql() if self.source else '-static-'}"] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class Write(Step):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Join(Step):
|
class Join(Step):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_joins(cls, joins, ctes=None):
|
def from_joins(
|
||||||
|
cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None
|
||||||
|
) -> Step:
|
||||||
step = Join()
|
step = Join()
|
||||||
|
|
||||||
for join in joins:
|
for join in joins:
|
||||||
|
@ -252,28 +307,28 @@ class Join(Step):
|
||||||
|
|
||||||
return step
|
return step
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.joins = {}
|
self.joins: t.Dict[str, t.Dict[str, t.List[str] | exp.Expression]] = {}
|
||||||
|
|
||||||
def _to_s(self, indent):
|
def _to_s(self, indent: str) -> t.List[str]:
|
||||||
lines = []
|
lines = []
|
||||||
for name, join in self.joins.items():
|
for name, join in self.joins.items():
|
||||||
lines.append(f"{indent}{name}: {join['side']}")
|
lines.append(f"{indent}{name}: {join['side']}")
|
||||||
if join.get("condition"):
|
if join.get("condition"):
|
||||||
lines.append(f"{indent}On: {join['condition'].sql()}")
|
lines.append(f"{indent}On: {join['condition'].sql()}") # type: ignore
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|
||||||
class Aggregate(Step):
|
class Aggregate(Step):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.aggregations = []
|
self.aggregations: t.List[exp.Expression] = []
|
||||||
self.operands = []
|
self.operands: t.Tuple[exp.Expression, ...] = ()
|
||||||
self.group = []
|
self.group: t.Dict[str, exp.Expression] = {}
|
||||||
self.source = None
|
self.source: t.Optional[str] = None
|
||||||
|
|
||||||
def _to_s(self, indent):
|
def _to_s(self, indent: str) -> t.List[str]:
|
||||||
lines = [f"{indent}Aggregations:"]
|
lines = [f"{indent}Aggregations:"]
|
||||||
|
|
||||||
for expression in self.aggregations:
|
for expression in self.aggregations:
|
||||||
|
@ -281,7 +336,7 @@ class Aggregate(Step):
|
||||||
|
|
||||||
if self.group:
|
if self.group:
|
||||||
lines.append(f"{indent}Group:")
|
lines.append(f"{indent}Group:")
|
||||||
for expression in self.group:
|
for expression in self.group.values():
|
||||||
lines.append(f"{indent} - {expression.sql()}")
|
lines.append(f"{indent} - {expression.sql()}")
|
||||||
if self.operands:
|
if self.operands:
|
||||||
lines.append(f"{indent}Operands:")
|
lines.append(f"{indent}Operands:")
|
||||||
|
@ -292,14 +347,56 @@ class Aggregate(Step):
|
||||||
|
|
||||||
|
|
||||||
class Sort(Step):
|
class Sort(Step):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.key = None
|
self.key = None
|
||||||
|
|
||||||
def _to_s(self, indent):
|
def _to_s(self, indent: str) -> t.List[str]:
|
||||||
lines = [f"{indent}Key:"]
|
lines = [f"{indent}Key:"]
|
||||||
|
|
||||||
for expression in self.key:
|
for expression in self.key: # type: ignore
|
||||||
lines.append(f"{indent} - {expression.sql()}")
|
lines.append(f"{indent} - {expression.sql()}")
|
||||||
|
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|
||||||
|
class SetOperation(Step):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
op: t.Type[exp.Expression],
|
||||||
|
left: str | None,
|
||||||
|
right: str | None,
|
||||||
|
distinct: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.op = op
|
||||||
|
self.left = left
|
||||||
|
self.right = right
|
||||||
|
self.distinct = distinct
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_expression(
|
||||||
|
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
|
||||||
|
) -> Step:
|
||||||
|
assert isinstance(expression, exp.Union)
|
||||||
|
left = Step.from_expression(expression.left, ctes)
|
||||||
|
right = Step.from_expression(expression.right, ctes)
|
||||||
|
step = cls(
|
||||||
|
op=expression.__class__,
|
||||||
|
left=left.name,
|
||||||
|
right=right.name,
|
||||||
|
distinct=expression.args.get("distinct"),
|
||||||
|
)
|
||||||
|
step.add_dependency(left)
|
||||||
|
step.add_dependency(right)
|
||||||
|
return step
|
||||||
|
|
||||||
|
def _to_s(self, indent: str) -> t.List[str]:
|
||||||
|
lines = []
|
||||||
|
if self.distinct:
|
||||||
|
lines.append(f"{indent}Distinct: {self.distinct}")
|
||||||
|
return lines
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type_name(self) -> str:
|
||||||
|
return self.op.__name__
|
||||||
|
|
|
@ -5,7 +5,7 @@ import typing as t
|
||||||
|
|
||||||
from sqlglot import expressions as exp
|
from sqlglot import expressions as exp
|
||||||
from sqlglot.errors import SchemaError
|
from sqlglot.errors import SchemaError
|
||||||
from sqlglot.helper import csv_reader
|
from sqlglot.helper import dict_depth
|
||||||
from sqlglot.trie import in_trie, new_trie
|
from sqlglot.trie import in_trie, new_trie
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
|
@ -15,6 +15,8 @@ if t.TYPE_CHECKING:
|
||||||
|
|
||||||
TABLE_ARGS = ("this", "db", "catalog")
|
TABLE_ARGS = ("this", "db", "catalog")
|
||||||
|
|
||||||
|
T = t.TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class Schema(abc.ABC):
|
class Schema(abc.ABC):
|
||||||
"""Abstract base class for database schemas"""
|
"""Abstract base class for database schemas"""
|
||||||
|
@ -57,8 +59,81 @@ class Schema(abc.ABC):
|
||||||
The resulting column type.
|
The resulting column type.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_table_args(self) -> t.Tuple[str, ...]:
|
||||||
|
"""
|
||||||
|
Table arguments this schema support, e.g. `("this", "db", "catalog")`
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
class MappingSchema(Schema):
|
|
||||||
|
class AbstractMappingSchema(t.Generic[T]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mapping: dict | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.mapping = mapping or {}
|
||||||
|
self.mapping_trie = self._build_trie(self.mapping)
|
||||||
|
self._supported_table_args: t.Tuple[str, ...] = tuple()
|
||||||
|
|
||||||
|
def _build_trie(self, schema: t.Dict) -> t.Dict:
|
||||||
|
return new_trie(tuple(reversed(t)) for t in flatten_schema(schema, depth=self._depth()))
|
||||||
|
|
||||||
|
def _depth(self) -> int:
|
||||||
|
return dict_depth(self.mapping)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_table_args(self) -> t.Tuple[str, ...]:
|
||||||
|
if not self._supported_table_args and self.mapping:
|
||||||
|
depth = self._depth()
|
||||||
|
|
||||||
|
if not depth: # None
|
||||||
|
self._supported_table_args = tuple()
|
||||||
|
elif 1 <= depth <= 3:
|
||||||
|
self._supported_table_args = TABLE_ARGS[:depth]
|
||||||
|
else:
|
||||||
|
raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
|
||||||
|
|
||||||
|
return self._supported_table_args
|
||||||
|
|
||||||
|
def table_parts(self, table: exp.Table) -> t.List[str]:
|
||||||
|
if isinstance(table.this, exp.ReadCSV):
|
||||||
|
return [table.this.name]
|
||||||
|
return [table.text(part) for part in TABLE_ARGS if table.text(part)]
|
||||||
|
|
||||||
|
def find(
|
||||||
|
self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
|
||||||
|
) -> t.Optional[T]:
|
||||||
|
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
|
||||||
|
value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
|
||||||
|
|
||||||
|
if value == 0:
|
||||||
|
if raise_on_missing:
|
||||||
|
raise SchemaError(f"Cannot find mapping for {table}.")
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
elif value == 1:
|
||||||
|
possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
|
||||||
|
if len(possibilities) == 1:
|
||||||
|
parts.extend(possibilities[0])
|
||||||
|
else:
|
||||||
|
message = ", ".join(".".join(parts) for parts in possibilities)
|
||||||
|
if raise_on_missing:
|
||||||
|
raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
|
||||||
|
return None
|
||||||
|
return self._nested_get(parts, raise_on_missing=raise_on_missing)
|
||||||
|
|
||||||
|
def _nested_get(
|
||||||
|
self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
|
||||||
|
) -> t.Optional[t.Any]:
|
||||||
|
return _nested_get(
|
||||||
|
d or self.mapping,
|
||||||
|
*zip(self.supported_table_args, reversed(parts)),
|
||||||
|
raise_on_missing=raise_on_missing,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
||||||
"""
|
"""
|
||||||
Schema based on a nested mapping.
|
Schema based on a nested mapping.
|
||||||
|
|
||||||
|
@ -82,17 +157,17 @@ class MappingSchema(Schema):
|
||||||
visible: t.Optional[t.Dict] = None,
|
visible: t.Optional[t.Dict] = None,
|
||||||
dialect: t.Optional[str] = None,
|
dialect: t.Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.schema = schema or {}
|
super().__init__(schema)
|
||||||
self.visible = visible or {}
|
self.visible = visible or {}
|
||||||
self.schema_trie = self._build_trie(self.schema)
|
|
||||||
self.dialect = dialect
|
self.dialect = dialect
|
||||||
self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {}
|
self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {
|
||||||
self._supported_table_args: t.Tuple[str, ...] = tuple()
|
"STR": exp.DataType.Type.TEXT,
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
|
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
|
||||||
return MappingSchema(
|
return MappingSchema(
|
||||||
schema=mapping_schema.schema,
|
schema=mapping_schema.mapping,
|
||||||
visible=mapping_schema.visible,
|
visible=mapping_schema.visible,
|
||||||
dialect=mapping_schema.dialect,
|
dialect=mapping_schema.dialect,
|
||||||
)
|
)
|
||||||
|
@ -100,27 +175,13 @@ class MappingSchema(Schema):
|
||||||
def copy(self, **kwargs) -> MappingSchema:
|
def copy(self, **kwargs) -> MappingSchema:
|
||||||
return MappingSchema(
|
return MappingSchema(
|
||||||
**{ # type: ignore
|
**{ # type: ignore
|
||||||
"schema": self.schema.copy(),
|
"schema": self.mapping.copy(),
|
||||||
"visible": self.visible.copy(),
|
"visible": self.visible.copy(),
|
||||||
"dialect": self.dialect,
|
"dialect": self.dialect,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_table_args(self):
|
|
||||||
if not self._supported_table_args and self.schema:
|
|
||||||
depth = _dict_depth(self.schema)
|
|
||||||
|
|
||||||
if not depth or depth == 1: # {}
|
|
||||||
self._supported_table_args = tuple()
|
|
||||||
elif 2 <= depth <= 4:
|
|
||||||
self._supported_table_args = TABLE_ARGS[: depth - 1]
|
|
||||||
else:
|
|
||||||
raise SchemaError(f"Invalid schema shape. Depth: {depth}")
|
|
||||||
|
|
||||||
return self._supported_table_args
|
|
||||||
|
|
||||||
def add_table(
|
def add_table(
|
||||||
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
|
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -133,17 +194,21 @@ class MappingSchema(Schema):
|
||||||
"""
|
"""
|
||||||
table_ = self._ensure_table(table)
|
table_ = self._ensure_table(table)
|
||||||
column_mapping = ensure_column_mapping(column_mapping)
|
column_mapping = ensure_column_mapping(column_mapping)
|
||||||
schema = self.find_schema(table_, raise_on_missing=False)
|
schema = self.find(table_, raise_on_missing=False)
|
||||||
|
|
||||||
if schema and not column_mapping:
|
if schema and not column_mapping:
|
||||||
return
|
return
|
||||||
|
|
||||||
_nested_set(
|
_nested_set(
|
||||||
self.schema,
|
self.mapping,
|
||||||
list(reversed(self.table_parts(table_))),
|
list(reversed(self.table_parts(table_))),
|
||||||
column_mapping,
|
column_mapping,
|
||||||
)
|
)
|
||||||
self.schema_trie = self._build_trie(self.schema)
|
self.mapping_trie = self._build_trie(self.mapping)
|
||||||
|
|
||||||
|
def _depth(self) -> int:
|
||||||
|
# The columns themselves are a mapping, but we don't want to include those
|
||||||
|
return super()._depth() - 1
|
||||||
|
|
||||||
def _ensure_table(self, table: exp.Table | str) -> exp.Table:
|
def _ensure_table(self, table: exp.Table | str) -> exp.Table:
|
||||||
table_ = exp.to_table(table)
|
table_ = exp.to_table(table)
|
||||||
|
@ -153,16 +218,9 @@ class MappingSchema(Schema):
|
||||||
|
|
||||||
return table_
|
return table_
|
||||||
|
|
||||||
def table_parts(self, table: exp.Table) -> t.List[str]:
|
|
||||||
return [table.text(part) for part in TABLE_ARGS if table.text(part)]
|
|
||||||
|
|
||||||
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
|
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
|
||||||
table_ = self._ensure_table(table)
|
table_ = self._ensure_table(table)
|
||||||
|
schema = self.find(table_)
|
||||||
if not isinstance(table_.this, exp.Identifier):
|
|
||||||
return fs_get(table) # type: ignore
|
|
||||||
|
|
||||||
schema = self.find_schema(table_)
|
|
||||||
|
|
||||||
if schema is None:
|
if schema is None:
|
||||||
raise SchemaError(f"Could not find table schema {table}")
|
raise SchemaError(f"Could not find table schema {table}")
|
||||||
|
@ -173,36 +231,13 @@ class MappingSchema(Schema):
|
||||||
visible = self._nested_get(self.table_parts(table_), self.visible)
|
visible = self._nested_get(self.table_parts(table_), self.visible)
|
||||||
return [col for col in schema if col in visible] # type: ignore
|
return [col for col in schema if col in visible] # type: ignore
|
||||||
|
|
||||||
def find_schema(
|
|
||||||
self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
|
|
||||||
) -> t.Optional[t.Dict[str, str]]:
|
|
||||||
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
|
|
||||||
value, trie = in_trie(self.schema_trie if trie is None else trie, parts)
|
|
||||||
|
|
||||||
if value == 0:
|
|
||||||
if raise_on_missing:
|
|
||||||
raise SchemaError(f"Cannot find schema for {table}.")
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
elif value == 1:
|
|
||||||
possibilities = flatten_schema(trie)
|
|
||||||
if len(possibilities) == 1:
|
|
||||||
parts.extend(possibilities[0])
|
|
||||||
else:
|
|
||||||
message = ", ".join(".".join(parts) for parts in possibilities)
|
|
||||||
if raise_on_missing:
|
|
||||||
raise SchemaError(f"Ambiguous schema for {table}: {message}.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return self._nested_get(parts, raise_on_missing=raise_on_missing)
|
|
||||||
|
|
||||||
def get_column_type(
|
def get_column_type(
|
||||||
self, table: exp.Table | str, column: exp.Column | str
|
self, table: exp.Table | str, column: exp.Column | str
|
||||||
) -> exp.DataType.Type:
|
) -> exp.DataType.Type:
|
||||||
column_name = column if isinstance(column, str) else column.name
|
column_name = column if isinstance(column, str) else column.name
|
||||||
table_ = exp.to_table(table)
|
table_ = exp.to_table(table)
|
||||||
if table_:
|
if table_:
|
||||||
table_schema = self.find_schema(table_)
|
table_schema = self.find(table_)
|
||||||
schema_type = table_schema.get(column_name).upper() # type: ignore
|
schema_type = table_schema.get(column_name).upper() # type: ignore
|
||||||
return self._convert_type(schema_type)
|
return self._convert_type(schema_type)
|
||||||
raise SchemaError(f"Could not convert table '{table}'")
|
raise SchemaError(f"Could not convert table '{table}'")
|
||||||
|
@ -228,18 +263,6 @@ class MappingSchema(Schema):
|
||||||
|
|
||||||
return self._type_mapping_cache[schema_type]
|
return self._type_mapping_cache[schema_type]
|
||||||
|
|
||||||
def _build_trie(self, schema: t.Dict):
|
|
||||||
return new_trie(tuple(reversed(t)) for t in flatten_schema(schema))
|
|
||||||
|
|
||||||
def _nested_get(
|
|
||||||
self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
|
|
||||||
) -> t.Optional[t.Any]:
|
|
||||||
return _nested_get(
|
|
||||||
d or self.schema,
|
|
||||||
*zip(self.supported_table_args, reversed(parts)),
|
|
||||||
raise_on_missing=raise_on_missing,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_schema(schema: t.Any) -> Schema:
|
def ensure_schema(schema: t.Any) -> Schema:
|
||||||
if isinstance(schema, Schema):
|
if isinstance(schema, Schema):
|
||||||
|
@ -267,29 +290,20 @@ def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
|
||||||
raise ValueError(f"Invalid mapping provided: {type(mapping)}")
|
raise ValueError(f"Invalid mapping provided: {type(mapping)}")
|
||||||
|
|
||||||
|
|
||||||
def flatten_schema(schema: t.Dict, keys: t.Optional[t.List[str]] = None) -> t.List[t.List[str]]:
|
def flatten_schema(
|
||||||
|
schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
|
||||||
|
) -> t.List[t.List[str]]:
|
||||||
tables = []
|
tables = []
|
||||||
keys = keys or []
|
keys = keys or []
|
||||||
depth = _dict_depth(schema)
|
|
||||||
|
|
||||||
for k, v in schema.items():
|
for k, v in schema.items():
|
||||||
if depth >= 3:
|
if depth >= 2:
|
||||||
tables.extend(flatten_schema(v, keys + [k]))
|
tables.extend(flatten_schema(v, depth - 1, keys + [k]))
|
||||||
elif depth == 2:
|
elif depth == 1:
|
||||||
tables.append(keys + [k])
|
tables.append(keys + [k])
|
||||||
return tables
|
return tables
|
||||||
|
|
||||||
|
|
||||||
def fs_get(table: exp.Table) -> t.List[str]:
|
|
||||||
name = table.this.name
|
|
||||||
|
|
||||||
if name.upper() == "READ_CSV":
|
|
||||||
with csv_reader(table) as reader:
|
|
||||||
return next(reader)
|
|
||||||
|
|
||||||
raise ValueError(f"Cannot read schema for {table}")
|
|
||||||
|
|
||||||
|
|
||||||
def _nested_get(
|
def _nested_get(
|
||||||
d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
|
d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
|
||||||
) -> t.Optional[t.Any]:
|
) -> t.Optional[t.Any]:
|
||||||
|
@ -310,7 +324,7 @@ def _nested_get(
|
||||||
if d is None:
|
if d is None:
|
||||||
if raise_on_missing:
|
if raise_on_missing:
|
||||||
name = "table" if name == "this" else name
|
name = "table" if name == "this" else name
|
||||||
raise ValueError(f"Unknown {name}")
|
raise ValueError(f"Unknown {name}: {key}")
|
||||||
return None
|
return None
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
@ -350,34 +364,3 @@ def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
|
||||||
|
|
||||||
subd[keys[-1]] = value
|
subd[keys[-1]] = value
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def _dict_depth(d: t.Dict) -> int:
|
|
||||||
"""
|
|
||||||
Get the nesting depth of a dictionary.
|
|
||||||
|
|
||||||
For example:
|
|
||||||
>>> _dict_depth(None)
|
|
||||||
0
|
|
||||||
>>> _dict_depth({})
|
|
||||||
1
|
|
||||||
>>> _dict_depth({"a": "b"})
|
|
||||||
1
|
|
||||||
>>> _dict_depth({"a": {}})
|
|
||||||
2
|
|
||||||
>>> _dict_depth({"a": {"b": {}}})
|
|
||||||
3
|
|
||||||
|
|
||||||
Args:
|
|
||||||
d (dict): dictionary
|
|
||||||
Returns:
|
|
||||||
int: depth
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return 1 + _dict_depth(next(iter(d.values())))
|
|
||||||
except AttributeError:
|
|
||||||
# d doesn't have attribute "values"
|
|
||||||
return 0
|
|
||||||
except StopIteration:
|
|
||||||
# d.values() returns an empty sequence
|
|
||||||
return 1
|
|
||||||
|
|
|
@ -105,12 +105,9 @@ class TokenType(AutoName):
|
||||||
OBJECT = auto()
|
OBJECT = auto()
|
||||||
|
|
||||||
# keywords
|
# keywords
|
||||||
ADD_FILE = auto()
|
|
||||||
ALIAS = auto()
|
ALIAS = auto()
|
||||||
ALWAYS = auto()
|
ALWAYS = auto()
|
||||||
ALL = auto()
|
ALL = auto()
|
||||||
ALTER = auto()
|
|
||||||
ANALYZE = auto()
|
|
||||||
ANTI = auto()
|
ANTI = auto()
|
||||||
ANY = auto()
|
ANY = auto()
|
||||||
APPLY = auto()
|
APPLY = auto()
|
||||||
|
@ -124,14 +121,14 @@ class TokenType(AutoName):
|
||||||
BUCKET = auto()
|
BUCKET = auto()
|
||||||
BY_DEFAULT = auto()
|
BY_DEFAULT = auto()
|
||||||
CACHE = auto()
|
CACHE = auto()
|
||||||
CALL = auto()
|
CASCADE = auto()
|
||||||
CASE = auto()
|
CASE = auto()
|
||||||
CHARACTER_SET = auto()
|
CHARACTER_SET = auto()
|
||||||
CHECK = auto()
|
CHECK = auto()
|
||||||
CLUSTER_BY = auto()
|
CLUSTER_BY = auto()
|
||||||
COLLATE = auto()
|
COLLATE = auto()
|
||||||
|
COMMAND = auto()
|
||||||
COMMENT = auto()
|
COMMENT = auto()
|
||||||
COMMENT_ON = auto()
|
|
||||||
COMMIT = auto()
|
COMMIT = auto()
|
||||||
CONSTRAINT = auto()
|
CONSTRAINT = auto()
|
||||||
CREATE = auto()
|
CREATE = auto()
|
||||||
|
@ -149,7 +146,9 @@ class TokenType(AutoName):
|
||||||
DETERMINISTIC = auto()
|
DETERMINISTIC = auto()
|
||||||
DISTINCT = auto()
|
DISTINCT = auto()
|
||||||
DISTINCT_FROM = auto()
|
DISTINCT_FROM = auto()
|
||||||
|
DISTKEY = auto()
|
||||||
DISTRIBUTE_BY = auto()
|
DISTRIBUTE_BY = auto()
|
||||||
|
DISTSTYLE = auto()
|
||||||
DIV = auto()
|
DIV = auto()
|
||||||
DROP = auto()
|
DROP = auto()
|
||||||
ELSE = auto()
|
ELSE = auto()
|
||||||
|
@ -159,7 +158,6 @@ class TokenType(AutoName):
|
||||||
EXCEPT = auto()
|
EXCEPT = auto()
|
||||||
EXECUTE = auto()
|
EXECUTE = auto()
|
||||||
EXISTS = auto()
|
EXISTS = auto()
|
||||||
EXPLAIN = auto()
|
|
||||||
FALSE = auto()
|
FALSE = auto()
|
||||||
FETCH = auto()
|
FETCH = auto()
|
||||||
FILTER = auto()
|
FILTER = auto()
|
||||||
|
@ -216,7 +214,6 @@ class TokenType(AutoName):
|
||||||
OFFSET = auto()
|
OFFSET = auto()
|
||||||
ON = auto()
|
ON = auto()
|
||||||
ONLY = auto()
|
ONLY = auto()
|
||||||
OPTIMIZE = auto()
|
|
||||||
OPTIONS = auto()
|
OPTIONS = auto()
|
||||||
ORDER_BY = auto()
|
ORDER_BY = auto()
|
||||||
ORDERED = auto()
|
ORDERED = auto()
|
||||||
|
@ -258,6 +255,7 @@ class TokenType(AutoName):
|
||||||
SHOW = auto()
|
SHOW = auto()
|
||||||
SIMILAR_TO = auto()
|
SIMILAR_TO = auto()
|
||||||
SOME = auto()
|
SOME = auto()
|
||||||
|
SORTKEY = auto()
|
||||||
SORT_BY = auto()
|
SORT_BY = auto()
|
||||||
STABLE = auto()
|
STABLE = auto()
|
||||||
STORED = auto()
|
STORED = auto()
|
||||||
|
@ -268,9 +266,8 @@ class TokenType(AutoName):
|
||||||
TRANSIENT = auto()
|
TRANSIENT = auto()
|
||||||
TOP = auto()
|
TOP = auto()
|
||||||
THEN = auto()
|
THEN = auto()
|
||||||
TRUE = auto()
|
|
||||||
TRAILING = auto()
|
TRAILING = auto()
|
||||||
TRUNCATE = auto()
|
TRUE = auto()
|
||||||
UNBOUNDED = auto()
|
UNBOUNDED = auto()
|
||||||
UNCACHE = auto()
|
UNCACHE = auto()
|
||||||
UNION = auto()
|
UNION = auto()
|
||||||
|
@ -280,7 +277,6 @@ class TokenType(AutoName):
|
||||||
USE = auto()
|
USE = auto()
|
||||||
USING = auto()
|
USING = auto()
|
||||||
VALUES = auto()
|
VALUES = auto()
|
||||||
VACUUM = auto()
|
|
||||||
VIEW = auto()
|
VIEW = auto()
|
||||||
VOLATILE = auto()
|
VOLATILE = auto()
|
||||||
WHEN = auto()
|
WHEN = auto()
|
||||||
|
@ -420,7 +416,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
"/*+": TokenType.HINT,
|
"/*+": TokenType.HINT,
|
||||||
"*/": TokenType.HINT,
|
|
||||||
"==": TokenType.EQ,
|
"==": TokenType.EQ,
|
||||||
"::": TokenType.DCOLON,
|
"::": TokenType.DCOLON,
|
||||||
"||": TokenType.DPIPE,
|
"||": TokenType.DPIPE,
|
||||||
|
@ -435,15 +430,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"#>": TokenType.HASH_ARROW,
|
"#>": TokenType.HASH_ARROW,
|
||||||
"#>>": TokenType.DHASH_ARROW,
|
"#>>": TokenType.DHASH_ARROW,
|
||||||
"<->": TokenType.LR_ARROW,
|
"<->": TokenType.LR_ARROW,
|
||||||
"ADD ARCHIVE": TokenType.ADD_FILE,
|
|
||||||
"ADD ARCHIVES": TokenType.ADD_FILE,
|
|
||||||
"ADD FILE": TokenType.ADD_FILE,
|
|
||||||
"ADD FILES": TokenType.ADD_FILE,
|
|
||||||
"ADD JAR": TokenType.ADD_FILE,
|
|
||||||
"ADD JARS": TokenType.ADD_FILE,
|
|
||||||
"ALL": TokenType.ALL,
|
"ALL": TokenType.ALL,
|
||||||
"ALTER": TokenType.ALTER,
|
|
||||||
"ANALYZE": TokenType.ANALYZE,
|
|
||||||
"AND": TokenType.AND,
|
"AND": TokenType.AND,
|
||||||
"ANTI": TokenType.ANTI,
|
"ANTI": TokenType.ANTI,
|
||||||
"ANY": TokenType.ANY,
|
"ANY": TokenType.ANY,
|
||||||
|
@ -455,10 +442,10 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"BETWEEN": TokenType.BETWEEN,
|
"BETWEEN": TokenType.BETWEEN,
|
||||||
"BOTH": TokenType.BOTH,
|
"BOTH": TokenType.BOTH,
|
||||||
"BUCKET": TokenType.BUCKET,
|
"BUCKET": TokenType.BUCKET,
|
||||||
"CALL": TokenType.CALL,
|
|
||||||
"CACHE": TokenType.CACHE,
|
"CACHE": TokenType.CACHE,
|
||||||
"UNCACHE": TokenType.UNCACHE,
|
"UNCACHE": TokenType.UNCACHE,
|
||||||
"CASE": TokenType.CASE,
|
"CASE": TokenType.CASE,
|
||||||
|
"CASCADE": TokenType.CASCADE,
|
||||||
"CHARACTER SET": TokenType.CHARACTER_SET,
|
"CHARACTER SET": TokenType.CHARACTER_SET,
|
||||||
"CHECK": TokenType.CHECK,
|
"CHECK": TokenType.CHECK,
|
||||||
"CLUSTER BY": TokenType.CLUSTER_BY,
|
"CLUSTER BY": TokenType.CLUSTER_BY,
|
||||||
|
@ -479,7 +466,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"DETERMINISTIC": TokenType.DETERMINISTIC,
|
"DETERMINISTIC": TokenType.DETERMINISTIC,
|
||||||
"DISTINCT": TokenType.DISTINCT,
|
"DISTINCT": TokenType.DISTINCT,
|
||||||
"DISTINCT FROM": TokenType.DISTINCT_FROM,
|
"DISTINCT FROM": TokenType.DISTINCT_FROM,
|
||||||
|
"DISTKEY": TokenType.DISTKEY,
|
||||||
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
|
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
|
||||||
|
"DISTSTYLE": TokenType.DISTSTYLE,
|
||||||
"DIV": TokenType.DIV,
|
"DIV": TokenType.DIV,
|
||||||
"DROP": TokenType.DROP,
|
"DROP": TokenType.DROP,
|
||||||
"ELSE": TokenType.ELSE,
|
"ELSE": TokenType.ELSE,
|
||||||
|
@ -489,7 +478,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"EXCEPT": TokenType.EXCEPT,
|
"EXCEPT": TokenType.EXCEPT,
|
||||||
"EXECUTE": TokenType.EXECUTE,
|
"EXECUTE": TokenType.EXECUTE,
|
||||||
"EXISTS": TokenType.EXISTS,
|
"EXISTS": TokenType.EXISTS,
|
||||||
"EXPLAIN": TokenType.EXPLAIN,
|
|
||||||
"FALSE": TokenType.FALSE,
|
"FALSE": TokenType.FALSE,
|
||||||
"FETCH": TokenType.FETCH,
|
"FETCH": TokenType.FETCH,
|
||||||
"FILTER": TokenType.FILTER,
|
"FILTER": TokenType.FILTER,
|
||||||
|
@ -541,7 +529,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"OFFSET": TokenType.OFFSET,
|
"OFFSET": TokenType.OFFSET,
|
||||||
"ON": TokenType.ON,
|
"ON": TokenType.ON,
|
||||||
"ONLY": TokenType.ONLY,
|
"ONLY": TokenType.ONLY,
|
||||||
"OPTIMIZE": TokenType.OPTIMIZE,
|
|
||||||
"OPTIONS": TokenType.OPTIONS,
|
"OPTIONS": TokenType.OPTIONS,
|
||||||
"OR": TokenType.OR,
|
"OR": TokenType.OR,
|
||||||
"ORDER BY": TokenType.ORDER_BY,
|
"ORDER BY": TokenType.ORDER_BY,
|
||||||
|
@ -579,6 +566,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"SET": TokenType.SET,
|
"SET": TokenType.SET,
|
||||||
"SHOW": TokenType.SHOW,
|
"SHOW": TokenType.SHOW,
|
||||||
"SOME": TokenType.SOME,
|
"SOME": TokenType.SOME,
|
||||||
|
"SORTKEY": TokenType.SORTKEY,
|
||||||
"SORT BY": TokenType.SORT_BY,
|
"SORT BY": TokenType.SORT_BY,
|
||||||
"STABLE": TokenType.STABLE,
|
"STABLE": TokenType.STABLE,
|
||||||
"STORED": TokenType.STORED,
|
"STORED": TokenType.STORED,
|
||||||
|
@ -592,7 +580,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"THEN": TokenType.THEN,
|
"THEN": TokenType.THEN,
|
||||||
"TRUE": TokenType.TRUE,
|
"TRUE": TokenType.TRUE,
|
||||||
"TRAILING": TokenType.TRAILING,
|
"TRAILING": TokenType.TRAILING,
|
||||||
"TRUNCATE": TokenType.TRUNCATE,
|
|
||||||
"UNBOUNDED": TokenType.UNBOUNDED,
|
"UNBOUNDED": TokenType.UNBOUNDED,
|
||||||
"UNION": TokenType.UNION,
|
"UNION": TokenType.UNION,
|
||||||
"UNPIVOT": TokenType.UNPIVOT,
|
"UNPIVOT": TokenType.UNPIVOT,
|
||||||
|
@ -600,7 +587,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"UPDATE": TokenType.UPDATE,
|
"UPDATE": TokenType.UPDATE,
|
||||||
"USE": TokenType.USE,
|
"USE": TokenType.USE,
|
||||||
"USING": TokenType.USING,
|
"USING": TokenType.USING,
|
||||||
"VACUUM": TokenType.VACUUM,
|
|
||||||
"VALUES": TokenType.VALUES,
|
"VALUES": TokenType.VALUES,
|
||||||
"VIEW": TokenType.VIEW,
|
"VIEW": TokenType.VIEW,
|
||||||
"VOLATILE": TokenType.VOLATILE,
|
"VOLATILE": TokenType.VOLATILE,
|
||||||
|
@ -659,6 +645,14 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"UNIQUE": TokenType.UNIQUE,
|
"UNIQUE": TokenType.UNIQUE,
|
||||||
"STRUCT": TokenType.STRUCT,
|
"STRUCT": TokenType.STRUCT,
|
||||||
"VARIANT": TokenType.VARIANT,
|
"VARIANT": TokenType.VARIANT,
|
||||||
|
"ALTER": TokenType.COMMAND,
|
||||||
|
"ANALYZE": TokenType.COMMAND,
|
||||||
|
"CALL": TokenType.COMMAND,
|
||||||
|
"EXPLAIN": TokenType.COMMAND,
|
||||||
|
"OPTIMIZE": TokenType.COMMAND,
|
||||||
|
"PREPARE": TokenType.COMMAND,
|
||||||
|
"TRUNCATE": TokenType.COMMAND,
|
||||||
|
"VACUUM": TokenType.COMMAND,
|
||||||
}
|
}
|
||||||
|
|
||||||
WHITE_SPACE = {
|
WHITE_SPACE = {
|
||||||
|
@ -670,20 +664,11 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
}
|
}
|
||||||
|
|
||||||
COMMANDS = {
|
COMMANDS = {
|
||||||
TokenType.ALTER,
|
TokenType.COMMAND,
|
||||||
TokenType.ADD_FILE,
|
TokenType.EXECUTE,
|
||||||
TokenType.ANALYZE,
|
TokenType.FETCH,
|
||||||
TokenType.BEGIN,
|
|
||||||
TokenType.CALL,
|
|
||||||
TokenType.COMMENT_ON,
|
|
||||||
TokenType.COMMIT,
|
|
||||||
TokenType.EXPLAIN,
|
|
||||||
TokenType.OPTIMIZE,
|
|
||||||
TokenType.SET,
|
TokenType.SET,
|
||||||
TokenType.SHOW,
|
TokenType.SHOW,
|
||||||
TokenType.TRUNCATE,
|
|
||||||
TokenType.VACUUM,
|
|
||||||
TokenType.ROLLBACK,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# handle numeric literals like in hive (3L = BIGINT)
|
# handle numeric literals like in hive (3L = BIGINT)
|
||||||
|
@ -885,6 +870,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
if comment_start_line == self._prev_token_line:
|
if comment_start_line == self._prev_token_line:
|
||||||
if self._prev_token_comment is None:
|
if self._prev_token_comment is None:
|
||||||
self.tokens[-1].comment = self._comment
|
self.tokens[-1].comment = self._comment
|
||||||
|
self._prev_token_comment = self._comment
|
||||||
|
|
||||||
self._comment = None
|
self._comment = None
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,8 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
|
||||||
|
|
||||||
|
|
||||||
class TestDataframe(DataFrameSQLValidator):
|
class TestDataframe(DataFrameSQLValidator):
|
||||||
|
maxDiff = None
|
||||||
|
|
||||||
def test_hash_select_expression(self):
|
def test_hash_select_expression(self):
|
||||||
expression = exp.select("cola").from_("table")
|
expression = exp.select("cola").from_("table")
|
||||||
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
|
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
|
||||||
|
@ -16,26 +18,26 @@ class TestDataframe(DataFrameSQLValidator):
|
||||||
def test_cache(self):
|
def test_cache(self):
|
||||||
df = self.df_employee.select("fname").cache()
|
df = self.df_employee.select("fname").cache()
|
||||||
expected_statements = [
|
expected_statements = [
|
||||||
"DROP VIEW IF EXISTS t11623",
|
"DROP VIEW IF EXISTS t31563",
|
||||||
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
"CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||||
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
|
"SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
|
||||||
]
|
]
|
||||||
self.compare_sql(df, expected_statements)
|
self.compare_sql(df, expected_statements)
|
||||||
|
|
||||||
def test_persist_default(self):
|
def test_persist_default(self):
|
||||||
df = self.df_employee.select("fname").persist()
|
df = self.df_employee.select("fname").persist()
|
||||||
expected_statements = [
|
expected_statements = [
|
||||||
"DROP VIEW IF EXISTS t11623",
|
"DROP VIEW IF EXISTS t31563",
|
||||||
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
"CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||||
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
|
"SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
|
||||||
]
|
]
|
||||||
self.compare_sql(df, expected_statements)
|
self.compare_sql(df, expected_statements)
|
||||||
|
|
||||||
def test_persist_storagelevel(self):
|
def test_persist_storagelevel(self):
|
||||||
df = self.df_employee.select("fname").persist("DISK_ONLY_2")
|
df = self.df_employee.select("fname").persist("DISK_ONLY_2")
|
||||||
expected_statements = [
|
expected_statements = [
|
||||||
"DROP VIEW IF EXISTS t11623",
|
"DROP VIEW IF EXISTS t31563",
|
||||||
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
"CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||||
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
|
"SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
|
||||||
]
|
]
|
||||||
self.compare_sql(df, expected_statements)
|
self.compare_sql(df, expected_statements)
|
||||||
|
|
|
@ -6,39 +6,41 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
|
||||||
|
|
||||||
|
|
||||||
class TestDataFrameWriter(DataFrameSQLValidator):
|
class TestDataFrameWriter(DataFrameSQLValidator):
|
||||||
|
maxDiff = None
|
||||||
|
|
||||||
def test_insertInto_full_path(self):
|
def test_insertInto_full_path(self):
|
||||||
df = self.df_employee.write.insertInto("catalog.db.table_name")
|
df = self.df_employee.write.insertInto("catalog.db.table_name")
|
||||||
expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_insertInto_db_table(self):
|
def test_insertInto_db_table(self):
|
||||||
df = self.df_employee.write.insertInto("db.table_name")
|
df = self.df_employee.write.insertInto("db.table_name")
|
||||||
expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_insertInto_table(self):
|
def test_insertInto_table(self):
|
||||||
df = self.df_employee.write.insertInto("table_name")
|
df = self.df_employee.write.insertInto("table_name")
|
||||||
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_insertInto_overwrite(self):
|
def test_insertInto_overwrite(self):
|
||||||
df = self.df_employee.write.insertInto("table_name", overwrite=True)
|
df = self.df_employee.write.insertInto("table_name", overwrite=True)
|
||||||
expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
@mock.patch("sqlglot.schema", MappingSchema())
|
@mock.patch("sqlglot.schema", MappingSchema())
|
||||||
def test_insertInto_byName(self):
|
def test_insertInto_byName(self):
|
||||||
sqlglot.schema.add_table("table_name", {"employee_id": "INT"})
|
sqlglot.schema.add_table("table_name", {"employee_id": "INT"})
|
||||||
df = self.df_employee.write.byName.insertInto("table_name")
|
df = self.df_employee.write.byName.insertInto("table_name")
|
||||||
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_insertInto_cache(self):
|
def test_insertInto_cache(self):
|
||||||
df = self.df_employee.cache().write.insertInto("table_name")
|
df = self.df_employee.cache().write.insertInto("table_name")
|
||||||
expected_statements = [
|
expected_statements = [
|
||||||
"DROP VIEW IF EXISTS t35612",
|
"DROP VIEW IF EXISTS t37164",
|
||||||
"CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
"CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||||
"INSERT INTO table_name SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`",
|
"INSERT INTO table_name SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`",
|
||||||
]
|
]
|
||||||
self.compare_sql(df, expected_statements)
|
self.compare_sql(df, expected_statements)
|
||||||
|
|
||||||
|
@ -48,39 +50,39 @@ class TestDataFrameWriter(DataFrameSQLValidator):
|
||||||
|
|
||||||
def test_saveAsTable_append(self):
|
def test_saveAsTable_append(self):
|
||||||
df = self.df_employee.write.saveAsTable("table_name", mode="append")
|
df = self.df_employee.write.saveAsTable("table_name", mode="append")
|
||||||
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_saveAsTable_overwrite(self):
|
def test_saveAsTable_overwrite(self):
|
||||||
df = self.df_employee.write.saveAsTable("table_name", mode="overwrite")
|
df = self.df_employee.write.saveAsTable("table_name", mode="overwrite")
|
||||||
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_saveAsTable_error(self):
|
def test_saveAsTable_error(self):
|
||||||
df = self.df_employee.write.saveAsTable("table_name", mode="error")
|
df = self.df_employee.write.saveAsTable("table_name", mode="error")
|
||||||
expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_saveAsTable_ignore(self):
|
def test_saveAsTable_ignore(self):
|
||||||
df = self.df_employee.write.saveAsTable("table_name", mode="ignore")
|
df = self.df_employee.write.saveAsTable("table_name", mode="ignore")
|
||||||
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_mode_standalone(self):
|
def test_mode_standalone(self):
|
||||||
df = self.df_employee.write.mode("ignore").saveAsTable("table_name")
|
df = self.df_employee.write.mode("ignore").saveAsTable("table_name")
|
||||||
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_mode_override(self):
|
def test_mode_override(self):
|
||||||
df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite")
|
df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite")
|
||||||
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_saveAsTable_cache(self):
|
def test_saveAsTable_cache(self):
|
||||||
df = self.df_employee.cache().write.saveAsTable("table_name")
|
df = self.df_employee.cache().write.saveAsTable("table_name")
|
||||||
expected_statements = [
|
expected_statements = [
|
||||||
"DROP VIEW IF EXISTS t35612",
|
"DROP VIEW IF EXISTS t37164",
|
||||||
"CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
"CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||||
"CREATE TABLE table_name AS SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`",
|
"CREATE TABLE table_name AS SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`",
|
||||||
]
|
]
|
||||||
self.compare_sql(df, expected_statements)
|
self.compare_sql(df, expected_statements)
|
||||||
|
|
|
@ -11,32 +11,32 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
|
||||||
class TestDataframeSession(DataFrameSQLValidator):
|
class TestDataframeSession(DataFrameSQLValidator):
|
||||||
def test_cdf_one_row(self):
|
def test_cdf_one_row(self):
|
||||||
df = self.spark.createDataFrame([[1, 2]], ["cola", "colb"])
|
df = self.spark.createDataFrame([[1, 2]], ["cola", "colb"])
|
||||||
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2)) AS `a2`(`cola`, `colb`)"
|
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2) AS `a2`(`cola`, `colb`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_cdf_multiple_rows(self):
|
def test_cdf_multiple_rows(self):
|
||||||
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]], ["cola", "colb"])
|
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]], ["cola", "colb"])
|
||||||
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`cola`, `colb`)"
|
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`cola`, `colb`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_cdf_no_schema(self):
|
def test_cdf_no_schema(self):
|
||||||
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
|
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
|
||||||
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
|
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`_1`, `_2`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_cdf_row_mixed_primitives(self):
|
def test_cdf_row_mixed_primitives(self):
|
||||||
df = self.spark.createDataFrame([[1, 10.1, "test", False, None]])
|
df = self.spark.createDataFrame([[1, 10.1, "test", False, None]])
|
||||||
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM (VALUES (1, 10.1, 'test', FALSE, NULL)) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)"
|
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM VALUES (1, 10.1, 'test', FALSE, NULL) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_cdf_dict_rows(self):
|
def test_cdf_dict_rows(self):
|
||||||
df = self.spark.createDataFrame([{"cola": 1, "colb": "test"}, {"cola": 2, "colb": "test2"}])
|
df = self.spark.createDataFrame([{"cola": 1, "colb": "test"}, {"cola": 2, "colb": "test2"}])
|
||||||
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 'test'), (2, 'test2')) AS `a2`(`cola`, `colb`)"
|
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 'test'), (2, 'test2') AS `a2`(`cola`, `colb`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_cdf_str_schema(self):
|
def test_cdf_str_schema(self):
|
||||||
df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING")
|
df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING")
|
||||||
expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)"
|
expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_typed_schema_basic(self):
|
def test_typed_schema_basic(self):
|
||||||
|
@ -47,7 +47,7 @@ class TestDataframeSession(DataFrameSQLValidator):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
df = self.spark.createDataFrame([[1, "test"]], schema)
|
df = self.spark.createDataFrame([[1, "test"]], schema)
|
||||||
expected = "SELECT CAST(`a2`.`cola` AS int) AS `cola`, CAST(`a2`.`colb` AS string) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)"
|
expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_typed_schema_nested(self):
|
def test_typed_schema_nested(self):
|
||||||
|
@ -65,7 +65,8 @@ class TestDataframeSession(DataFrameSQLValidator):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
df = self.spark.createDataFrame([[{"sub_cola": 1, "sub_colb": "test"}]], schema)
|
df = self.spark.createDataFrame([[{"sub_cola": 1, "sub_colb": "test"}]], schema)
|
||||||
expected = "SELECT CAST(`a2`.`cola` AS struct<sub_cola:int, sub_colb:string>) AS `cola` FROM (VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`))) AS `a2`(`cola`)"
|
expected = "SELECT CAST(`a2`.`cola` AS STRUCT<`sub_cola`: INT, `sub_colb`: STRING>) AS `cola` FROM VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`)) AS `a2`(`cola`)"
|
||||||
|
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
@mock.patch("sqlglot.schema", MappingSchema())
|
@mock.patch("sqlglot.schema", MappingSchema())
|
||||||
|
|
|
@ -286,6 +286,10 @@ class TestBigQuery(Validator):
|
||||||
"bigquery": "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))",
|
"bigquery": "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_identity("BEGIN A B C D E F")
|
||||||
|
self.validate_identity("BEGIN TRANSACTION")
|
||||||
|
self.validate_identity("COMMIT TRANSACTION")
|
||||||
|
self.validate_identity("ROLLBACK TRANSACTION")
|
||||||
|
|
||||||
def test_user_defined_functions(self):
|
def test_user_defined_functions(self):
|
||||||
self.validate_identity(
|
self.validate_identity(
|
||||||
|
|
|
@ -69,6 +69,7 @@ class TestDialect(Validator):
|
||||||
write={
|
write={
|
||||||
"bigquery": "CAST(a AS STRING)",
|
"bigquery": "CAST(a AS STRING)",
|
||||||
"clickhouse": "CAST(a AS TEXT)",
|
"clickhouse": "CAST(a AS TEXT)",
|
||||||
|
"drill": "CAST(a AS VARCHAR)",
|
||||||
"duckdb": "CAST(a AS TEXT)",
|
"duckdb": "CAST(a AS TEXT)",
|
||||||
"mysql": "CAST(a AS TEXT)",
|
"mysql": "CAST(a AS TEXT)",
|
||||||
"hive": "CAST(a AS STRING)",
|
"hive": "CAST(a AS STRING)",
|
||||||
|
@ -86,6 +87,7 @@ class TestDialect(Validator):
|
||||||
write={
|
write={
|
||||||
"bigquery": "CAST(a AS BINARY(4))",
|
"bigquery": "CAST(a AS BINARY(4))",
|
||||||
"clickhouse": "CAST(a AS BINARY(4))",
|
"clickhouse": "CAST(a AS BINARY(4))",
|
||||||
|
"drill": "CAST(a AS VARBINARY(4))",
|
||||||
"duckdb": "CAST(a AS BINARY(4))",
|
"duckdb": "CAST(a AS BINARY(4))",
|
||||||
"mysql": "CAST(a AS BINARY(4))",
|
"mysql": "CAST(a AS BINARY(4))",
|
||||||
"hive": "CAST(a AS BINARY(4))",
|
"hive": "CAST(a AS BINARY(4))",
|
||||||
|
@ -146,6 +148,7 @@ class TestDialect(Validator):
|
||||||
"CAST(a AS STRING)",
|
"CAST(a AS STRING)",
|
||||||
write={
|
write={
|
||||||
"bigquery": "CAST(a AS STRING)",
|
"bigquery": "CAST(a AS STRING)",
|
||||||
|
"drill": "CAST(a AS VARCHAR)",
|
||||||
"duckdb": "CAST(a AS TEXT)",
|
"duckdb": "CAST(a AS TEXT)",
|
||||||
"mysql": "CAST(a AS TEXT)",
|
"mysql": "CAST(a AS TEXT)",
|
||||||
"hive": "CAST(a AS STRING)",
|
"hive": "CAST(a AS STRING)",
|
||||||
|
@ -162,6 +165,7 @@ class TestDialect(Validator):
|
||||||
"CAST(a AS VARCHAR)",
|
"CAST(a AS VARCHAR)",
|
||||||
write={
|
write={
|
||||||
"bigquery": "CAST(a AS STRING)",
|
"bigquery": "CAST(a AS STRING)",
|
||||||
|
"drill": "CAST(a AS VARCHAR)",
|
||||||
"duckdb": "CAST(a AS TEXT)",
|
"duckdb": "CAST(a AS TEXT)",
|
||||||
"mysql": "CAST(a AS VARCHAR)",
|
"mysql": "CAST(a AS VARCHAR)",
|
||||||
"hive": "CAST(a AS STRING)",
|
"hive": "CAST(a AS STRING)",
|
||||||
|
@ -178,6 +182,7 @@ class TestDialect(Validator):
|
||||||
"CAST(a AS VARCHAR(3))",
|
"CAST(a AS VARCHAR(3))",
|
||||||
write={
|
write={
|
||||||
"bigquery": "CAST(a AS STRING(3))",
|
"bigquery": "CAST(a AS STRING(3))",
|
||||||
|
"drill": "CAST(a AS VARCHAR(3))",
|
||||||
"duckdb": "CAST(a AS TEXT(3))",
|
"duckdb": "CAST(a AS TEXT(3))",
|
||||||
"mysql": "CAST(a AS VARCHAR(3))",
|
"mysql": "CAST(a AS VARCHAR(3))",
|
||||||
"hive": "CAST(a AS VARCHAR(3))",
|
"hive": "CAST(a AS VARCHAR(3))",
|
||||||
|
@ -194,6 +199,7 @@ class TestDialect(Validator):
|
||||||
"CAST(a AS SMALLINT)",
|
"CAST(a AS SMALLINT)",
|
||||||
write={
|
write={
|
||||||
"bigquery": "CAST(a AS INT64)",
|
"bigquery": "CAST(a AS INT64)",
|
||||||
|
"drill": "CAST(a AS INTEGER)",
|
||||||
"duckdb": "CAST(a AS SMALLINT)",
|
"duckdb": "CAST(a AS SMALLINT)",
|
||||||
"mysql": "CAST(a AS SMALLINT)",
|
"mysql": "CAST(a AS SMALLINT)",
|
||||||
"hive": "CAST(a AS SMALLINT)",
|
"hive": "CAST(a AS SMALLINT)",
|
||||||
|
@ -215,6 +221,7 @@ class TestDialect(Validator):
|
||||||
},
|
},
|
||||||
write={
|
write={
|
||||||
"duckdb": "TRY_CAST(a AS DOUBLE)",
|
"duckdb": "TRY_CAST(a AS DOUBLE)",
|
||||||
|
"drill": "CAST(a AS DOUBLE)",
|
||||||
"postgres": "CAST(a AS DOUBLE PRECISION)",
|
"postgres": "CAST(a AS DOUBLE PRECISION)",
|
||||||
"redshift": "CAST(a AS DOUBLE PRECISION)",
|
"redshift": "CAST(a AS DOUBLE PRECISION)",
|
||||||
},
|
},
|
||||||
|
@ -225,6 +232,7 @@ class TestDialect(Validator):
|
||||||
write={
|
write={
|
||||||
"bigquery": "CAST(a AS FLOAT64)",
|
"bigquery": "CAST(a AS FLOAT64)",
|
||||||
"clickhouse": "CAST(a AS Float64)",
|
"clickhouse": "CAST(a AS Float64)",
|
||||||
|
"drill": "CAST(a AS DOUBLE)",
|
||||||
"duckdb": "CAST(a AS DOUBLE)",
|
"duckdb": "CAST(a AS DOUBLE)",
|
||||||
"mysql": "CAST(a AS DOUBLE)",
|
"mysql": "CAST(a AS DOUBLE)",
|
||||||
"hive": "CAST(a AS DOUBLE)",
|
"hive": "CAST(a AS DOUBLE)",
|
||||||
|
@ -279,6 +287,7 @@ class TestDialect(Validator):
|
||||||
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')",
|
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')",
|
||||||
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)",
|
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)",
|
||||||
"presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')",
|
"presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')",
|
||||||
|
"drill": "TO_TIMESTAMP(x, 'yyyy-MM-dd''T''HH:mm:ss')",
|
||||||
"redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH:MI:SS')",
|
"redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH:MI:SS')",
|
||||||
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')",
|
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')",
|
||||||
},
|
},
|
||||||
|
@ -286,6 +295,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"STR_TO_TIME('2020-01-01', '%Y-%m-%d')",
|
"STR_TO_TIME('2020-01-01', '%Y-%m-%d')",
|
||||||
write={
|
write={
|
||||||
|
"drill": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')",
|
||||||
"duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')",
|
"duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')",
|
||||||
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
|
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
|
||||||
"oracle": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
|
"oracle": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
|
||||||
|
@ -298,6 +308,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"STR_TO_TIME(x, '%y')",
|
"STR_TO_TIME(x, '%y')",
|
||||||
write={
|
write={
|
||||||
|
"drill": "TO_TIMESTAMP(x, 'yy')",
|
||||||
"duckdb": "STRPTIME(x, '%y')",
|
"duckdb": "STRPTIME(x, '%y')",
|
||||||
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)",
|
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)",
|
||||||
"presto": "DATE_PARSE(x, '%y')",
|
"presto": "DATE_PARSE(x, '%y')",
|
||||||
|
@ -319,6 +330,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"TIME_STR_TO_DATE('2020-01-01')",
|
"TIME_STR_TO_DATE('2020-01-01')",
|
||||||
write={
|
write={
|
||||||
|
"drill": "CAST('2020-01-01' AS DATE)",
|
||||||
"duckdb": "CAST('2020-01-01' AS DATE)",
|
"duckdb": "CAST('2020-01-01' AS DATE)",
|
||||||
"hive": "TO_DATE('2020-01-01')",
|
"hive": "TO_DATE('2020-01-01')",
|
||||||
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')",
|
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')",
|
||||||
|
@ -328,6 +340,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"TIME_STR_TO_TIME('2020-01-01')",
|
"TIME_STR_TO_TIME('2020-01-01')",
|
||||||
write={
|
write={
|
||||||
|
"drill": "CAST('2020-01-01' AS TIMESTAMP)",
|
||||||
"duckdb": "CAST('2020-01-01' AS TIMESTAMP)",
|
"duckdb": "CAST('2020-01-01' AS TIMESTAMP)",
|
||||||
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
|
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
|
||||||
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')",
|
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')",
|
||||||
|
@ -344,6 +357,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"TIME_TO_STR(x, '%Y-%m-%d')",
|
"TIME_TO_STR(x, '%Y-%m-%d')",
|
||||||
write={
|
write={
|
||||||
|
"drill": "TO_CHAR(x, 'yyyy-MM-dd')",
|
||||||
"duckdb": "STRFTIME(x, '%Y-%m-%d')",
|
"duckdb": "STRFTIME(x, '%Y-%m-%d')",
|
||||||
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd')",
|
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd')",
|
||||||
"oracle": "TO_CHAR(x, 'YYYY-MM-DD')",
|
"oracle": "TO_CHAR(x, 'YYYY-MM-DD')",
|
||||||
|
@ -355,6 +369,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"TIME_TO_TIME_STR(x)",
|
"TIME_TO_TIME_STR(x)",
|
||||||
write={
|
write={
|
||||||
|
"drill": "CAST(x AS VARCHAR)",
|
||||||
"duckdb": "CAST(x AS TEXT)",
|
"duckdb": "CAST(x AS TEXT)",
|
||||||
"hive": "CAST(x AS STRING)",
|
"hive": "CAST(x AS STRING)",
|
||||||
"presto": "CAST(x AS VARCHAR)",
|
"presto": "CAST(x AS VARCHAR)",
|
||||||
|
@ -364,6 +379,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"TIME_TO_UNIX(x)",
|
"TIME_TO_UNIX(x)",
|
||||||
write={
|
write={
|
||||||
|
"drill": "UNIX_TIMESTAMP(x)",
|
||||||
"duckdb": "EPOCH(x)",
|
"duckdb": "EPOCH(x)",
|
||||||
"hive": "UNIX_TIMESTAMP(x)",
|
"hive": "UNIX_TIMESTAMP(x)",
|
||||||
"presto": "TO_UNIXTIME(x)",
|
"presto": "TO_UNIXTIME(x)",
|
||||||
|
@ -425,6 +441,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"DATE_TO_DATE_STR(x)",
|
"DATE_TO_DATE_STR(x)",
|
||||||
write={
|
write={
|
||||||
|
"drill": "CAST(x AS VARCHAR)",
|
||||||
"duckdb": "CAST(x AS TEXT)",
|
"duckdb": "CAST(x AS TEXT)",
|
||||||
"hive": "CAST(x AS STRING)",
|
"hive": "CAST(x AS STRING)",
|
||||||
"presto": "CAST(x AS VARCHAR)",
|
"presto": "CAST(x AS VARCHAR)",
|
||||||
|
@ -433,6 +450,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"DATE_TO_DI(x)",
|
"DATE_TO_DI(x)",
|
||||||
write={
|
write={
|
||||||
|
"drill": "CAST(TO_DATE(x, 'yyyyMMdd') AS INT)",
|
||||||
"duckdb": "CAST(STRFTIME(x, '%Y%m%d') AS INT)",
|
"duckdb": "CAST(STRFTIME(x, '%Y%m%d') AS INT)",
|
||||||
"hive": "CAST(DATE_FORMAT(x, 'yyyyMMdd') AS INT)",
|
"hive": "CAST(DATE_FORMAT(x, 'yyyyMMdd') AS INT)",
|
||||||
"presto": "CAST(DATE_FORMAT(x, '%Y%m%d') AS INT)",
|
"presto": "CAST(DATE_FORMAT(x, '%Y%m%d') AS INT)",
|
||||||
|
@ -441,6 +459,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"DI_TO_DATE(x)",
|
"DI_TO_DATE(x)",
|
||||||
write={
|
write={
|
||||||
|
"drill": "TO_DATE(CAST(x AS VARCHAR), 'yyyyMMdd')",
|
||||||
"duckdb": "CAST(STRPTIME(CAST(x AS TEXT), '%Y%m%d') AS DATE)",
|
"duckdb": "CAST(STRPTIME(CAST(x AS TEXT), '%Y%m%d') AS DATE)",
|
||||||
"hive": "TO_DATE(CAST(x AS STRING), 'yyyyMMdd')",
|
"hive": "TO_DATE(CAST(x AS STRING), 'yyyyMMdd')",
|
||||||
"presto": "CAST(DATE_PARSE(CAST(x AS VARCHAR), '%Y%m%d') AS DATE)",
|
"presto": "CAST(DATE_PARSE(CAST(x AS VARCHAR), '%Y%m%d') AS DATE)",
|
||||||
|
@ -463,6 +482,7 @@ class TestDialect(Validator):
|
||||||
},
|
},
|
||||||
write={
|
write={
|
||||||
"bigquery": "DATE_ADD(x, INTERVAL 1 'day')",
|
"bigquery": "DATE_ADD(x, INTERVAL 1 'day')",
|
||||||
|
"drill": "DATE_ADD(x, INTERVAL '1' DAY)",
|
||||||
"duckdb": "x + INTERVAL 1 day",
|
"duckdb": "x + INTERVAL 1 day",
|
||||||
"hive": "DATE_ADD(x, 1)",
|
"hive": "DATE_ADD(x, 1)",
|
||||||
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
|
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
|
||||||
|
@ -477,6 +497,7 @@ class TestDialect(Validator):
|
||||||
"DATE_ADD(x, 1)",
|
"DATE_ADD(x, 1)",
|
||||||
write={
|
write={
|
||||||
"bigquery": "DATE_ADD(x, INTERVAL 1 'day')",
|
"bigquery": "DATE_ADD(x, INTERVAL 1 'day')",
|
||||||
|
"drill": "DATE_ADD(x, INTERVAL '1' DAY)",
|
||||||
"duckdb": "x + INTERVAL 1 DAY",
|
"duckdb": "x + INTERVAL 1 DAY",
|
||||||
"hive": "DATE_ADD(x, 1)",
|
"hive": "DATE_ADD(x, 1)",
|
||||||
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
|
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
|
||||||
|
@ -546,6 +567,7 @@ class TestDialect(Validator):
|
||||||
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
|
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
|
||||||
},
|
},
|
||||||
write={
|
write={
|
||||||
|
"drill": "TO_DATE(x, 'yyyy-MM-dd''T''HH:mm:ss')",
|
||||||
"mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
|
"mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
|
||||||
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
|
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
|
||||||
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)",
|
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)",
|
||||||
|
@ -556,6 +578,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"STR_TO_DATE(x, '%Y-%m-%d')",
|
"STR_TO_DATE(x, '%Y-%m-%d')",
|
||||||
write={
|
write={
|
||||||
|
"drill": "CAST(x AS DATE)",
|
||||||
"mysql": "STR_TO_DATE(x, '%Y-%m-%d')",
|
"mysql": "STR_TO_DATE(x, '%Y-%m-%d')",
|
||||||
"starrocks": "STR_TO_DATE(x, '%Y-%m-%d')",
|
"starrocks": "STR_TO_DATE(x, '%Y-%m-%d')",
|
||||||
"hive": "CAST(x AS DATE)",
|
"hive": "CAST(x AS DATE)",
|
||||||
|
@ -566,6 +589,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"DATE_STR_TO_DATE(x)",
|
"DATE_STR_TO_DATE(x)",
|
||||||
write={
|
write={
|
||||||
|
"drill": "CAST(x AS DATE)",
|
||||||
"duckdb": "CAST(x AS DATE)",
|
"duckdb": "CAST(x AS DATE)",
|
||||||
"hive": "TO_DATE(x)",
|
"hive": "TO_DATE(x)",
|
||||||
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)",
|
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)",
|
||||||
|
@ -575,6 +599,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"TS_OR_DS_ADD('2021-02-01', 1, 'DAY')",
|
"TS_OR_DS_ADD('2021-02-01', 1, 'DAY')",
|
||||||
write={
|
write={
|
||||||
|
"drill": "DATE_ADD(CAST('2021-02-01' AS DATE), INTERVAL '1' DAY)",
|
||||||
"duckdb": "CAST('2021-02-01' AS DATE) + INTERVAL 1 DAY",
|
"duckdb": "CAST('2021-02-01' AS DATE) + INTERVAL 1 DAY",
|
||||||
"hive": "DATE_ADD('2021-02-01', 1)",
|
"hive": "DATE_ADD('2021-02-01', 1)",
|
||||||
"presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR('2021-02-01', 1, 10), '%Y-%m-%d'))",
|
"presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR('2021-02-01', 1, 10), '%Y-%m-%d'))",
|
||||||
|
@ -584,6 +609,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
|
"DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
|
||||||
write={
|
write={
|
||||||
|
"drill": "DATE_ADD(CAST('2020-01-01' AS DATE), INTERVAL '1' DAY)",
|
||||||
"duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY",
|
"duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY",
|
||||||
"hive": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
|
"hive": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
|
||||||
"presto": "DATE_ADD('day', 1, CAST('2020-01-01' AS DATE))",
|
"presto": "DATE_ADD('day', 1, CAST('2020-01-01' AS DATE))",
|
||||||
|
@ -593,6 +619,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"TIMESTAMP '2022-01-01'",
|
"TIMESTAMP '2022-01-01'",
|
||||||
write={
|
write={
|
||||||
|
"drill": "CAST('2022-01-01' AS TIMESTAMP)",
|
||||||
"mysql": "CAST('2022-01-01' AS TIMESTAMP)",
|
"mysql": "CAST('2022-01-01' AS TIMESTAMP)",
|
||||||
"starrocks": "CAST('2022-01-01' AS DATETIME)",
|
"starrocks": "CAST('2022-01-01' AS DATETIME)",
|
||||||
"hive": "CAST('2022-01-01' AS TIMESTAMP)",
|
"hive": "CAST('2022-01-01' AS TIMESTAMP)",
|
||||||
|
@ -614,6 +641,7 @@ class TestDialect(Validator):
|
||||||
dialect: f"{unit}(x)"
|
dialect: f"{unit}(x)"
|
||||||
for dialect in (
|
for dialect in (
|
||||||
"bigquery",
|
"bigquery",
|
||||||
|
"drill",
|
||||||
"duckdb",
|
"duckdb",
|
||||||
"mysql",
|
"mysql",
|
||||||
"presto",
|
"presto",
|
||||||
|
@ -624,6 +652,7 @@ class TestDialect(Validator):
|
||||||
dialect: f"{unit}(x)"
|
dialect: f"{unit}(x)"
|
||||||
for dialect in (
|
for dialect in (
|
||||||
"bigquery",
|
"bigquery",
|
||||||
|
"drill",
|
||||||
"duckdb",
|
"duckdb",
|
||||||
"mysql",
|
"mysql",
|
||||||
"presto",
|
"presto",
|
||||||
|
@ -649,6 +678,7 @@ class TestDialect(Validator):
|
||||||
write={
|
write={
|
||||||
"bigquery": "ARRAY_LENGTH(x)",
|
"bigquery": "ARRAY_LENGTH(x)",
|
||||||
"duckdb": "ARRAY_LENGTH(x)",
|
"duckdb": "ARRAY_LENGTH(x)",
|
||||||
|
"drill": "REPEATED_COUNT(x)",
|
||||||
"presto": "CARDINALITY(x)",
|
"presto": "CARDINALITY(x)",
|
||||||
"spark": "SIZE(x)",
|
"spark": "SIZE(x)",
|
||||||
},
|
},
|
||||||
|
@ -736,6 +766,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT a FROM x CROSS JOIN UNNEST(y) AS t (a)",
|
"SELECT a FROM x CROSS JOIN UNNEST(y) AS t (a)",
|
||||||
write={
|
write={
|
||||||
|
"drill": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)",
|
||||||
"presto": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)",
|
"presto": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)",
|
||||||
"spark": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a",
|
"spark": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a",
|
||||||
},
|
},
|
||||||
|
@ -743,6 +774,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t (a, b)",
|
"SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t (a, b)",
|
||||||
write={
|
write={
|
||||||
|
"drill": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
|
||||||
"presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
|
"presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
|
||||||
"spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b",
|
"spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b",
|
||||||
},
|
},
|
||||||
|
@ -775,6 +807,7 @@ class TestDialect(Validator):
|
||||||
},
|
},
|
||||||
write={
|
write={
|
||||||
"bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b",
|
"bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b",
|
||||||
|
"drill": "SELECT * FROM a UNION SELECT * FROM b",
|
||||||
"duckdb": "SELECT * FROM a UNION SELECT * FROM b",
|
"duckdb": "SELECT * FROM a UNION SELECT * FROM b",
|
||||||
"presto": "SELECT * FROM a UNION SELECT * FROM b",
|
"presto": "SELECT * FROM a UNION SELECT * FROM b",
|
||||||
"spark": "SELECT * FROM a UNION SELECT * FROM b",
|
"spark": "SELECT * FROM a UNION SELECT * FROM b",
|
||||||
|
@ -887,6 +920,7 @@ class TestDialect(Validator):
|
||||||
write={
|
write={
|
||||||
"bigquery": "LOWER(x) LIKE '%y'",
|
"bigquery": "LOWER(x) LIKE '%y'",
|
||||||
"clickhouse": "x ILIKE '%y'",
|
"clickhouse": "x ILIKE '%y'",
|
||||||
|
"drill": "x `ILIKE` '%y'",
|
||||||
"duckdb": "x ILIKE '%y'",
|
"duckdb": "x ILIKE '%y'",
|
||||||
"hive": "LOWER(x) LIKE '%y'",
|
"hive": "LOWER(x) LIKE '%y'",
|
||||||
"mysql": "LOWER(x) LIKE '%y'",
|
"mysql": "LOWER(x) LIKE '%y'",
|
||||||
|
@ -910,32 +944,38 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"POSITION(' ' in x)",
|
"POSITION(' ' in x)",
|
||||||
write={
|
write={
|
||||||
|
"drill": "STRPOS(x, ' ')",
|
||||||
"duckdb": "STRPOS(x, ' ')",
|
"duckdb": "STRPOS(x, ' ')",
|
||||||
"postgres": "STRPOS(x, ' ')",
|
"postgres": "STRPOS(x, ' ')",
|
||||||
"presto": "STRPOS(x, ' ')",
|
"presto": "STRPOS(x, ' ')",
|
||||||
"spark": "LOCATE(' ', x)",
|
"spark": "LOCATE(' ', x)",
|
||||||
"clickhouse": "position(x, ' ')",
|
"clickhouse": "position(x, ' ')",
|
||||||
"snowflake": "POSITION(' ', x)",
|
"snowflake": "POSITION(' ', x)",
|
||||||
|
"mysql": "LOCATE(' ', x)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"STR_POSITION('a', x)",
|
"STR_POSITION('a', x)",
|
||||||
write={
|
write={
|
||||||
|
"drill": "STRPOS(x, 'a')",
|
||||||
"duckdb": "STRPOS(x, 'a')",
|
"duckdb": "STRPOS(x, 'a')",
|
||||||
"postgres": "STRPOS(x, 'a')",
|
"postgres": "STRPOS(x, 'a')",
|
||||||
"presto": "STRPOS(x, 'a')",
|
"presto": "STRPOS(x, 'a')",
|
||||||
"spark": "LOCATE('a', x)",
|
"spark": "LOCATE('a', x)",
|
||||||
"clickhouse": "position(x, 'a')",
|
"clickhouse": "position(x, 'a')",
|
||||||
"snowflake": "POSITION('a', x)",
|
"snowflake": "POSITION('a', x)",
|
||||||
|
"mysql": "LOCATE('a', x)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"POSITION('a', x, 3)",
|
"POSITION('a', x, 3)",
|
||||||
write={
|
write={
|
||||||
|
"drill": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
|
||||||
"presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
|
"presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
|
||||||
"spark": "LOCATE('a', x, 3)",
|
"spark": "LOCATE('a', x, 3)",
|
||||||
"clickhouse": "position(x, 'a', 3)",
|
"clickhouse": "position(x, 'a', 3)",
|
||||||
"snowflake": "POSITION('a', x, 3)",
|
"snowflake": "POSITION('a', x, 3)",
|
||||||
|
"mysql": "LOCATE('a', x, 3)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -960,6 +1000,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"IF(x > 1, 1, 0)",
|
"IF(x > 1, 1, 0)",
|
||||||
write={
|
write={
|
||||||
|
"drill": "`IF`(x > 1, 1, 0)",
|
||||||
"duckdb": "CASE WHEN x > 1 THEN 1 ELSE 0 END",
|
"duckdb": "CASE WHEN x > 1 THEN 1 ELSE 0 END",
|
||||||
"presto": "IF(x > 1, 1, 0)",
|
"presto": "IF(x > 1, 1, 0)",
|
||||||
"hive": "IF(x > 1, 1, 0)",
|
"hive": "IF(x > 1, 1, 0)",
|
||||||
|
@ -970,6 +1011,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CASE WHEN 1 THEN x ELSE 0 END",
|
"CASE WHEN 1 THEN x ELSE 0 END",
|
||||||
write={
|
write={
|
||||||
|
"drill": "CASE WHEN 1 THEN x ELSE 0 END",
|
||||||
"duckdb": "CASE WHEN 1 THEN x ELSE 0 END",
|
"duckdb": "CASE WHEN 1 THEN x ELSE 0 END",
|
||||||
"presto": "CASE WHEN 1 THEN x ELSE 0 END",
|
"presto": "CASE WHEN 1 THEN x ELSE 0 END",
|
||||||
"hive": "CASE WHEN 1 THEN x ELSE 0 END",
|
"hive": "CASE WHEN 1 THEN x ELSE 0 END",
|
||||||
|
@ -980,6 +1022,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"x[y]",
|
"x[y]",
|
||||||
write={
|
write={
|
||||||
|
"drill": "x[y]",
|
||||||
"duckdb": "x[y]",
|
"duckdb": "x[y]",
|
||||||
"presto": "x[y]",
|
"presto": "x[y]",
|
||||||
"hive": "x[y]",
|
"hive": "x[y]",
|
||||||
|
@ -1000,6 +1043,7 @@ class TestDialect(Validator):
|
||||||
'true or null as "foo"',
|
'true or null as "foo"',
|
||||||
write={
|
write={
|
||||||
"bigquery": "TRUE OR NULL AS `foo`",
|
"bigquery": "TRUE OR NULL AS `foo`",
|
||||||
|
"drill": "TRUE OR NULL AS `foo`",
|
||||||
"duckdb": 'TRUE OR NULL AS "foo"',
|
"duckdb": 'TRUE OR NULL AS "foo"',
|
||||||
"presto": 'TRUE OR NULL AS "foo"',
|
"presto": 'TRUE OR NULL AS "foo"',
|
||||||
"hive": "TRUE OR NULL AS `foo`",
|
"hive": "TRUE OR NULL AS `foo`",
|
||||||
|
@ -1020,6 +1064,7 @@ class TestDialect(Validator):
|
||||||
"LEVENSHTEIN(col1, col2)",
|
"LEVENSHTEIN(col1, col2)",
|
||||||
write={
|
write={
|
||||||
"duckdb": "LEVENSHTEIN(col1, col2)",
|
"duckdb": "LEVENSHTEIN(col1, col2)",
|
||||||
|
"drill": "LEVENSHTEIN_DISTANCE(col1, col2)",
|
||||||
"presto": "LEVENSHTEIN_DISTANCE(col1, col2)",
|
"presto": "LEVENSHTEIN_DISTANCE(col1, col2)",
|
||||||
"hive": "LEVENSHTEIN(col1, col2)",
|
"hive": "LEVENSHTEIN(col1, col2)",
|
||||||
"spark": "LEVENSHTEIN(col1, col2)",
|
"spark": "LEVENSHTEIN(col1, col2)",
|
||||||
|
@ -1029,6 +1074,7 @@ class TestDialect(Validator):
|
||||||
"LEVENSHTEIN(coalesce(col1, col2), coalesce(col2, col1))",
|
"LEVENSHTEIN(coalesce(col1, col2), coalesce(col2, col1))",
|
||||||
write={
|
write={
|
||||||
"duckdb": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
|
"duckdb": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
|
||||||
|
"drill": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
|
||||||
"presto": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
|
"presto": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
|
||||||
"hive": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
|
"hive": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
|
||||||
"spark": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
|
"spark": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
|
||||||
|
@ -1152,6 +1198,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT a AS b FROM x GROUP BY b",
|
"SELECT a AS b FROM x GROUP BY b",
|
||||||
write={
|
write={
|
||||||
|
"drill": "SELECT a AS b FROM x GROUP BY b",
|
||||||
"duckdb": "SELECT a AS b FROM x GROUP BY b",
|
"duckdb": "SELECT a AS b FROM x GROUP BY b",
|
||||||
"presto": "SELECT a AS b FROM x GROUP BY 1",
|
"presto": "SELECT a AS b FROM x GROUP BY 1",
|
||||||
"hive": "SELECT a AS b FROM x GROUP BY 1",
|
"hive": "SELECT a AS b FROM x GROUP BY 1",
|
||||||
|
@ -1162,6 +1209,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT y x FROM my_table t",
|
"SELECT y x FROM my_table t",
|
||||||
write={
|
write={
|
||||||
|
"drill": "SELECT y AS x FROM my_table AS t",
|
||||||
"hive": "SELECT y AS x FROM my_table AS t",
|
"hive": "SELECT y AS x FROM my_table AS t",
|
||||||
"oracle": "SELECT y AS x FROM my_table t",
|
"oracle": "SELECT y AS x FROM my_table t",
|
||||||
"postgres": "SELECT y AS x FROM my_table AS t",
|
"postgres": "SELECT y AS x FROM my_table AS t",
|
||||||
|
@ -1230,3 +1278,36 @@ SELECT
|
||||||
},
|
},
|
||||||
pretty=True,
|
pretty=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_transactions(self):
|
||||||
|
self.validate_all(
|
||||||
|
"BEGIN TRANSACTION",
|
||||||
|
write={
|
||||||
|
"bigquery": "BEGIN TRANSACTION",
|
||||||
|
"mysql": "BEGIN",
|
||||||
|
"postgres": "BEGIN",
|
||||||
|
"presto": "START TRANSACTION",
|
||||||
|
"trino": "START TRANSACTION",
|
||||||
|
"redshift": "BEGIN",
|
||||||
|
"snowflake": "BEGIN",
|
||||||
|
"sqlite": "BEGIN TRANSACTION",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"BEGIN",
|
||||||
|
read={
|
||||||
|
"presto": "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE",
|
||||||
|
"trino": "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"BEGIN",
|
||||||
|
read={
|
||||||
|
"presto": "START TRANSACTION ISOLATION LEVEL REPEATABLE READ",
|
||||||
|
"trino": "START TRANSACTION ISOLATION LEVEL REPEATABLE READ",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"BEGIN IMMEDIATE TRANSACTION",
|
||||||
|
write={"sqlite": "BEGIN IMMEDIATE TRANSACTION"},
|
||||||
|
)
|
||||||
|
|
53
tests/dialects/test_drill.py
Normal file
53
tests/dialects/test_drill.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
from tests.dialects.test_dialect import Validator
|
||||||
|
|
||||||
|
|
||||||
|
class TestDrill(Validator):
|
||||||
|
dialect = "drill"
|
||||||
|
|
||||||
|
def test_string_literals(self):
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT '2021-01-01' + INTERVAL 1 MONTH",
|
||||||
|
write={
|
||||||
|
"mysql": "SELECT '2021-01-01' + INTERVAL 1 MONTH",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_quotes(self):
|
||||||
|
self.validate_all(
|
||||||
|
"'\\''",
|
||||||
|
write={
|
||||||
|
"duckdb": "''''",
|
||||||
|
"presto": "''''",
|
||||||
|
"hive": "'\\''",
|
||||||
|
"spark": "'\\''",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"'\"x\"'",
|
||||||
|
write={
|
||||||
|
"duckdb": "'\"x\"'",
|
||||||
|
"presto": "'\"x\"'",
|
||||||
|
"hive": "'\"x\"'",
|
||||||
|
"spark": "'\"x\"'",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"'\\\\a'",
|
||||||
|
read={
|
||||||
|
"presto": "'\\a'",
|
||||||
|
},
|
||||||
|
write={
|
||||||
|
"duckdb": "'\\a'",
|
||||||
|
"presto": "'\\a'",
|
||||||
|
"hive": "'\\\\a'",
|
||||||
|
"spark": "'\\\\a'",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_table_function(self):
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT * FROM table( dfs.`test_data.xlsx` (type => 'excel', sheetName => 'secondSheet'))",
|
||||||
|
write={
|
||||||
|
"drill": "SELECT * FROM table(dfs.`test_data.xlsx`(type => 'excel', sheetName => 'secondSheet'))",
|
||||||
|
},
|
||||||
|
)
|
|
@ -58,6 +58,16 @@ class TestMySQL(Validator):
|
||||||
self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'")
|
self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'")
|
||||||
self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci")
|
self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci")
|
||||||
self.validate_identity("SET autocommit = ON")
|
self.validate_identity("SET autocommit = ON")
|
||||||
|
self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL SERIALIZABLE")
|
||||||
|
self.validate_identity("SET TRANSACTION READ ONLY")
|
||||||
|
self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE")
|
||||||
|
self.validate_identity("SELECT SCHEMA()")
|
||||||
|
|
||||||
|
def test_canonical_functions(self):
|
||||||
|
self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)")
|
||||||
|
self.validate_identity("SELECT INSTR('str', 'substr')", "SELECT LOCATE('substr', 'str')")
|
||||||
|
self.validate_identity("SELECT UCASE('foo')", "SELECT UPPER('foo')")
|
||||||
|
self.validate_identity("SELECT LCASE('foo')", "SELECT LOWER('foo')")
|
||||||
|
|
||||||
def test_escape(self):
|
def test_escape(self):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
|
|
@ -177,6 +177,15 @@ class TestPresto(Validator):
|
||||||
"spark": "CREATE TABLE test USING PARQUET AS SELECT 1",
|
"spark": "CREATE TABLE test USING PARQUET AS SELECT 1",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"CREATE TABLE test STORED = 'PARQUET' AS SELECT 1",
|
||||||
|
write={
|
||||||
|
"duckdb": "CREATE TABLE test AS SELECT 1",
|
||||||
|
"presto": "CREATE TABLE test WITH (FORMAT='PARQUET') AS SELECT 1",
|
||||||
|
"hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
|
||||||
|
"spark": "CREATE TABLE test USING PARQUET AS SELECT 1",
|
||||||
|
},
|
||||||
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1",
|
"CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1",
|
||||||
write={
|
write={
|
||||||
|
@ -427,3 +436,69 @@ class TestPresto(Validator):
|
||||||
"spark": UnsupportedError,
|
"spark": UnsupportedError,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE")
|
||||||
|
self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ")
|
||||||
|
|
||||||
|
def test_encode_decode(self):
|
||||||
|
self.validate_all(
|
||||||
|
"TO_UTF8(x)",
|
||||||
|
write={
|
||||||
|
"spark": "ENCODE(x, 'utf-8')",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"FROM_UTF8(x)",
|
||||||
|
write={
|
||||||
|
"spark": "DECODE(x, 'utf-8')",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"ENCODE(x, 'utf-8')",
|
||||||
|
write={
|
||||||
|
"presto": "TO_UTF8(x)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"DECODE(x, 'utf-8')",
|
||||||
|
write={
|
||||||
|
"presto": "FROM_UTF8(x)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"ENCODE(x, 'invalid')",
|
||||||
|
write={
|
||||||
|
"presto": UnsupportedError,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"DECODE(x, 'invalid')",
|
||||||
|
write={
|
||||||
|
"presto": UnsupportedError,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_hex_unhex(self):
|
||||||
|
self.validate_all(
|
||||||
|
"TO_HEX(x)",
|
||||||
|
write={
|
||||||
|
"spark": "HEX(x)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"FROM_HEX(x)",
|
||||||
|
write={
|
||||||
|
"spark": "UNHEX(x)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"HEX(x)",
|
||||||
|
write={
|
||||||
|
"presto": "TO_HEX(x)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"UNHEX(x)",
|
||||||
|
write={
|
||||||
|
"presto": "FROM_HEX(x)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
|
@ -169,6 +169,17 @@ class TestSnowflake(Validator):
|
||||||
"snowflake": "SELECT a FROM test AS unpivot",
|
"snowflake": "SELECT a FROM test AS unpivot",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"trim(date_column, 'UTC')",
|
||||||
|
write={
|
||||||
|
"snowflake": "TRIM(date_column, 'UTC')",
|
||||||
|
"postgres": "TRIM('UTC' FROM date_column)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"trim(date_column)",
|
||||||
|
write={"snowflake": "TRIM(date_column)"},
|
||||||
|
)
|
||||||
|
|
||||||
def test_null_treatment(self):
|
def test_null_treatment(self):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
|
21
tests/fixtures/identity.sql
vendored
21
tests/fixtures/identity.sql
vendored
|
@ -122,13 +122,6 @@ x AT TIME ZONE 'UTC'
|
||||||
CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo'
|
CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo'
|
||||||
SET x = 1
|
SET x = 1
|
||||||
SET -v
|
SET -v
|
||||||
ADD JAR s3://bucket
|
|
||||||
ADD JARS s3://bucket, c
|
|
||||||
ADD FILE s3://file
|
|
||||||
ADD FILES s3://file, s3://a
|
|
||||||
ADD ARCHIVE s3://file
|
|
||||||
ADD ARCHIVES s3://file, s3://a
|
|
||||||
BEGIN IMMEDIATE TRANSACTION
|
|
||||||
COMMIT
|
COMMIT
|
||||||
USE db
|
USE db
|
||||||
NOT 1
|
NOT 1
|
||||||
|
@ -278,6 +271,7 @@ SELECT CEIL(a, b) FROM test
|
||||||
SELECT COUNT(a) FROM test
|
SELECT COUNT(a) FROM test
|
||||||
SELECT COUNT(1) FROM test
|
SELECT COUNT(1) FROM test
|
||||||
SELECT COUNT(*) FROM test
|
SELECT COUNT(*) FROM test
|
||||||
|
SELECT COUNT() FROM test
|
||||||
SELECT COUNT(DISTINCT a) FROM test
|
SELECT COUNT(DISTINCT a) FROM test
|
||||||
SELECT EXP(a) FROM test
|
SELECT EXP(a) FROM test
|
||||||
SELECT FLOOR(a) FROM test
|
SELECT FLOOR(a) FROM test
|
||||||
|
@ -372,6 +366,8 @@ WITH a AS (SELECT 1) SELECT 1 UNION SELECT 2
|
||||||
WITH a AS (SELECT 1) SELECT 1 INTERSECT SELECT 2
|
WITH a AS (SELECT 1) SELECT 1 INTERSECT SELECT 2
|
||||||
WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2
|
WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2
|
||||||
WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2
|
WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2
|
||||||
|
WITH sub_query AS (SELECT a FROM table) (SELECT a FROM sub_query)
|
||||||
|
WITH sub_query AS (SELECT a FROM table) ((((SELECT a FROM sub_query))))
|
||||||
(SELECT 1) UNION (SELECT 2)
|
(SELECT 1) UNION (SELECT 2)
|
||||||
(SELECT 1) UNION SELECT 2
|
(SELECT 1) UNION SELECT 2
|
||||||
SELECT 1 UNION (SELECT 2)
|
SELECT 1 UNION (SELECT 2)
|
||||||
|
@ -463,6 +459,7 @@ CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECI
|
||||||
CREATE TABLE z (a INT(11) DEFAULT UUID())
|
CREATE TABLE z (a INT(11) DEFAULT UUID())
|
||||||
CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id')
|
CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id')
|
||||||
CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1)
|
CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1)
|
||||||
|
CREATE TABLE z (a INT(11) NOT NULL DEFAULT -1)
|
||||||
CREATE TABLE z (a INT(11) NOT NULL COLLATE utf8_bin AUTO_INCREMENT)
|
CREATE TABLE z (a INT(11) NOT NULL COLLATE utf8_bin AUTO_INCREMENT)
|
||||||
CREATE TABLE z (a INT, PRIMARY KEY(a))
|
CREATE TABLE z (a INT, PRIMARY KEY(a))
|
||||||
CREATE TABLE z WITH (FORMAT='parquet') AS SELECT 1
|
CREATE TABLE z WITH (FORMAT='parquet') AS SELECT 1
|
||||||
|
@ -476,6 +473,9 @@ CREATE TABLE z AS ((WITH cte AS (SELECT 1) SELECT * FROM cte))
|
||||||
CREATE TABLE z (a INT UNIQUE)
|
CREATE TABLE z (a INT UNIQUE)
|
||||||
CREATE TABLE z (a INT AUTO_INCREMENT)
|
CREATE TABLE z (a INT AUTO_INCREMENT)
|
||||||
CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT)
|
CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT)
|
||||||
|
CREATE TABLE z (a INT REFERENCES parent(b, c))
|
||||||
|
CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id))
|
||||||
|
CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c))
|
||||||
CREATE TEMPORARY FUNCTION f
|
CREATE TEMPORARY FUNCTION f
|
||||||
CREATE TEMPORARY FUNCTION f AS 'g'
|
CREATE TEMPORARY FUNCTION f AS 'g'
|
||||||
CREATE FUNCTION f
|
CREATE FUNCTION f
|
||||||
|
@ -514,17 +514,23 @@ DELETE FROM x WHERE y > 1
|
||||||
DELETE FROM y
|
DELETE FROM y
|
||||||
DELETE FROM event USING sales WHERE event.eventid = sales.eventid
|
DELETE FROM event USING sales WHERE event.eventid = sales.eventid
|
||||||
DELETE FROM event USING sales, USING bla WHERE event.eventid = sales.eventid
|
DELETE FROM event USING sales, USING bla WHERE event.eventid = sales.eventid
|
||||||
|
DELETE FROM event USING sales AS s WHERE event.eventid = s.eventid
|
||||||
|
PREPARE statement
|
||||||
|
EXECUTE statement
|
||||||
DROP TABLE a
|
DROP TABLE a
|
||||||
DROP TABLE a.b
|
DROP TABLE a.b
|
||||||
DROP TABLE IF EXISTS a
|
DROP TABLE IF EXISTS a
|
||||||
DROP TABLE IF EXISTS a.b
|
DROP TABLE IF EXISTS a.b
|
||||||
|
DROP TABLE a CASCADE
|
||||||
DROP VIEW a
|
DROP VIEW a
|
||||||
DROP VIEW a.b
|
DROP VIEW a.b
|
||||||
DROP VIEW IF EXISTS a
|
DROP VIEW IF EXISTS a
|
||||||
DROP VIEW IF EXISTS a.b
|
DROP VIEW IF EXISTS a.b
|
||||||
SHOW TABLES
|
SHOW TABLES
|
||||||
USE db
|
USE db
|
||||||
|
BEGIN
|
||||||
ROLLBACK
|
ROLLBACK
|
||||||
|
ROLLBACK TO b
|
||||||
EXPLAIN SELECT * FROM x
|
EXPLAIN SELECT * FROM x
|
||||||
INSERT INTO x SELECT * FROM y
|
INSERT INTO x SELECT * FROM y
|
||||||
INSERT INTO x (SELECT * FROM y)
|
INSERT INTO x (SELECT * FROM y)
|
||||||
|
@ -581,3 +587,4 @@ SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */
|
||||||
SELECT x FROM a.b.c /* x */, e.f.g /* x */
|
SELECT x FROM a.b.c /* x */, e.f.g /* x */
|
||||||
SELECT FOO(x /* c */) /* FOO */, b /* b */
|
SELECT FOO(x /* c */) /* FOO */, b /* b */
|
||||||
SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */
|
SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */
|
||||||
|
SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b'
|
||||||
|
|
5
tests/fixtures/optimizer/canonicalize.sql
vendored
Normal file
5
tests/fixtures/optimizer/canonicalize.sql
vendored
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
SELECT w.d + w.e AS c FROM w AS w;
|
||||||
|
SELECT CONCAT(w.d, w.e) AS c FROM w AS w;
|
||||||
|
|
||||||
|
SELECT CAST(w.d AS DATE) > w.e AS a FROM w AS w;
|
||||||
|
SELECT CAST(w.d AS DATE) > CAST(w.e AS DATE) AS a FROM w AS w;
|
4
tests/fixtures/optimizer/optimizer.sql
vendored
4
tests/fixtures/optimizer/optimizer.sql
vendored
|
@ -119,7 +119,7 @@ GROUP BY
|
||||||
LIMIT 1;
|
LIMIT 1;
|
||||||
|
|
||||||
# title: Root subquery is union
|
# title: Root subquery is union
|
||||||
(SELECT b FROM x UNION SELECT b FROM y) LIMIT 1;
|
(SELECT b FROM x UNION SELECT b FROM y ORDER BY b) LIMIT 1;
|
||||||
(
|
(
|
||||||
SELECT
|
SELECT
|
||||||
"x"."b" AS "b"
|
"x"."b" AS "b"
|
||||||
|
@ -128,6 +128,8 @@ LIMIT 1;
|
||||||
SELECT
|
SELECT
|
||||||
"y"."b" AS "b"
|
"y"."b" AS "b"
|
||||||
FROM "y" AS "y"
|
FROM "y" AS "y"
|
||||||
|
ORDER BY
|
||||||
|
"b"
|
||||||
)
|
)
|
||||||
LIMIT 1;
|
LIMIT 1;
|
||||||
|
|
||||||
|
|
50
tests/fixtures/optimizer/tpc-h/tpc-h.sql
vendored
50
tests/fixtures/optimizer/tpc-h/tpc-h.sql
vendored
|
@ -15,7 +15,7 @@ select
|
||||||
from
|
from
|
||||||
lineitem
|
lineitem
|
||||||
where
|
where
|
||||||
CAST(l_shipdate AS DATE) <= date '1998-12-01' - interval '90' day
|
l_shipdate <= date '1998-12-01' - interval '90' day
|
||||||
group by
|
group by
|
||||||
l_returnflag,
|
l_returnflag,
|
||||||
l_linestatus
|
l_linestatus
|
||||||
|
@ -250,8 +250,8 @@ FROM "orders" AS "orders"
|
||||||
LEFT JOIN "_u_0" AS "_u_0"
|
LEFT JOIN "_u_0" AS "_u_0"
|
||||||
ON "_u_0"."l_orderkey" = "orders"."o_orderkey"
|
ON "_u_0"."l_orderkey" = "orders"."o_orderkey"
|
||||||
WHERE
|
WHERE
|
||||||
"orders"."o_orderdate" < CAST('1993-10-01' AS DATE)
|
CAST("orders"."o_orderdate" AS DATE) < CAST('1993-10-01' AS DATE)
|
||||||
AND "orders"."o_orderdate" >= CAST('1993-07-01' AS DATE)
|
AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1993-07-01' AS DATE)
|
||||||
AND NOT "_u_0"."l_orderkey" IS NULL
|
AND NOT "_u_0"."l_orderkey" IS NULL
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"orders"."o_orderpriority"
|
"orders"."o_orderpriority"
|
||||||
|
@ -293,8 +293,8 @@ SELECT
|
||||||
FROM "customer" AS "customer"
|
FROM "customer" AS "customer"
|
||||||
JOIN "orders" AS "orders"
|
JOIN "orders" AS "orders"
|
||||||
ON "customer"."c_custkey" = "orders"."o_custkey"
|
ON "customer"."c_custkey" = "orders"."o_custkey"
|
||||||
AND "orders"."o_orderdate" < CAST('1995-01-01' AS DATE)
|
AND CAST("orders"."o_orderdate" AS DATE) < CAST('1995-01-01' AS DATE)
|
||||||
AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE)
|
AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1994-01-01' AS DATE)
|
||||||
JOIN "region" AS "region"
|
JOIN "region" AS "region"
|
||||||
ON "region"."r_name" = 'ASIA'
|
ON "region"."r_name" = 'ASIA'
|
||||||
JOIN "nation" AS "nation"
|
JOIN "nation" AS "nation"
|
||||||
|
@ -328,8 +328,8 @@ FROM "lineitem" AS "lineitem"
|
||||||
WHERE
|
WHERE
|
||||||
"lineitem"."l_discount" BETWEEN 0.05 AND 0.07
|
"lineitem"."l_discount" BETWEEN 0.05 AND 0.07
|
||||||
AND "lineitem"."l_quantity" < 24
|
AND "lineitem"."l_quantity" < 24
|
||||||
AND "lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE)
|
AND CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE)
|
||||||
AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE);
|
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE);
|
||||||
|
|
||||||
--------------------------------------
|
--------------------------------------
|
||||||
-- TPC-H 7
|
-- TPC-H 7
|
||||||
|
@ -384,13 +384,13 @@ WITH "n1" AS (
|
||||||
SELECT
|
SELECT
|
||||||
"n1"."n_name" AS "supp_nation",
|
"n1"."n_name" AS "supp_nation",
|
||||||
"n2"."n_name" AS "cust_nation",
|
"n2"."n_name" AS "cust_nation",
|
||||||
EXTRACT(year FROM "lineitem"."l_shipdate") AS "l_year",
|
EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME)) AS "l_year",
|
||||||
SUM("lineitem"."l_extendedprice" * (
|
SUM("lineitem"."l_extendedprice" * (
|
||||||
1 - "lineitem"."l_discount"
|
1 - "lineitem"."l_discount"
|
||||||
)) AS "revenue"
|
)) AS "revenue"
|
||||||
FROM "supplier" AS "supplier"
|
FROM "supplier" AS "supplier"
|
||||||
JOIN "lineitem" AS "lineitem"
|
JOIN "lineitem" AS "lineitem"
|
||||||
ON "lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
|
ON CAST("lineitem"."l_shipdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
|
||||||
AND "supplier"."s_suppkey" = "lineitem"."l_suppkey"
|
AND "supplier"."s_suppkey" = "lineitem"."l_suppkey"
|
||||||
JOIN "orders" AS "orders"
|
JOIN "orders" AS "orders"
|
||||||
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
|
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
|
||||||
|
@ -409,7 +409,7 @@ JOIN "n1" AS "n2"
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"n1"."n_name",
|
"n1"."n_name",
|
||||||
"n2"."n_name",
|
"n2"."n_name",
|
||||||
EXTRACT(year FROM "lineitem"."l_shipdate")
|
EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME))
|
||||||
ORDER BY
|
ORDER BY
|
||||||
"supp_nation",
|
"supp_nation",
|
||||||
"cust_nation",
|
"cust_nation",
|
||||||
|
@ -456,7 +456,7 @@ group by
|
||||||
order by
|
order by
|
||||||
o_year;
|
o_year;
|
||||||
SELECT
|
SELECT
|
||||||
EXTRACT(year FROM "orders"."o_orderdate") AS "o_year",
|
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year",
|
||||||
SUM(
|
SUM(
|
||||||
CASE
|
CASE
|
||||||
WHEN "nation_2"."n_name" = 'BRAZIL'
|
WHEN "nation_2"."n_name" = 'BRAZIL'
|
||||||
|
@ -477,7 +477,7 @@ JOIN "customer" AS "customer"
|
||||||
ON "customer"."c_nationkey" = "nation"."n_nationkey"
|
ON "customer"."c_nationkey" = "nation"."n_nationkey"
|
||||||
JOIN "orders" AS "orders"
|
JOIN "orders" AS "orders"
|
||||||
ON "orders"."o_custkey" = "customer"."c_custkey"
|
ON "orders"."o_custkey" = "customer"."c_custkey"
|
||||||
AND "orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
|
AND CAST("orders"."o_orderdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
|
||||||
JOIN "lineitem" AS "lineitem"
|
JOIN "lineitem" AS "lineitem"
|
||||||
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
|
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
|
||||||
AND "part"."p_partkey" = "lineitem"."l_partkey"
|
AND "part"."p_partkey" = "lineitem"."l_partkey"
|
||||||
|
@ -488,7 +488,7 @@ JOIN "nation" AS "nation_2"
|
||||||
WHERE
|
WHERE
|
||||||
"part"."p_type" = 'ECONOMY ANODIZED STEEL'
|
"part"."p_type" = 'ECONOMY ANODIZED STEEL'
|
||||||
GROUP BY
|
GROUP BY
|
||||||
EXTRACT(year FROM "orders"."o_orderdate")
|
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME))
|
||||||
ORDER BY
|
ORDER BY
|
||||||
"o_year";
|
"o_year";
|
||||||
|
|
||||||
|
@ -529,7 +529,7 @@ order by
|
||||||
o_year desc;
|
o_year desc;
|
||||||
SELECT
|
SELECT
|
||||||
"nation"."n_name" AS "nation",
|
"nation"."n_name" AS "nation",
|
||||||
EXTRACT(year FROM "orders"."o_orderdate") AS "o_year",
|
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year",
|
||||||
SUM(
|
SUM(
|
||||||
"lineitem"."l_extendedprice" * (
|
"lineitem"."l_extendedprice" * (
|
||||||
1 - "lineitem"."l_discount"
|
1 - "lineitem"."l_discount"
|
||||||
|
@ -551,7 +551,7 @@ WHERE
|
||||||
"part"."p_name" LIKE '%green%'
|
"part"."p_name" LIKE '%green%'
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"nation"."n_name",
|
"nation"."n_name",
|
||||||
EXTRACT(year FROM "orders"."o_orderdate")
|
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME))
|
||||||
ORDER BY
|
ORDER BY
|
||||||
"nation",
|
"nation",
|
||||||
"o_year" DESC;
|
"o_year" DESC;
|
||||||
|
@ -606,8 +606,8 @@ SELECT
|
||||||
FROM "customer" AS "customer"
|
FROM "customer" AS "customer"
|
||||||
JOIN "orders" AS "orders"
|
JOIN "orders" AS "orders"
|
||||||
ON "customer"."c_custkey" = "orders"."o_custkey"
|
ON "customer"."c_custkey" = "orders"."o_custkey"
|
||||||
AND "orders"."o_orderdate" < CAST('1994-01-01' AS DATE)
|
AND CAST("orders"."o_orderdate" AS DATE) < CAST('1994-01-01' AS DATE)
|
||||||
AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE)
|
AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1993-10-01' AS DATE)
|
||||||
JOIN "lineitem" AS "lineitem"
|
JOIN "lineitem" AS "lineitem"
|
||||||
ON "lineitem"."l_orderkey" = "orders"."o_orderkey" AND "lineitem"."l_returnflag" = 'R'
|
ON "lineitem"."l_orderkey" = "orders"."o_orderkey" AND "lineitem"."l_returnflag" = 'R'
|
||||||
JOIN "nation" AS "nation"
|
JOIN "nation" AS "nation"
|
||||||
|
@ -740,8 +740,8 @@ SELECT
|
||||||
FROM "orders" AS "orders"
|
FROM "orders" AS "orders"
|
||||||
JOIN "lineitem" AS "lineitem"
|
JOIN "lineitem" AS "lineitem"
|
||||||
ON "lineitem"."l_commitdate" < "lineitem"."l_receiptdate"
|
ON "lineitem"."l_commitdate" < "lineitem"."l_receiptdate"
|
||||||
AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE)
|
AND CAST("lineitem"."l_receiptdate" AS DATE) < CAST('1995-01-01' AS DATE)
|
||||||
AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE)
|
AND CAST("lineitem"."l_receiptdate" AS DATE) >= CAST('1994-01-01' AS DATE)
|
||||||
AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate"
|
AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate"
|
||||||
AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP')
|
AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP')
|
||||||
AND "orders"."o_orderkey" = "lineitem"."l_orderkey"
|
AND "orders"."o_orderkey" = "lineitem"."l_orderkey"
|
||||||
|
@ -832,8 +832,8 @@ FROM "lineitem" AS "lineitem"
|
||||||
JOIN "part" AS "part"
|
JOIN "part" AS "part"
|
||||||
ON "lineitem"."l_partkey" = "part"."p_partkey"
|
ON "lineitem"."l_partkey" = "part"."p_partkey"
|
||||||
WHERE
|
WHERE
|
||||||
"lineitem"."l_shipdate" < CAST('1995-10-01' AS DATE)
|
CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-10-01' AS DATE)
|
||||||
AND "lineitem"."l_shipdate" >= CAST('1995-09-01' AS DATE);
|
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1995-09-01' AS DATE);
|
||||||
|
|
||||||
--------------------------------------
|
--------------------------------------
|
||||||
-- TPC-H 15
|
-- TPC-H 15
|
||||||
|
@ -876,8 +876,8 @@ WITH "revenue" AS (
|
||||||
)) AS "total_revenue"
|
)) AS "total_revenue"
|
||||||
FROM "lineitem" AS "lineitem"
|
FROM "lineitem" AS "lineitem"
|
||||||
WHERE
|
WHERE
|
||||||
"lineitem"."l_shipdate" < CAST('1996-04-01' AS DATE)
|
CAST("lineitem"."l_shipdate" AS DATE) < CAST('1996-04-01' AS DATE)
|
||||||
AND "lineitem"."l_shipdate" >= CAST('1996-01-01' AS DATE)
|
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1996-01-01' AS DATE)
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"lineitem"."l_suppkey"
|
"lineitem"."l_suppkey"
|
||||||
)
|
)
|
||||||
|
@ -1220,8 +1220,8 @@ WITH "_u_0" AS (
|
||||||
"lineitem"."l_suppkey" AS "_u_2"
|
"lineitem"."l_suppkey" AS "_u_2"
|
||||||
FROM "lineitem" AS "lineitem"
|
FROM "lineitem" AS "lineitem"
|
||||||
WHERE
|
WHERE
|
||||||
"lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE)
|
CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE)
|
||||||
AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE)
|
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE)
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"lineitem"."l_partkey",
|
"lineitem"."l_partkey",
|
||||||
"lineitem"."l_suppkey"
|
"lineitem"."l_suppkey"
|
||||||
|
|
7
tests/fixtures/pretty.sql
vendored
7
tests/fixtures/pretty.sql
vendored
|
@ -315,3 +315,10 @@ FROM (
|
||||||
WHERE
|
WHERE
|
||||||
id = 1
|
id = 1
|
||||||
) /* x */;
|
) /* x */;
|
||||||
|
SELECT * /* multi
|
||||||
|
line
|
||||||
|
comment */;
|
||||||
|
SELECT
|
||||||
|
* /* multi
|
||||||
|
line
|
||||||
|
comment */;
|
||||||
|
|
|
@ -57,79 +57,79 @@ SKIP_INTEGRATION = string_to_bool(os.environ.get("SKIP_INTEGRATION", "0").lower(
|
||||||
|
|
||||||
TPCH_SCHEMA = {
|
TPCH_SCHEMA = {
|
||||||
"lineitem": {
|
"lineitem": {
|
||||||
"l_orderkey": "uint64",
|
"l_orderkey": "bigint",
|
||||||
"l_partkey": "uint64",
|
"l_partkey": "bigint",
|
||||||
"l_suppkey": "uint64",
|
"l_suppkey": "bigint",
|
||||||
"l_linenumber": "uint64",
|
"l_linenumber": "bigint",
|
||||||
"l_quantity": "float64",
|
"l_quantity": "double",
|
||||||
"l_extendedprice": "float64",
|
"l_extendedprice": "double",
|
||||||
"l_discount": "float64",
|
"l_discount": "double",
|
||||||
"l_tax": "float64",
|
"l_tax": "double",
|
||||||
"l_returnflag": "string",
|
"l_returnflag": "string",
|
||||||
"l_linestatus": "string",
|
"l_linestatus": "string",
|
||||||
"l_shipdate": "date32",
|
"l_shipdate": "string",
|
||||||
"l_commitdate": "date32",
|
"l_commitdate": "string",
|
||||||
"l_receiptdate": "date32",
|
"l_receiptdate": "string",
|
||||||
"l_shipinstruct": "string",
|
"l_shipinstruct": "string",
|
||||||
"l_shipmode": "string",
|
"l_shipmode": "string",
|
||||||
"l_comment": "string",
|
"l_comment": "string",
|
||||||
},
|
},
|
||||||
"orders": {
|
"orders": {
|
||||||
"o_orderkey": "uint64",
|
"o_orderkey": "bigint",
|
||||||
"o_custkey": "uint64",
|
"o_custkey": "bigint",
|
||||||
"o_orderstatus": "string",
|
"o_orderstatus": "string",
|
||||||
"o_totalprice": "float64",
|
"o_totalprice": "double",
|
||||||
"o_orderdate": "date32",
|
"o_orderdate": "string",
|
||||||
"o_orderpriority": "string",
|
"o_orderpriority": "string",
|
||||||
"o_clerk": "string",
|
"o_clerk": "string",
|
||||||
"o_shippriority": "int32",
|
"o_shippriority": "int",
|
||||||
"o_comment": "string",
|
"o_comment": "string",
|
||||||
},
|
},
|
||||||
"customer": {
|
"customer": {
|
||||||
"c_custkey": "uint64",
|
"c_custkey": "bigint",
|
||||||
"c_name": "string",
|
"c_name": "string",
|
||||||
"c_address": "string",
|
"c_address": "string",
|
||||||
"c_nationkey": "uint64",
|
"c_nationkey": "bigint",
|
||||||
"c_phone": "string",
|
"c_phone": "string",
|
||||||
"c_acctbal": "float64",
|
"c_acctbal": "double",
|
||||||
"c_mktsegment": "string",
|
"c_mktsegment": "string",
|
||||||
"c_comment": "string",
|
"c_comment": "string",
|
||||||
},
|
},
|
||||||
"part": {
|
"part": {
|
||||||
"p_partkey": "uint64",
|
"p_partkey": "bigint",
|
||||||
"p_name": "string",
|
"p_name": "string",
|
||||||
"p_mfgr": "string",
|
"p_mfgr": "string",
|
||||||
"p_brand": "string",
|
"p_brand": "string",
|
||||||
"p_type": "string",
|
"p_type": "string",
|
||||||
"p_size": "int32",
|
"p_size": "int",
|
||||||
"p_container": "string",
|
"p_container": "string",
|
||||||
"p_retailprice": "float64",
|
"p_retailprice": "double",
|
||||||
"p_comment": "string",
|
"p_comment": "string",
|
||||||
},
|
},
|
||||||
"supplier": {
|
"supplier": {
|
||||||
"s_suppkey": "uint64",
|
"s_suppkey": "bigint",
|
||||||
"s_name": "string",
|
"s_name": "string",
|
||||||
"s_address": "string",
|
"s_address": "string",
|
||||||
"s_nationkey": "uint64",
|
"s_nationkey": "bigint",
|
||||||
"s_phone": "string",
|
"s_phone": "string",
|
||||||
"s_acctbal": "float64",
|
"s_acctbal": "double",
|
||||||
"s_comment": "string",
|
"s_comment": "string",
|
||||||
},
|
},
|
||||||
"partsupp": {
|
"partsupp": {
|
||||||
"ps_partkey": "uint64",
|
"ps_partkey": "bigint",
|
||||||
"ps_suppkey": "uint64",
|
"ps_suppkey": "bigint",
|
||||||
"ps_availqty": "int32",
|
"ps_availqty": "int",
|
||||||
"ps_supplycost": "float64",
|
"ps_supplycost": "double",
|
||||||
"ps_comment": "string",
|
"ps_comment": "string",
|
||||||
},
|
},
|
||||||
"nation": {
|
"nation": {
|
||||||
"n_nationkey": "uint64",
|
"n_nationkey": "bigint",
|
||||||
"n_name": "string",
|
"n_name": "string",
|
||||||
"n_regionkey": "uint64",
|
"n_regionkey": "bigint",
|
||||||
"n_comment": "string",
|
"n_comment": "string",
|
||||||
},
|
},
|
||||||
"region": {
|
"region": {
|
||||||
"r_regionkey": "uint64",
|
"r_regionkey": "bigint",
|
||||||
"r_name": "string",
|
"r_name": "string",
|
||||||
"r_comment": "string",
|
"r_comment": "string",
|
||||||
},
|
},
|
||||||
|
|
|
@ -1,12 +1,15 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
import duckdb
|
import duckdb
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pandas.testing import assert_frame_equal
|
from pandas.testing import assert_frame_equal
|
||||||
|
|
||||||
from sqlglot import exp, parse_one
|
from sqlglot import exp, parse_one
|
||||||
|
from sqlglot.errors import ExecuteError
|
||||||
from sqlglot.executor import execute
|
from sqlglot.executor import execute
|
||||||
from sqlglot.executor.python import Python
|
from sqlglot.executor.python import Python
|
||||||
|
from sqlglot.executor.table import Table, ensure_tables
|
||||||
from tests.helpers import (
|
from tests.helpers import (
|
||||||
FIXTURES_DIR,
|
FIXTURES_DIR,
|
||||||
SKIP_INTEGRATION,
|
SKIP_INTEGRATION,
|
||||||
|
@ -67,13 +70,399 @@ class TestExecutor(unittest.TestCase):
|
||||||
def to_csv(expression):
|
def to_csv(expression):
|
||||||
if isinstance(expression, exp.Table):
|
if isinstance(expression, exp.Table):
|
||||||
return parse_one(
|
return parse_one(
|
||||||
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}"
|
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}"
|
||||||
)
|
)
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
for sql, _ in self.sqls[0:3]:
|
for i, (sql, _) in enumerate(self.sqls[0:7]):
|
||||||
|
with self.subTest(f"tpch-h {i + 1}"):
|
||||||
a = self.cached_execute(sql)
|
a = self.cached_execute(sql)
|
||||||
sql = parse_one(sql).transform(to_csv).sql(pretty=True)
|
sql = parse_one(sql).transform(to_csv).sql(pretty=True)
|
||||||
table = execute(sql, TPCH_SCHEMA)
|
table = execute(sql, TPCH_SCHEMA)
|
||||||
b = pd.DataFrame(table.rows, columns=table.columns)
|
b = pd.DataFrame(table.rows, columns=table.columns)
|
||||||
assert_frame_equal(a, b, check_dtype=False)
|
assert_frame_equal(a, b, check_dtype=False)
|
||||||
|
|
||||||
|
def test_execute_callable(self):
|
||||||
|
tables = {
|
||||||
|
"x": [
|
||||||
|
{"a": "a", "b": "d"},
|
||||||
|
{"a": "b", "b": "e"},
|
||||||
|
{"a": "c", "b": "f"},
|
||||||
|
],
|
||||||
|
"y": [
|
||||||
|
{"b": "d", "c": "g"},
|
||||||
|
{"b": "e", "c": "h"},
|
||||||
|
{"b": "f", "c": "i"},
|
||||||
|
],
|
||||||
|
"z": [],
|
||||||
|
}
|
||||||
|
schema = {
|
||||||
|
"x": {
|
||||||
|
"a": "VARCHAR",
|
||||||
|
"b": "VARCHAR",
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"b": "VARCHAR",
|
||||||
|
"c": "VARCHAR",
|
||||||
|
},
|
||||||
|
"z": {"d": "VARCHAR"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for sql, cols, rows in [
|
||||||
|
("SELECT * FROM x", ["a", "b"], [("a", "d"), ("b", "e"), ("c", "f")]),
|
||||||
|
(
|
||||||
|
"SELECT * FROM x JOIN y ON x.b = y.b",
|
||||||
|
["a", "b", "b", "c"],
|
||||||
|
[("a", "d", "d", "g"), ("b", "e", "e", "h"), ("c", "f", "f", "i")],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT j.c AS d FROM x AS i JOIN y AS j ON i.b = j.b",
|
||||||
|
["d"],
|
||||||
|
[("g",), ("h",), ("i",)],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT CONCAT(x.a, y.c) FROM x JOIN y ON x.b = y.b WHERE y.b = 'e'",
|
||||||
|
["_col_0"],
|
||||||
|
[("bh",)],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT * FROM x JOIN y ON x.b = y.b WHERE y.b = 'e'",
|
||||||
|
["a", "b", "b", "c"],
|
||||||
|
[("b", "e", "e", "h")],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT * FROM z",
|
||||||
|
["d"],
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT d FROM z ORDER BY d",
|
||||||
|
["d"],
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT a FROM x WHERE x.a <> 'b'",
|
||||||
|
["a"],
|
||||||
|
[("a",), ("c",)],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT a AS i FROM x ORDER BY a",
|
||||||
|
["i"],
|
||||||
|
[("a",), ("b",), ("c",)],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT a AS i FROM x ORDER BY i",
|
||||||
|
["i"],
|
||||||
|
[("a",), ("b",), ("c",)],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT 100 - ORD(a) AS a, a AS i FROM x ORDER BY a",
|
||||||
|
["a", "i"],
|
||||||
|
[(1, "c"), (2, "b"), (3, "a")],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT a /* test */ FROM x LIMIT 1",
|
||||||
|
["a"],
|
||||||
|
[("a",)],
|
||||||
|
),
|
||||||
|
]:
|
||||||
|
with self.subTest(sql):
|
||||||
|
result = execute(sql, schema=schema, tables=tables)
|
||||||
|
self.assertEqual(result.columns, tuple(cols))
|
||||||
|
self.assertEqual(result.rows, rows)
|
||||||
|
|
||||||
|
def test_set_operations(self):
|
||||||
|
tables = {
|
||||||
|
"x": [
|
||||||
|
{"a": "a"},
|
||||||
|
{"a": "b"},
|
||||||
|
{"a": "c"},
|
||||||
|
],
|
||||||
|
"y": [
|
||||||
|
{"a": "b"},
|
||||||
|
{"a": "c"},
|
||||||
|
{"a": "d"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
schema = {
|
||||||
|
"x": {
|
||||||
|
"a": "VARCHAR",
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"a": "VARCHAR",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for sql, cols, rows in [
|
||||||
|
(
|
||||||
|
"SELECT a FROM x UNION ALL SELECT a FROM y",
|
||||||
|
["a"],
|
||||||
|
[("a",), ("b",), ("c",), ("b",), ("c",), ("d",)],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT a FROM x UNION SELECT a FROM y",
|
||||||
|
["a"],
|
||||||
|
[("a",), ("b",), ("c",), ("d",)],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT a FROM x EXCEPT SELECT a FROM y",
|
||||||
|
["a"],
|
||||||
|
[("a",)],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT a FROM x INTERSECT SELECT a FROM y",
|
||||||
|
["a"],
|
||||||
|
[("b",), ("c",)],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"""SELECT i.a
|
||||||
|
FROM (
|
||||||
|
SELECT a FROM x UNION SELECT a FROM y
|
||||||
|
) AS i
|
||||||
|
JOIN (
|
||||||
|
SELECT a FROM x UNION SELECT a FROM y
|
||||||
|
) AS j
|
||||||
|
ON i.a = j.a""",
|
||||||
|
["a"],
|
||||||
|
[("a",), ("b",), ("c",), ("d",)],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT 1 AS a UNION SELECT 2 AS a UNION SELECT 3 AS a",
|
||||||
|
["a"],
|
||||||
|
[(1,), (2,), (3,)],
|
||||||
|
),
|
||||||
|
]:
|
||||||
|
with self.subTest(sql):
|
||||||
|
result = execute(sql, schema=schema, tables=tables)
|
||||||
|
self.assertEqual(result.columns, tuple(cols))
|
||||||
|
self.assertEqual(set(result.rows), set(rows))
|
||||||
|
|
||||||
|
def test_execute_catalog_db_table(self):
|
||||||
|
tables = {
|
||||||
|
"catalog": {
|
||||||
|
"db": {
|
||||||
|
"x": [
|
||||||
|
{"a": "a"},
|
||||||
|
{"a": "b"},
|
||||||
|
{"a": "c"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
schema = {
|
||||||
|
"catalog": {
|
||||||
|
"db": {
|
||||||
|
"x": {
|
||||||
|
"a": "VARCHAR",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result1 = execute("SELECT * FROM x", schema=schema, tables=tables)
|
||||||
|
result2 = execute("SELECT * FROM catalog.db.x", schema=schema, tables=tables)
|
||||||
|
assert result1.columns == result2.columns
|
||||||
|
assert result1.rows == result2.rows
|
||||||
|
|
||||||
|
def test_execute_tables(self):
|
||||||
|
tables = {
|
||||||
|
"sushi": [
|
||||||
|
{"id": 1, "price": 1.0},
|
||||||
|
{"id": 2, "price": 2.0},
|
||||||
|
{"id": 3, "price": 3.0},
|
||||||
|
],
|
||||||
|
"order_items": [
|
||||||
|
{"sushi_id": 1, "order_id": 1},
|
||||||
|
{"sushi_id": 1, "order_id": 1},
|
||||||
|
{"sushi_id": 2, "order_id": 1},
|
||||||
|
{"sushi_id": 3, "order_id": 2},
|
||||||
|
],
|
||||||
|
"orders": [
|
||||||
|
{"id": 1, "user_id": 1},
|
||||||
|
{"id": 2, "user_id": 2},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
execute(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
o.user_id,
|
||||||
|
SUM(s.price) AS price
|
||||||
|
FROM orders o
|
||||||
|
JOIN order_items i
|
||||||
|
ON o.id = i.order_id
|
||||||
|
JOIN sushi s
|
||||||
|
ON i.sushi_id = s.id
|
||||||
|
GROUP BY o.user_id
|
||||||
|
""",
|
||||||
|
tables=tables,
|
||||||
|
).rows,
|
||||||
|
[
|
||||||
|
(1, 4.0),
|
||||||
|
(2, 3.0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
execute(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
o.id, x.*
|
||||||
|
FROM orders o
|
||||||
|
LEFT JOIN (
|
||||||
|
SELECT
|
||||||
|
1 AS id, 'b' AS x
|
||||||
|
UNION ALL
|
||||||
|
SELECT
|
||||||
|
3 AS id, 'c' AS x
|
||||||
|
) x
|
||||||
|
ON o.id = x.id
|
||||||
|
""",
|
||||||
|
tables=tables,
|
||||||
|
).rows,
|
||||||
|
[(1, 1, "b"), (2, None, None)],
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
execute(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
o.id, x.*
|
||||||
|
FROM orders o
|
||||||
|
RIGHT JOIN (
|
||||||
|
SELECT
|
||||||
|
1 AS id,
|
||||||
|
'b' AS x
|
||||||
|
UNION ALL
|
||||||
|
SELECT
|
||||||
|
3 AS id, 'c' AS x
|
||||||
|
) x
|
||||||
|
ON o.id = x.id
|
||||||
|
""",
|
||||||
|
tables=tables,
|
||||||
|
).rows,
|
||||||
|
[
|
||||||
|
(1, 1, "b"),
|
||||||
|
(None, 3, "c"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_table_depth_mismatch(self):
|
||||||
|
tables = {"table": []}
|
||||||
|
schema = {"db": {"table": {"col": "VARCHAR"}}}
|
||||||
|
with self.assertRaises(ExecuteError):
|
||||||
|
execute("SELECT * FROM table", schema=schema, tables=tables)
|
||||||
|
|
||||||
|
def test_tables(self):
|
||||||
|
tables = ensure_tables(
|
||||||
|
{
|
||||||
|
"catalog1": {
|
||||||
|
"db1": {
|
||||||
|
"t1": [
|
||||||
|
{"a": 1},
|
||||||
|
],
|
||||||
|
"t2": [
|
||||||
|
{"a": 1},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"db2": {
|
||||||
|
"t3": [
|
||||||
|
{"a": 1},
|
||||||
|
],
|
||||||
|
"t4": [
|
||||||
|
{"a": 1},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"catalog2": {
|
||||||
|
"db3": {
|
||||||
|
"t5": Table(columns=("a",), rows=[(1,)]),
|
||||||
|
"t6": Table(columns=("a",), rows=[(1,)]),
|
||||||
|
},
|
||||||
|
"db4": {
|
||||||
|
"t7": Table(columns=("a",), rows=[(1,)]),
|
||||||
|
"t8": Table(columns=("a",), rows=[(1,)]),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
t1 = tables.find(exp.table_(table="t1", db="db1", catalog="catalog1"))
|
||||||
|
self.assertEqual(t1.columns, ("a",))
|
||||||
|
self.assertEqual(t1.rows, [(1,)])
|
||||||
|
|
||||||
|
t8 = tables.find(exp.table_(table="t8"))
|
||||||
|
self.assertEqual(t1.columns, t8.columns)
|
||||||
|
self.assertEqual(t1.rows, t8.rows)
|
||||||
|
|
||||||
|
def test_static_queries(self):
|
||||||
|
for sql, cols, rows in [
|
||||||
|
("SELECT 1", ["_col_0"], [(1,)]),
|
||||||
|
("SELECT 1 + 2 AS x", ["x"], [(3,)]),
|
||||||
|
("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
|
||||||
|
("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
|
||||||
|
("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]),
|
||||||
|
]:
|
||||||
|
result = execute(sql)
|
||||||
|
self.assertEqual(result.columns, tuple(cols))
|
||||||
|
self.assertEqual(result.rows, rows)
|
||||||
|
|
||||||
|
def test_aggregate_without_group_by(self):
|
||||||
|
result = execute("SELECT SUM(x) FROM t", tables={"t": [{"x": 1}, {"x": 2}]})
|
||||||
|
self.assertEqual(result.columns, ("_col_0",))
|
||||||
|
self.assertEqual(result.rows, [(3,)])
|
||||||
|
|
||||||
|
def test_scalar_functions(self):
|
||||||
|
for sql, expected in [
|
||||||
|
("CONCAT('a', 'b')", "ab"),
|
||||||
|
("CONCAT('a', NULL)", None),
|
||||||
|
("CONCAT_WS('_', 'a', 'b')", "a_b"),
|
||||||
|
("STR_POSITION('bar', 'foobarbar')", 4),
|
||||||
|
("STR_POSITION('bar', 'foobarbar', 5)", 7),
|
||||||
|
("STR_POSITION(NULL, 'foobarbar')", None),
|
||||||
|
("STR_POSITION('bar', NULL)", None),
|
||||||
|
("UPPER('foo')", "FOO"),
|
||||||
|
("UPPER(NULL)", None),
|
||||||
|
("LOWER('FOO')", "foo"),
|
||||||
|
("LOWER(NULL)", None),
|
||||||
|
("IFNULL('a', 'b')", "a"),
|
||||||
|
("IFNULL(NULL, 'b')", "b"),
|
||||||
|
("IFNULL(NULL, NULL)", None),
|
||||||
|
("SUBSTRING('12345')", "12345"),
|
||||||
|
("SUBSTRING('12345', 3)", "345"),
|
||||||
|
("SUBSTRING('12345', 3, 0)", ""),
|
||||||
|
("SUBSTRING('12345', 3, 1)", "3"),
|
||||||
|
("SUBSTRING('12345', 3, 2)", "34"),
|
||||||
|
("SUBSTRING('12345', 3, 3)", "345"),
|
||||||
|
("SUBSTRING('12345', 3, 4)", "345"),
|
||||||
|
("SUBSTRING('12345', -3)", "345"),
|
||||||
|
("SUBSTRING('12345', -3, 0)", ""),
|
||||||
|
("SUBSTRING('12345', -3, 1)", "3"),
|
||||||
|
("SUBSTRING('12345', -3, 2)", "34"),
|
||||||
|
("SUBSTRING('12345', 0)", ""),
|
||||||
|
("SUBSTRING('12345', 0, 1)", ""),
|
||||||
|
("SUBSTRING(NULL)", None),
|
||||||
|
("SUBSTRING(NULL, 1)", None),
|
||||||
|
("CAST(1 AS TEXT)", "1"),
|
||||||
|
("CAST('1' AS LONG)", 1),
|
||||||
|
("CAST('1.1' AS FLOAT)", 1.1),
|
||||||
|
("COALESCE(NULL)", None),
|
||||||
|
("COALESCE(NULL, NULL)", None),
|
||||||
|
("COALESCE(NULL, 'b')", "b"),
|
||||||
|
("COALESCE('a', 'b')", "a"),
|
||||||
|
("1 << 1", 2),
|
||||||
|
("1 >> 1", 0),
|
||||||
|
("1 & 1", 1),
|
||||||
|
("1 | 1", 1),
|
||||||
|
("1 < 1", False),
|
||||||
|
("1 <= 1", True),
|
||||||
|
("1 > 1", False),
|
||||||
|
("1 >= 1", True),
|
||||||
|
("1 + NULL", None),
|
||||||
|
("IF(true, 1, 0)", 1),
|
||||||
|
("IF(false, 1, 0)", 0),
|
||||||
|
("CASE WHEN 0 = 1 THEN 'foo' ELSE 'bar' END", "bar"),
|
||||||
|
("CAST('2022-01-01' AS DATE) + INTERVAL '1' DAY", date(2022, 1, 2)),
|
||||||
|
]:
|
||||||
|
with self.subTest(sql):
|
||||||
|
result = execute(f"SELECT {sql}")
|
||||||
|
self.assertEqual(result.rows, [(expected,)])
|
||||||
|
|
|
@ -441,6 +441,9 @@ class TestExpressions(unittest.TestCase):
|
||||||
self.assertIsInstance(parse_one("VARIANCE(a)"), exp.Variance)
|
self.assertIsInstance(parse_one("VARIANCE(a)"), exp.Variance)
|
||||||
self.assertIsInstance(parse_one("VARIANCE_POP(a)"), exp.VariancePop)
|
self.assertIsInstance(parse_one("VARIANCE_POP(a)"), exp.VariancePop)
|
||||||
self.assertIsInstance(parse_one("YEAR(a)"), exp.Year)
|
self.assertIsInstance(parse_one("YEAR(a)"), exp.Year)
|
||||||
|
self.assertIsInstance(parse_one("BEGIN DEFERRED TRANSACTION"), exp.Transaction)
|
||||||
|
self.assertIsInstance(parse_one("COMMIT"), exp.Commit)
|
||||||
|
self.assertIsInstance(parse_one("ROLLBACK"), exp.Rollback)
|
||||||
|
|
||||||
def test_column(self):
|
def test_column(self):
|
||||||
dot = parse_one("a.b.c")
|
dot = parse_one("a.b.c")
|
||||||
|
@ -479,9 +482,9 @@ class TestExpressions(unittest.TestCase):
|
||||||
self.assertEqual(column.text("expression"), "c")
|
self.assertEqual(column.text("expression"), "c")
|
||||||
self.assertEqual(column.text("y"), "")
|
self.assertEqual(column.text("y"), "")
|
||||||
self.assertEqual(parse_one("select * from x.y").find(exp.Table).text("db"), "x")
|
self.assertEqual(parse_one("select * from x.y").find(exp.Table).text("db"), "x")
|
||||||
self.assertEqual(parse_one("select *").text("this"), "")
|
self.assertEqual(parse_one("select *").name, "")
|
||||||
self.assertEqual(parse_one("1 + 1").text("this"), "1")
|
self.assertEqual(parse_one("1 + 1").name, "1")
|
||||||
self.assertEqual(parse_one("'a'").text("this"), "a")
|
self.assertEqual(parse_one("'a'").name, "a")
|
||||||
|
|
||||||
def test_alias(self):
|
def test_alias(self):
|
||||||
self.assertEqual(alias("foo", "bar").sql(), "foo AS bar")
|
self.assertEqual(alias("foo", "bar").sql(), "foo AS bar")
|
||||||
|
@ -538,8 +541,8 @@ class TestExpressions(unittest.TestCase):
|
||||||
this=exp.Literal.string("TABLE_FORMAT"),
|
this=exp.Literal.string("TABLE_FORMAT"),
|
||||||
value=exp.to_identifier("test_format"),
|
value=exp.to_identifier("test_format"),
|
||||||
),
|
),
|
||||||
exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.NULL),
|
exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.null()),
|
||||||
exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.TRUE),
|
exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.true()),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -29,6 +29,7 @@ class TestOptimizer(unittest.TestCase):
|
||||||
CREATE TABLE x (a INT, b INT);
|
CREATE TABLE x (a INT, b INT);
|
||||||
CREATE TABLE y (b INT, c INT);
|
CREATE TABLE y (b INT, c INT);
|
||||||
CREATE TABLE z (b INT, c INT);
|
CREATE TABLE z (b INT, c INT);
|
||||||
|
CREATE TABLE w (d TEXT, e TEXT);
|
||||||
|
|
||||||
INSERT INTO x VALUES (1, 1);
|
INSERT INTO x VALUES (1, 1);
|
||||||
INSERT INTO x VALUES (2, 2);
|
INSERT INTO x VALUES (2, 2);
|
||||||
|
@ -47,6 +48,8 @@ class TestOptimizer(unittest.TestCase):
|
||||||
INSERT INTO y VALUES (4, 4);
|
INSERT INTO y VALUES (4, 4);
|
||||||
INSERT INTO y VALUES (5, 5);
|
INSERT INTO y VALUES (5, 5);
|
||||||
INSERT INTO y VALUES (null, null);
|
INSERT INTO y VALUES (null, null);
|
||||||
|
|
||||||
|
INSERT INTO w VALUES ('a', 'b');
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -64,6 +67,10 @@ class TestOptimizer(unittest.TestCase):
|
||||||
"b": "INT",
|
"b": "INT",
|
||||||
"c": "INT",
|
"c": "INT",
|
||||||
},
|
},
|
||||||
|
"w": {
|
||||||
|
"d": "TEXT",
|
||||||
|
"e": "TEXT",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def check_file(self, file, func, pretty=False, execute=False, **kwargs):
|
def check_file(self, file, func, pretty=False, execute=False, **kwargs):
|
||||||
|
@ -224,6 +231,18 @@ class TestOptimizer(unittest.TestCase):
|
||||||
def test_eliminate_subqueries(self):
|
def test_eliminate_subqueries(self):
|
||||||
self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries)
|
self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries)
|
||||||
|
|
||||||
|
def test_canonicalize(self):
|
||||||
|
optimize = partial(
|
||||||
|
optimizer.optimize,
|
||||||
|
rules=[
|
||||||
|
optimizer.qualify_tables.qualify_tables,
|
||||||
|
optimizer.qualify_columns.qualify_columns,
|
||||||
|
annotate_types,
|
||||||
|
optimizer.canonicalize.canonicalize,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self.check_file("canonicalize", optimize, schema=self.schema)
|
||||||
|
|
||||||
def test_tpch(self):
|
def test_tpch(self):
|
||||||
self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True)
|
self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True)
|
||||||
|
|
||||||
|
|
|
@ -41,12 +41,41 @@ class TestParser(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_command(self):
|
def test_command(self):
|
||||||
expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1")
|
expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1", read="hive")
|
||||||
self.assertEqual(len(expressions), 3)
|
self.assertEqual(len(expressions), 3)
|
||||||
self.assertEqual(expressions[0].sql(), "SET x = 1")
|
self.assertEqual(expressions[0].sql(), "SET x = 1")
|
||||||
self.assertEqual(expressions[1].sql(), "ADD JAR s3://a")
|
self.assertEqual(expressions[1].sql(), "ADD JAR s3://a")
|
||||||
self.assertEqual(expressions[2].sql(), "SELECT 1")
|
self.assertEqual(expressions[2].sql(), "SELECT 1")
|
||||||
|
|
||||||
|
def test_transactions(self):
|
||||||
|
expression = parse_one("BEGIN TRANSACTION")
|
||||||
|
self.assertIsNone(expression.this)
|
||||||
|
self.assertEqual(expression.args["modes"], [])
|
||||||
|
self.assertEqual(expression.sql(), "BEGIN")
|
||||||
|
|
||||||
|
expression = parse_one("START TRANSACTION", read="mysql")
|
||||||
|
self.assertIsNone(expression.this)
|
||||||
|
self.assertEqual(expression.args["modes"], [])
|
||||||
|
self.assertEqual(expression.sql(), "BEGIN")
|
||||||
|
|
||||||
|
expression = parse_one("BEGIN DEFERRED TRANSACTION")
|
||||||
|
self.assertEqual(expression.this, "DEFERRED")
|
||||||
|
self.assertEqual(expression.args["modes"], [])
|
||||||
|
self.assertEqual(expression.sql(), "BEGIN")
|
||||||
|
|
||||||
|
expression = parse_one(
|
||||||
|
"START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE", read="presto"
|
||||||
|
)
|
||||||
|
self.assertIsNone(expression.this)
|
||||||
|
self.assertEqual(expression.args["modes"][0], "READ WRITE")
|
||||||
|
self.assertEqual(expression.args["modes"][1], "ISOLATION LEVEL SERIALIZABLE")
|
||||||
|
self.assertEqual(expression.sql(), "BEGIN")
|
||||||
|
|
||||||
|
expression = parse_one("BEGIN", read="bigquery")
|
||||||
|
self.assertNotIsInstance(expression, exp.Transaction)
|
||||||
|
self.assertIsNone(expression.expression)
|
||||||
|
self.assertEqual(expression.sql(), "BEGIN")
|
||||||
|
|
||||||
def test_identify(self):
|
def test_identify(self):
|
||||||
expression = parse_one(
|
expression = parse_one(
|
||||||
"""
|
"""
|
||||||
|
@ -55,14 +84,14 @@ class TestParser(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
assert expression.expressions[0].text("this") == "a"
|
assert expression.expressions[0].name == "a"
|
||||||
assert expression.expressions[1].text("this") == "b"
|
assert expression.expressions[1].name == "b"
|
||||||
assert expression.expressions[2].text("alias") == "c"
|
assert expression.expressions[2].alias == "c"
|
||||||
assert expression.expressions[3].text("alias") == "D"
|
assert expression.expressions[3].alias == "D"
|
||||||
assert expression.expressions[4].text("alias") == "y|z'"
|
assert expression.expressions[4].alias == "y|z'"
|
||||||
table = expression.args["from"].expressions[0]
|
table = expression.args["from"].expressions[0]
|
||||||
assert table.args["this"].args["this"] == "z"
|
assert table.this.name == "z"
|
||||||
assert table.args["db"].args["this"] == "y"
|
assert table.args["db"].name == "y"
|
||||||
|
|
||||||
def test_multi(self):
|
def test_multi(self):
|
||||||
expressions = parse(
|
expressions = parse(
|
||||||
|
@ -72,8 +101,8 @@ class TestParser(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(expressions) == 2
|
assert len(expressions) == 2
|
||||||
assert expressions[0].args["from"].expressions[0].args["this"].args["this"] == "a"
|
assert expressions[0].args["from"].expressions[0].this.name == "a"
|
||||||
assert expressions[1].args["from"].expressions[0].args["this"].args["this"] == "b"
|
assert expressions[1].args["from"].expressions[0].this.name == "b"
|
||||||
|
|
||||||
def test_expression(self):
|
def test_expression(self):
|
||||||
ignore = Parser(error_level=ErrorLevel.IGNORE)
|
ignore = Parser(error_level=ErrorLevel.IGNORE)
|
||||||
|
@ -200,7 +229,7 @@ class TestParser(unittest.TestCase):
|
||||||
@patch("sqlglot.parser.logger")
|
@patch("sqlglot.parser.logger")
|
||||||
def test_comment_error_n(self, logger):
|
def test_comment_error_n(self, logger):
|
||||||
parse_one(
|
parse_one(
|
||||||
"""CREATE TABLE x
|
"""SUM
|
||||||
(
|
(
|
||||||
-- test
|
-- test
|
||||||
)""",
|
)""",
|
||||||
|
@ -208,19 +237,19 @@ class TestParser(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert_logger_contains(
|
assert_logger_contains(
|
||||||
"Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Schema'>. Line 4, Col: 1.",
|
"Required keyword: 'this' missing for <class 'sqlglot.expressions.Sum'>. Line 4, Col: 1.",
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("sqlglot.parser.logger")
|
@patch("sqlglot.parser.logger")
|
||||||
def test_comment_error_r(self, logger):
|
def test_comment_error_r(self, logger):
|
||||||
parse_one(
|
parse_one(
|
||||||
"""CREATE TABLE x (-- test\r)""",
|
"""SUM(-- test\r)""",
|
||||||
error_level=ErrorLevel.WARN,
|
error_level=ErrorLevel.WARN,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert_logger_contains(
|
assert_logger_contains(
|
||||||
"Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Schema'>. Line 2, Col: 1.",
|
"Required keyword: 'this' missing for <class 'sqlglot.expressions.Sum'>. Line 2, Col: 1.",
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ class TestTokens(unittest.TestCase):
|
||||||
("--comment\nfoo --test", "comment"),
|
("--comment\nfoo --test", "comment"),
|
||||||
("foo --comment", "comment"),
|
("foo --comment", "comment"),
|
||||||
("foo", None),
|
("foo", None),
|
||||||
|
("foo /*comment 1*/ /*comment 2*/", "comment 1"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for sql, comment in sql_comment:
|
for sql, comment in sql_comment:
|
||||||
|
|
|
@ -20,6 +20,13 @@ class TestTranspile(unittest.TestCase):
|
||||||
self.assertEqual(transpile(sql, **kwargs)[0], target)
|
self.assertEqual(transpile(sql, **kwargs)[0], target)
|
||||||
|
|
||||||
def test_alias(self):
|
def test_alias(self):
|
||||||
|
self.assertEqual(transpile("SELECT 1 current_time")[0], "SELECT 1 AS current_time")
|
||||||
|
self.assertEqual(
|
||||||
|
transpile("SELECT 1 current_timestamp")[0], "SELECT 1 AS current_timestamp"
|
||||||
|
)
|
||||||
|
self.assertEqual(transpile("SELECT 1 current_date")[0], "SELECT 1 AS current_date")
|
||||||
|
self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime")
|
||||||
|
|
||||||
for key in ("union", "filter", "over", "from", "join"):
|
for key in ("union", "filter", "over", "from", "join"):
|
||||||
with self.subTest(f"alias {key}"):
|
with self.subTest(f"alias {key}"):
|
||||||
self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}")
|
self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}")
|
||||||
|
@ -69,6 +76,10 @@ class TestTranspile(unittest.TestCase):
|
||||||
self.validate("SELECT 3>=3", "SELECT 3 >= 3")
|
self.validate("SELECT 3>=3", "SELECT 3 >= 3")
|
||||||
|
|
||||||
def test_comments(self):
|
def test_comments(self):
|
||||||
|
self.validate("SELECT */*comment*/", "SELECT * /* comment */")
|
||||||
|
self.validate(
|
||||||
|
"SELECT * FROM table /*comment 1*/ /*comment 2*/", "SELECT * FROM table /* comment 1 */"
|
||||||
|
)
|
||||||
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */")
|
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */")
|
||||||
self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo")
|
self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo")
|
||||||
self.validate("SELECT --!5\nx FROM foo", "/* !5 */ SELECT x FROM foo")
|
self.validate("SELECT --!5\nx FROM foo", "/* !5 */ SELECT x FROM foo")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue