1
0
Fork 0

Merging upstream version 10.0.8.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:54:32 +01:00
parent 407314e8d2
commit efc1e37108
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
67 changed files with 2461 additions and 840 deletions

View file

@ -7,21 +7,37 @@ v10.0.0
Changes: Changes:
- Breaking: replaced SQLGlot annotations with comments. Now comments can be preserved after transpilation, and they can appear in other places besides SELECT's expressions. - Breaking: replaced SQLGlot annotations with comments. Now comments can be preserved after transpilation, and they can appear in other places besides SELECT's expressions.
- Breaking: renamed list_get to seq_get. - Breaking: renamed list_get to seq_get.
- Breaking: activated mypy type checking for SQLGlot. - Breaking: activated mypy type checking for SQLGlot.
- New: Azure Databricks support. - New: Azure Databricks support.
- New: placeholders can now be replaced in an expression. - New: placeholders can now be replaced in an expression.
- New: null safe equal operator (<=>). - New: null safe equal operator (<=>).
- New: [SET statements](https://github.com/tobymao/sqlglot/pull/673) for MySQL. - New: [SET statements](https://github.com/tobymao/sqlglot/pull/673) for MySQL.
- New: [SHOW commands](https://dev.mysql.com/doc/refman/8.0/en/show.html) for MySQL. - New: [SHOW commands](https://dev.mysql.com/doc/refman/8.0/en/show.html) for MySQL.
- New: [FORMAT function](https://www.w3schools.com/sql/func_sqlserver_format.asp) for TSQL. - New: [FORMAT function](https://www.w3schools.com/sql/func_sqlserver_format.asp) for TSQL.
- New: CROSS APPLY / OUTER APPLY [support](https://github.com/tobymao/sqlglot/pull/641) for TSQL. - New: CROSS APPLY / OUTER APPLY [support](https://github.com/tobymao/sqlglot/pull/641) for TSQL.
- New: added formats for TSQL's [DATENAME/DATEPART functions](https://learn.microsoft.com/en-us/sql/t-sql/functions/datename-transact-sql?view=sql-server-ver16)
- New: added formats for TSQL's [DATENAME/DATEPART functions](https://learn.microsoft.com/en-us/sql/t-sql/functions/datename-transact-sql?view=sql-server-ver16).
- New: added styles for TSQL's [CONVERT function](https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16). - New: added styles for TSQL's [CONVERT function](https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16).
- Improvement: [refactored the schema](https://github.com/tobymao/sqlglot/pull/668) to be more lenient; before it needed to do an exact match of db.table, now it finds table if there are no ambiguities. - Improvement: [refactored the schema](https://github.com/tobymao/sqlglot/pull/668) to be more lenient; before it needed to do an exact match of db.table, now it finds table if there are no ambiguities.
- Improvement: allow functions to [inherit](https://github.com/tobymao/sqlglot/pull/674) their arguments' types, so that annotating CASE, IF etc. is possible. - Improvement: allow functions to [inherit](https://github.com/tobymao/sqlglot/pull/674) their arguments' types, so that annotating CASE, IF etc. is possible.
- Improvement: allow [joining with same names](https://github.com/tobymao/sqlglot/pull/660) in the python executor. - Improvement: allow [joining with same names](https://github.com/tobymao/sqlglot/pull/660) in the python executor.
- Improvement: the "using" field can now be set for the [join expression builders](https://github.com/tobymao/sqlglot/pull/636). - Improvement: the "using" field can now be set for the [join expression builders](https://github.com/tobymao/sqlglot/pull/636).
- Improvement: qualify_columns [now qualifies](https://github.com/tobymao/sqlglot/pull/635) only non-alias columns in the having clause. - Improvement: qualify_columns [now qualifies](https://github.com/tobymao/sqlglot/pull/635) only non-alias columns in the having clause.
v9.0.0 v9.0.0
@ -37,6 +53,7 @@ v8.0.0
Changes: Changes:
- Breaking : New add\_table method in Schema ABC. - Breaking : New add\_table method in Schema ABC.
- New: SQLGlot now supports the [PySpark](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dataframe) dataframe API. This is still relatively experimental. - New: SQLGlot now supports the [PySpark](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dataframe) dataframe API. This is still relatively experimental.
v7.1.0 v7.1.0
@ -45,8 +62,11 @@ v7.1.0
Changes: Changes:
- Improvement: Pretty generator now takes max\_text\_width which breaks segments into new lines - Improvement: Pretty generator now takes max\_text\_width which breaks segments into new lines
- New: exp.to\_table helper to turn table names into table expression objects - New: exp.to\_table helper to turn table names into table expression objects
- New: int[] type parsers - New: int[] type parsers
- New: annotations are now generated in sql - New: annotations are now generated in sql
v7.0.0 v7.0.0

View file

@ -21,7 +21,7 @@ Pull requests are the best way to propose changes to the codebase. We actively w
5. Issue that pull request and wait for it to be reviewed by a maintainer or contributor! 5. Issue that pull request and wait for it to be reviewed by a maintainer or contributor!
## Report bugs using Github's [issues](https://github.com/tobymao/sqlglot/issues) ## Report bugs using Github's [issues](https://github.com/tobymao/sqlglot/issues)
We use GitHub issues to track public bugs. Report a bug by [opening a new issue](). We use GitHub issues to track public bugs. Report a bug by opening a new issue.
**Great Bug Reports** tend to have: **Great Bug Reports** tend to have:

View file

@ -90,7 +90,7 @@ sqlglot.transpile("SELECT STRFTIME(x, '%y-%-m-%S')", read="duckdb", write="hive"
"SELECT DATE_FORMAT(x, 'yy-M-ss')" "SELECT DATE_FORMAT(x, 'yy-M-ss')"
``` ```
As another example, let's suppose that we want to read in a SQL query that contains a CTE and a cast to `REAL`, and then transpile it to Spark, which uses backticks as identifiers and `FLOAT` instead of `REAL`: As another example, let's suppose that we want to read in a SQL query that contains a CTE and a cast to `REAL`, and then transpile it to Spark, which uses backticks for identifiers and `FLOAT` instead of `REAL`:
```python ```python
import sqlglot import sqlglot
@ -376,12 +376,12 @@ print(Dialect["custom"])
[Benchmarks](benchmarks) run on Python 3.10.5 in seconds. [Benchmarks](benchmarks) run on Python 3.10.5 in seconds.
| Query | sqlglot | sqltree | sqlparse | moz_sql_parser | sqloxide | | Query | sqlglot | sqlfluff | sqltree | sqlparse | moz_sql_parser | sqloxide |
| --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- |
| tpch | 0.01178 (1.0) | 0.01173 (0.995) | 0.04676 (3.966) | 0.06800 (5.768) | 0.00094 (0.080) | | tpch | 0.01308 (1.0) | 1.60626 (122.7) | 0.01168 (0.893) | 0.04958 (3.791) | 0.08543 (6.531) | 0.00136 (0.104) |
| short | 0.00084 (1.0) | 0.00079 (0.948) | 0.00296 (3.524) | 0.00443 (5.266) | 0.00006 (0.072) | | short | 0.00109 (1.0) | 0.14134 (129.2) | 0.00099 (0.906) | 0.00342 (3.131) | 0.00652 (5.970) | 8.76621 (0.080) |
| long | 0.01102 (1.0) | 0.01044 (0.947) | 0.04349 (3.945) | 0.05998 (5.440) | 0.00084 (0.077) | | long | 0.01399 (1.0) | 2.12632 (151.9) | 0.01126 (0.805) | 0.04410 (3.151) | 0.06671 (4.767) | 0.00107 (0.076) |
| crazy | 0.03751 (1.0) | 0.03471 (0.925) | 11.0796 (295.3) | 1.03355 (27.55) | 0.00529 (0.141) | | crazy | 0.03969 (1.0) | 24.3777 (614.1) | 0.03917 (0.987) | 11.7043 (294.8) | 1.03280 (26.02) | 0.00625 (0.157) |
## Optional Dependencies ## Optional Dependencies

View file

@ -5,8 +5,10 @@ collections.Iterable = collections.abc.Iterable
import gc import gc
import timeit import timeit
import moz_sql_parser
import numpy as np import numpy as np
import sqlfluff
import moz_sql_parser
import sqloxide import sqloxide
import sqlparse import sqlparse
import sqltree import sqltree
@ -177,6 +179,10 @@ def sqloxide_parse(sql):
sqloxide.parse_sql(sql, dialect="ansi") sqloxide.parse_sql(sql, dialect="ansi")
def sqlfluff_parse(sql):
sqlfluff.parse(sql)
def border(columns): def border(columns):
columns = " | ".join(columns) columns = " | ".join(columns)
return f"| {columns} |" return f"| {columns} |"
@ -193,6 +199,7 @@ def diff(row, column):
libs = [ libs = [
"sqlglot", "sqlglot",
"sqlfluff",
"sqltree", "sqltree",
"sqlparse", "sqlparse",
"moz_sql_parser", "moz_sql_parser",
@ -206,7 +213,8 @@ for name, sql in {"tpch": tpch, "short": short, "long": long, "crazy": crazy}.it
for lib in libs: for lib in libs:
try: try:
row[lib] = np.mean(timeit.repeat(lambda: globals()[lib + "_parse"](sql), number=3)) row[lib] = np.mean(timeit.repeat(lambda: globals()[lib + "_parse"](sql), number=3))
except: except Exception as e:
print(e)
row[lib] = "error" row[lib] = "error"
columns = ["Query"] + libs columns = ["Query"] + libs

View file

@ -30,7 +30,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType from sqlglot.tokens import Tokenizer, TokenType
__version__ = "10.0.1" __version__ = "10.0.8"
pretty = False pretty = False

View file

@ -260,7 +260,10 @@ class Column:
""" """
if isinstance(dataType, DataType): if isinstance(dataType, DataType):
dataType = dataType.simpleString() dataType = dataType.simpleString()
new_expression = exp.Cast(this=self.column_expression, to=dataType) new_expression = exp.Cast(
this=self.column_expression,
to=sqlglot.parse_one(dataType, into=exp.DataType, read="spark"), # type: ignore
)
return Column(new_expression) return Column(new_expression)
def startswith(self, value: t.Union[str, Column]) -> Column: def startswith(self, value: t.Union[str, Column]) -> Column:

View file

@ -314,7 +314,13 @@ class DataFrame:
replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore
cache_table_name cache_table_name
) )
sqlglot.schema.add_table(cache_table_name, select_expression.named_selects) sqlglot.schema.add_table(
cache_table_name,
{
expression.alias_or_name: expression.type.name
for expression in select_expression.expressions
},
)
cache_storage_level = select_expression.args["cache_storage_level"] cache_storage_level = select_expression.args["cache_storage_level"]
options = [ options = [
exp.Literal.string("storageLevel"), exp.Literal.string("storageLevel"),

View file

@ -757,11 +757,15 @@ def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
def decode(col: ColumnOrName, charset: str) -> Column: def decode(col: ColumnOrName, charset: str) -> Column:
return Column.invoke_anonymous_function(col, "DECODE", lit(charset)) return Column.invoke_expression_over_column(
col, glotexp.Decode, charset=glotexp.Literal.string(charset)
)
def encode(col: ColumnOrName, charset: str) -> Column: def encode(col: ColumnOrName, charset: str) -> Column:
return Column.invoke_anonymous_function(col, "ENCODE", lit(charset)) return Column.invoke_expression_over_column(
col, glotexp.Encode, charset=glotexp.Literal.string(charset)
)
def format_number(col: ColumnOrName, d: int) -> Column: def format_number(col: ColumnOrName, d: int) -> Column:
@ -867,11 +871,11 @@ def bin(col: ColumnOrName) -> Column:
def hex(col: ColumnOrName) -> Column: def hex(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "HEX") return Column.invoke_expression_over_column(col, glotexp.Hex)
def unhex(col: ColumnOrName) -> Column: def unhex(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "UNHEX") return Column.invoke_expression_over_column(col, glotexp.Unhex)
def length(col: ColumnOrName) -> Column: def length(col: ColumnOrName) -> Column:
@ -939,11 +943,7 @@ def array_join(
def concat(*cols: ColumnOrName) -> Column: def concat(*cols: ColumnOrName) -> Column:
if len(cols) == 1: return Column.invoke_expression_over_column(None, glotexp.Concat, expressions=cols)
return Column.invoke_anonymous_function(cols[0], "CONCAT")
return Column.invoke_anonymous_function(
cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]]
)
def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column: def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:

View file

@ -88,14 +88,14 @@ class SparkSession:
"expressions": sel_columns, "expressions": sel_columns,
"from": exp.From( "from": exp.From(
expressions=[ expressions=[
exp.Subquery( exp.Values(
this=exp.Values(expressions=data_expressions), expressions=data_expressions,
alias=exp.TableAlias( alias=exp.TableAlias(
this=exp.to_identifier(self._auto_incrementing_name), this=exp.to_identifier(self._auto_incrementing_name),
columns=[exp.to_identifier(col_name) for col_name in column_mapping], columns=[exp.to_identifier(col_name) for col_name in column_mapping],
), ),
) ),
] ],
), ),
} }

View file

@ -2,6 +2,7 @@ from sqlglot.dialects.bigquery import BigQuery
from sqlglot.dialects.clickhouse import ClickHouse from sqlglot.dialects.clickhouse import ClickHouse
from sqlglot.dialects.databricks import Databricks from sqlglot.dialects.databricks import Databricks
from sqlglot.dialects.dialect import Dialect, Dialects from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.dialects.drill import Drill
from sqlglot.dialects.duckdb import DuckDB from sqlglot.dialects.duckdb import DuckDB
from sqlglot.dialects.hive import Hive from sqlglot.dialects.hive import Hive
from sqlglot.dialects.mysql import MySQL from sqlglot.dialects.mysql import MySQL

View file

@ -119,6 +119,8 @@ class BigQuery(Dialect):
"UNKNOWN": TokenType.NULL, "UNKNOWN": TokenType.NULL,
"WINDOW": TokenType.WINDOW, "WINDOW": TokenType.WINDOW,
"NOT DETERMINISTIC": TokenType.VOLATILE, "NOT DETERMINISTIC": TokenType.VOLATILE,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
} }
KEYWORDS.pop("DIV") KEYWORDS.pop("DIV")
@ -204,6 +206,15 @@ class BigQuery(Dialect):
EXPLICIT_UNION = True EXPLICIT_UNION = True
def transaction_sql(self, *_):
return "BEGIN TRANSACTION"
def commit_sql(self, *_):
return "COMMIT TRANSACTION"
def rollback_sql(self, *_):
return "ROLLBACK TRANSACTION"
def in_unnest_op(self, unnest): def in_unnest_op(self, unnest):
return self.sql(unnest) return self.sql(unnest)

View file

@ -32,6 +32,7 @@ class Dialects(str, Enum):
TRINO = "trino" TRINO = "trino"
TSQL = "tsql" TSQL = "tsql"
DATABRICKS = "databricks" DATABRICKS = "databricks"
DRILL = "drill"
class _Dialect(type): class _Dialect(type):
@ -362,3 +363,18 @@ def parse_date_delta(exp_class, unit_mapping=None):
return exp_class(this=this, expression=expression, unit=unit) return exp_class(this=this, expression=expression, unit=unit)
return inner_func return inner_func
def locate_to_strposition(args):
return exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
)
def strposition_to_local_sql(self, expression):
args = self.format_args(
expression.args.get("substr"), expression.this, expression.args.get("position")
)
return f"LOCATE({args})"

174
sqlglot/dialects/drill.py Normal file
View file

@ -0,0 +1,174 @@
from __future__ import annotations
import re
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
create_with_partitions_sql,
format_time_lambda,
no_pivot_sql,
no_trycast_sql,
rename_func,
str_position_sql,
)
from sqlglot.dialects.postgres import _lateral_sql
def _to_timestamp(args):
# TO_TIMESTAMP accepts either a single double argument or (text, text)
if len(args) == 1 and args[0].is_number:
return exp.UnixToTime.from_arg_list(args)
return format_time_lambda(exp.StrToTime, "drill")(args)
def _str_to_time_sql(self, expression):
return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
def _ts_or_ds_to_date_sql(self, expression):
time_format = self.format_time(expression)
if time_format and time_format not in (Drill.time_format, Drill.date_format):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
return f"CAST({self.sql(expression, 'this')} AS DATE)"
def _date_add_sql(kind):
def func(self, expression):
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
expression = self.sql(expression, "expression")
return f"DATE_{kind}({this}, INTERVAL '{expression}' {unit})"
return func
def if_sql(self, expression):
"""
Drill requires backticks around certain SQL reserved words, IF being one of them, This function
adds the backticks around the keyword IF.
Args:
self: The Drill dialect
expression: The input IF expression
Returns: The expression with IF in backticks.
"""
expressions = self.format_args(
expression.this, expression.args.get("true"), expression.args.get("false")
)
return f"`IF`({expressions})"
def _str_to_date(self, expression):
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Drill.date_format:
return f"CAST({this} AS DATE)"
return f"TO_DATE({this}, {time_format})"
class Drill(Dialect):
normalize_functions = None
null_ordering = "nulls_are_last"
date_format = "'yyyy-MM-dd'"
dateint_format = "'yyyyMMdd'"
time_format = "'yyyy-MM-dd HH:mm:ss'"
time_mapping = {
"y": "%Y",
"Y": "%Y",
"YYYY": "%Y",
"yyyy": "%Y",
"YY": "%y",
"yy": "%y",
"MMMM": "%B",
"MMM": "%b",
"MM": "%m",
"M": "%-m",
"dd": "%d",
"d": "%-d",
"HH": "%H",
"H": "%-H",
"hh": "%I",
"h": "%-I",
"mm": "%M",
"m": "%-M",
"ss": "%S",
"s": "%-S",
"SSSSSS": "%f",
"a": "%p",
"DD": "%j",
"D": "%-j",
"E": "%a",
"EE": "%a",
"EEE": "%a",
"EEEE": "%A",
"''T''": "T",
}
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'"]
IDENTIFIERS = ["`"]
ESCAPES = ["\\"]
ENCODE = "utf-8"
class Parser(parser.Parser):
STRICT_CAST = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
}
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER",
exp.DataType.Type.BINARY: "VARBINARY",
exp.DataType.Type.TEXT: "VARCHAR",
exp.DataType.Type.NCHAR: "VARCHAR",
exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.DATETIME: "TIMESTAMP",
}
ROOT_PROPERTIES = {exp.PartitionedByProperty}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.Lateral: _lateral_sql,
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
exp.ArraySize: rename_func("REPEATED_COUNT"),
exp.Create: create_with_partitions_sql,
exp.DateAdd: _date_add_sql("ADD"),
exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.DateSub: _date_add_sql("SUB"),
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})",
exp.If: if_sql,
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
exp.Pivot: no_pivot_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
exp.StrPosition: str_position_sql,
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), INTERVAL '{self.sql(e, 'expression')}' DAY)",
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
}
def normalize_func(self, name):
return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`"

View file

@ -55,13 +55,13 @@ def _array_sort_sql(self, expression):
def _sort_array_sql(self, expression): def _sort_array_sql(self, expression):
this = self.sql(expression, "this") this = self.sql(expression, "this")
if expression.args.get("asc") == exp.FALSE: if expression.args.get("asc") == exp.false():
return f"ARRAY_REVERSE_SORT({this})" return f"ARRAY_REVERSE_SORT({this})"
return f"ARRAY_SORT({this})" return f"ARRAY_SORT({this})"
def _sort_array_reverse(args): def _sort_array_reverse(args):
return exp.SortArray(this=seq_get(args, 0), asc=exp.FALSE) return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
def _struct_pack_sql(self, expression): def _struct_pack_sql(self, expression):

View file

@ -7,16 +7,19 @@ from sqlglot.dialects.dialect import (
create_with_partitions_sql, create_with_partitions_sql,
format_time_lambda, format_time_lambda,
if_sql, if_sql,
locate_to_strposition,
no_ilike_sql, no_ilike_sql,
no_recursive_cte_sql, no_recursive_cte_sql,
no_safe_divide_sql, no_safe_divide_sql,
no_trycast_sql, no_trycast_sql,
rename_func, rename_func,
strposition_to_local_sql,
struct_extract_sql, struct_extract_sql,
var_map_sql, var_map_sql,
) )
from sqlglot.helper import seq_get from sqlglot.helper import seq_get
from sqlglot.parser import parse_var_map from sqlglot.parser import parse_var_map
from sqlglot.tokens import TokenType
# (FuncType, Multiplier) # (FuncType, Multiplier)
DATE_DELTA_INTERVAL = { DATE_DELTA_INTERVAL = {
@ -181,6 +184,15 @@ class Hive(Dialect):
"F": "FLOAT", "F": "FLOAT",
"BD": "DECIMAL", "BD": "DECIMAL",
} }
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ADD ARCHIVE": TokenType.COMMAND,
"ADD ARCHIVES": TokenType.COMMAND,
"ADD FILE": TokenType.COMMAND,
"ADD FILES": TokenType.COMMAND,
"ADD JAR": TokenType.COMMAND,
"ADD JARS": TokenType.COMMAND,
}
class Parser(parser.Parser): class Parser(parser.Parser):
STRICT_CAST = False STRICT_CAST = False
@ -210,11 +222,7 @@ class Hive(Dialect):
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list, "GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
"LOCATE": lambda args: exp.StrPosition( "LOCATE": locate_to_strposition,
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
),
"LOG": ( "LOG": (
lambda args: exp.Log.from_arg_list(args) lambda args: exp.Log.from_arg_list(args)
if len(args) > 1 if len(args) > 1
@ -272,7 +280,7 @@ class Hive(Dialect):
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e), exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.SetAgg: rename_func("COLLECT_SET"), exp.SetAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
exp.StrPosition: lambda self, e: f"LOCATE({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})", exp.StrPosition: strposition_to_local_sql,
exp.StrToDate: _str_to_date, exp.StrToDate: _str_to_date,
exp.StrToTime: _str_to_time, exp.StrToTime: _str_to_time,
exp.StrToUnix: _str_to_unix, exp.StrToUnix: _str_to_unix,

View file

@ -5,10 +5,12 @@ import typing as t
from sqlglot import exp, generator, parser, tokens from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import ( from sqlglot.dialects.dialect import (
Dialect, Dialect,
locate_to_strposition,
no_ilike_sql, no_ilike_sql,
no_paren_current_date_sql, no_paren_current_date_sql,
no_tablesample_sql, no_tablesample_sql,
no_trycast_sql, no_trycast_sql,
strposition_to_local_sql,
) )
from sqlglot.helper import seq_get from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
@ -120,6 +122,7 @@ class MySQL(Dialect):
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"START": TokenType.BEGIN,
"SEPARATOR": TokenType.SEPARATOR, "SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER, "_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER,
@ -172,13 +175,18 @@ class MySQL(Dialect):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW} COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
class Parser(parser.Parser): class Parser(parser.Parser):
STRICT_CAST = False FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA}
FUNCTIONS = { FUNCTIONS = {
**parser.Parser.FUNCTIONS, **parser.Parser.FUNCTIONS,
"DATE_ADD": _date_add(exp.DateAdd), "DATE_ADD": _date_add(exp.DateAdd),
"DATE_SUB": _date_add(exp.DateSub), "DATE_SUB": _date_add(exp.DateSub),
"STR_TO_DATE": _str_to_date, "STR_TO_DATE": _str_to_date,
"LOCATE": locate_to_strposition,
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
"LEFT": lambda args: exp.Substring(
this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1)
),
} }
FUNCTION_PARSERS = { FUNCTION_PARSERS = {
@ -264,6 +272,7 @@ class MySQL(Dialect):
"CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"), "CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"), "CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"NAMES": lambda self: self._parse_set_item_names(), "NAMES": lambda self: self._parse_set_item_names(),
"TRANSACTION": lambda self: self._parse_set_transaction(),
} }
PROFILE_TYPES = { PROFILE_TYPES = {
@ -278,39 +287,48 @@ class MySQL(Dialect):
"SWAPS", "SWAPS",
} }
TRANSACTION_CHARACTERISTICS = {
"ISOLATION LEVEL REPEATABLE READ",
"ISOLATION LEVEL READ COMMITTED",
"ISOLATION LEVEL READ UNCOMMITTED",
"ISOLATION LEVEL SERIALIZABLE",
"READ WRITE",
"READ ONLY",
}
def _parse_show_mysql(self, this, target=False, full=None, global_=None): def _parse_show_mysql(self, this, target=False, full=None, global_=None):
if target: if target:
if isinstance(target, str): if isinstance(target, str):
self._match_text(target) self._match_text_seq(target)
target_id = self._parse_id_var() target_id = self._parse_id_var()
else: else:
target_id = None target_id = None
log = self._parse_string() if self._match_text("IN") else None log = self._parse_string() if self._match_text_seq("IN") else None
if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}: if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}:
position = self._parse_number() if self._match_text("FROM") else None position = self._parse_number() if self._match_text_seq("FROM") else None
db = None db = None
else: else:
position = None position = None
db = self._parse_id_var() if self._match_text("FROM") else None db = self._parse_id_var() if self._match_text_seq("FROM") else None
channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None channel = self._parse_id_var() if self._match_text_seq("FOR", "CHANNEL") else None
like = self._parse_string() if self._match_text("LIKE") else None like = self._parse_string() if self._match_text_seq("LIKE") else None
where = self._parse_where() where = self._parse_where()
if this == "PROFILE": if this == "PROFILE":
types = self._parse_csv(self._parse_show_profile_type) types = self._parse_csv(lambda: self._parse_var_from_options(self.PROFILE_TYPES))
query = self._parse_number() if self._match_text("FOR", "QUERY") else None query = self._parse_number() if self._match_text_seq("FOR", "QUERY") else None
offset = self._parse_number() if self._match_text("OFFSET") else None offset = self._parse_number() if self._match_text_seq("OFFSET") else None
limit = self._parse_number() if self._match_text("LIMIT") else None limit = self._parse_number() if self._match_text_seq("LIMIT") else None
else: else:
types, query = None, None types, query = None, None
offset, limit = self._parse_oldstyle_limit() offset, limit = self._parse_oldstyle_limit()
mutex = True if self._match_text("MUTEX") else None mutex = True if self._match_text_seq("MUTEX") else None
mutex = False if self._match_text("STATUS") else mutex mutex = False if self._match_text_seq("STATUS") else mutex
return self.expression( return self.expression(
exp.Show, exp.Show,
@ -331,16 +349,16 @@ class MySQL(Dialect):
**{"global": global_}, **{"global": global_},
) )
def _parse_show_profile_type(self): def _parse_var_from_options(self, options):
for type_ in self.PROFILE_TYPES: for option in options:
if self._match_text(*type_.split(" ")): if self._match_text_seq(*option.split(" ")):
return exp.Var(this=type_) return exp.Var(this=option)
return None return None
def _parse_oldstyle_limit(self): def _parse_oldstyle_limit(self):
limit = None limit = None
offset = None offset = None
if self._match_text("LIMIT"): if self._match_text_seq("LIMIT"):
parts = self._parse_csv(self._parse_number) parts = self._parse_csv(self._parse_number)
if len(parts) == 1: if len(parts) == 1:
limit = parts[0] limit = parts[0]
@ -353,6 +371,9 @@ class MySQL(Dialect):
return self._parse_set_item_assignment(kind=None) return self._parse_set_item_assignment(kind=None)
def _parse_set_item_assignment(self, kind): def _parse_set_item_assignment(self, kind):
if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"):
return self._parse_set_transaction(global_=kind == "GLOBAL")
left = self._parse_primary() or self._parse_id_var() left = self._parse_primary() or self._parse_id_var()
if not self._match(TokenType.EQ): if not self._match(TokenType.EQ):
self.raise_error("Expected =") self.raise_error("Expected =")
@ -381,7 +402,7 @@ class MySQL(Dialect):
def _parse_set_item_names(self): def _parse_set_item_names(self):
charset = self._parse_string() or self._parse_id_var() charset = self._parse_string() or self._parse_id_var()
if self._match_text("COLLATE"): if self._match_text_seq("COLLATE"):
collate = self._parse_string() or self._parse_id_var() collate = self._parse_string() or self._parse_id_var()
else: else:
collate = None collate = None
@ -392,6 +413,18 @@ class MySQL(Dialect):
kind="NAMES", kind="NAMES",
) )
def _parse_set_transaction(self, global_=False):
self._match_text_seq("TRANSACTION")
characteristics = self._parse_csv(
lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS)
)
return self.expression(
exp.SetItem,
expressions=characteristics,
kind="TRANSACTION",
**{"global": global_},
)
class Generator(generator.Generator): class Generator(generator.Generator):
NULL_ORDERING_SUPPORTED = False NULL_ORDERING_SUPPORTED = False
@ -411,6 +444,7 @@ class MySQL(Dialect):
exp.Trim: _trim_sql, exp.Trim: _trim_sql,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
exp.StrPosition: strposition_to_local_sql,
} }
ROOT_PROPERTIES = { ROOT_PROPERTIES = {
@ -481,9 +515,11 @@ class MySQL(Dialect):
kind = self.sql(expression, "kind") kind = self.sql(expression, "kind")
kind = f"{kind} " if kind else "" kind = f"{kind} " if kind else ""
this = self.sql(expression, "this") this = self.sql(expression, "this")
expressions = self.expressions(expression)
collate = self.sql(expression, "collate") collate = self.sql(expression, "collate")
collate = f" COLLATE {collate}" if collate else "" collate = f" COLLATE {collate}" if collate else ""
return f"{kind}{this}{collate}" global_ = "GLOBAL " if expression.args.get("global") else ""
return f"{global_}{kind}{this}{expressions}{collate}"
def set_sql(self, expression): def set_sql(self, expression):
return f"SET {self.expressions(expression)}" return f"SET {self.expressions(expression)}"

View file

@ -91,6 +91,7 @@ class Oracle(Dialect):
class Tokenizer(tokens.Tokenizer): class Tokenizer(tokens.Tokenizer):
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"START": TokenType.BEGIN,
"TOP": TokenType.TOP, "TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR, "VARCHAR2": TokenType.VARCHAR,
"NVARCHAR2": TokenType.NVARCHAR, "NVARCHAR2": TokenType.NVARCHAR,

View file

@ -164,11 +164,34 @@ class Postgres(Dialect):
BIT_STRINGS = [("b'", "'"), ("B'", "'")] BIT_STRINGS = [("b'", "'"), ("B'", "'")]
HEX_STRINGS = [("x'", "'"), ("X'", "'")] HEX_STRINGS = [("x'", "'"), ("X'", "'")]
BYTE_STRINGS = [("e'", "'"), ("E'", "'")] BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
CREATABLES = (
"AGGREGATE",
"CAST",
"CONVERSION",
"COLLATION",
"DEFAULT CONVERSION",
"CONSTRAINT",
"DOMAIN",
"EXTENSION",
"FOREIGN",
"FUNCTION",
"OPERATOR",
"POLICY",
"ROLE",
"RULE",
"SEQUENCE",
"TEXT",
"TRIGGER",
"TYPE",
"UNLOGGED",
"USER",
)
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"ALWAYS": TokenType.ALWAYS, "ALWAYS": TokenType.ALWAYS,
"BY DEFAULT": TokenType.BY_DEFAULT, "BY DEFAULT": TokenType.BY_DEFAULT,
"COMMENT ON": TokenType.COMMENT_ON,
"IDENTITY": TokenType.IDENTITY, "IDENTITY": TokenType.IDENTITY,
"GENERATED": TokenType.GENERATED, "GENERATED": TokenType.GENERATED,
"DOUBLE PRECISION": TokenType.DOUBLE, "DOUBLE PRECISION": TokenType.DOUBLE,
@ -176,6 +199,19 @@ class Postgres(Dialect):
"SERIAL": TokenType.SERIAL, "SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL, "SMALLSERIAL": TokenType.SMALLSERIAL,
"UUID": TokenType.UUID, "UUID": TokenType.UUID,
"TEMP": TokenType.TEMPORARY,
"BEGIN TRANSACTION": TokenType.BEGIN,
"BEGIN": TokenType.COMMAND,
"COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
"REFRESH": TokenType.COMMAND,
"REINDEX": TokenType.COMMAND,
"RESET": TokenType.COMMAND,
"REVOKE": TokenType.COMMAND,
"GRANT": TokenType.COMMAND,
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
} }
QUOTES = ["'", "$$"] QUOTES = ["'", "$$"]
SINGLE_TOKENS = { SINGLE_TOKENS = {

View file

@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
struct_extract_sql, struct_extract_sql,
) )
from sqlglot.dialects.mysql import MySQL from sqlglot.dialects.mysql import MySQL
from sqlglot.errors import UnsupportedError
from sqlglot.helper import seq_get from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
@ -61,8 +62,18 @@ def _initcap_sql(self, expression):
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))" return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
def _decode_sql(self, expression):
_ensure_utf8(expression.args.get("charset"))
return f"FROM_UTF8({self.sql(expression, 'this')})"
def _encode_sql(self, expression):
_ensure_utf8(expression.args.get("charset"))
return f"TO_UTF8({self.sql(expression, 'this')})"
def _no_sort_array(self, expression): def _no_sort_array(self, expression):
if expression.args.get("asc") == exp.FALSE: if expression.args.get("asc") == exp.false():
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END" comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
else: else:
comparator = None comparator = None
@ -72,7 +83,7 @@ def _no_sort_array(self, expression):
def _schema_sql(self, expression): def _schema_sql(self, expression):
if isinstance(expression.parent, exp.Property): if isinstance(expression.parent, exp.Property):
columns = ", ".join(f"'{c.text('this')}'" for c in expression.expressions) columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
return f"ARRAY[{columns}]" return f"ARRAY[{columns}]"
for schema in expression.parent.find_all(exp.Schema): for schema in expression.parent.find_all(exp.Schema):
@ -106,6 +117,11 @@ def _ts_or_ds_add_sql(self, expression):
return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))" return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
def _ensure_utf8(charset):
if charset.name.lower() != "utf-8":
raise UnsupportedError(f"Unsupported charset {charset}")
class Presto(Dialect): class Presto(Dialect):
index_offset = 1 index_offset = 1
null_ordering = "nulls_are_last" null_ordering = "nulls_are_last"
@ -115,6 +131,7 @@ class Presto(Dialect):
class Tokenizer(tokens.Tokenizer): class Tokenizer(tokens.Tokenizer):
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"START": TokenType.BEGIN,
"ROW": TokenType.STRUCT, "ROW": TokenType.STRUCT,
} }
@ -140,6 +157,14 @@ class Presto(Dialect):
"STRPOS": exp.StrPosition.from_arg_list, "STRPOS": exp.StrPosition.from_arg_list,
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list, "TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"FROM_HEX": exp.Unhex.from_arg_list,
"TO_HEX": exp.Hex.from_arg_list,
"TO_UTF8": lambda args: exp.Encode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
),
"FROM_UTF8": lambda args: exp.Decode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
),
} }
class Generator(generator.Generator): class Generator(generator.Generator):
@ -187,7 +212,10 @@ class Presto(Dialect):
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)", exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)", exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
exp.Decode: _decode_sql,
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)", exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.Encode: _encode_sql,
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql, exp.If: if_sql,
exp.ILike: no_ilike_sql, exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql, exp.Initcap: _initcap_sql,
@ -212,7 +240,13 @@ class Presto(Dialect):
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql, exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsToDate: _ts_or_ds_to_date_sql, exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.Unhex: rename_func("FROM_HEX"),
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})", exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
} }
def transaction_sql(self, expression):
modes = expression.args.get("modes")
modes = f" {', '.join(modes)}" if modes else ""
return f"START TRANSACTION{modes}"

View file

@ -148,6 +148,7 @@ class Snowflake(Dialect):
**parser.Parser.FUNCTION_PARSERS, **parser.Parser.FUNCTION_PARSERS,
"DATE_PART": _parse_date_part, "DATE_PART": _parse_date_part,
} }
FUNCTION_PARSERS.pop("TRIM")
FUNC_TOKENS = { FUNC_TOKENS = {
*parser.Parser.FUNC_TOKENS, *parser.Parser.FUNC_TOKENS,
@ -203,6 +204,7 @@ class Snowflake(Dialect):
exp.StrPosition: rename_func("POSITION"), exp.StrPosition: rename_func("POSITION"),
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}", exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}", exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
} }
TYPE_MAPPING = { TYPE_MAPPING = {

View file

@ -63,3 +63,8 @@ class SQLite(Dialect):
exp.TableSample: no_tablesample_sql, exp.TableSample: no_tablesample_sql,
exp.TryCast: no_trycast_sql, exp.TryCast: no_trycast_sql,
} }
def transaction_sql(self, expression):
this = expression.this
this = f" {this}" if this else ""
return f"BEGIN{this} TRANSACTION"

View file

@ -248,7 +248,7 @@ class TSQL(Dialect):
def _parse_convert(self, strict): def _parse_convert(self, strict):
to = self._parse_types() to = self._parse_types()
self._match(TokenType.COMMA) self._match(TokenType.COMMA)
this = self._parse_column() this = self._parse_conjunction()
# Retrieve length of datatype and override to default if not specified # Retrieve length of datatype and override to default if not specified
if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:

View file

@ -1,3 +1,6 @@
from __future__ import annotations
import typing as t
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from heapq import heappop, heappush from heapq import heappop, heappush
@ -6,6 +9,10 @@ from sqlglot import Dialect
from sqlglot import expressions as exp from sqlglot import expressions as exp
from sqlglot.helper import ensure_collection from sqlglot.helper import ensure_collection
if t.TYPE_CHECKING:
T = t.TypeVar("T")
Edit = t.Union[Insert, Remove, Move, Update, Keep]
@dataclass(frozen=True) @dataclass(frozen=True)
class Insert: class Insert:
@ -44,7 +51,7 @@ class Keep:
target: exp.Expression target: exp.Expression
def diff(source, target): def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
""" """
Returns the list of changes between the source and the target expressions. Returns the list of changes between the source and the target expressions.
@ -89,25 +96,25 @@ class ChangeDistiller:
Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf. Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf.
""" """
def __init__(self, f=0.6, t=0.6): def __init__(self, f: float = 0.6, t: float = 0.6) -> None:
self.f = f self.f = f
self.t = t self.t = t
self._sql_generator = Dialect().generator() self._sql_generator = Dialect().generator()
def diff(self, source, target): def diff(self, source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
self._source = source self._source = source
self._target = target self._target = target
self._source_index = {id(n[0]): n[0] for n in source.bfs()} self._source_index = {id(n[0]): n[0] for n in source.bfs()}
self._target_index = {id(n[0]): n[0] for n in target.bfs()} self._target_index = {id(n[0]): n[0] for n in target.bfs()}
self._unmatched_source_nodes = set(self._source_index) self._unmatched_source_nodes = set(self._source_index)
self._unmatched_target_nodes = set(self._target_index) self._unmatched_target_nodes = set(self._target_index)
self._bigram_histo_cache = {} self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
matching_set = self._compute_matching_set() matching_set = self._compute_matching_set()
return self._generate_edit_script(matching_set) return self._generate_edit_script(matching_set)
def _generate_edit_script(self, matching_set): def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.List[Edit]:
edit_script = [] edit_script: t.List[Edit] = []
for removed_node_id in self._unmatched_source_nodes: for removed_node_id in self._unmatched_source_nodes:
edit_script.append(Remove(self._source_index[removed_node_id])) edit_script.append(Remove(self._source_index[removed_node_id]))
for inserted_node_id in self._unmatched_target_nodes: for inserted_node_id in self._unmatched_target_nodes:
@ -125,7 +132,9 @@ class ChangeDistiller:
return edit_script return edit_script
def _generate_move_edits(self, source, target, matching_set): def _generate_move_edits(
self, source: exp.Expression, target: exp.Expression, matching_set: t.Set[t.Tuple[int, int]]
) -> t.List[Move]:
source_args = [id(e) for e in _expression_only_args(source)] source_args = [id(e) for e in _expression_only_args(source)]
target_args = [id(e) for e in _expression_only_args(target)] target_args = [id(e) for e in _expression_only_args(target)]
@ -138,7 +147,7 @@ class ChangeDistiller:
return move_edits return move_edits
def _compute_matching_set(self): def _compute_matching_set(self) -> t.Set[t.Tuple[int, int]]:
leaves_matching_set = self._compute_leaf_matching_set() leaves_matching_set = self._compute_leaf_matching_set()
matching_set = leaves_matching_set.copy() matching_set = leaves_matching_set.copy()
@ -183,8 +192,8 @@ class ChangeDistiller:
return matching_set return matching_set
def _compute_leaf_matching_set(self): def _compute_leaf_matching_set(self) -> t.Set[t.Tuple[int, int]]:
candidate_matchings = [] candidate_matchings: t.List[t.Tuple[float, int, exp.Expression, exp.Expression]] = []
source_leaves = list(_get_leaves(self._source)) source_leaves = list(_get_leaves(self._source))
target_leaves = list(_get_leaves(self._target)) target_leaves = list(_get_leaves(self._target))
for source_leaf in source_leaves: for source_leaf in source_leaves:
@ -216,7 +225,7 @@ class ChangeDistiller:
return matching_set return matching_set
def _dice_coefficient(self, source, target): def _dice_coefficient(self, source: exp.Expression, target: exp.Expression) -> float:
source_histo = self._bigram_histo(source) source_histo = self._bigram_histo(source)
target_histo = self._bigram_histo(target) target_histo = self._bigram_histo(target)
@ -231,13 +240,13 @@ class ChangeDistiller:
return 2 * overlap_len / total_grams return 2 * overlap_len / total_grams
def _bigram_histo(self, expression): def _bigram_histo(self, expression: exp.Expression) -> t.DefaultDict[str, int]:
if id(expression) in self._bigram_histo_cache: if id(expression) in self._bigram_histo_cache:
return self._bigram_histo_cache[id(expression)] return self._bigram_histo_cache[id(expression)]
expression_str = self._sql_generator.generate(expression) expression_str = self._sql_generator.generate(expression)
count = max(0, len(expression_str) - 1) count = max(0, len(expression_str) - 1)
bigram_histo = defaultdict(int) bigram_histo: t.DefaultDict[str, int] = defaultdict(int)
for i in range(count): for i in range(count):
bigram_histo[expression_str[i : i + 2]] += 1 bigram_histo[expression_str[i : i + 2]] += 1
@ -245,7 +254,7 @@ class ChangeDistiller:
return bigram_histo return bigram_histo
def _get_leaves(expression): def _get_leaves(expression: exp.Expression) -> t.Generator[exp.Expression, None, None]:
has_child_exprs = False has_child_exprs = False
for a in expression.args.values(): for a in expression.args.values():
@ -258,7 +267,7 @@ def _get_leaves(expression):
yield expression yield expression
def _is_same_type(source, target): def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool:
if type(source) is type(target): if type(source) is type(target):
if isinstance(source, exp.Join): if isinstance(source, exp.Join):
return source.args.get("side") == target.args.get("side") return source.args.get("side") == target.args.get("side")
@ -271,15 +280,17 @@ def _is_same_type(source, target):
return False return False
def _expression_only_args(expression): def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]:
args = [] args: t.List[t.Union[exp.Expression, t.List]] = []
if expression: if expression:
for a in expression.args.values(): for a in expression.args.values():
args.extend(ensure_collection(a)) args.extend(ensure_collection(a))
return [a for a in args if isinstance(a, exp.Expression)] return [a for a in args if isinstance(a, exp.Expression)]
def _lcs(seq_a, seq_b, equal): def _lcs(
seq_a: t.Sequence[T], seq_b: t.Sequence[T], equal: t.Callable[[T, T], bool]
) -> t.Sequence[t.Optional[T]]:
"""Calculates the longest common subsequence""" """Calculates the longest common subsequence"""
len_a = len(seq_a) len_a = len(seq_a)
@ -289,14 +300,14 @@ def _lcs(seq_a, seq_b, equal):
for i in range(len_a + 1): for i in range(len_a + 1):
for j in range(len_b + 1): for j in range(len_b + 1):
if i == 0 or j == 0: if i == 0 or j == 0:
lcs_result[i][j] = [] lcs_result[i][j] = [] # type: ignore
elif equal(seq_a[i - 1], seq_b[j - 1]): elif equal(seq_a[i - 1], seq_b[j - 1]):
lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]] lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]] # type: ignore
else: else:
lcs_result[i][j] = ( lcs_result[i][j] = (
lcs_result[i - 1][j] lcs_result[i - 1][j]
if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1]) if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1]) # type: ignore
else lcs_result[i][j - 1] else lcs_result[i][j - 1]
) )
return lcs_result[len_a][len_b] return lcs_result[len_a][len_b] # type: ignore

View file

@ -37,6 +37,10 @@ class SchemaError(SqlglotError):
pass pass
class ExecuteError(SqlglotError):
pass
def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str: def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str:
msg = [str(e) for e in errors[:maximum]] msg = [str(e) for e in errors[:maximum]]
remaining = len(errors) - maximum remaining = len(errors) - maximum

View file

@ -1,20 +1,23 @@
import logging import logging
import time import time
from sqlglot import parse_one from sqlglot import maybe_parse
from sqlglot.errors import ExecuteError
from sqlglot.executor.python import PythonExecutor from sqlglot.executor.python import PythonExecutor
from sqlglot.executor.table import Table, ensure_tables
from sqlglot.optimizer import optimize from sqlglot.optimizer import optimize
from sqlglot.planner import Plan from sqlglot.planner import Plan
from sqlglot.schema import ensure_schema
logger = logging.getLogger("sqlglot") logger = logging.getLogger("sqlglot")
def execute(sql, schema, read=None): def execute(sql, schema=None, read=None, tables=None):
""" """
Run a sql query against data. Run a sql query against data.
Args: Args:
sql (str): a sql statement sql (str|sqlglot.Expression): a sql statement
schema (dict|sqlglot.optimizer.Schema): database schema. schema (dict|sqlglot.optimizer.Schema): database schema.
This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
the following forms: the following forms:
@ -23,10 +26,20 @@ def execute(sql, schema, read=None):
3. {catalog: {db: {table: {col: type}}}} 3. {catalog: {db: {table: {col: type}}}}
read (str): the SQL dialect to apply during parsing read (str): the SQL dialect to apply during parsing
(eg. "spark", "hive", "presto", "mysql"). (eg. "spark", "hive", "presto", "mysql").
tables (dict): additional tables to register.
Returns: Returns:
sqlglot.executor.Table: Simple columnar data structure. sqlglot.executor.Table: Simple columnar data structure.
""" """
expression = parse_one(sql, read=read) tables = ensure_tables(tables)
if not schema:
schema = {
name: {column: type(table[0][column]).__name__ for column in table.columns}
for name, table in tables.mapping.items()
}
schema = ensure_schema(schema)
if tables.supported_table_args and tables.supported_table_args != schema.supported_table_args:
raise ExecuteError("Tables must support the same table args as schema")
expression = maybe_parse(sql, dialect=read)
now = time.time() now = time.time()
expression = optimize(expression, schema, leave_tables_isolated=True) expression = optimize(expression, schema, leave_tables_isolated=True)
logger.debug("Optimization finished: %f", time.time() - now) logger.debug("Optimization finished: %f", time.time() - now)
@ -34,6 +47,6 @@ def execute(sql, schema, read=None):
plan = Plan(expression) plan = Plan(expression)
logger.debug("Logical Plan: %s", plan) logger.debug("Logical Plan: %s", plan)
now = time.time() now = time.time()
result = PythonExecutor().execute(plan) result = PythonExecutor(tables=tables).execute(plan)
logger.debug("Query finished: %f", time.time() - now) logger.debug("Query finished: %f", time.time() - now)
return result return result

View file

@ -1,5 +1,12 @@
from __future__ import annotations
import typing as t
from sqlglot.executor.env import ENV from sqlglot.executor.env import ENV
if t.TYPE_CHECKING:
from sqlglot.executor.table import Table, TableIter
class Context: class Context:
""" """
@ -12,14 +19,14 @@ class Context:
evaluation of aggregation functions. evaluation of aggregation functions.
""" """
def __init__(self, tables, env=None): def __init__(self, tables: t.Dict[str, Table], env: t.Optional[t.Dict] = None) -> None:
""" """
Args Args
tables (dict): table_name -> Table, representing the scope of the current execution context tables: representing the scope of the current execution context.
env (Optional[dict]): dictionary of functions within the execution context env: dictionary of functions within the execution context.
""" """
self.tables = tables self.tables = tables
self._table = None self._table: t.Optional[Table] = None
self.range_readers = {name: table.range_reader for name, table in self.tables.items()} self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
self.row_readers = {name: table.reader for name, table in tables.items()} self.row_readers = {name: table.reader for name, table in tables.items()}
self.env = {**(env or {}), "scope": self.row_readers} self.env = {**(env or {}), "scope": self.row_readers}
@ -31,7 +38,7 @@ class Context:
return tuple(self.eval(code) for code in codes) return tuple(self.eval(code) for code in codes)
@property @property
def table(self): def table(self) -> Table:
if self._table is None: if self._table is None:
self._table = list(self.tables.values())[0] self._table = list(self.tables.values())[0]
for other in self.tables.values(): for other in self.tables.values():
@ -41,8 +48,12 @@ class Context:
raise Exception(f"Rows are different.") raise Exception(f"Rows are different.")
return self._table return self._table
def add_columns(self, *columns: str) -> None:
for table in self.tables.values():
table.add_columns(*columns)
@property @property
def columns(self): def columns(self) -> t.Tuple:
return self.table.columns return self.table.columns
def __iter__(self): def __iter__(self):
@ -52,35 +63,39 @@ class Context:
reader = table[i] reader = table[i]
yield reader, self yield reader, self
def table_iter(self, table): def table_iter(self, table: str) -> t.Generator[t.Tuple[TableIter, Context], None, None]:
self.env["scope"] = self.row_readers self.env["scope"] = self.row_readers
for reader in self.tables[table]: for reader in self.tables[table]:
yield reader, self yield reader, self
def sort(self, key): def filter(self, condition) -> None:
table = self.table rows = [reader.row for reader, _ in self if self.eval(condition)]
def sort_key(row): for table in self.tables.values():
table.reader.row = row table.rows = rows
def sort(self, key) -> None:
def sort_key(row: t.Tuple) -> t.Tuple:
self.set_row(row)
return self.eval_tuple(key) return self.eval_tuple(key)
table.rows.sort(key=sort_key) self.table.rows.sort(key=sort_key)
def set_row(self, row): def set_row(self, row: t.Tuple) -> None:
for table in self.tables.values(): for table in self.tables.values():
table.reader.row = row table.reader.row = row
self.env["scope"] = self.row_readers self.env["scope"] = self.row_readers
def set_index(self, index): def set_index(self, index: int) -> None:
for table in self.tables.values(): for table in self.tables.values():
table[index] table[index]
self.env["scope"] = self.row_readers self.env["scope"] = self.row_readers
def set_range(self, start, end): def set_range(self, start: int, end: int) -> None:
for name in self.tables: for name in self.tables:
self.range_readers[name].range = range(start, end) self.range_readers[name].range = range(start, end)
self.env["scope"] = self.range_readers self.env["scope"] = self.range_readers
def __contains__(self, table): def __contains__(self, table: str) -> bool:
return table in self.tables return table in self.tables

View file

@ -1,7 +1,10 @@
import datetime import datetime
import inspect
import re import re
import statistics import statistics
from functools import wraps
from sqlglot import exp
from sqlglot.helper import PYTHON_VERSION from sqlglot.helper import PYTHON_VERSION
@ -16,20 +19,153 @@ class reverse_key:
return other.obj < self.obj return other.obj < self.obj
def filter_nulls(func):
@wraps(func)
def _func(values):
return func(v for v in values if v is not None)
return _func
def null_if_any(*required):
"""
Decorator that makes a function return `None` if any of the `required` arguments are `None`.
This also supports decoration with no arguments, e.g.:
@null_if_any
def foo(a, b): ...
In which case all arguments are required.
"""
f = None
if len(required) == 1 and callable(required[0]):
f = required[0]
required = ()
def decorator(func):
if required:
required_indices = [
i for i, param in enumerate(inspect.signature(func).parameters) if param in required
]
def predicate(*args):
return any(args[i] is None for i in required_indices)
else:
def predicate(*args):
return any(a is None for a in args)
@wraps(func)
def _func(*args):
if predicate(*args):
return None
return func(*args)
return _func
if f:
return decorator(f)
return decorator
@null_if_any("substr", "this")
def str_position(substr, this, position=None):
position = position - 1 if position is not None else position
return this.find(substr, position) + 1
@null_if_any("this")
def substring(this, start=None, length=None):
if start is None:
return this
elif start == 0:
return ""
elif start < 0:
start = len(this) + start
else:
start -= 1
end = None if length is None else start + length
return this[start:end]
@null_if_any
def cast(this, to):
if to == exp.DataType.Type.DATE:
return datetime.date.fromisoformat(this)
if to == exp.DataType.Type.DATETIME:
return datetime.datetime.fromisoformat(this)
if to in exp.DataType.TEXT_TYPES:
return str(this)
if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
return float(this)
if to in exp.DataType.NUMERIC_TYPES:
return int(this)
raise NotImplementedError(f"Casting to '{to}' not implemented.")
def ordered(this, desc, nulls_first):
if desc:
return reverse_key(this)
return this
@null_if_any
def interval(this, unit):
if unit == "DAY":
return datetime.timedelta(days=float(this))
raise NotImplementedError
ENV = { ENV = {
"__builtins__": {}, "__builtins__": {},
"datetime": datetime, "exp": exp,
"locals": locals, # aggs
"re": re, "SUM": filter_nulls(sum),
"bool": bool, "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
"float": float, "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc)),
"int": int, "MAX": filter_nulls(max),
"str": str, "MIN": filter_nulls(min),
"desc": reverse_key, # scalar functions
"SUM": sum, "ABS": null_if_any(lambda this: abs(this)),
"AVG": statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean, # type: ignore "ADD": null_if_any(lambda e, this: e + this),
"COUNT": lambda acc: sum(1 for e in acc if e is not None), "BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
"MAX": max, "BITWISEAND": null_if_any(lambda this, e: this & e),
"MIN": min, "BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
"BITWISEOR": null_if_any(lambda this, e: this | e),
"BITWISERIGHTSHIFT": null_if_any(lambda this, e: this >> e),
"BITWISEXOR": null_if_any(lambda this, e: this ^ e),
"CAST": cast,
"COALESCE": lambda *args: next((a for a in args if a is not None), None),
"CONCAT": null_if_any(lambda *args: "".join(args)),
"CONCATWS": null_if_any(lambda this, *args: this.join(args)),
"DIV": null_if_any(lambda e, this: e / this),
"EQ": null_if_any(lambda this, e: this == e),
"EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
"GT": null_if_any(lambda this, e: this > e),
"GTE": null_if_any(lambda this, e: this >= e),
"IFNULL": lambda e, alt: alt if e is None else e,
"IF": lambda predicate, true, false: true if predicate else false,
"INTDIV": null_if_any(lambda e, this: e // this),
"INTERVAL": interval,
"LIKE": null_if_any(
lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this))
),
"LOWER": null_if_any(lambda arg: arg.lower()),
"LT": null_if_any(lambda this, e: this < e),
"LTE": null_if_any(lambda this, e: this <= e),
"MOD": null_if_any(lambda e, this: e % this),
"MUL": null_if_any(lambda e, this: e * this),
"NEQ": null_if_any(lambda this, e: this != e),
"ORD": null_if_any(ord),
"ORDERED": ordered,
"POW": pow, "POW": pow,
"STRPOSITION": str_position,
"SUB": null_if_any(lambda e, this: e - this),
"SUBSTRING": substring,
"UPPER": null_if_any(lambda arg: arg.upper()),
} }

View file

@ -5,16 +5,18 @@ import math
from sqlglot import exp, generator, planner, tokens from sqlglot import exp, generator, planner, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql from sqlglot.dialects.dialect import Dialect, inline_array_sql
from sqlglot.errors import ExecuteError
from sqlglot.executor.context import Context from sqlglot.executor.context import Context
from sqlglot.executor.env import ENV from sqlglot.executor.env import ENV
from sqlglot.executor.table import Table from sqlglot.executor.table import RowReader, Table
from sqlglot.helper import csv_reader from sqlglot.helper import csv_reader, subclasses
class PythonExecutor: class PythonExecutor:
def __init__(self, env=None): def __init__(self, env=None, tables=None):
self.generator = Python().generator(identify=True) self.generator = Python().generator(identify=True, comments=False)
self.env = {**ENV, **(env or {})} self.env = {**ENV, **(env or {})}
self.tables = tables or {}
def execute(self, plan): def execute(self, plan):
running = set() running = set()
@ -24,6 +26,7 @@ class PythonExecutor:
while queue: while queue:
node = queue.pop() node = queue.pop()
try:
context = self.context( context = self.context(
{ {
name: table name: table
@ -41,6 +44,8 @@ class PythonExecutor:
contexts[node] = self.join(node, context) contexts[node] = self.join(node, context)
elif isinstance(node, planner.Sort): elif isinstance(node, planner.Sort):
contexts[node] = self.sort(node, context) contexts[node] = self.sort(node, context)
elif isinstance(node, planner.SetOperation):
contexts[node] = self.set_operation(node, context)
else: else:
raise NotImplementedError raise NotImplementedError
@ -54,6 +59,8 @@ class PythonExecutor:
for dep in node.dependencies: for dep in node.dependencies:
if all(d in finished for d in dep.dependents): if all(d in finished for d in dep.dependents):
contexts.pop(dep) contexts.pop(dep)
except Exception as e:
raise ExecuteError(f"Step '{node.id}' failed: {e}") from e
root = plan.root root = plan.root
return contexts[root].tables[root.name] return contexts[root].tables[root.name]
@ -76,38 +83,43 @@ class PythonExecutor:
return Context(tables, env=self.env) return Context(tables, env=self.env)
def table(self, expressions): def table(self, expressions):
return Table(expression.alias_or_name for expression in expressions) return Table(
expression.alias_or_name if isinstance(expression, exp.Expression) else expression
for expression in expressions
)
def scan(self, step, context): def scan(self, step, context):
source = step.source source = step.source
if isinstance(source, exp.Expression): if source and isinstance(source, exp.Expression):
source = source.name or source.alias source = source.name or source.alias
condition = self.generate(step.condition) condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections) projections = self.generate_tuple(step.projections)
if source in context: if source is None:
context, table_iter = self.static()
elif source in context:
if not projections and not condition: if not projections and not condition:
return self.context({step.name: context.tables[source]}) return self.context({step.name: context.tables[source]})
table_iter = context.table_iter(source) table_iter = context.table_iter(source)
else: elif isinstance(step.source, exp.Table) and isinstance(step.source.this, exp.ReadCSV):
table_iter = self.scan_csv(step) table_iter = self.scan_csv(step)
context = next(table_iter)
else:
context, table_iter = self.scan_table(step)
if projections: if projections:
sink = self.table(step.projections) sink = self.table(step.projections)
else: else:
sink = None sink = self.table(context.columns)
for reader, ctx in table_iter: for reader in table_iter:
if sink is None: if condition and not context.eval(condition):
sink = Table(reader.columns)
if condition and not ctx.eval(condition):
continue continue
if projections: if projections:
sink.append(ctx.eval_tuple(projections)) sink.append(context.eval_tuple(projections))
else: else:
sink.append(reader.row) sink.append(reader.row)
@ -116,14 +128,23 @@ class PythonExecutor:
return self.context({step.name: sink}) return self.context({step.name: sink})
def static(self):
return self.context({}), [RowReader(())]
def scan_table(self, step):
table = self.tables.find(step.source)
context = self.context({step.source.alias_or_name: table})
return context, iter(table)
def scan_csv(self, step): def scan_csv(self, step):
source = step.source alias = step.source.alias
alias = source.alias source = step.source.this
with csv_reader(source) as reader: with csv_reader(source) as reader:
columns = next(reader) columns = next(reader)
table = Table(columns) table = Table(columns)
context = self.context({alias: table}) context = self.context({alias: table})
yield context
types = [] types = []
for row in reader: for row in reader:
@ -134,7 +155,7 @@ class PythonExecutor:
except (ValueError, SyntaxError): except (ValueError, SyntaxError):
types.append(str) types.append(str)
context.set_row(tuple(t(v) for t, v in zip(types, row))) context.set_row(tuple(t(v) for t, v in zip(types, row)))
yield context.table.reader, context yield context.table.reader
def join(self, step, context): def join(self, step, context):
source = step.name source = step.name
@ -160,16 +181,19 @@ class PythonExecutor:
for name, column_range in column_ranges.items() for name, column_range in column_ranges.items()
} }
) )
condition = self.generate(join["condition"])
if condition:
source_context.filter(condition)
condition = self.generate(step.condition) condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections) projections = self.generate_tuple(step.projections)
if not condition or not projections: if not condition and not projections:
return source_context return source_context
sink = self.table(step.projections if projections else source_context.columns) sink = self.table(step.projections if projections else source_context.columns)
for reader, ctx in join_context: for reader, ctx in source_context:
if condition and not ctx.eval(condition): if condition and not ctx.eval(condition):
continue continue
@ -181,7 +205,15 @@ class PythonExecutor:
if len(sink) >= step.limit: if len(sink) >= step.limit:
break break
if projections:
return self.context({step.name: sink}) return self.context({step.name: sink})
else:
return self.context(
{
name: Table(table.columns, sink.rows, table.column_range)
for name, table in source_context.tables.items()
}
)
def nested_loop_join(self, _join, source_context, join_context): def nested_loop_join(self, _join, source_context, join_context):
table = Table(source_context.columns + join_context.columns) table = Table(source_context.columns + join_context.columns)
@ -195,6 +227,8 @@ class PythonExecutor:
def hash_join(self, join, source_context, join_context): def hash_join(self, join, source_context, join_context):
source_key = self.generate_tuple(join["source_key"]) source_key = self.generate_tuple(join["source_key"])
join_key = self.generate_tuple(join["join_key"]) join_key = self.generate_tuple(join["join_key"])
left = join.get("side") == "LEFT"
right = join.get("side") == "RIGHT"
results = collections.defaultdict(lambda: ([], [])) results = collections.defaultdict(lambda: ([], []))
@ -204,28 +238,47 @@ class PythonExecutor:
results[ctx.eval_tuple(join_key)][1].append(reader.row) results[ctx.eval_tuple(join_key)][1].append(reader.row)
table = Table(source_context.columns + join_context.columns) table = Table(source_context.columns + join_context.columns)
nulls = [(None,) * len(join_context.columns if left else source_context.columns)]
for a_group, b_group in results.values(): for a_group, b_group in results.values():
if left:
b_group = b_group or nulls
elif right:
a_group = a_group or nulls
for a_row, b_row in itertools.product(a_group, b_group): for a_row, b_row in itertools.product(a_group, b_group):
table.append(a_row + b_row) table.append(a_row + b_row)
return table return table
def aggregate(self, step, context): def aggregate(self, step, context):
source = step.source group_by = self.generate_tuple(step.group.values())
group_by = self.generate_tuple(step.group)
aggregations = self.generate_tuple(step.aggregations) aggregations = self.generate_tuple(step.aggregations)
operands = self.generate_tuple(step.operands) operands = self.generate_tuple(step.operands)
if operands: if operands:
source_table = context.tables[source] operand_table = Table(self.table(step.operands).columns)
operand_table = Table(source_table.columns + self.table(step.operands).columns)
for reader, ctx in context: for reader, ctx in context:
operand_table.append(reader.row + ctx.eval_tuple(operands)) operand_table.append(ctx.eval_tuple(operands))
for i, (a, b) in enumerate(zip(context.table.rows, operand_table.rows)):
context.table.rows[i] = a + b
width = len(context.columns)
context.add_columns(*operand_table.columns)
operand_table = Table(
context.columns,
context.table.rows,
range(width, width + len(operand_table.columns)),
)
context = self.context( context = self.context(
{None: operand_table, **{table: operand_table for table in context.tables}} {
None: operand_table,
**context.tables,
}
) )
context.sort(group_by) context.sort(group_by)
@ -233,25 +286,22 @@ class PythonExecutor:
group = None group = None
start = 0 start = 0
end = 1 end = 1
length = len(context.tables[source]) length = len(context.table)
table = self.table(step.group + step.aggregations) table = self.table(list(step.group) + step.aggregations)
for i in range(length): for i in range(length):
context.set_index(i) context.set_index(i)
key = context.eval_tuple(group_by) key = context.eval_tuple(group_by)
group = key if group is None else group group = key if group is None else group
end += 1 end += 1
if key != group:
if i == length - 1:
context.set_range(start, end - 1)
elif key != group:
context.set_range(start, end - 2) context.set_range(start, end - 2)
else:
continue
table.append(group + context.eval_tuple(aggregations)) table.append(group + context.eval_tuple(aggregations))
group = key group = key
start = end - 2 start = end - 2
if i == length - 1:
context.set_range(start, end - 1)
table.append(group + context.eval_tuple(aggregations))
context = self.context({step.name: table, **{name: table for name in context.tables}}) context = self.context({step.name: table, **{name: table for name in context.tables}})
@ -262,87 +312,67 @@ class PythonExecutor:
def sort(self, step, context): def sort(self, step, context):
projections = self.generate_tuple(step.projections) projections = self.generate_tuple(step.projections)
sink = self.table(step.projections) projection_columns = [p.alias_or_name for p in step.projections]
all_columns = list(context.columns) + projection_columns
sink = self.table(all_columns)
for reader, ctx in context: for reader, ctx in context:
sink.append(ctx.eval_tuple(projections)) sink.append(reader.row + ctx.eval_tuple(projections))
context = self.context( sort_ctx = self.context(
{ {
None: sink, None: sink,
**{table: sink for table in context.tables}, **{table: sink for table in context.tables},
} }
) )
context.sort(self.generate_tuple(step.key)) sort_ctx.sort(self.generate_tuple(step.key))
if not math.isinf(step.limit): if not math.isinf(step.limit):
context.table.rows = context.table.rows[0 : step.limit] sort_ctx.table.rows = sort_ctx.table.rows[0 : step.limit]
return self.context({step.name: context.table}) output = Table(
projection_columns,
rows=[r[len(context.columns) : len(all_columns)] for r in sort_ctx.table.rows],
)
return self.context({step.name: output})
def set_operation(self, step, context):
left = context.tables[step.left]
right = context.tables[step.right]
def _cast_py(self, expression): sink = self.table(left.columns)
to = expression.args["to"].this
this = self.sql(expression, "this")
if to == exp.DataType.Type.DATE: if issubclass(step.op, exp.Intersect):
return f"datetime.date.fromisoformat({this})" sink.rows = list(set(left.rows).intersection(set(right.rows)))
if to == exp.DataType.Type.TEXT: elif issubclass(step.op, exp.Except):
return f"str({this})" sink.rows = list(set(left.rows).difference(set(right.rows)))
raise NotImplementedError elif issubclass(step.op, exp.Union) and step.distinct:
sink.rows = list(set(left.rows).union(set(right.rows)))
else:
sink.rows = left.rows + right.rows
return self.context({step.name: sink})
def _column_py(self, expression):
table = self.sql(expression, "table") or None
this = self.sql(expression, "this")
return f"scope[{table}][{this}]"
def _interval_py(self, expression):
this = self.sql(expression, "this")
unit = expression.text("unit").upper()
if unit == "DAY":
return f"datetime.timedelta(days=float({this}))"
raise NotImplementedError
def _like_py(self, expression):
this = self.sql(expression, "this")
expression = self.sql(expression, "expression")
return f"""bool(re.match({expression}.replace("_", ".").replace("%", ".*"), {this}))"""
def _ordered_py(self, expression): def _ordered_py(self, expression):
this = self.sql(expression, "this") this = self.sql(expression, "this")
desc = expression.args.get("desc") desc = "True" if expression.args.get("desc") else "False"
return f"desc({this})" if desc else this nulls_first = "True" if expression.args.get("nulls_first") else "False"
return f"ORDERED({this}, {desc}, {nulls_first})"
class Python(Dialect): def _rename(self, e):
class Tokenizer(tokens.Tokenizer): try:
ESCAPES = ["\\"] if "expressions" in e.args:
this = self.sql(e, "this")
this = f"{this}, " if this else ""
return f"{e.key.upper()}({this}{self.expressions(e)})"
return f"{e.key.upper()}({self.format_args(*e.args.values())})"
except Exception as ex:
raise Exception(f"Could not rename {repr(e)}") from ex
class Generator(generator.Generator):
TRANSFORMS = {
exp.Alias: lambda self, e: self.sql(e.this),
exp.Array: inline_array_sql,
exp.And: lambda self, e: self.binary(e, "and"),
exp.Boolean: lambda self, e: "True" if e.this else "False",
exp.Cast: _cast_py,
exp.Column: _column_py,
exp.EQ: lambda self, e: self.binary(e, "=="),
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}",
exp.Interval: _interval_py,
exp.Is: lambda self, e: self.binary(e, "is"),
exp.Like: _like_py,
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",
exp.Or: lambda self, e: self.binary(e, "or"),
exp.Ordered: _ordered_py,
exp.Star: lambda *_: "1",
}
def case_sql(self, expression): def _case_sql(self, expression):
this = self.sql(expression, "this") this = self.sql(expression, "this")
chain = self.sql(expression, "default") or "None" chain = self.sql(expression, "default") or "None"
@ -353,3 +383,30 @@ class Python(Dialect):
chain = f"{true} if {condition} else ({chain})" chain = f"{true} if {condition} else ({chain})"
return chain return chain
class Python(Dialect):
class Tokenizer(tokens.Tokenizer):
ESCAPES = ["\\"]
class Generator(generator.Generator):
TRANSFORMS = {
**{klass: _rename for klass in subclasses(exp.__name__, exp.Binary)},
**{klass: _rename for klass in exp.ALL_FUNCTIONS},
exp.Case: _case_sql,
exp.Alias: lambda self, e: self.sql(e.this),
exp.Array: inline_array_sql,
exp.And: lambda self, e: self.binary(e, "and"),
exp.Between: _rename,
exp.Boolean: lambda self, e: "True" if e.this else "False",
exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}",
exp.Is: lambda self, e: self.binary(e, "is"),
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",
exp.Or: lambda self, e: self.binary(e, "or"),
exp.Ordered: _ordered_py,
exp.Star: lambda *_: "1",
}

View file

@ -1,14 +1,27 @@
from __future__ import annotations
from sqlglot.helper import dict_depth
from sqlglot.schema import AbstractMappingSchema
class Table: class Table:
def __init__(self, columns, rows=None, column_range=None): def __init__(self, columns, rows=None, column_range=None):
self.columns = tuple(columns) self.columns = tuple(columns)
self.column_range = column_range self.column_range = column_range
self.reader = RowReader(self.columns, self.column_range) self.reader = RowReader(self.columns, self.column_range)
self.rows = rows or [] self.rows = rows or []
if rows: if rows:
assert len(rows[0]) == len(self.columns) assert len(rows[0]) == len(self.columns)
self.range_reader = RangeReader(self) self.range_reader = RangeReader(self)
def add_columns(self, *columns: str) -> None:
self.columns += columns
if self.column_range:
self.column_range = range(
self.column_range.start, self.column_range.stop + len(columns)
)
self.reader = RowReader(self.columns, self.column_range)
def append(self, row): def append(self, row):
assert len(row) == len(self.columns) assert len(row) == len(self.columns)
self.rows.append(row) self.rows.append(row)
@ -87,3 +100,31 @@ class RowReader:
def __getitem__(self, column): def __getitem__(self, column):
return self.row[self.columns[column]] return self.row[self.columns[column]]
class Tables(AbstractMappingSchema[Table]):
pass
def ensure_tables(d: dict | None) -> Tables:
return Tables(_ensure_tables(d))
def _ensure_tables(d: dict | None) -> dict:
if not d:
return {}
depth = dict_depth(d)
if depth > 1:
return {k: _ensure_tables(v) for k, v in d.items()}
result = {}
for name, table in d.items():
if isinstance(table, Table):
result[name] = table
else:
columns = tuple(table[0]) if table else ()
rows = [tuple(row[c] for c in columns) for row in table]
result[name] = Table(columns=columns, rows=rows)
return result

View file

@ -641,9 +641,11 @@ class Set(Expression):
class SetItem(Expression): class SetItem(Expression):
arg_types = { arg_types = {
"this": True, "this": False,
"expressions": False,
"kind": False, "kind": False,
"collate": False, # MySQL SET NAMES statement "collate": False, # MySQL SET NAMES statement
"global": False,
} }
@ -787,6 +789,7 @@ class Drop(Expression):
"exists": False, "exists": False,
"temporary": False, "temporary": False,
"materialized": False, "materialized": False,
"cascade": False,
} }
@ -1073,6 +1076,18 @@ class FileFormatProperty(Property):
pass pass
class DistKeyProperty(Property):
pass
class SortKeyProperty(Property):
pass
class DistStyleProperty(Property):
pass
class LocationProperty(Property): class LocationProperty(Property):
pass pass
@ -1130,6 +1145,9 @@ class Properties(Expression):
"LOCATION": LocationProperty, "LOCATION": LocationProperty,
"PARTITIONED_BY": PartitionedByProperty, "PARTITIONED_BY": PartitionedByProperty,
"TABLE_FORMAT": TableFormatProperty, "TABLE_FORMAT": TableFormatProperty,
"DISTKEY": DistKeyProperty,
"DISTSTYLE": DistStyleProperty,
"SORTKEY": SortKeyProperty,
} }
@classmethod @classmethod
@ -1356,7 +1374,7 @@ class Var(Expression):
class Schema(Expression): class Schema(Expression):
arg_types = {"this": False, "expressions": True} arg_types = {"this": False, "expressions": False}
class Select(Subqueryable): class Select(Subqueryable):
@ -1741,7 +1759,7 @@ class Select(Subqueryable):
) )
if join_alias: if join_alias:
join.set("this", alias_(join.args["this"], join_alias, table=True)) join.set("this", alias_(join.this, join_alias, table=True))
return _apply_list_builder( return _apply_list_builder(
join, join,
instance=self, instance=self,
@ -1884,6 +1902,7 @@ class Subquery(DerivedTable, Unionable):
arg_types = { arg_types = {
"this": True, "this": True,
"alias": False, "alias": False,
"with": False,
**QUERY_MODIFIERS, **QUERY_MODIFIERS,
} }
@ -2025,6 +2044,31 @@ class DataType(Expression):
NULL = auto() NULL = auto()
UNKNOWN = auto() # Sentinel value, useful for type annotation UNKNOWN = auto() # Sentinel value, useful for type annotation
TEXT_TYPES = {
Type.CHAR,
Type.NCHAR,
Type.VARCHAR,
Type.NVARCHAR,
Type.TEXT,
}
NUMERIC_TYPES = {
Type.INT,
Type.TINYINT,
Type.SMALLINT,
Type.BIGINT,
Type.FLOAT,
Type.DOUBLE,
}
TEMPORAL_TYPES = {
Type.TIMESTAMP,
Type.TIMESTAMPTZ,
Type.TIMESTAMPLTZ,
Type.DATE,
Type.DATETIME,
}
@classmethod @classmethod
def build(cls, dtype, **kwargs) -> DataType: def build(cls, dtype, **kwargs) -> DataType:
return DataType( return DataType(
@ -2054,16 +2098,25 @@ class Exists(SubqueryPredicate):
pass pass
# Commands to interact with the databases or engines # Commands to interact with the databases or engines. For most of the command
# These expressions don't truly parse the expression and consume # expressions we parse whatever comes after the command's name as a string.
# whatever exists as a string until the end or a semicolon
class Command(Expression): class Command(Expression):
arg_types = {"this": True, "expression": False} arg_types = {"this": True, "expression": False}
# Binary Expressions class Transaction(Command):
# (ADD a b) arg_types = {"this": False, "modes": False}
# (FROM table selects)
class Commit(Command):
arg_types = {} # type: ignore
class Rollback(Command):
arg_types = {"savepoint": False}
# Binary expressions like (ADD a b)
class Binary(Expression): class Binary(Expression):
arg_types = {"this": True, "expression": True} arg_types = {"this": True, "expression": True}
@ -2215,7 +2268,7 @@ class Not(Unary, Condition):
class Paren(Unary, Condition): class Paren(Unary, Condition):
pass arg_types = {"this": True, "with": False}
class Neg(Unary): class Neg(Unary):
@ -2428,6 +2481,10 @@ class Cast(Func):
return self.args["to"] return self.args["to"]
class Collate(Binary):
pass
class TryCast(Cast): class TryCast(Cast):
pass pass
@ -2442,13 +2499,17 @@ class Coalesce(Func):
is_var_len_args = True is_var_len_args = True
class ConcatWs(Func): class Concat(Func):
arg_types = {"expressions": False} arg_types = {"expressions": True}
is_var_len_args = True is_var_len_args = True
class ConcatWs(Concat):
_sql_names = ["CONCAT_WS"]
class Count(AggFunc): class Count(AggFunc):
pass arg_types = {"this": False}
class CurrentDate(Func): class CurrentDate(Func):
@ -2556,10 +2617,18 @@ class Day(Func):
pass pass
class Decode(Func):
arg_types = {"this": True, "charset": True}
class DiToDate(Func): class DiToDate(Func):
pass pass
class Encode(Func):
arg_types = {"this": True, "charset": True}
class Exp(Func): class Exp(Func):
pass pass
@ -2581,6 +2650,10 @@ class GroupConcat(Func):
arg_types = {"this": True, "separator": False} arg_types = {"this": True, "separator": False}
class Hex(Func):
pass
class If(Func): class If(Func):
arg_types = {"this": True, "true": True, "false": False} arg_types = {"this": True, "true": True, "false": False}
@ -2641,7 +2714,7 @@ class Log10(Func):
class Lower(Func): class Lower(Func):
pass _sql_names = ["LOWER", "LCASE"]
class Map(Func): class Map(Func):
@ -2686,6 +2759,12 @@ class ApproxQuantile(Quantile):
arg_types = {"this": True, "quantile": True, "accuracy": False} arg_types = {"this": True, "quantile": True, "accuracy": False}
class ReadCSV(Func):
_sql_names = ["READ_CSV"]
is_var_len_args = True
arg_types = {"this": True, "expressions": False}
class Reduce(Func): class Reduce(Func):
arg_types = {"this": True, "initial": True, "merge": True, "finish": True} arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
@ -2804,8 +2883,8 @@ class TimeStrToUnix(Func):
class Trim(Func): class Trim(Func):
arg_types = { arg_types = {
"this": True, "this": True,
"position": False,
"expression": False, "expression": False,
"position": False,
"collation": False, "collation": False,
} }
@ -2826,6 +2905,10 @@ class TsOrDiToDi(Func):
pass pass
class Unhex(Func):
pass
class UnixToStr(Func): class UnixToStr(Func):
arg_types = {"this": True, "format": False} arg_types = {"this": True, "format": False}
@ -2843,7 +2926,7 @@ class UnixToTimeStr(Func):
class Upper(Func): class Upper(Func):
pass _sql_names = ["UPPER", "UCASE"]
class Variance(AggFunc): class Variance(AggFunc):
@ -3701,6 +3784,19 @@ def replace_placeholders(expression, *args, **kwargs):
return expression.transform(_replace_placeholders, iter(args), **kwargs) return expression.transform(_replace_placeholders, iter(args), **kwargs)
def true():
return Boolean(this=True)
def false():
return Boolean(this=False)
def null():
return Null()
# TODO: deprecate this
TRUE = Boolean(this=True) TRUE = Boolean(this=True)
FALSE = Boolean(this=False) FALSE = Boolean(this=False)
NULL = Null() NULL = Null()

View file

@ -67,7 +67,7 @@ class Generator:
exp.LocationProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: self.naked_property(e), exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e), exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
exp.VolatilityProperty: lambda self, e: self.sql(e.name), exp.VolatilityProperty: lambda self, e: e.name,
} }
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed # Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
@ -94,6 +94,9 @@ class Generator:
ROOT_PROPERTIES = { ROOT_PROPERTIES = {
exp.ReturnsProperty, exp.ReturnsProperty,
exp.LanguageProperty, exp.LanguageProperty,
exp.DistStyleProperty,
exp.DistKeyProperty,
exp.SortKeyProperty,
} }
WITH_PROPERTIES = { WITH_PROPERTIES = {
@ -241,7 +244,7 @@ class Generator:
if not NEWLINE_RE.search(comment): if not NEWLINE_RE.search(comment):
return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/" return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
return f"/*{comment}*/\n{sql}" return f"/*{comment}*/\n{sql}" if sql else f" /*{comment}*/"
def wrap(self, expression): def wrap(self, expression):
this_sql = self.indent( this_sql = self.indent(
@ -475,7 +478,8 @@ class Generator:
exists_sql = " IF EXISTS " if expression.args.get("exists") else " " exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
temporary = " TEMPORARY" if expression.args.get("temporary") else "" temporary = " TEMPORARY" if expression.args.get("temporary") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else "" materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}" cascade = " CASCADE" if expression.args.get("cascade") else ""
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}"
def except_sql(self, expression): def except_sql(self, expression):
return self.prepend_ctes( return self.prepend_ctes(
@ -915,13 +919,15 @@ class Generator:
def subquery_sql(self, expression): def subquery_sql(self, expression):
alias = self.sql(expression, "alias") alias = self.sql(expression, "alias")
return self.query_modifiers( sql = self.query_modifiers(
expression, expression,
self.wrap(expression), self.wrap(expression),
self.expressions(expression, key="pivots", sep=" "), self.expressions(expression, key="pivots", sep=" "),
f" AS {alias}" if alias else "", f" AS {alias}" if alias else "",
) )
return self.prepend_ctes(expression, sql)
def qualify_sql(self, expression): def qualify_sql(self, expression):
this = self.indent(self.sql(expression, "this")) this = self.indent(self.sql(expression, "this"))
return f"{self.seg('QUALIFY')}{self.sep()}{this}" return f"{self.seg('QUALIFY')}{self.sep()}{this}"
@ -1111,9 +1117,12 @@ class Generator:
def paren_sql(self, expression): def paren_sql(self, expression):
if isinstance(expression.unnest(), exp.Select): if isinstance(expression.unnest(), exp.Select):
return self.wrap(expression) sql = self.wrap(expression)
else:
sql = self.seg(self.indent(self.sql(expression, "this")), sep="") sql = self.seg(self.indent(self.sql(expression, "this")), sep="")
return f"({sql}{self.seg(')', sep='')}" sql = f"({sql}{self.seg(')', sep='')}"
return self.prepend_ctes(expression, sql)
def neg_sql(self, expression): def neg_sql(self, expression):
return f"-{self.sql(expression, 'this')}" return f"-{self.sql(expression, 'this')}"
@ -1173,9 +1182,23 @@ class Generator:
zone = self.sql(expression, "this") zone = self.sql(expression, "this")
return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE" return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE"
def collate_sql(self, expression):
return self.binary(expression, "COLLATE")
def command_sql(self, expression): def command_sql(self, expression):
return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}" return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}"
def transaction_sql(self, *_):
return "BEGIN"
def commit_sql(self, *_):
return "COMMIT"
def rollback_sql(self, expression):
savepoint = expression.args.get("savepoint")
savepoint = f" TO {savepoint}" if savepoint else ""
return f"ROLLBACK{savepoint}"
def distinct_sql(self, expression): def distinct_sql(self, expression):
this = self.expressions(expression, flat=True) this = self.expressions(expression, flat=True)
this = f" {this}" if this else "" this = f" {this}" if this else ""
@ -1193,10 +1216,7 @@ class Generator:
def intdiv_sql(self, expression): def intdiv_sql(self, expression):
return self.sql( return self.sql(
exp.Cast( exp.Cast(
this=exp.Div( this=exp.Div(this=expression.this, expression=expression.expression),
this=expression.args["this"],
expression=expression.args["expression"],
),
to=exp.DataType(this=exp.DataType.Type.INT), to=exp.DataType(this=exp.DataType.Type.INT),
) )
) )

View file

@ -11,7 +11,8 @@ from copy import copy
from enum import Enum from enum import Enum
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from sqlglot.expressions import Expression, Table from sqlglot import exp
from sqlglot.expressions import Expression
T = t.TypeVar("T") T = t.TypeVar("T")
E = t.TypeVar("E", bound=Expression) E = t.TypeVar("E", bound=Expression)
@ -150,7 +151,7 @@ def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
if expression.is_int: if expression.is_int:
expression = expression.copy() expression = expression.copy()
logger.warning("Applying array index offset (%s)", offset) logger.warning("Applying array index offset (%s)", offset)
expression.args["this"] = str(int(expression.args["this"]) + offset) expression.args["this"] = str(int(expression.this) + offset)
return [expression] return [expression]
return expressions return expressions
@ -228,19 +229,18 @@ def open_file(file_name: str) -> t.TextIO:
@contextmanager @contextmanager
def csv_reader(table: Table) -> t.Any: def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
""" """
Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`. Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
Args: Args:
table: a `Table` expression with an anonymous function `READ_CSV` in it. read_csv: a `ReadCSV` function call
Yields: Yields:
A python csv reader. A python csv reader.
""" """
file, *args = table.this.expressions args = read_csv.expressions
file = file.name file = open_file(read_csv.name)
file = open_file(file)
delimiter = "," delimiter = ","
args = iter(arg.name for arg in args) args = iter(arg.name for arg in args)
@ -354,3 +354,34 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any,
yield from flatten(value) yield from flatten(value)
else: else:
yield value yield value
def dict_depth(d: t.Dict) -> int:
"""
Get the nesting depth of a dictionary.
For example:
>>> dict_depth(None)
0
>>> dict_depth({})
1
>>> dict_depth({"a": "b"})
1
>>> dict_depth({"a": {}})
2
>>> dict_depth({"a": {"b": {}}})
3
Args:
d (dict): dictionary
Returns:
int: depth
"""
try:
return 1 + dict_depth(next(iter(d.values())))
except AttributeError:
# d doesn't have attribute "values"
return 0
except StopIteration:
# d.values() returns an empty sequence
return 1

View file

@ -245,23 +245,31 @@ class TypeAnnotator:
def annotate(self, expression): def annotate(self, expression):
if isinstance(expression, self.TRAVERSABLES): if isinstance(expression, self.TRAVERSABLES):
for scope in traverse_scope(expression): for scope in traverse_scope(expression):
subscope_selects = { selects = {}
name: {select.alias_or_name: select for select in source.selects} for name, source in scope.sources.items():
for name, source in scope.sources.items() if not isinstance(source, Scope):
if isinstance(source, Scope) continue
if isinstance(source.expression, exp.Values):
selects[name] = {
alias: column
for alias, column in zip(
source.expression.alias_column_names,
source.expression.expressions[0].expressions,
)
}
else:
selects[name] = {
select.alias_or_name: select for select in source.expression.selects
} }
# First annotate the current scope's column references # First annotate the current scope's column references
for col in scope.columns: for col in scope.columns:
source = scope.sources[col.table] source = scope.sources[col.table]
if isinstance(source, exp.Table): if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col) col.type = self.schema.get_column_type(source, col)
else: else:
col.type = subscope_selects[col.table][col.name].type col.type = selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope # Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression) self._maybe_annotate(scope.expression)
return self._maybe_annotate(expression) # This takes care of non-traversable expressions return self._maybe_annotate(expression) # This takes care of non-traversable expressions
def _maybe_annotate(self, expression): def _maybe_annotate(self, expression):

View file

@ -0,0 +1,48 @@
import itertools
from sqlglot import exp
def canonicalize(expression: exp.Expression) -> exp.Expression:
"""Converts a sql expression into a standard form.
This method relies on annotate_types because many of the
conversions rely on type inference.
Args:
expression: The expression to canonicalize.
"""
exp.replace_children(expression, canonicalize)
expression = add_text_to_concat(expression)
expression = coerce_type(expression)
return expression
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES:
node = exp.Concat(this=node.this, expression=node.expression)
return node
def coerce_type(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Binary):
_coerce_date(node.left, node.right)
elif isinstance(node, exp.Between):
_coerce_date(node.this, node.args["low"])
elif isinstance(node, exp.Extract):
if node.expression.type not in exp.DataType.TEMPORAL_TYPES:
_replace_cast(node.expression, "datetime")
return node
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
for a, b in itertools.permutations([a, b]):
if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE:
_replace_cast(b, "date")
def _replace_cast(node: exp.Expression, to: str) -> None:
data_type = exp.DataType.build(to)
cast = exp.Cast(this=node.copy(), to=data_type)
cast.type = data_type
node.replace(cast)

View file

@ -128,8 +128,8 @@ def join_condition(join):
Tuple of (source key, join key, remaining predicate) Tuple of (source key, join key, remaining predicate)
""" """
name = join.this.alias_or_name name = join.this.alias_or_name
on = join.args.get("on") or exp.TRUE on = (join.args.get("on") or exp.true()).copy()
on = on.copy() on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
source_key = [] source_key = []
join_key = [] join_key = []
@ -141,7 +141,7 @@ def join_condition(join):
# #
# should pull y.b as the join key and x.a as the source key # should pull y.b as the join key and x.a as the source key
if normalized(on): if normalized(on):
for condition in on.flatten() if isinstance(on, exp.And) else [on]: for condition in on.flatten():
if isinstance(condition, exp.EQ): if isinstance(condition, exp.EQ):
left, right = condition.unnest_operands() left, right = condition.unnest_operands()
left_tables = exp.column_table_names(left) left_tables = exp.column_table_names(left)
@ -150,13 +150,12 @@ def join_condition(join):
if name in left_tables and name not in right_tables: if name in left_tables and name not in right_tables:
join_key.append(left) join_key.append(left)
source_key.append(right) source_key.append(right)
condition.replace(exp.TRUE) condition.replace(exp.true())
elif name in right_tables and name not in left_tables: elif name in right_tables and name not in left_tables:
join_key.append(right) join_key.append(right)
source_key.append(left) source_key.append(left)
condition.replace(exp.TRUE) condition.replace(exp.true())
on = simplify(on) on = simplify(on)
remaining_condition = None if on == exp.TRUE else on remaining_condition = None if on == exp.true() else on
return source_key, join_key, remaining_condition return source_key, join_key, remaining_condition

View file

@ -29,7 +29,7 @@ def optimize_joins(expression):
if isinstance(on, exp.Connector): if isinstance(on, exp.Connector):
for predicate in on.flatten(): for predicate in on.flatten():
if name in exp.column_table_names(predicate): if name in exp.column_table_names(predicate):
predicate.replace(exp.TRUE) predicate.replace(exp.true())
join.on(predicate, copy=False) join.on(predicate, copy=False)
expression = reorder_joins(expression) expression = reorder_joins(expression)
@ -70,6 +70,6 @@ def normalize(expression):
def other_table_names(join, exclude): def other_table_names(join, exclude):
return [ return [
name name
for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) for name in (exp.column_table_names(join.args.get("on") or exp.true()))
if name != exclude if name != exclude
] ]

View file

@ -1,4 +1,6 @@
import sqlglot import sqlglot
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
@ -28,6 +30,8 @@ RULES = (
merge_subqueries, merge_subqueries,
eliminate_joins, eliminate_joins,
eliminate_ctes, eliminate_ctes,
annotate_types,
canonicalize,
quote_identities, quote_identities,
) )

View file

@ -64,11 +64,11 @@ def pushdown_cnf(predicates, scope, scope_ref_count):
for predicate in predicates: for predicate in predicates:
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values(): for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
if isinstance(node, exp.Join): if isinstance(node, exp.Join):
predicate.replace(exp.TRUE) predicate.replace(exp.true())
node.on(predicate, copy=False) node.on(predicate, copy=False)
break break
if isinstance(node, exp.Select): if isinstance(node, exp.Select):
predicate.replace(exp.TRUE) predicate.replace(exp.true())
node.where(replace_aliases(node, predicate), copy=False) node.where(replace_aliases(node, predicate), copy=False)

View file

@ -382,9 +382,7 @@ class _Resolver:
raise OptimizeError(str(e)) from e raise OptimizeError(str(e)) from e
if isinstance(source, Scope) and isinstance(source.expression, exp.Values): if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
values_alias = source.expression.parent return source.expression.alias_column_names
if hasattr(values_alias, "alias_column_names"):
return values_alias.alias_column_names
# Otherwise, if referencing another scope, return that scope's named selects # Otherwise, if referencing another scope, return that scope's named selects
return source.expression.named_selects return source.expression.named_selects

View file

@ -1,10 +1,11 @@
import itertools import itertools
from sqlglot import alias, exp from sqlglot import alias, exp
from sqlglot.helper import csv_reader
from sqlglot.optimizer.scope import traverse_scope from sqlglot.optimizer.scope import traverse_scope
def qualify_tables(expression, db=None, catalog=None): def qualify_tables(expression, db=None, catalog=None, schema=None):
""" """
Rewrite sqlglot AST to have fully qualified tables. Rewrite sqlglot AST to have fully qualified tables.
@ -18,6 +19,7 @@ def qualify_tables(expression, db=None, catalog=None):
expression (sqlglot.Expression): expression to qualify expression (sqlglot.Expression): expression to qualify
db (str): Database name db (str): Database name
catalog (str): Catalog name catalog (str): Catalog name
schema: A schema to populate
Returns: Returns:
sqlglot.Expression: qualified expression sqlglot.Expression: qualified expression
""" """
@ -41,7 +43,7 @@ def qualify_tables(expression, db=None, catalog=None):
source.set("catalog", exp.to_identifier(catalog)) source.set("catalog", exp.to_identifier(catalog))
if not source.alias: if not source.alias:
source.replace( source = source.replace(
alias( alias(
source.copy(), source.copy(),
source.this if identifier else f"_q_{next(sequence)}", source.this if identifier else f"_q_{next(sequence)}",
@ -49,4 +51,12 @@ def qualify_tables(expression, db=None, catalog=None):
) )
) )
if schema and isinstance(source.this, exp.ReadCSV):
with csv_reader(source.this) as reader:
header = next(reader)
columns = next(reader)
schema.add_table(
source, {k: type(v).__name__ for k, v in zip(header, columns)}
)
return expression return expression

View file

@ -189,11 +189,11 @@ def absorb_and_eliminate(expression):
# absorb # absorb
if is_complement(b, aa): if is_complement(b, aa):
aa.replace(exp.TRUE if kind == exp.And else exp.FALSE) aa.replace(exp.true() if kind == exp.And else exp.false())
elif is_complement(b, ab): elif is_complement(b, ab):
ab.replace(exp.TRUE if kind == exp.And else exp.FALSE) ab.replace(exp.true() if kind == exp.And else exp.false())
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
a.replace(exp.FALSE if kind == exp.And else exp.TRUE) a.replace(exp.false() if kind == exp.And else exp.true())
elif isinstance(b, kind): elif isinstance(b, kind):
# eliminate # eliminate
rhs = b.unnest_operands() rhs = b.unnest_operands()

View file

@ -169,7 +169,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
select.parent.replace(alias) select.parent.replace(alias)
for key, column, predicate in keys: for key, column, predicate in keys:
predicate.replace(exp.TRUE) predicate.replace(exp.true())
nested = exp.column(key_aliases[key], table_alias) nested = exp.column(key_aliases[key], table_alias)
if key in group_by: if key in group_by:

View file

@ -141,26 +141,29 @@ class Parser(metaclass=_Parser):
ID_VAR_TOKENS = { ID_VAR_TOKENS = {
TokenType.VAR, TokenType.VAR,
TokenType.ALTER,
TokenType.ALWAYS, TokenType.ALWAYS,
TokenType.ANTI, TokenType.ANTI,
TokenType.APPLY, TokenType.APPLY,
TokenType.AUTO_INCREMENT,
TokenType.BEGIN, TokenType.BEGIN,
TokenType.BOTH, TokenType.BOTH,
TokenType.BUCKET, TokenType.BUCKET,
TokenType.CACHE, TokenType.CACHE,
TokenType.CALL, TokenType.CASCADE,
TokenType.COLLATE, TokenType.COLLATE,
TokenType.COMMAND,
TokenType.COMMIT, TokenType.COMMIT,
TokenType.CONSTRAINT, TokenType.CONSTRAINT,
TokenType.CURRENT_TIME,
TokenType.DEFAULT, TokenType.DEFAULT,
TokenType.DELETE, TokenType.DELETE,
TokenType.DESCRIBE, TokenType.DESCRIBE,
TokenType.DETERMINISTIC, TokenType.DETERMINISTIC,
TokenType.DISTKEY,
TokenType.DISTSTYLE,
TokenType.EXECUTE, TokenType.EXECUTE,
TokenType.ENGINE, TokenType.ENGINE,
TokenType.ESCAPE, TokenType.ESCAPE,
TokenType.EXPLAIN,
TokenType.FALSE, TokenType.FALSE,
TokenType.FIRST, TokenType.FIRST,
TokenType.FOLLOWING, TokenType.FOLLOWING,
@ -182,7 +185,6 @@ class Parser(metaclass=_Parser):
TokenType.NATURAL, TokenType.NATURAL,
TokenType.NEXT, TokenType.NEXT,
TokenType.ONLY, TokenType.ONLY,
TokenType.OPTIMIZE,
TokenType.OPTIONS, TokenType.OPTIONS,
TokenType.ORDINALITY, TokenType.ORDINALITY,
TokenType.PARTITIONED_BY, TokenType.PARTITIONED_BY,
@ -199,6 +201,7 @@ class Parser(metaclass=_Parser):
TokenType.SEMI, TokenType.SEMI,
TokenType.SET, TokenType.SET,
TokenType.SHOW, TokenType.SHOW,
TokenType.SORTKEY,
TokenType.STABLE, TokenType.STABLE,
TokenType.STORED, TokenType.STORED,
TokenType.TABLE, TokenType.TABLE,
@ -207,7 +210,6 @@ class Parser(metaclass=_Parser):
TokenType.TRANSIENT, TokenType.TRANSIENT,
TokenType.TOP, TokenType.TOP,
TokenType.TRAILING, TokenType.TRAILING,
TokenType.TRUNCATE,
TokenType.TRUE, TokenType.TRUE,
TokenType.UNBOUNDED, TokenType.UNBOUNDED,
TokenType.UNIQUE, TokenType.UNIQUE,
@ -217,6 +219,7 @@ class Parser(metaclass=_Parser):
TokenType.VOLATILE, TokenType.VOLATILE,
*SUBQUERY_PREDICATES, *SUBQUERY_PREDICATES,
*TYPE_TOKENS, *TYPE_TOKENS,
*NO_PAREN_FUNCTIONS,
} }
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY} TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY}
@ -231,6 +234,7 @@ class Parser(metaclass=_Parser):
TokenType.FILTER, TokenType.FILTER,
TokenType.FIRST, TokenType.FIRST,
TokenType.FORMAT, TokenType.FORMAT,
TokenType.IDENTIFIER,
TokenType.ISNULL, TokenType.ISNULL,
TokenType.OFFSET, TokenType.OFFSET,
TokenType.PRIMARY_KEY, TokenType.PRIMARY_KEY,
@ -242,6 +246,7 @@ class Parser(metaclass=_Parser):
TokenType.RIGHT, TokenType.RIGHT,
TokenType.DATE, TokenType.DATE,
TokenType.DATETIME, TokenType.DATETIME,
TokenType.TABLE,
TokenType.TIMESTAMP, TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ, TokenType.TIMESTAMPTZ,
*TYPE_TOKENS, *TYPE_TOKENS,
@ -277,6 +282,7 @@ class Parser(metaclass=_Parser):
TokenType.DASH: exp.Sub, TokenType.DASH: exp.Sub,
TokenType.PLUS: exp.Add, TokenType.PLUS: exp.Add,
TokenType.MOD: exp.Mod, TokenType.MOD: exp.Mod,
TokenType.COLLATE: exp.Collate,
} }
FACTOR = { FACTOR = {
@ -391,7 +397,10 @@ class Parser(metaclass=_Parser):
TokenType.DELETE: lambda self: self._parse_delete(), TokenType.DELETE: lambda self: self._parse_delete(),
TokenType.CACHE: lambda self: self._parse_cache(), TokenType.CACHE: lambda self: self._parse_cache(),
TokenType.UNCACHE: lambda self: self._parse_uncache(), TokenType.UNCACHE: lambda self: self._parse_uncache(),
TokenType.USE: lambda self: self._parse_use(), TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
TokenType.BEGIN: lambda self: self._parse_transaction(),
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
} }
PRIMARY_PARSERS = { PRIMARY_PARSERS = {
@ -402,7 +411,8 @@ class Parser(metaclass=_Parser):
exp.Literal, this=token.text, is_string=False exp.Literal, this=token.text, is_string=False
), ),
TokenType.STAR: lambda self, _: self.expression( TokenType.STAR: lambda self, _: self.expression(
exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()} exp.Star,
**{"except": self._parse_except(), "replace": self._parse_replace()},
), ),
TokenType.NULL: lambda self, _: self.expression(exp.Null), TokenType.NULL: lambda self, _: self.expression(exp.Null),
TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True), TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
@ -446,6 +456,9 @@ class Parser(metaclass=_Parser):
TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(), TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(),
TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(), TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(),
TokenType.STORED: lambda self: self._parse_stored(), TokenType.STORED: lambda self: self._parse_stored(),
TokenType.DISTKEY: lambda self: self._parse_distkey(),
TokenType.DISTSTYLE: lambda self: self._parse_diststyle(),
TokenType.SORTKEY: lambda self: self._parse_sortkey(),
TokenType.RETURNS: lambda self: self._parse_returns(), TokenType.RETURNS: lambda self: self._parse_returns(),
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty), TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
@ -471,7 +484,9 @@ class Parser(metaclass=_Parser):
} }
CONSTRAINT_PARSERS = { CONSTRAINT_PARSERS = {
TokenType.CHECK: lambda self: self._parse_check(), TokenType.CHECK: lambda self: self.expression(
exp.Check, this=self._parse_wrapped(self._parse_conjunction)
),
TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(), TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(),
TokenType.UNIQUE: lambda self: self._parse_unique(), TokenType.UNIQUE: lambda self: self._parse_unique(),
} }
@ -521,6 +536,8 @@ class Parser(metaclass=_Parser):
TokenType.SCHEMA, TokenType.SCHEMA,
} }
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
STRICT_CAST = True STRICT_CAST = True
__slots__ = ( __slots__ = (
@ -740,6 +757,7 @@ class Parser(metaclass=_Parser):
kind=kind, kind=kind,
temporary=temporary, temporary=temporary,
materialized=materialized, materialized=materialized,
cascade=self._match(TokenType.CASCADE),
) )
def _parse_exists(self, not_=False): def _parse_exists(self, not_=False):
@ -777,7 +795,11 @@ class Parser(metaclass=_Parser):
expression = self._parse_select_or_expression() expression = self._parse_select_or_expression()
elif create_token.token_type == TokenType.INDEX: elif create_token.token_type == TokenType.INDEX:
this = self._parse_index() this = self._parse_index()
elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW, TokenType.SCHEMA): elif create_token.token_type in (
TokenType.TABLE,
TokenType.VIEW,
TokenType.SCHEMA,
):
this = self._parse_table(schema=True) this = self._parse_table(schema=True)
properties = self._parse_properties() properties = self._parse_properties()
if self._match(TokenType.ALIAS): if self._match(TokenType.ALIAS):
@ -834,7 +856,38 @@ class Parser(metaclass=_Parser):
return self.expression( return self.expression(
exp.FileFormatProperty, exp.FileFormatProperty,
this=exp.Literal.string("FORMAT"), this=exp.Literal.string("FORMAT"),
value=exp.Literal.string(self._parse_var().name), value=exp.Literal.string(self._parse_var_or_string().name),
)
def _parse_distkey(self):
self._match_l_paren()
this = exp.Literal.string("DISTKEY")
value = exp.Literal.string(self._parse_var().name)
self._match_r_paren()
return self.expression(
exp.DistKeyProperty,
this=this,
value=value,
)
def _parse_sortkey(self):
self._match_l_paren()
this = exp.Literal.string("SORTKEY")
value = exp.Literal.string(self._parse_var().name)
self._match_r_paren()
return self.expression(
exp.SortKeyProperty,
this=this,
value=value,
)
def _parse_diststyle(self):
this = exp.Literal.string("DISTSTYLE")
value = exp.Literal.string(self._parse_var().name)
return self.expression(
exp.DistStyleProperty,
this=this,
value=value,
) )
def _parse_auto_increment(self): def _parse_auto_increment(self):
@ -842,7 +895,7 @@ class Parser(metaclass=_Parser):
return self.expression( return self.expression(
exp.AutoIncrementProperty, exp.AutoIncrementProperty,
this=exp.Literal.string("AUTO_INCREMENT"), this=exp.Literal.string("AUTO_INCREMENT"),
value=self._parse_var() or self._parse_number(), value=self._parse_number(),
) )
def _parse_schema_comment(self): def _parse_schema_comment(self):
@ -898,13 +951,10 @@ class Parser(metaclass=_Parser):
while True: while True:
if self._match(TokenType.WITH): if self._match(TokenType.WITH):
self._match_l_paren() properties.extend(self._parse_wrapped_csv(self._parse_property))
properties.extend(self._parse_csv(lambda: self._parse_property()))
self._match_r_paren()
elif self._match(TokenType.PROPERTIES): elif self._match(TokenType.PROPERTIES):
self._match_l_paren()
properties.extend( properties.extend(
self._parse_csv( self._parse_wrapped_csv(
lambda: self.expression( lambda: self.expression(
exp.AnonymousProperty, exp.AnonymousProperty,
this=self._parse_string(), this=self._parse_string(),
@ -912,25 +962,24 @@ class Parser(metaclass=_Parser):
) )
) )
) )
self._match_r_paren()
else: else:
identified_property = self._parse_property() identified_property = self._parse_property()
if not identified_property: if not identified_property:
break break
properties.append(identified_property) properties.append(identified_property)
if properties: if properties:
return self.expression(exp.Properties, expressions=properties) return self.expression(exp.Properties, expressions=properties)
return None return None
def _parse_describe(self): def _parse_describe(self):
self._match(TokenType.TABLE) self._match(TokenType.TABLE)
return self.expression(exp.Describe, this=self._parse_id_var()) return self.expression(exp.Describe, this=self._parse_id_var())
def _parse_insert(self): def _parse_insert(self):
overwrite = self._match(TokenType.OVERWRITE) overwrite = self._match(TokenType.OVERWRITE)
local = self._match(TokenType.LOCAL) local = self._match(TokenType.LOCAL)
if self._match_text("DIRECTORY"): if self._match_text_seq("DIRECTORY"):
this = self.expression( this = self.expression(
exp.Directory, exp.Directory,
this=self._parse_var_or_string(), this=self._parse_var_or_string(),
@ -954,27 +1003,27 @@ class Parser(metaclass=_Parser):
if not self._match_pair(TokenType.ROW, TokenType.FORMAT): if not self._match_pair(TokenType.ROW, TokenType.FORMAT):
return None return None
self._match_text("DELIMITED") self._match_text_seq("DELIMITED")
kwargs = {} kwargs = {}
if self._match_text("FIELDS", "TERMINATED", "BY"): if self._match_text_seq("FIELDS", "TERMINATED", "BY"):
kwargs["fields"] = self._parse_string() kwargs["fields"] = self._parse_string()
if self._match_text("ESCAPED", "BY"): if self._match_text_seq("ESCAPED", "BY"):
kwargs["escaped"] = self._parse_string() kwargs["escaped"] = self._parse_string()
if self._match_text("COLLECTION", "ITEMS", "TERMINATED", "BY"): if self._match_text_seq("COLLECTION", "ITEMS", "TERMINATED", "BY"):
kwargs["collection_items"] = self._parse_string() kwargs["collection_items"] = self._parse_string()
if self._match_text("MAP", "KEYS", "TERMINATED", "BY"): if self._match_text_seq("MAP", "KEYS", "TERMINATED", "BY"):
kwargs["map_keys"] = self._parse_string() kwargs["map_keys"] = self._parse_string()
if self._match_text("LINES", "TERMINATED", "BY"): if self._match_text_seq("LINES", "TERMINATED", "BY"):
kwargs["lines"] = self._parse_string() kwargs["lines"] = self._parse_string()
if self._match_text("NULL", "DEFINED", "AS"): if self._match_text_seq("NULL", "DEFINED", "AS"):
kwargs["null"] = self._parse_string() kwargs["null"] = self._parse_string()
return self.expression(exp.RowFormat, **kwargs) return self.expression(exp.RowFormat, **kwargs)
def _parse_load_data(self): def _parse_load_data(self):
local = self._match(TokenType.LOCAL) local = self._match(TokenType.LOCAL)
self._match_text("INPATH") self._match_text_seq("INPATH")
inpath = self._parse_string() inpath = self._parse_string()
overwrite = self._match(TokenType.OVERWRITE) overwrite = self._match(TokenType.OVERWRITE)
self._match_pair(TokenType.INTO, TokenType.TABLE) self._match_pair(TokenType.INTO, TokenType.TABLE)
@ -986,8 +1035,8 @@ class Parser(metaclass=_Parser):
overwrite=overwrite, overwrite=overwrite,
inpath=inpath, inpath=inpath,
partition=self._parse_partition(), partition=self._parse_partition(),
input_format=self._match_text("INPUTFORMAT") and self._parse_string(), input_format=self._match_text_seq("INPUTFORMAT") and self._parse_string(),
serde=self._match_text("SERDE") and self._parse_string(), serde=self._match_text_seq("SERDE") and self._parse_string(),
) )
def _parse_delete(self): def _parse_delete(self):
@ -996,9 +1045,7 @@ class Parser(metaclass=_Parser):
return self.expression( return self.expression(
exp.Delete, exp.Delete,
this=self._parse_table(schema=True), this=self._parse_table(schema=True),
using=self._parse_csv( using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table()),
lambda: self._match(TokenType.USING) and self._parse_table(schema=True)
),
where=self._parse_where(), where=self._parse_where(),
) )
@ -1029,12 +1076,7 @@ class Parser(metaclass=_Parser):
options = [] options = []
if self._match(TokenType.OPTIONS): if self._match(TokenType.OPTIONS):
self._match_l_paren() options = self._parse_wrapped_csv(self._parse_string, sep=TokenType.EQ)
k = self._parse_string()
self._match(TokenType.EQ)
v = self._parse_string()
options = [k, v]
self._match_r_paren()
self._match(TokenType.ALIAS) self._match(TokenType.ALIAS)
return self.expression( return self.expression(
@ -1050,27 +1092,13 @@ class Parser(metaclass=_Parser):
return None return None
def parse_values(): def parse_values():
key = self._parse_var() props = self._parse_csv(self._parse_var_or_string, sep=TokenType.EQ)
value = None return exp.Property(this=seq_get(props, 0), value=seq_get(props, 1))
if self._match(TokenType.EQ): return self.expression(exp.Partition, this=self._parse_wrapped_csv(parse_values))
value = self._parse_string()
return exp.Property(this=key, value=value)
self._match_l_paren()
values = self._parse_csv(parse_values)
self._match_r_paren()
return self.expression(
exp.Partition,
this=values,
)
def _parse_value(self): def _parse_value(self):
self._match_l_paren() expressions = self._parse_wrapped_csv(self._parse_conjunction)
expressions = self._parse_csv(self._parse_conjunction)
self._match_r_paren()
return self.expression(exp.Tuple, expressions=expressions) return self.expression(exp.Tuple, expressions=expressions)
def _parse_select(self, nested=False, table=False): def _parse_select(self, nested=False, table=False):
@ -1124,10 +1152,11 @@ class Parser(metaclass=_Parser):
self._match_r_paren() self._match_r_paren()
this = self._parse_subquery(this) this = self._parse_subquery(this)
elif self._match(TokenType.VALUES): elif self._match(TokenType.VALUES):
this = self.expression(exp.Values, expressions=self._parse_csv(self._parse_value)) this = self.expression(
alias = self._parse_table_alias() exp.Values,
if alias: expressions=self._parse_csv(self._parse_value),
this = self.expression(exp.Subquery, this=this, alias=alias) alias=self._parse_table_alias(),
)
else: else:
this = None this = None
@ -1140,7 +1169,6 @@ class Parser(metaclass=_Parser):
recursive = self._match(TokenType.RECURSIVE) recursive = self._match(TokenType.RECURSIVE)
expressions = [] expressions = []
while True: while True:
expressions.append(self._parse_cte()) expressions.append(self._parse_cte())
@ -1149,11 +1177,7 @@ class Parser(metaclass=_Parser):
else: else:
self._match(TokenType.WITH) self._match(TokenType.WITH)
return self.expression( return self.expression(exp.With, expressions=expressions, recursive=recursive)
exp.With,
expressions=expressions,
recursive=recursive,
)
def _parse_cte(self): def _parse_cte(self):
alias = self._parse_table_alias() alias = self._parse_table_alias()
@ -1163,13 +1187,9 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.ALIAS): if not self._match(TokenType.ALIAS):
self.raise_error("Expected AS in CTE") self.raise_error("Expected AS in CTE")
self._match_l_paren()
expression = self._parse_statement()
self._match_r_paren()
return self.expression( return self.expression(
exp.CTE, exp.CTE,
this=expression, this=self._parse_wrapped(self._parse_statement),
alias=alias, alias=alias,
) )
@ -1223,7 +1243,7 @@ class Parser(metaclass=_Parser):
def _parse_hint(self): def _parse_hint(self):
if self._match(TokenType.HINT): if self._match(TokenType.HINT):
hints = self._parse_csv(self._parse_function) hints = self._parse_csv(self._parse_function)
if not self._match(TokenType.HINT): if not self._match_pair(TokenType.STAR, TokenType.SLASH):
self.raise_error("Expected */ after HINT") self.raise_error("Expected */ after HINT")
return self.expression(exp.Hint, expressions=hints) return self.expression(exp.Hint, expressions=hints)
return None return None
@ -1259,26 +1279,18 @@ class Parser(metaclass=_Parser):
columns = self._parse_csv(self._parse_id_var) columns = self._parse_csv(self._parse_id_var)
elif self._match(TokenType.L_PAREN): elif self._match(TokenType.L_PAREN):
columns = self._parse_csv(self._parse_id_var) columns = self._parse_csv(self._parse_id_var)
self._match(TokenType.R_PAREN) self._match_r_paren()
expression = self.expression( expression = self.expression(
exp.Lateral, exp.Lateral,
this=this, this=this,
view=view, view=view,
outer=outer, outer=outer,
alias=self.expression( alias=self.expression(exp.TableAlias, this=table_alias, columns=columns),
exp.TableAlias,
this=table_alias,
columns=columns,
),
) )
if outer_apply or cross_apply: if outer_apply or cross_apply:
return self.expression( return self.expression(exp.Join, this=expression, side=None if cross_apply else "LEFT")
exp.Join,
this=expression,
side=None if cross_apply else "LEFT",
)
return expression return expression
@ -1387,12 +1399,8 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.UNNEST): if not self._match(TokenType.UNNEST):
return None return None
self._match_l_paren() expressions = self._parse_wrapped_csv(self._parse_column)
expressions = self._parse_csv(self._parse_column)
self._match_r_paren()
ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY)) ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY))
alias = self._parse_table_alias() alias = self._parse_table_alias()
if alias and self.unnest_column_only: if alias and self.unnest_column_only:
@ -1402,10 +1410,7 @@ class Parser(metaclass=_Parser):
alias.set("this", None) alias.set("this", None)
return self.expression( return self.expression(
exp.Unnest, exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias
expressions=expressions,
ordinality=ordinality,
alias=alias,
) )
def _parse_derived_table_values(self): def _parse_derived_table_values(self):
@ -1418,13 +1423,7 @@ class Parser(metaclass=_Parser):
if is_derived: if is_derived:
self._match_r_paren() self._match_r_paren()
alias = self._parse_table_alias() return self.expression(exp.Values, expressions=expressions, alias=self._parse_table_alias())
return self.expression(
exp.Values,
expressions=expressions,
alias=alias,
)
def _parse_table_sample(self): def _parse_table_sample(self):
if not self._match(TokenType.TABLE_SAMPLE): if not self._match(TokenType.TABLE_SAMPLE):
@ -1460,9 +1459,7 @@ class Parser(metaclass=_Parser):
self._match_r_paren() self._match_r_paren()
if self._match(TokenType.SEED): if self._match(TokenType.SEED):
self._match_l_paren() seed = self._parse_wrapped(self._parse_number)
seed = self._parse_number()
self._match_r_paren()
return self.expression( return self.expression(
exp.TableSample, exp.TableSample,
@ -1513,12 +1510,7 @@ class Parser(metaclass=_Parser):
self._match_r_paren() self._match_r_paren()
return self.expression( return self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot)
exp.Pivot,
expressions=expressions,
field=field,
unpivot=unpivot,
)
def _parse_where(self, skip_where_token=False): def _parse_where(self, skip_where_token=False):
if not skip_where_token and not self._match(TokenType.WHERE): if not skip_where_token and not self._match(TokenType.WHERE):
@ -1539,11 +1531,7 @@ class Parser(metaclass=_Parser):
def _parse_grouping_sets(self): def _parse_grouping_sets(self):
if not self._match(TokenType.GROUPING_SETS): if not self._match(TokenType.GROUPING_SETS):
return None return None
return self._parse_wrapped_csv(self._parse_grouping_set)
self._match_l_paren()
grouping_sets = self._parse_csv(self._parse_grouping_set)
self._match_r_paren()
return grouping_sets
def _parse_grouping_set(self): def _parse_grouping_set(self):
if self._match(TokenType.L_PAREN): if self._match(TokenType.L_PAREN):
@ -1573,7 +1561,6 @@ class Parser(metaclass=_Parser):
def _parse_sort(self, token_type, exp_class): def _parse_sort(self, token_type, exp_class):
if not self._match(token_type): if not self._match(token_type):
return None return None
return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered)) return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
def _parse_ordered(self): def _parse_ordered(self):
@ -1602,9 +1589,12 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.TOP if top else TokenType.LIMIT): if self._match(TokenType.TOP if top else TokenType.LIMIT):
limit_paren = self._match(TokenType.L_PAREN) limit_paren = self._match(TokenType.L_PAREN)
limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number()) limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number())
if limit_paren: if limit_paren:
self._match(TokenType.R_PAREN) self._match_r_paren()
return limit_exp return limit_exp
if self._match(TokenType.FETCH): if self._match(TokenType.FETCH):
direction = self._match_set((TokenType.FIRST, TokenType.NEXT)) direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
direction = self._prev.text if direction else "FIRST" direction = self._prev.text if direction else "FIRST"
@ -1612,11 +1602,13 @@ class Parser(metaclass=_Parser):
self._match_set((TokenType.ROW, TokenType.ROWS)) self._match_set((TokenType.ROW, TokenType.ROWS))
self._match(TokenType.ONLY) self._match(TokenType.ONLY)
return self.expression(exp.Fetch, direction=direction, count=count) return self.expression(exp.Fetch, direction=direction, count=count)
return this return this
def _parse_offset(self, this=None): def _parse_offset(self, this=None):
if not self._match_set((TokenType.OFFSET, TokenType.COMMA)): if not self._match_set((TokenType.OFFSET, TokenType.COMMA)):
return this return this
count = self._parse_number() count = self._parse_number()
self._match_set((TokenType.ROW, TokenType.ROWS)) self._match_set((TokenType.ROW, TokenType.ROWS))
return self.expression(exp.Offset, this=this, expression=count) return self.expression(exp.Offset, this=this, expression=count)
@ -1678,6 +1670,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.DISTINCT_FROM): if self._match(TokenType.DISTINCT_FROM):
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
return self.expression(klass, this=this, expression=self._parse_expression()) return self.expression(klass, this=this, expression=self._parse_expression())
this = self.expression( this = self.expression(
exp.Is, exp.Is,
this=this, this=this,
@ -1754,11 +1747,7 @@ class Parser(metaclass=_Parser):
def _parse_type(self): def _parse_type(self):
if self._match(TokenType.INTERVAL): if self._match(TokenType.INTERVAL):
return self.expression( return self.expression(exp.Interval, this=self._parse_term(), unit=self._parse_var())
exp.Interval,
this=self._parse_term(),
unit=self._parse_var(),
)
index = self._index index = self._index
type_token = self._parse_types(check_func=True) type_token = self._parse_types(check_func=True)
@ -1824,30 +1813,18 @@ class Parser(metaclass=_Parser):
value = None value = None
if type_token in self.TIMESTAMPS: if type_token in self.TIMESTAMPS:
if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ: if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ:
value = exp.DataType( value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions)
this=exp.DataType.Type.TIMESTAMPTZ,
expressions=expressions,
)
elif ( elif (
self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
): ):
value = exp.DataType( value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions)
this=exp.DataType.Type.TIMESTAMPLTZ,
expressions=expressions,
)
elif self._match(TokenType.WITHOUT_TIME_ZONE): elif self._match(TokenType.WITHOUT_TIME_ZONE):
value = exp.DataType( value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
this=exp.DataType.Type.TIMESTAMP,
expressions=expressions,
)
maybe_func = maybe_func and value is None maybe_func = maybe_func and value is None
if value is None: if value is None:
value = exp.DataType( value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
this=exp.DataType.Type.TIMESTAMP,
expressions=expressions,
)
if maybe_func and check_func: if maybe_func and check_func:
index2 = self._index index2 = self._index
@ -1872,6 +1849,7 @@ class Parser(metaclass=_Parser):
this = self._parse_id_var() this = self._parse_id_var()
self._match(TokenType.COLON) self._match(TokenType.COLON)
data_type = self._parse_types() data_type = self._parse_types()
if not data_type: if not data_type:
return None return None
return self.expression(exp.StructKwarg, this=this, expression=data_type) return self.expression(exp.StructKwarg, this=this, expression=data_type)
@ -1879,7 +1857,6 @@ class Parser(metaclass=_Parser):
def _parse_at_time_zone(self, this): def _parse_at_time_zone(self, this):
if not self._match(TokenType.AT_TIME_ZONE): if not self._match(TokenType.AT_TIME_ZONE):
return this return this
return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary()) return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary())
def _parse_column(self): def _parse_column(self):
@ -1984,16 +1961,14 @@ class Parser(metaclass=_Parser):
else: else:
subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type) subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type)
if subquery_predicate and self._curr.token_type in ( if subquery_predicate and self._curr.token_type in (TokenType.SELECT, TokenType.WITH):
TokenType.SELECT,
TokenType.WITH,
):
this = self.expression(subquery_predicate, this=self._parse_select()) this = self.expression(subquery_predicate, this=self._parse_select())
self._match_r_paren() self._match_r_paren()
return this return this
if functions is None: if functions is None:
functions = self.FUNCTIONS functions = self.FUNCTIONS
function = functions.get(upper) function = functions.get(upper)
args = self._parse_csv(self._parse_lambda) args = self._parse_csv(self._parse_lambda)
@ -2014,6 +1989,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.L_PAREN): if not self._match(TokenType.L_PAREN):
return this return this
expressions = self._parse_csv(self._parse_udf_kwarg) expressions = self._parse_csv(self._parse_udf_kwarg)
self._match_r_paren() self._match_r_paren()
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
@ -2021,25 +1997,19 @@ class Parser(metaclass=_Parser):
def _parse_introducer(self, token): def _parse_introducer(self, token):
literal = self._parse_primary() literal = self._parse_primary()
if literal: if literal:
return self.expression( return self.expression(exp.Introducer, this=token.text, expression=literal)
exp.Introducer,
this=token.text,
expression=literal,
)
return self.expression(exp.Identifier, this=token.text) return self.expression(exp.Identifier, this=token.text)
def _parse_session_parameter(self): def _parse_session_parameter(self):
kind = None kind = None
this = self._parse_id_var() or self._parse_primary() this = self._parse_id_var() or self._parse_primary()
if self._match(TokenType.DOT): if self._match(TokenType.DOT):
kind = this.name kind = this.name
this = self._parse_var() or self._parse_primary() this = self._parse_var() or self._parse_primary()
return self.expression(
exp.SessionParameter, return self.expression(exp.SessionParameter, this=this, kind=kind)
this=this,
kind=kind,
)
def _parse_udf_kwarg(self): def _parse_udf_kwarg(self):
this = self._parse_id_var() this = self._parse_id_var()
@ -2106,7 +2076,10 @@ class Parser(metaclass=_Parser):
return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints) return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints)
def _parse_column_constraint(self): def _parse_column_constraint(self):
this = None this = self._parse_references()
if this:
return this
if self._match(TokenType.CONSTRAINT): if self._match(TokenType.CONSTRAINT):
this = self._parse_id_var() this = self._parse_id_var()
@ -2114,13 +2087,12 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.AUTO_INCREMENT): if self._match(TokenType.AUTO_INCREMENT):
kind = exp.AutoIncrementColumnConstraint() kind = exp.AutoIncrementColumnConstraint()
elif self._match(TokenType.CHECK): elif self._match(TokenType.CHECK):
self._match_l_paren() constraint = self._parse_wrapped(self._parse_conjunction)
kind = self.expression(exp.CheckColumnConstraint, this=self._parse_conjunction()) kind = self.expression(exp.CheckColumnConstraint, this=constraint)
self._match_r_paren()
elif self._match(TokenType.COLLATE): elif self._match(TokenType.COLLATE):
kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var()) kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var())
elif self._match(TokenType.DEFAULT): elif self._match(TokenType.DEFAULT):
kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_field()) kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction())
elif self._match_pair(TokenType.NOT, TokenType.NULL): elif self._match_pair(TokenType.NOT, TokenType.NULL):
kind = exp.NotNullColumnConstraint() kind = exp.NotNullColumnConstraint()
elif self._match(TokenType.SCHEMA_COMMENT): elif self._match(TokenType.SCHEMA_COMMENT):
@ -2137,7 +2109,7 @@ class Parser(metaclass=_Parser):
kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True) kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
self._match_pair(TokenType.ALIAS, TokenType.IDENTITY) self._match_pair(TokenType.ALIAS, TokenType.IDENTITY)
else: else:
return None return this
return self.expression(exp.ColumnConstraint, this=this, kind=kind) return self.expression(exp.ColumnConstraint, this=this, kind=kind)
@ -2159,37 +2131,29 @@ class Parser(metaclass=_Parser):
def _parse_unnamed_constraint(self): def _parse_unnamed_constraint(self):
if not self._match_set(self.CONSTRAINT_PARSERS): if not self._match_set(self.CONSTRAINT_PARSERS):
return None return None
return self.CONSTRAINT_PARSERS[self._prev.token_type](self) return self.CONSTRAINT_PARSERS[self._prev.token_type](self)
def _parse_check(self):
self._match(TokenType.CHECK)
self._match_l_paren()
expression = self._parse_conjunction()
self._match_r_paren()
return self.expression(exp.Check, this=expression)
def _parse_unique(self): def _parse_unique(self):
self._match(TokenType.UNIQUE) return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars())
columns = self._parse_wrapped_id_vars()
return self.expression(exp.Unique, expressions=columns) def _parse_references(self):
if not self._match(TokenType.REFERENCES):
def _parse_foreign_key(self): return None
self._match(TokenType.FOREIGN_KEY) return self.expression(
expressions = self._parse_wrapped_id_vars()
reference = self._match(TokenType.REFERENCES) and self.expression(
exp.Reference, exp.Reference,
this=self._parse_id_var(), this=self._parse_id_var(),
expressions=self._parse_wrapped_id_vars(), expressions=self._parse_wrapped_id_vars(),
) )
def _parse_foreign_key(self):
expressions = self._parse_wrapped_id_vars()
reference = self._parse_references()
options = {} options = {}
while self._match(TokenType.ON): while self._match(TokenType.ON):
if not self._match_set((TokenType.DELETE, TokenType.UPDATE)): if not self._match_set((TokenType.DELETE, TokenType.UPDATE)):
self.raise_error("Expected DELETE or UPDATE") self.raise_error("Expected DELETE or UPDATE")
kind = self._prev.text.lower() kind = self._prev.text.lower()
if self._match(TokenType.NO_ACTION): if self._match(TokenType.NO_ACTION):
@ -2200,6 +2164,7 @@ class Parser(metaclass=_Parser):
else: else:
self._advance() self._advance()
action = self._prev.text.upper() action = self._prev.text.upper()
options[kind] = action options[kind] = action
return self.expression( return self.expression(
@ -2363,20 +2328,14 @@ class Parser(metaclass=_Parser):
def _parse_window(self, this, alias=False): def _parse_window(self, this, alias=False):
if self._match(TokenType.FILTER): if self._match(TokenType.FILTER):
self._match_l_paren() where = self._parse_wrapped(self._parse_where)
this = self.expression(exp.Filter, this=this, expression=self._parse_where()) this = self.expression(exp.Filter, this=this, expression=where)
self._match_r_paren()
# T-SQL allows the OVER (...) syntax after WITHIN GROUP. # T-SQL allows the OVER (...) syntax after WITHIN GROUP.
# https://learn.microsoft.com/en-us/sql/t-sql/functions/percentile-disc-transact-sql?view=sql-server-ver16 # https://learn.microsoft.com/en-us/sql/t-sql/functions/percentile-disc-transact-sql?view=sql-server-ver16
if self._match(TokenType.WITHIN_GROUP): if self._match(TokenType.WITHIN_GROUP):
self._match_l_paren() order = self._parse_wrapped(self._parse_order)
this = self.expression( this = self.expression(exp.WithinGroup, this=this, expression=order)
exp.WithinGroup,
this=this,
expression=self._parse_order(),
)
self._match_r_paren()
# SQL spec defines an optional [ { IGNORE | RESPECT } NULLS ] OVER # SQL spec defines an optional [ { IGNORE | RESPECT } NULLS ] OVER
# Some dialects choose to implement and some do not. # Some dialects choose to implement and some do not.
@ -2404,18 +2363,11 @@ class Parser(metaclass=_Parser):
return this return this
if not self._match(TokenType.L_PAREN): if not self._match(TokenType.L_PAREN):
alias = self._parse_id_var(False) return self.expression(exp.Window, this=this, alias=self._parse_id_var(False))
return self.expression( alias = self._parse_id_var(False)
exp.Window,
this=this,
alias=alias,
)
partition = None partition = None
alias = self._parse_id_var(False)
if self._match(TokenType.PARTITION_BY): if self._match(TokenType.PARTITION_BY):
partition = self._parse_csv(self._parse_conjunction) partition = self._parse_csv(self._parse_conjunction)
@ -2552,17 +2504,13 @@ class Parser(metaclass=_Parser):
def _parse_replace(self): def _parse_replace(self):
if not self._match(TokenType.REPLACE): if not self._match(TokenType.REPLACE):
return None return None
return self._parse_wrapped_csv(lambda: self._parse_alias(self._parse_expression()))
self._match_l_paren() def _parse_csv(self, parse_method, sep=TokenType.COMMA):
columns = self._parse_csv(lambda: self._parse_alias(self._parse_expression()))
self._match_r_paren()
return columns
def _parse_csv(self, parse_method):
parse_result = parse_method() parse_result = parse_method()
items = [parse_result] if parse_result is not None else [] items = [parse_result] if parse_result is not None else []
while self._match(TokenType.COMMA): while self._match(sep):
if parse_result and self._prev_comment is not None: if parse_result and self._prev_comment is not None:
parse_result.comment = self._prev_comment parse_result.comment = self._prev_comment
@ -2583,16 +2531,53 @@ class Parser(metaclass=_Parser):
return this return this
def _parse_wrapped_id_vars(self): def _parse_wrapped_id_vars(self):
return self._parse_wrapped_csv(self._parse_id_var)
def _parse_wrapped_csv(self, parse_method, sep=TokenType.COMMA):
return self._parse_wrapped(lambda: self._parse_csv(parse_method, sep=sep))
def _parse_wrapped(self, parse_method):
self._match_l_paren() self._match_l_paren()
expressions = self._parse_csv(self._parse_id_var) parse_result = parse_method()
self._match_r_paren() self._match_r_paren()
return expressions return parse_result
def _parse_select_or_expression(self): def _parse_select_or_expression(self):
return self._parse_select() or self._parse_expression() return self._parse_select() or self._parse_expression()
def _parse_use(self): def _parse_transaction(self):
return self.expression(exp.Use, this=self._parse_id_var()) this = None
if self._match_texts(self.TRANSACTION_KIND):
this = self._prev.text
self._match_texts({"TRANSACTION", "WORK"})
modes = []
while True:
mode = []
while self._match(TokenType.VAR):
mode.append(self._prev.text)
if mode:
modes.append(" ".join(mode))
if not self._match(TokenType.COMMA):
break
return self.expression(exp.Transaction, this=this, modes=modes)
def _parse_commit_or_rollback(self):
savepoint = None
is_rollback = self._prev.token_type == TokenType.ROLLBACK
self._match_texts({"TRANSACTION", "WORK"})
if self._match_text_seq("TO"):
self._match_text_seq("SAVEPOINT")
savepoint = self._parse_id_var()
if is_rollback:
return self.expression(exp.Rollback, savepoint=savepoint)
return self.expression(exp.Commit)
def _parse_show(self): def _parse_show(self):
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
@ -2675,7 +2660,13 @@ class Parser(metaclass=_Parser):
if expression and self._prev_comment: if expression and self._prev_comment:
expression.comment = self._prev_comment expression.comment = self._prev_comment
def _match_text(self, *texts): def _match_texts(self, texts):
if self._curr and self._curr.text.upper() in texts:
self._advance()
return True
return False
def _match_text_seq(self, *texts):
index = self._index index = self._index
for text in texts: for text in texts:
if self._curr and self._curr.text.upper() == text: if self._curr and self._curr.text.upper() == text:

View file

@ -1,5 +1,8 @@
from __future__ import annotations
import itertools import itertools
import math import math
import typing as t
from sqlglot import alias, exp from sqlglot import alias, exp
from sqlglot.errors import UnsupportedError from sqlglot.errors import UnsupportedError
@ -7,15 +10,15 @@ from sqlglot.optimizer.eliminate_joins import join_condition
class Plan: class Plan:
def __init__(self, expression): def __init__(self, expression: exp.Expression) -> None:
self.expression = expression self.expression = expression.copy()
self.root = Step.from_expression(self.expression) self.root = Step.from_expression(self.expression)
self._dag = {} self._dag: t.Dict[Step, t.Set[Step]] = {}
@property @property
def dag(self): def dag(self) -> t.Dict[Step, t.Set[Step]]:
if not self._dag: if not self._dag:
dag = {} dag: t.Dict[Step, t.Set[Step]] = {}
nodes = {self.root} nodes = {self.root}
while nodes: while nodes:
@ -29,32 +32,64 @@ class Plan:
return self._dag return self._dag
@property @property
def leaves(self): def leaves(self) -> t.Generator[Step, None, None]:
return (node for node, deps in self.dag.items() if not deps) return (node for node, deps in self.dag.items() if not deps)
def __repr__(self) -> str:
return f"Plan\n----\n{repr(self.root)}"
class Step: class Step:
@classmethod @classmethod
def from_expression(cls, expression, ctes=None): def from_expression(
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
) -> Step:
""" """
Build a DAG of Steps from a SQL expression. Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine.
Note: the expression's tables and subqueries must be aliased for this method to work. For
example, given the following expression:
Giving an expression like: SELECT
x.a,
SELECT x.a, SUM(x.b) SUM(x.b)
FROM x FROM x AS x
JOIN y JOIN y AS y
ON x.a = y.a ON x.a = y.a
GROUP BY x.a GROUP BY x.a
Transform it into a DAG of the form: the following DAG is produced (the expression IDs might differ per execution):
Aggregate(x.a, SUM(x.b)) - Aggregate: x (4347984624)
Join(y) Context:
Scan(x) Aggregations:
Scan(y) - SUM(x.b)
Group:
- x.a
Projections:
- x.a
- "x".""
Dependencies:
- Join: x (4347985296)
Context:
y:
On: x.a = y.a
Projections:
Dependencies:
- Scan: x (4347983136)
Context:
Source: x AS x
Projections:
- Scan: y (4343416624)
Context:
Source: y AS y
Projections:
This can then more easily be executed on by an engine. Args:
expression: the expression to build the DAG from.
ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
Returns:
A Step DAG corresponding to `expression`.
""" """
ctes = ctes or {} ctes = ctes or {}
with_ = expression.args.get("with") with_ = expression.args.get("with")
@ -65,11 +100,11 @@ class Step:
for cte in with_.expressions: for cte in with_.expressions:
step = Step.from_expression(cte.this, ctes) step = Step.from_expression(cte.this, ctes)
step.name = cte.alias step.name = cte.alias
ctes[step.name] = step ctes[step.name] = step # type: ignore
from_ = expression.args.get("from") from_ = expression.args.get("from")
if from_: if isinstance(expression, exp.Select) and from_:
from_ = from_.expressions from_ = from_.expressions
if len(from_) > 1: if len(from_) > 1:
raise UnsupportedError( raise UnsupportedError(
@ -77,8 +112,10 @@ class Step:
) )
step = Scan.from_expression(from_[0], ctes) step = Scan.from_expression(from_[0], ctes)
elif isinstance(expression, exp.Union):
step = SetOperation.from_expression(expression, ctes)
else: else:
raise UnsupportedError("Static selects are unsupported.") step = Scan()
joins = expression.args.get("joins") joins = expression.args.get("joins")
@ -115,7 +152,7 @@ class Step:
group = expression.args.get("group") group = expression.args.get("group")
if group: if group or aggregations:
aggregate = Aggregate() aggregate = Aggregate()
aggregate.source = step.name aggregate.source = step.name
aggregate.name = step.name aggregate.name = step.name
@ -123,7 +160,15 @@ class Step:
alias(operand, alias_) for operand, alias_ in operands.items() alias(operand, alias_) for operand, alias_ in operands.items()
) )
aggregate.aggregations = aggregations aggregate.aggregations = aggregations
aggregate.group = group.expressions # give aggregates names and replace projections with references to them
aggregate.group = {
f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
}
for projection in projections:
for i, e in aggregate.group.items():
for child, _, _ in projection.walk():
if child == e:
child.replace(exp.column(i, step.name))
aggregate.add_dependency(step) aggregate.add_dependency(step)
step = aggregate step = aggregate
@ -150,22 +195,22 @@ class Step:
return step return step
def __init__(self): def __init__(self) -> None:
self.name = None self.name: t.Optional[str] = None
self.dependencies = set() self.dependencies: t.Set[Step] = set()
self.dependents = set() self.dependents: t.Set[Step] = set()
self.projections = [] self.projections: t.Sequence[exp.Expression] = []
self.limit = math.inf self.limit: float = math.inf
self.condition = None self.condition: t.Optional[exp.Expression] = None
def add_dependency(self, dependency): def add_dependency(self, dependency: Step) -> None:
self.dependencies.add(dependency) self.dependencies.add(dependency)
dependency.dependents.add(self) dependency.dependents.add(self)
def __repr__(self): def __repr__(self) -> str:
return self.to_s() return self.to_s()
def to_s(self, level=0): def to_s(self, level: int = 0) -> str:
indent = " " * level indent = " " * level
nested = f"{indent} " nested = f"{indent} "
@ -175,7 +220,7 @@ class Step:
context = [f"{nested}Context:"] + context context = [f"{nested}Context:"] + context
lines = [ lines = [
f"{indent}- {self.__class__.__name__}: {self.name}", f"{indent}- {self.id}",
*context, *context,
f"{nested}Projections:", f"{nested}Projections:",
] ]
@ -193,13 +238,25 @@ class Step:
return "\n".join(lines) return "\n".join(lines)
def _to_s(self, _indent): @property
def type_name(self) -> str:
return self.__class__.__name__
@property
def id(self) -> str:
name = self.name
name = f" {name}" if name else ""
return f"{self.type_name}:{name} ({id(self)})"
def _to_s(self, _indent: str) -> t.List[str]:
return [] return []
class Scan(Step): class Scan(Step):
@classmethod @classmethod
def from_expression(cls, expression, ctes=None): def from_expression(
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
) -> Step:
table = expression table = expression
alias_ = expression.alias alias_ = expression.alias
@ -217,26 +274,24 @@ class Scan(Step):
step = Scan() step = Scan()
step.name = alias_ step.name = alias_
step.source = expression step.source = expression
if table.name in ctes: if ctes and table.name in ctes:
step.add_dependency(ctes[table.name]) step.add_dependency(ctes[table.name])
return step return step
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.source = None self.source: t.Optional[exp.Expression] = None
def _to_s(self, indent): def _to_s(self, indent: str) -> t.List[str]:
return [f"{indent}Source: {self.source.sql()}"] return [f"{indent}Source: {self.source.sql() if self.source else '-static-'}"] # type: ignore
class Write(Step):
pass
class Join(Step): class Join(Step):
@classmethod @classmethod
def from_joins(cls, joins, ctes=None): def from_joins(
cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None
) -> Step:
step = Join() step = Join()
for join in joins: for join in joins:
@ -252,28 +307,28 @@ class Join(Step):
return step return step
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.joins = {} self.joins: t.Dict[str, t.Dict[str, t.List[str] | exp.Expression]] = {}
def _to_s(self, indent): def _to_s(self, indent: str) -> t.List[str]:
lines = [] lines = []
for name, join in self.joins.items(): for name, join in self.joins.items():
lines.append(f"{indent}{name}: {join['side']}") lines.append(f"{indent}{name}: {join['side']}")
if join.get("condition"): if join.get("condition"):
lines.append(f"{indent}On: {join['condition'].sql()}") lines.append(f"{indent}On: {join['condition'].sql()}") # type: ignore
return lines return lines
class Aggregate(Step): class Aggregate(Step):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.aggregations = [] self.aggregations: t.List[exp.Expression] = []
self.operands = [] self.operands: t.Tuple[exp.Expression, ...] = ()
self.group = [] self.group: t.Dict[str, exp.Expression] = {}
self.source = None self.source: t.Optional[str] = None
def _to_s(self, indent): def _to_s(self, indent: str) -> t.List[str]:
lines = [f"{indent}Aggregations:"] lines = [f"{indent}Aggregations:"]
for expression in self.aggregations: for expression in self.aggregations:
@ -281,7 +336,7 @@ class Aggregate(Step):
if self.group: if self.group:
lines.append(f"{indent}Group:") lines.append(f"{indent}Group:")
for expression in self.group: for expression in self.group.values():
lines.append(f"{indent} - {expression.sql()}") lines.append(f"{indent} - {expression.sql()}")
if self.operands: if self.operands:
lines.append(f"{indent}Operands:") lines.append(f"{indent}Operands:")
@ -292,14 +347,56 @@ class Aggregate(Step):
class Sort(Step): class Sort(Step):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.key = None self.key = None
def _to_s(self, indent): def _to_s(self, indent: str) -> t.List[str]:
lines = [f"{indent}Key:"] lines = [f"{indent}Key:"]
for expression in self.key: for expression in self.key: # type: ignore
lines.append(f"{indent} - {expression.sql()}") lines.append(f"{indent} - {expression.sql()}")
return lines return lines
class SetOperation(Step):
def __init__(
self,
op: t.Type[exp.Expression],
left: str | None,
right: str | None,
distinct: bool = False,
) -> None:
super().__init__()
self.op = op
self.left = left
self.right = right
self.distinct = distinct
@classmethod
def from_expression(
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
) -> Step:
assert isinstance(expression, exp.Union)
left = Step.from_expression(expression.left, ctes)
right = Step.from_expression(expression.right, ctes)
step = cls(
op=expression.__class__,
left=left.name,
right=right.name,
distinct=expression.args.get("distinct"),
)
step.add_dependency(left)
step.add_dependency(right)
return step
def _to_s(self, indent: str) -> t.List[str]:
lines = []
if self.distinct:
lines.append(f"{indent}Distinct: {self.distinct}")
return lines
@property
def type_name(self) -> str:
return self.op.__name__

View file

@ -5,7 +5,7 @@ import typing as t
from sqlglot import expressions as exp from sqlglot import expressions as exp
from sqlglot.errors import SchemaError from sqlglot.errors import SchemaError
from sqlglot.helper import csv_reader from sqlglot.helper import dict_depth
from sqlglot.trie import in_trie, new_trie from sqlglot.trie import in_trie, new_trie
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
@ -15,6 +15,8 @@ if t.TYPE_CHECKING:
TABLE_ARGS = ("this", "db", "catalog") TABLE_ARGS = ("this", "db", "catalog")
T = t.TypeVar("T")
class Schema(abc.ABC): class Schema(abc.ABC):
"""Abstract base class for database schemas""" """Abstract base class for database schemas"""
@ -57,8 +59,81 @@ class Schema(abc.ABC):
The resulting column type. The resulting column type.
""" """
@property
def supported_table_args(self) -> t.Tuple[str, ...]:
"""
Table arguments this schema support, e.g. `("this", "db", "catalog")`
"""
raise NotImplementedError
class MappingSchema(Schema):
class AbstractMappingSchema(t.Generic[T]):
def __init__(
self,
mapping: dict | None = None,
) -> None:
self.mapping = mapping or {}
self.mapping_trie = self._build_trie(self.mapping)
self._supported_table_args: t.Tuple[str, ...] = tuple()
def _build_trie(self, schema: t.Dict) -> t.Dict:
return new_trie(tuple(reversed(t)) for t in flatten_schema(schema, depth=self._depth()))
def _depth(self) -> int:
return dict_depth(self.mapping)
@property
def supported_table_args(self) -> t.Tuple[str, ...]:
if not self._supported_table_args and self.mapping:
depth = self._depth()
if not depth: # None
self._supported_table_args = tuple()
elif 1 <= depth <= 3:
self._supported_table_args = TABLE_ARGS[:depth]
else:
raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
return self._supported_table_args
def table_parts(self, table: exp.Table) -> t.List[str]:
if isinstance(table.this, exp.ReadCSV):
return [table.this.name]
return [table.text(part) for part in TABLE_ARGS if table.text(part)]
def find(
self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
) -> t.Optional[T]:
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
if value == 0:
if raise_on_missing:
raise SchemaError(f"Cannot find mapping for {table}.")
else:
return None
elif value == 1:
possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
if len(possibilities) == 1:
parts.extend(possibilities[0])
else:
message = ", ".join(".".join(parts) for parts in possibilities)
if raise_on_missing:
raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
return None
return self._nested_get(parts, raise_on_missing=raise_on_missing)
def _nested_get(
self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
) -> t.Optional[t.Any]:
return _nested_get(
d or self.mapping,
*zip(self.supported_table_args, reversed(parts)),
raise_on_missing=raise_on_missing,
)
class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
""" """
Schema based on a nested mapping. Schema based on a nested mapping.
@ -82,17 +157,17 @@ class MappingSchema(Schema):
visible: t.Optional[t.Dict] = None, visible: t.Optional[t.Dict] = None,
dialect: t.Optional[str] = None, dialect: t.Optional[str] = None,
) -> None: ) -> None:
self.schema = schema or {} super().__init__(schema)
self.visible = visible or {} self.visible = visible or {}
self.schema_trie = self._build_trie(self.schema)
self.dialect = dialect self.dialect = dialect
self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {} self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {
self._supported_table_args: t.Tuple[str, ...] = tuple() "STR": exp.DataType.Type.TEXT,
}
@classmethod @classmethod
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
return MappingSchema( return MappingSchema(
schema=mapping_schema.schema, schema=mapping_schema.mapping,
visible=mapping_schema.visible, visible=mapping_schema.visible,
dialect=mapping_schema.dialect, dialect=mapping_schema.dialect,
) )
@ -100,27 +175,13 @@ class MappingSchema(Schema):
def copy(self, **kwargs) -> MappingSchema: def copy(self, **kwargs) -> MappingSchema:
return MappingSchema( return MappingSchema(
**{ # type: ignore **{ # type: ignore
"schema": self.schema.copy(), "schema": self.mapping.copy(),
"visible": self.visible.copy(), "visible": self.visible.copy(),
"dialect": self.dialect, "dialect": self.dialect,
**kwargs, **kwargs,
} }
) )
@property
def supported_table_args(self):
if not self._supported_table_args and self.schema:
depth = _dict_depth(self.schema)
if not depth or depth == 1: # {}
self._supported_table_args = tuple()
elif 2 <= depth <= 4:
self._supported_table_args = TABLE_ARGS[: depth - 1]
else:
raise SchemaError(f"Invalid schema shape. Depth: {depth}")
return self._supported_table_args
def add_table( def add_table(
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
) -> None: ) -> None:
@ -133,17 +194,21 @@ class MappingSchema(Schema):
""" """
table_ = self._ensure_table(table) table_ = self._ensure_table(table)
column_mapping = ensure_column_mapping(column_mapping) column_mapping = ensure_column_mapping(column_mapping)
schema = self.find_schema(table_, raise_on_missing=False) schema = self.find(table_, raise_on_missing=False)
if schema and not column_mapping: if schema and not column_mapping:
return return
_nested_set( _nested_set(
self.schema, self.mapping,
list(reversed(self.table_parts(table_))), list(reversed(self.table_parts(table_))),
column_mapping, column_mapping,
) )
self.schema_trie = self._build_trie(self.schema) self.mapping_trie = self._build_trie(self.mapping)
def _depth(self) -> int:
# The columns themselves are a mapping, but we don't want to include those
return super()._depth() - 1
def _ensure_table(self, table: exp.Table | str) -> exp.Table: def _ensure_table(self, table: exp.Table | str) -> exp.Table:
table_ = exp.to_table(table) table_ = exp.to_table(table)
@ -153,16 +218,9 @@ class MappingSchema(Schema):
return table_ return table_
def table_parts(self, table: exp.Table) -> t.List[str]:
return [table.text(part) for part in TABLE_ARGS if table.text(part)]
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
table_ = self._ensure_table(table) table_ = self._ensure_table(table)
schema = self.find(table_)
if not isinstance(table_.this, exp.Identifier):
return fs_get(table) # type: ignore
schema = self.find_schema(table_)
if schema is None: if schema is None:
raise SchemaError(f"Could not find table schema {table}") raise SchemaError(f"Could not find table schema {table}")
@ -173,36 +231,13 @@ class MappingSchema(Schema):
visible = self._nested_get(self.table_parts(table_), self.visible) visible = self._nested_get(self.table_parts(table_), self.visible)
return [col for col in schema if col in visible] # type: ignore return [col for col in schema if col in visible] # type: ignore
def find_schema(
self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
) -> t.Optional[t.Dict[str, str]]:
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
value, trie = in_trie(self.schema_trie if trie is None else trie, parts)
if value == 0:
if raise_on_missing:
raise SchemaError(f"Cannot find schema for {table}.")
else:
return None
elif value == 1:
possibilities = flatten_schema(trie)
if len(possibilities) == 1:
parts.extend(possibilities[0])
else:
message = ", ".join(".".join(parts) for parts in possibilities)
if raise_on_missing:
raise SchemaError(f"Ambiguous schema for {table}: {message}.")
return None
return self._nested_get(parts, raise_on_missing=raise_on_missing)
def get_column_type( def get_column_type(
self, table: exp.Table | str, column: exp.Column | str self, table: exp.Table | str, column: exp.Column | str
) -> exp.DataType.Type: ) -> exp.DataType.Type:
column_name = column if isinstance(column, str) else column.name column_name = column if isinstance(column, str) else column.name
table_ = exp.to_table(table) table_ = exp.to_table(table)
if table_: if table_:
table_schema = self.find_schema(table_) table_schema = self.find(table_)
schema_type = table_schema.get(column_name).upper() # type: ignore schema_type = table_schema.get(column_name).upper() # type: ignore
return self._convert_type(schema_type) return self._convert_type(schema_type)
raise SchemaError(f"Could not convert table '{table}'") raise SchemaError(f"Could not convert table '{table}'")
@ -228,18 +263,6 @@ class MappingSchema(Schema):
return self._type_mapping_cache[schema_type] return self._type_mapping_cache[schema_type]
def _build_trie(self, schema: t.Dict):
return new_trie(tuple(reversed(t)) for t in flatten_schema(schema))
def _nested_get(
self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
) -> t.Optional[t.Any]:
return _nested_get(
d or self.schema,
*zip(self.supported_table_args, reversed(parts)),
raise_on_missing=raise_on_missing,
)
def ensure_schema(schema: t.Any) -> Schema: def ensure_schema(schema: t.Any) -> Schema:
if isinstance(schema, Schema): if isinstance(schema, Schema):
@ -267,29 +290,20 @@ def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
raise ValueError(f"Invalid mapping provided: {type(mapping)}") raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema(schema: t.Dict, keys: t.Optional[t.List[str]] = None) -> t.List[t.List[str]]: def flatten_schema(
schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
) -> t.List[t.List[str]]:
tables = [] tables = []
keys = keys or [] keys = keys or []
depth = _dict_depth(schema)
for k, v in schema.items(): for k, v in schema.items():
if depth >= 3: if depth >= 2:
tables.extend(flatten_schema(v, keys + [k])) tables.extend(flatten_schema(v, depth - 1, keys + [k]))
elif depth == 2: elif depth == 1:
tables.append(keys + [k]) tables.append(keys + [k])
return tables return tables
def fs_get(table: exp.Table) -> t.List[str]:
name = table.this.name
if name.upper() == "READ_CSV":
with csv_reader(table) as reader:
return next(reader)
raise ValueError(f"Cannot read schema for {table}")
def _nested_get( def _nested_get(
d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
) -> t.Optional[t.Any]: ) -> t.Optional[t.Any]:
@ -310,7 +324,7 @@ def _nested_get(
if d is None: if d is None:
if raise_on_missing: if raise_on_missing:
name = "table" if name == "this" else name name = "table" if name == "this" else name
raise ValueError(f"Unknown {name}") raise ValueError(f"Unknown {name}: {key}")
return None return None
return d return d
@ -350,34 +364,3 @@ def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
subd[keys[-1]] = value subd[keys[-1]] = value
return d return d
def _dict_depth(d: t.Dict) -> int:
"""
Get the nesting depth of a dictionary.
For example:
>>> _dict_depth(None)
0
>>> _dict_depth({})
1
>>> _dict_depth({"a": "b"})
1
>>> _dict_depth({"a": {}})
2
>>> _dict_depth({"a": {"b": {}}})
3
Args:
d (dict): dictionary
Returns:
int: depth
"""
try:
return 1 + _dict_depth(next(iter(d.values())))
except AttributeError:
# d doesn't have attribute "values"
return 0
except StopIteration:
# d.values() returns an empty sequence
return 1

View file

@ -105,12 +105,9 @@ class TokenType(AutoName):
OBJECT = auto() OBJECT = auto()
# keywords # keywords
ADD_FILE = auto()
ALIAS = auto() ALIAS = auto()
ALWAYS = auto() ALWAYS = auto()
ALL = auto() ALL = auto()
ALTER = auto()
ANALYZE = auto()
ANTI = auto() ANTI = auto()
ANY = auto() ANY = auto()
APPLY = auto() APPLY = auto()
@ -124,14 +121,14 @@ class TokenType(AutoName):
BUCKET = auto() BUCKET = auto()
BY_DEFAULT = auto() BY_DEFAULT = auto()
CACHE = auto() CACHE = auto()
CALL = auto() CASCADE = auto()
CASE = auto() CASE = auto()
CHARACTER_SET = auto() CHARACTER_SET = auto()
CHECK = auto() CHECK = auto()
CLUSTER_BY = auto() CLUSTER_BY = auto()
COLLATE = auto() COLLATE = auto()
COMMAND = auto()
COMMENT = auto() COMMENT = auto()
COMMENT_ON = auto()
COMMIT = auto() COMMIT = auto()
CONSTRAINT = auto() CONSTRAINT = auto()
CREATE = auto() CREATE = auto()
@ -149,7 +146,9 @@ class TokenType(AutoName):
DETERMINISTIC = auto() DETERMINISTIC = auto()
DISTINCT = auto() DISTINCT = auto()
DISTINCT_FROM = auto() DISTINCT_FROM = auto()
DISTKEY = auto()
DISTRIBUTE_BY = auto() DISTRIBUTE_BY = auto()
DISTSTYLE = auto()
DIV = auto() DIV = auto()
DROP = auto() DROP = auto()
ELSE = auto() ELSE = auto()
@ -159,7 +158,6 @@ class TokenType(AutoName):
EXCEPT = auto() EXCEPT = auto()
EXECUTE = auto() EXECUTE = auto()
EXISTS = auto() EXISTS = auto()
EXPLAIN = auto()
FALSE = auto() FALSE = auto()
FETCH = auto() FETCH = auto()
FILTER = auto() FILTER = auto()
@ -216,7 +214,6 @@ class TokenType(AutoName):
OFFSET = auto() OFFSET = auto()
ON = auto() ON = auto()
ONLY = auto() ONLY = auto()
OPTIMIZE = auto()
OPTIONS = auto() OPTIONS = auto()
ORDER_BY = auto() ORDER_BY = auto()
ORDERED = auto() ORDERED = auto()
@ -258,6 +255,7 @@ class TokenType(AutoName):
SHOW = auto() SHOW = auto()
SIMILAR_TO = auto() SIMILAR_TO = auto()
SOME = auto() SOME = auto()
SORTKEY = auto()
SORT_BY = auto() SORT_BY = auto()
STABLE = auto() STABLE = auto()
STORED = auto() STORED = auto()
@ -268,9 +266,8 @@ class TokenType(AutoName):
TRANSIENT = auto() TRANSIENT = auto()
TOP = auto() TOP = auto()
THEN = auto() THEN = auto()
TRUE = auto()
TRAILING = auto() TRAILING = auto()
TRUNCATE = auto() TRUE = auto()
UNBOUNDED = auto() UNBOUNDED = auto()
UNCACHE = auto() UNCACHE = auto()
UNION = auto() UNION = auto()
@ -280,7 +277,6 @@ class TokenType(AutoName):
USE = auto() USE = auto()
USING = auto() USING = auto()
VALUES = auto() VALUES = auto()
VACUUM = auto()
VIEW = auto() VIEW = auto()
VOLATILE = auto() VOLATILE = auto()
WHEN = auto() WHEN = auto()
@ -420,7 +416,6 @@ class Tokenizer(metaclass=_Tokenizer):
KEYWORDS = { KEYWORDS = {
"/*+": TokenType.HINT, "/*+": TokenType.HINT,
"*/": TokenType.HINT,
"==": TokenType.EQ, "==": TokenType.EQ,
"::": TokenType.DCOLON, "::": TokenType.DCOLON,
"||": TokenType.DPIPE, "||": TokenType.DPIPE,
@ -435,15 +430,7 @@ class Tokenizer(metaclass=_Tokenizer):
"#>": TokenType.HASH_ARROW, "#>": TokenType.HASH_ARROW,
"#>>": TokenType.DHASH_ARROW, "#>>": TokenType.DHASH_ARROW,
"<->": TokenType.LR_ARROW, "<->": TokenType.LR_ARROW,
"ADD ARCHIVE": TokenType.ADD_FILE,
"ADD ARCHIVES": TokenType.ADD_FILE,
"ADD FILE": TokenType.ADD_FILE,
"ADD FILES": TokenType.ADD_FILE,
"ADD JAR": TokenType.ADD_FILE,
"ADD JARS": TokenType.ADD_FILE,
"ALL": TokenType.ALL, "ALL": TokenType.ALL,
"ALTER": TokenType.ALTER,
"ANALYZE": TokenType.ANALYZE,
"AND": TokenType.AND, "AND": TokenType.AND,
"ANTI": TokenType.ANTI, "ANTI": TokenType.ANTI,
"ANY": TokenType.ANY, "ANY": TokenType.ANY,
@ -455,10 +442,10 @@ class Tokenizer(metaclass=_Tokenizer):
"BETWEEN": TokenType.BETWEEN, "BETWEEN": TokenType.BETWEEN,
"BOTH": TokenType.BOTH, "BOTH": TokenType.BOTH,
"BUCKET": TokenType.BUCKET, "BUCKET": TokenType.BUCKET,
"CALL": TokenType.CALL,
"CACHE": TokenType.CACHE, "CACHE": TokenType.CACHE,
"UNCACHE": TokenType.UNCACHE, "UNCACHE": TokenType.UNCACHE,
"CASE": TokenType.CASE, "CASE": TokenType.CASE,
"CASCADE": TokenType.CASCADE,
"CHARACTER SET": TokenType.CHARACTER_SET, "CHARACTER SET": TokenType.CHARACTER_SET,
"CHECK": TokenType.CHECK, "CHECK": TokenType.CHECK,
"CLUSTER BY": TokenType.CLUSTER_BY, "CLUSTER BY": TokenType.CLUSTER_BY,
@ -479,7 +466,9 @@ class Tokenizer(metaclass=_Tokenizer):
"DETERMINISTIC": TokenType.DETERMINISTIC, "DETERMINISTIC": TokenType.DETERMINISTIC,
"DISTINCT": TokenType.DISTINCT, "DISTINCT": TokenType.DISTINCT,
"DISTINCT FROM": TokenType.DISTINCT_FROM, "DISTINCT FROM": TokenType.DISTINCT_FROM,
"DISTKEY": TokenType.DISTKEY,
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY, "DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
"DISTSTYLE": TokenType.DISTSTYLE,
"DIV": TokenType.DIV, "DIV": TokenType.DIV,
"DROP": TokenType.DROP, "DROP": TokenType.DROP,
"ELSE": TokenType.ELSE, "ELSE": TokenType.ELSE,
@ -489,7 +478,6 @@ class Tokenizer(metaclass=_Tokenizer):
"EXCEPT": TokenType.EXCEPT, "EXCEPT": TokenType.EXCEPT,
"EXECUTE": TokenType.EXECUTE, "EXECUTE": TokenType.EXECUTE,
"EXISTS": TokenType.EXISTS, "EXISTS": TokenType.EXISTS,
"EXPLAIN": TokenType.EXPLAIN,
"FALSE": TokenType.FALSE, "FALSE": TokenType.FALSE,
"FETCH": TokenType.FETCH, "FETCH": TokenType.FETCH,
"FILTER": TokenType.FILTER, "FILTER": TokenType.FILTER,
@ -541,7 +529,6 @@ class Tokenizer(metaclass=_Tokenizer):
"OFFSET": TokenType.OFFSET, "OFFSET": TokenType.OFFSET,
"ON": TokenType.ON, "ON": TokenType.ON,
"ONLY": TokenType.ONLY, "ONLY": TokenType.ONLY,
"OPTIMIZE": TokenType.OPTIMIZE,
"OPTIONS": TokenType.OPTIONS, "OPTIONS": TokenType.OPTIONS,
"OR": TokenType.OR, "OR": TokenType.OR,
"ORDER BY": TokenType.ORDER_BY, "ORDER BY": TokenType.ORDER_BY,
@ -579,6 +566,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SET": TokenType.SET, "SET": TokenType.SET,
"SHOW": TokenType.SHOW, "SHOW": TokenType.SHOW,
"SOME": TokenType.SOME, "SOME": TokenType.SOME,
"SORTKEY": TokenType.SORTKEY,
"SORT BY": TokenType.SORT_BY, "SORT BY": TokenType.SORT_BY,
"STABLE": TokenType.STABLE, "STABLE": TokenType.STABLE,
"STORED": TokenType.STORED, "STORED": TokenType.STORED,
@ -592,7 +580,6 @@ class Tokenizer(metaclass=_Tokenizer):
"THEN": TokenType.THEN, "THEN": TokenType.THEN,
"TRUE": TokenType.TRUE, "TRUE": TokenType.TRUE,
"TRAILING": TokenType.TRAILING, "TRAILING": TokenType.TRAILING,
"TRUNCATE": TokenType.TRUNCATE,
"UNBOUNDED": TokenType.UNBOUNDED, "UNBOUNDED": TokenType.UNBOUNDED,
"UNION": TokenType.UNION, "UNION": TokenType.UNION,
"UNPIVOT": TokenType.UNPIVOT, "UNPIVOT": TokenType.UNPIVOT,
@ -600,7 +587,6 @@ class Tokenizer(metaclass=_Tokenizer):
"UPDATE": TokenType.UPDATE, "UPDATE": TokenType.UPDATE,
"USE": TokenType.USE, "USE": TokenType.USE,
"USING": TokenType.USING, "USING": TokenType.USING,
"VACUUM": TokenType.VACUUM,
"VALUES": TokenType.VALUES, "VALUES": TokenType.VALUES,
"VIEW": TokenType.VIEW, "VIEW": TokenType.VIEW,
"VOLATILE": TokenType.VOLATILE, "VOLATILE": TokenType.VOLATILE,
@ -659,6 +645,14 @@ class Tokenizer(metaclass=_Tokenizer):
"UNIQUE": TokenType.UNIQUE, "UNIQUE": TokenType.UNIQUE,
"STRUCT": TokenType.STRUCT, "STRUCT": TokenType.STRUCT,
"VARIANT": TokenType.VARIANT, "VARIANT": TokenType.VARIANT,
"ALTER": TokenType.COMMAND,
"ANALYZE": TokenType.COMMAND,
"CALL": TokenType.COMMAND,
"EXPLAIN": TokenType.COMMAND,
"OPTIMIZE": TokenType.COMMAND,
"PREPARE": TokenType.COMMAND,
"TRUNCATE": TokenType.COMMAND,
"VACUUM": TokenType.COMMAND,
} }
WHITE_SPACE = { WHITE_SPACE = {
@ -670,20 +664,11 @@ class Tokenizer(metaclass=_Tokenizer):
} }
COMMANDS = { COMMANDS = {
TokenType.ALTER, TokenType.COMMAND,
TokenType.ADD_FILE, TokenType.EXECUTE,
TokenType.ANALYZE, TokenType.FETCH,
TokenType.BEGIN,
TokenType.CALL,
TokenType.COMMENT_ON,
TokenType.COMMIT,
TokenType.EXPLAIN,
TokenType.OPTIMIZE,
TokenType.SET, TokenType.SET,
TokenType.SHOW, TokenType.SHOW,
TokenType.TRUNCATE,
TokenType.VACUUM,
TokenType.ROLLBACK,
} }
# handle numeric literals like in hive (3L = BIGINT) # handle numeric literals like in hive (3L = BIGINT)
@ -885,6 +870,7 @@ class Tokenizer(metaclass=_Tokenizer):
if comment_start_line == self._prev_token_line: if comment_start_line == self._prev_token_line:
if self._prev_token_comment is None: if self._prev_token_comment is None:
self.tokens[-1].comment = self._comment self.tokens[-1].comment = self._comment
self._prev_token_comment = self._comment
self._comment = None self._comment = None

View file

@ -4,6 +4,8 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataframe(DataFrameSQLValidator): class TestDataframe(DataFrameSQLValidator):
maxDiff = None
def test_hash_select_expression(self): def test_hash_select_expression(self):
expression = exp.select("cola").from_("table") expression = exp.select("cola").from_("table")
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression)) self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
@ -16,26 +18,26 @@ class TestDataframe(DataFrameSQLValidator):
def test_cache(self): def test_cache(self):
df = self.df_employee.select("fname").cache() df = self.df_employee.select("fname").cache()
expected_statements = [ expected_statements = [
"DROP VIEW IF EXISTS t11623", "DROP VIEW IF EXISTS t31563",
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", "CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`", "SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
] ]
self.compare_sql(df, expected_statements) self.compare_sql(df, expected_statements)
def test_persist_default(self): def test_persist_default(self):
df = self.df_employee.select("fname").persist() df = self.df_employee.select("fname").persist()
expected_statements = [ expected_statements = [
"DROP VIEW IF EXISTS t11623", "DROP VIEW IF EXISTS t31563",
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", "CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`", "SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
] ]
self.compare_sql(df, expected_statements) self.compare_sql(df, expected_statements)
def test_persist_storagelevel(self): def test_persist_storagelevel(self):
df = self.df_employee.select("fname").persist("DISK_ONLY_2") df = self.df_employee.select("fname").persist("DISK_ONLY_2")
expected_statements = [ expected_statements = [
"DROP VIEW IF EXISTS t11623", "DROP VIEW IF EXISTS t31563",
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", "CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`", "SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
] ]
self.compare_sql(df, expected_statements) self.compare_sql(df, expected_statements)

View file

@ -6,39 +6,41 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataFrameWriter(DataFrameSQLValidator): class TestDataFrameWriter(DataFrameSQLValidator):
maxDiff = None
def test_insertInto_full_path(self): def test_insertInto_full_path(self):
df = self.df_employee.write.insertInto("catalog.db.table_name") df = self.df_employee.write.insertInto("catalog.db.table_name")
expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_insertInto_db_table(self): def test_insertInto_db_table(self):
df = self.df_employee.write.insertInto("db.table_name") df = self.df_employee.write.insertInto("db.table_name")
expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_insertInto_table(self): def test_insertInto_table(self):
df = self.df_employee.write.insertInto("table_name") df = self.df_employee.write.insertInto("table_name")
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_insertInto_overwrite(self): def test_insertInto_overwrite(self):
df = self.df_employee.write.insertInto("table_name", overwrite=True) df = self.df_employee.write.insertInto("table_name", overwrite=True)
expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema()) @mock.patch("sqlglot.schema", MappingSchema())
def test_insertInto_byName(self): def test_insertInto_byName(self):
sqlglot.schema.add_table("table_name", {"employee_id": "INT"}) sqlglot.schema.add_table("table_name", {"employee_id": "INT"})
df = self.df_employee.write.byName.insertInto("table_name") df = self.df_employee.write.byName.insertInto("table_name")
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_insertInto_cache(self): def test_insertInto_cache(self):
df = self.df_employee.cache().write.insertInto("table_name") df = self.df_employee.cache().write.insertInto("table_name")
expected_statements = [ expected_statements = [
"DROP VIEW IF EXISTS t35612", "DROP VIEW IF EXISTS t37164",
"CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", "CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"INSERT INTO table_name SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`", "INSERT INTO table_name SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`",
] ]
self.compare_sql(df, expected_statements) self.compare_sql(df, expected_statements)
@ -48,39 +50,39 @@ class TestDataFrameWriter(DataFrameSQLValidator):
def test_saveAsTable_append(self): def test_saveAsTable_append(self):
df = self.df_employee.write.saveAsTable("table_name", mode="append") df = self.df_employee.write.saveAsTable("table_name", mode="append")
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_saveAsTable_overwrite(self): def test_saveAsTable_overwrite(self):
df = self.df_employee.write.saveAsTable("table_name", mode="overwrite") df = self.df_employee.write.saveAsTable("table_name", mode="overwrite")
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_saveAsTable_error(self): def test_saveAsTable_error(self):
df = self.df_employee.write.saveAsTable("table_name", mode="error") df = self.df_employee.write.saveAsTable("table_name", mode="error")
expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_saveAsTable_ignore(self): def test_saveAsTable_ignore(self):
df = self.df_employee.write.saveAsTable("table_name", mode="ignore") df = self.df_employee.write.saveAsTable("table_name", mode="ignore")
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_mode_standalone(self): def test_mode_standalone(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name") df = self.df_employee.write.mode("ignore").saveAsTable("table_name")
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_mode_override(self): def test_mode_override(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite") df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite")
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_saveAsTable_cache(self): def test_saveAsTable_cache(self):
df = self.df_employee.cache().write.saveAsTable("table_name") df = self.df_employee.cache().write.saveAsTable("table_name")
expected_statements = [ expected_statements = [
"DROP VIEW IF EXISTS t35612", "DROP VIEW IF EXISTS t37164",
"CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", "CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"CREATE TABLE table_name AS SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`", "CREATE TABLE table_name AS SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`",
] ]
self.compare_sql(df, expected_statements) self.compare_sql(df, expected_statements)

View file

@ -11,32 +11,32 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataframeSession(DataFrameSQLValidator): class TestDataframeSession(DataFrameSQLValidator):
def test_cdf_one_row(self): def test_cdf_one_row(self):
df = self.spark.createDataFrame([[1, 2]], ["cola", "colb"]) df = self.spark.createDataFrame([[1, 2]], ["cola", "colb"])
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2)) AS `a2`(`cola`, `colb`)" expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_cdf_multiple_rows(self): def test_cdf_multiple_rows(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]], ["cola", "colb"]) df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]], ["cola", "colb"])
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`cola`, `colb`)" expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_cdf_no_schema(self): def test_cdf_no_schema(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]]) df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)" expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`_1`, `_2`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_cdf_row_mixed_primitives(self): def test_cdf_row_mixed_primitives(self):
df = self.spark.createDataFrame([[1, 10.1, "test", False, None]]) df = self.spark.createDataFrame([[1, 10.1, "test", False, None]])
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM (VALUES (1, 10.1, 'test', FALSE, NULL)) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)" expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM VALUES (1, 10.1, 'test', FALSE, NULL) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_cdf_dict_rows(self): def test_cdf_dict_rows(self):
df = self.spark.createDataFrame([{"cola": 1, "colb": "test"}, {"cola": 2, "colb": "test2"}]) df = self.spark.createDataFrame([{"cola": 1, "colb": "test"}, {"cola": 2, "colb": "test2"}])
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 'test'), (2, 'test2')) AS `a2`(`cola`, `colb`)" expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 'test'), (2, 'test2') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_cdf_str_schema(self): def test_cdf_str_schema(self):
df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING") df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING")
expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)" expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_typed_schema_basic(self): def test_typed_schema_basic(self):
@ -47,7 +47,7 @@ class TestDataframeSession(DataFrameSQLValidator):
] ]
) )
df = self.spark.createDataFrame([[1, "test"]], schema) df = self.spark.createDataFrame([[1, "test"]], schema)
expected = "SELECT CAST(`a2`.`cola` AS int) AS `cola`, CAST(`a2`.`colb` AS string) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)" expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_typed_schema_nested(self): def test_typed_schema_nested(self):
@ -65,7 +65,8 @@ class TestDataframeSession(DataFrameSQLValidator):
] ]
) )
df = self.spark.createDataFrame([[{"sub_cola": 1, "sub_colb": "test"}]], schema) df = self.spark.createDataFrame([[{"sub_cola": 1, "sub_colb": "test"}]], schema)
expected = "SELECT CAST(`a2`.`cola` AS struct<sub_cola:int, sub_colb:string>) AS `cola` FROM (VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`))) AS `a2`(`cola`)" expected = "SELECT CAST(`a2`.`cola` AS STRUCT<`sub_cola`: INT, `sub_colb`: STRING>) AS `cola` FROM VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`)) AS `a2`(`cola`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema()) @mock.patch("sqlglot.schema", MappingSchema())

View file

@ -286,6 +286,10 @@ class TestBigQuery(Validator):
"bigquery": "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))", "bigquery": "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))",
}, },
) )
self.validate_identity("BEGIN A B C D E F")
self.validate_identity("BEGIN TRANSACTION")
self.validate_identity("COMMIT TRANSACTION")
self.validate_identity("ROLLBACK TRANSACTION")
def test_user_defined_functions(self): def test_user_defined_functions(self):
self.validate_identity( self.validate_identity(

View file

@ -69,6 +69,7 @@ class TestDialect(Validator):
write={ write={
"bigquery": "CAST(a AS STRING)", "bigquery": "CAST(a AS STRING)",
"clickhouse": "CAST(a AS TEXT)", "clickhouse": "CAST(a AS TEXT)",
"drill": "CAST(a AS VARCHAR)",
"duckdb": "CAST(a AS TEXT)", "duckdb": "CAST(a AS TEXT)",
"mysql": "CAST(a AS TEXT)", "mysql": "CAST(a AS TEXT)",
"hive": "CAST(a AS STRING)", "hive": "CAST(a AS STRING)",
@ -86,6 +87,7 @@ class TestDialect(Validator):
write={ write={
"bigquery": "CAST(a AS BINARY(4))", "bigquery": "CAST(a AS BINARY(4))",
"clickhouse": "CAST(a AS BINARY(4))", "clickhouse": "CAST(a AS BINARY(4))",
"drill": "CAST(a AS VARBINARY(4))",
"duckdb": "CAST(a AS BINARY(4))", "duckdb": "CAST(a AS BINARY(4))",
"mysql": "CAST(a AS BINARY(4))", "mysql": "CAST(a AS BINARY(4))",
"hive": "CAST(a AS BINARY(4))", "hive": "CAST(a AS BINARY(4))",
@ -146,6 +148,7 @@ class TestDialect(Validator):
"CAST(a AS STRING)", "CAST(a AS STRING)",
write={ write={
"bigquery": "CAST(a AS STRING)", "bigquery": "CAST(a AS STRING)",
"drill": "CAST(a AS VARCHAR)",
"duckdb": "CAST(a AS TEXT)", "duckdb": "CAST(a AS TEXT)",
"mysql": "CAST(a AS TEXT)", "mysql": "CAST(a AS TEXT)",
"hive": "CAST(a AS STRING)", "hive": "CAST(a AS STRING)",
@ -162,6 +165,7 @@ class TestDialect(Validator):
"CAST(a AS VARCHAR)", "CAST(a AS VARCHAR)",
write={ write={
"bigquery": "CAST(a AS STRING)", "bigquery": "CAST(a AS STRING)",
"drill": "CAST(a AS VARCHAR)",
"duckdb": "CAST(a AS TEXT)", "duckdb": "CAST(a AS TEXT)",
"mysql": "CAST(a AS VARCHAR)", "mysql": "CAST(a AS VARCHAR)",
"hive": "CAST(a AS STRING)", "hive": "CAST(a AS STRING)",
@ -178,6 +182,7 @@ class TestDialect(Validator):
"CAST(a AS VARCHAR(3))", "CAST(a AS VARCHAR(3))",
write={ write={
"bigquery": "CAST(a AS STRING(3))", "bigquery": "CAST(a AS STRING(3))",
"drill": "CAST(a AS VARCHAR(3))",
"duckdb": "CAST(a AS TEXT(3))", "duckdb": "CAST(a AS TEXT(3))",
"mysql": "CAST(a AS VARCHAR(3))", "mysql": "CAST(a AS VARCHAR(3))",
"hive": "CAST(a AS VARCHAR(3))", "hive": "CAST(a AS VARCHAR(3))",
@ -194,6 +199,7 @@ class TestDialect(Validator):
"CAST(a AS SMALLINT)", "CAST(a AS SMALLINT)",
write={ write={
"bigquery": "CAST(a AS INT64)", "bigquery": "CAST(a AS INT64)",
"drill": "CAST(a AS INTEGER)",
"duckdb": "CAST(a AS SMALLINT)", "duckdb": "CAST(a AS SMALLINT)",
"mysql": "CAST(a AS SMALLINT)", "mysql": "CAST(a AS SMALLINT)",
"hive": "CAST(a AS SMALLINT)", "hive": "CAST(a AS SMALLINT)",
@ -215,6 +221,7 @@ class TestDialect(Validator):
}, },
write={ write={
"duckdb": "TRY_CAST(a AS DOUBLE)", "duckdb": "TRY_CAST(a AS DOUBLE)",
"drill": "CAST(a AS DOUBLE)",
"postgres": "CAST(a AS DOUBLE PRECISION)", "postgres": "CAST(a AS DOUBLE PRECISION)",
"redshift": "CAST(a AS DOUBLE PRECISION)", "redshift": "CAST(a AS DOUBLE PRECISION)",
}, },
@ -225,6 +232,7 @@ class TestDialect(Validator):
write={ write={
"bigquery": "CAST(a AS FLOAT64)", "bigquery": "CAST(a AS FLOAT64)",
"clickhouse": "CAST(a AS Float64)", "clickhouse": "CAST(a AS Float64)",
"drill": "CAST(a AS DOUBLE)",
"duckdb": "CAST(a AS DOUBLE)", "duckdb": "CAST(a AS DOUBLE)",
"mysql": "CAST(a AS DOUBLE)", "mysql": "CAST(a AS DOUBLE)",
"hive": "CAST(a AS DOUBLE)", "hive": "CAST(a AS DOUBLE)",
@ -279,6 +287,7 @@ class TestDialect(Validator):
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')", "duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)",
"presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')", "presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')",
"drill": "TO_TIMESTAMP(x, 'yyyy-MM-dd''T''HH:mm:ss')",
"redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH:MI:SS')", "redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH:MI:SS')",
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')", "spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')",
}, },
@ -286,6 +295,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"STR_TO_TIME('2020-01-01', '%Y-%m-%d')", "STR_TO_TIME('2020-01-01', '%Y-%m-%d')",
write={ write={
"drill": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')",
"duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')", "duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')",
"hive": "CAST('2020-01-01' AS TIMESTAMP)", "hive": "CAST('2020-01-01' AS TIMESTAMP)",
"oracle": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')", "oracle": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
@ -298,6 +308,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"STR_TO_TIME(x, '%y')", "STR_TO_TIME(x, '%y')",
write={ write={
"drill": "TO_TIMESTAMP(x, 'yy')",
"duckdb": "STRPTIME(x, '%y')", "duckdb": "STRPTIME(x, '%y')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)",
"presto": "DATE_PARSE(x, '%y')", "presto": "DATE_PARSE(x, '%y')",
@ -319,6 +330,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"TIME_STR_TO_DATE('2020-01-01')", "TIME_STR_TO_DATE('2020-01-01')",
write={ write={
"drill": "CAST('2020-01-01' AS DATE)",
"duckdb": "CAST('2020-01-01' AS DATE)", "duckdb": "CAST('2020-01-01' AS DATE)",
"hive": "TO_DATE('2020-01-01')", "hive": "TO_DATE('2020-01-01')",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')", "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')",
@ -328,6 +340,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"TIME_STR_TO_TIME('2020-01-01')", "TIME_STR_TO_TIME('2020-01-01')",
write={ write={
"drill": "CAST('2020-01-01' AS TIMESTAMP)",
"duckdb": "CAST('2020-01-01' AS TIMESTAMP)", "duckdb": "CAST('2020-01-01' AS TIMESTAMP)",
"hive": "CAST('2020-01-01' AS TIMESTAMP)", "hive": "CAST('2020-01-01' AS TIMESTAMP)",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')", "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')",
@ -344,6 +357,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"TIME_TO_STR(x, '%Y-%m-%d')", "TIME_TO_STR(x, '%Y-%m-%d')",
write={ write={
"drill": "TO_CHAR(x, 'yyyy-MM-dd')",
"duckdb": "STRFTIME(x, '%Y-%m-%d')", "duckdb": "STRFTIME(x, '%Y-%m-%d')",
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd')", "hive": "DATE_FORMAT(x, 'yyyy-MM-dd')",
"oracle": "TO_CHAR(x, 'YYYY-MM-DD')", "oracle": "TO_CHAR(x, 'YYYY-MM-DD')",
@ -355,6 +369,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"TIME_TO_TIME_STR(x)", "TIME_TO_TIME_STR(x)",
write={ write={
"drill": "CAST(x AS VARCHAR)",
"duckdb": "CAST(x AS TEXT)", "duckdb": "CAST(x AS TEXT)",
"hive": "CAST(x AS STRING)", "hive": "CAST(x AS STRING)",
"presto": "CAST(x AS VARCHAR)", "presto": "CAST(x AS VARCHAR)",
@ -364,6 +379,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"TIME_TO_UNIX(x)", "TIME_TO_UNIX(x)",
write={ write={
"drill": "UNIX_TIMESTAMP(x)",
"duckdb": "EPOCH(x)", "duckdb": "EPOCH(x)",
"hive": "UNIX_TIMESTAMP(x)", "hive": "UNIX_TIMESTAMP(x)",
"presto": "TO_UNIXTIME(x)", "presto": "TO_UNIXTIME(x)",
@ -425,6 +441,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"DATE_TO_DATE_STR(x)", "DATE_TO_DATE_STR(x)",
write={ write={
"drill": "CAST(x AS VARCHAR)",
"duckdb": "CAST(x AS TEXT)", "duckdb": "CAST(x AS TEXT)",
"hive": "CAST(x AS STRING)", "hive": "CAST(x AS STRING)",
"presto": "CAST(x AS VARCHAR)", "presto": "CAST(x AS VARCHAR)",
@ -433,6 +450,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"DATE_TO_DI(x)", "DATE_TO_DI(x)",
write={ write={
"drill": "CAST(TO_DATE(x, 'yyyyMMdd') AS INT)",
"duckdb": "CAST(STRFTIME(x, '%Y%m%d') AS INT)", "duckdb": "CAST(STRFTIME(x, '%Y%m%d') AS INT)",
"hive": "CAST(DATE_FORMAT(x, 'yyyyMMdd') AS INT)", "hive": "CAST(DATE_FORMAT(x, 'yyyyMMdd') AS INT)",
"presto": "CAST(DATE_FORMAT(x, '%Y%m%d') AS INT)", "presto": "CAST(DATE_FORMAT(x, '%Y%m%d') AS INT)",
@ -441,6 +459,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"DI_TO_DATE(x)", "DI_TO_DATE(x)",
write={ write={
"drill": "TO_DATE(CAST(x AS VARCHAR), 'yyyyMMdd')",
"duckdb": "CAST(STRPTIME(CAST(x AS TEXT), '%Y%m%d') AS DATE)", "duckdb": "CAST(STRPTIME(CAST(x AS TEXT), '%Y%m%d') AS DATE)",
"hive": "TO_DATE(CAST(x AS STRING), 'yyyyMMdd')", "hive": "TO_DATE(CAST(x AS STRING), 'yyyyMMdd')",
"presto": "CAST(DATE_PARSE(CAST(x AS VARCHAR), '%Y%m%d') AS DATE)", "presto": "CAST(DATE_PARSE(CAST(x AS VARCHAR), '%Y%m%d') AS DATE)",
@ -463,6 +482,7 @@ class TestDialect(Validator):
}, },
write={ write={
"bigquery": "DATE_ADD(x, INTERVAL 1 'day')", "bigquery": "DATE_ADD(x, INTERVAL 1 'day')",
"drill": "DATE_ADD(x, INTERVAL '1' DAY)",
"duckdb": "x + INTERVAL 1 day", "duckdb": "x + INTERVAL 1 day",
"hive": "DATE_ADD(x, 1)", "hive": "DATE_ADD(x, 1)",
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)", "mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
@ -477,6 +497,7 @@ class TestDialect(Validator):
"DATE_ADD(x, 1)", "DATE_ADD(x, 1)",
write={ write={
"bigquery": "DATE_ADD(x, INTERVAL 1 'day')", "bigquery": "DATE_ADD(x, INTERVAL 1 'day')",
"drill": "DATE_ADD(x, INTERVAL '1' DAY)",
"duckdb": "x + INTERVAL 1 DAY", "duckdb": "x + INTERVAL 1 DAY",
"hive": "DATE_ADD(x, 1)", "hive": "DATE_ADD(x, 1)",
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)", "mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
@ -546,6 +567,7 @@ class TestDialect(Validator):
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", "starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
}, },
write={ write={
"drill": "TO_DATE(x, 'yyyy-MM-dd''T''HH:mm:ss')",
"mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", "mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", "starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)",
@ -556,6 +578,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"STR_TO_DATE(x, '%Y-%m-%d')", "STR_TO_DATE(x, '%Y-%m-%d')",
write={ write={
"drill": "CAST(x AS DATE)",
"mysql": "STR_TO_DATE(x, '%Y-%m-%d')", "mysql": "STR_TO_DATE(x, '%Y-%m-%d')",
"starrocks": "STR_TO_DATE(x, '%Y-%m-%d')", "starrocks": "STR_TO_DATE(x, '%Y-%m-%d')",
"hive": "CAST(x AS DATE)", "hive": "CAST(x AS DATE)",
@ -566,6 +589,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"DATE_STR_TO_DATE(x)", "DATE_STR_TO_DATE(x)",
write={ write={
"drill": "CAST(x AS DATE)",
"duckdb": "CAST(x AS DATE)", "duckdb": "CAST(x AS DATE)",
"hive": "TO_DATE(x)", "hive": "TO_DATE(x)",
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)", "presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)",
@ -575,6 +599,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"TS_OR_DS_ADD('2021-02-01', 1, 'DAY')", "TS_OR_DS_ADD('2021-02-01', 1, 'DAY')",
write={ write={
"drill": "DATE_ADD(CAST('2021-02-01' AS DATE), INTERVAL '1' DAY)",
"duckdb": "CAST('2021-02-01' AS DATE) + INTERVAL 1 DAY", "duckdb": "CAST('2021-02-01' AS DATE) + INTERVAL 1 DAY",
"hive": "DATE_ADD('2021-02-01', 1)", "hive": "DATE_ADD('2021-02-01', 1)",
"presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR('2021-02-01', 1, 10), '%Y-%m-%d'))", "presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR('2021-02-01', 1, 10), '%Y-%m-%d'))",
@ -584,6 +609,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"DATE_ADD(CAST('2020-01-01' AS DATE), 1)", "DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
write={ write={
"drill": "DATE_ADD(CAST('2020-01-01' AS DATE), INTERVAL '1' DAY)",
"duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY", "duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY",
"hive": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)", "hive": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
"presto": "DATE_ADD('day', 1, CAST('2020-01-01' AS DATE))", "presto": "DATE_ADD('day', 1, CAST('2020-01-01' AS DATE))",
@ -593,6 +619,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"TIMESTAMP '2022-01-01'", "TIMESTAMP '2022-01-01'",
write={ write={
"drill": "CAST('2022-01-01' AS TIMESTAMP)",
"mysql": "CAST('2022-01-01' AS TIMESTAMP)", "mysql": "CAST('2022-01-01' AS TIMESTAMP)",
"starrocks": "CAST('2022-01-01' AS DATETIME)", "starrocks": "CAST('2022-01-01' AS DATETIME)",
"hive": "CAST('2022-01-01' AS TIMESTAMP)", "hive": "CAST('2022-01-01' AS TIMESTAMP)",
@ -614,6 +641,7 @@ class TestDialect(Validator):
dialect: f"{unit}(x)" dialect: f"{unit}(x)"
for dialect in ( for dialect in (
"bigquery", "bigquery",
"drill",
"duckdb", "duckdb",
"mysql", "mysql",
"presto", "presto",
@ -624,6 +652,7 @@ class TestDialect(Validator):
dialect: f"{unit}(x)" dialect: f"{unit}(x)"
for dialect in ( for dialect in (
"bigquery", "bigquery",
"drill",
"duckdb", "duckdb",
"mysql", "mysql",
"presto", "presto",
@ -649,6 +678,7 @@ class TestDialect(Validator):
write={ write={
"bigquery": "ARRAY_LENGTH(x)", "bigquery": "ARRAY_LENGTH(x)",
"duckdb": "ARRAY_LENGTH(x)", "duckdb": "ARRAY_LENGTH(x)",
"drill": "REPEATED_COUNT(x)",
"presto": "CARDINALITY(x)", "presto": "CARDINALITY(x)",
"spark": "SIZE(x)", "spark": "SIZE(x)",
}, },
@ -736,6 +766,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"SELECT a FROM x CROSS JOIN UNNEST(y) AS t (a)", "SELECT a FROM x CROSS JOIN UNNEST(y) AS t (a)",
write={ write={
"drill": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)",
"presto": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)", "presto": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)",
"spark": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a", "spark": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a",
}, },
@ -743,6 +774,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t (a, b)", "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t (a, b)",
write={ write={
"drill": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
"presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)", "presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
"spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b", "spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b",
}, },
@ -775,6 +807,7 @@ class TestDialect(Validator):
}, },
write={ write={
"bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b", "bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b",
"drill": "SELECT * FROM a UNION SELECT * FROM b",
"duckdb": "SELECT * FROM a UNION SELECT * FROM b", "duckdb": "SELECT * FROM a UNION SELECT * FROM b",
"presto": "SELECT * FROM a UNION SELECT * FROM b", "presto": "SELECT * FROM a UNION SELECT * FROM b",
"spark": "SELECT * FROM a UNION SELECT * FROM b", "spark": "SELECT * FROM a UNION SELECT * FROM b",
@ -887,6 +920,7 @@ class TestDialect(Validator):
write={ write={
"bigquery": "LOWER(x) LIKE '%y'", "bigquery": "LOWER(x) LIKE '%y'",
"clickhouse": "x ILIKE '%y'", "clickhouse": "x ILIKE '%y'",
"drill": "x `ILIKE` '%y'",
"duckdb": "x ILIKE '%y'", "duckdb": "x ILIKE '%y'",
"hive": "LOWER(x) LIKE '%y'", "hive": "LOWER(x) LIKE '%y'",
"mysql": "LOWER(x) LIKE '%y'", "mysql": "LOWER(x) LIKE '%y'",
@ -910,32 +944,38 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"POSITION(' ' in x)", "POSITION(' ' in x)",
write={ write={
"drill": "STRPOS(x, ' ')",
"duckdb": "STRPOS(x, ' ')", "duckdb": "STRPOS(x, ' ')",
"postgres": "STRPOS(x, ' ')", "postgres": "STRPOS(x, ' ')",
"presto": "STRPOS(x, ' ')", "presto": "STRPOS(x, ' ')",
"spark": "LOCATE(' ', x)", "spark": "LOCATE(' ', x)",
"clickhouse": "position(x, ' ')", "clickhouse": "position(x, ' ')",
"snowflake": "POSITION(' ', x)", "snowflake": "POSITION(' ', x)",
"mysql": "LOCATE(' ', x)",
}, },
) )
self.validate_all( self.validate_all(
"STR_POSITION('a', x)", "STR_POSITION('a', x)",
write={ write={
"drill": "STRPOS(x, 'a')",
"duckdb": "STRPOS(x, 'a')", "duckdb": "STRPOS(x, 'a')",
"postgres": "STRPOS(x, 'a')", "postgres": "STRPOS(x, 'a')",
"presto": "STRPOS(x, 'a')", "presto": "STRPOS(x, 'a')",
"spark": "LOCATE('a', x)", "spark": "LOCATE('a', x)",
"clickhouse": "position(x, 'a')", "clickhouse": "position(x, 'a')",
"snowflake": "POSITION('a', x)", "snowflake": "POSITION('a', x)",
"mysql": "LOCATE('a', x)",
}, },
) )
self.validate_all( self.validate_all(
"POSITION('a', x, 3)", "POSITION('a', x, 3)",
write={ write={
"drill": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
"presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", "presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
"spark": "LOCATE('a', x, 3)", "spark": "LOCATE('a', x, 3)",
"clickhouse": "position(x, 'a', 3)", "clickhouse": "position(x, 'a', 3)",
"snowflake": "POSITION('a', x, 3)", "snowflake": "POSITION('a', x, 3)",
"mysql": "LOCATE('a', x, 3)",
}, },
) )
self.validate_all( self.validate_all(
@ -960,6 +1000,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"IF(x > 1, 1, 0)", "IF(x > 1, 1, 0)",
write={ write={
"drill": "`IF`(x > 1, 1, 0)",
"duckdb": "CASE WHEN x > 1 THEN 1 ELSE 0 END", "duckdb": "CASE WHEN x > 1 THEN 1 ELSE 0 END",
"presto": "IF(x > 1, 1, 0)", "presto": "IF(x > 1, 1, 0)",
"hive": "IF(x > 1, 1, 0)", "hive": "IF(x > 1, 1, 0)",
@ -970,6 +1011,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"CASE WHEN 1 THEN x ELSE 0 END", "CASE WHEN 1 THEN x ELSE 0 END",
write={ write={
"drill": "CASE WHEN 1 THEN x ELSE 0 END",
"duckdb": "CASE WHEN 1 THEN x ELSE 0 END", "duckdb": "CASE WHEN 1 THEN x ELSE 0 END",
"presto": "CASE WHEN 1 THEN x ELSE 0 END", "presto": "CASE WHEN 1 THEN x ELSE 0 END",
"hive": "CASE WHEN 1 THEN x ELSE 0 END", "hive": "CASE WHEN 1 THEN x ELSE 0 END",
@ -980,6 +1022,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"x[y]", "x[y]",
write={ write={
"drill": "x[y]",
"duckdb": "x[y]", "duckdb": "x[y]",
"presto": "x[y]", "presto": "x[y]",
"hive": "x[y]", "hive": "x[y]",
@ -1000,6 +1043,7 @@ class TestDialect(Validator):
'true or null as "foo"', 'true or null as "foo"',
write={ write={
"bigquery": "TRUE OR NULL AS `foo`", "bigquery": "TRUE OR NULL AS `foo`",
"drill": "TRUE OR NULL AS `foo`",
"duckdb": 'TRUE OR NULL AS "foo"', "duckdb": 'TRUE OR NULL AS "foo"',
"presto": 'TRUE OR NULL AS "foo"', "presto": 'TRUE OR NULL AS "foo"',
"hive": "TRUE OR NULL AS `foo`", "hive": "TRUE OR NULL AS `foo`",
@ -1020,6 +1064,7 @@ class TestDialect(Validator):
"LEVENSHTEIN(col1, col2)", "LEVENSHTEIN(col1, col2)",
write={ write={
"duckdb": "LEVENSHTEIN(col1, col2)", "duckdb": "LEVENSHTEIN(col1, col2)",
"drill": "LEVENSHTEIN_DISTANCE(col1, col2)",
"presto": "LEVENSHTEIN_DISTANCE(col1, col2)", "presto": "LEVENSHTEIN_DISTANCE(col1, col2)",
"hive": "LEVENSHTEIN(col1, col2)", "hive": "LEVENSHTEIN(col1, col2)",
"spark": "LEVENSHTEIN(col1, col2)", "spark": "LEVENSHTEIN(col1, col2)",
@ -1029,6 +1074,7 @@ class TestDialect(Validator):
"LEVENSHTEIN(coalesce(col1, col2), coalesce(col2, col1))", "LEVENSHTEIN(coalesce(col1, col2), coalesce(col2, col1))",
write={ write={
"duckdb": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))", "duckdb": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
"drill": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
"presto": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))", "presto": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
"hive": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))", "hive": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
"spark": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))", "spark": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
@ -1152,6 +1198,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"SELECT a AS b FROM x GROUP BY b", "SELECT a AS b FROM x GROUP BY b",
write={ write={
"drill": "SELECT a AS b FROM x GROUP BY b",
"duckdb": "SELECT a AS b FROM x GROUP BY b", "duckdb": "SELECT a AS b FROM x GROUP BY b",
"presto": "SELECT a AS b FROM x GROUP BY 1", "presto": "SELECT a AS b FROM x GROUP BY 1",
"hive": "SELECT a AS b FROM x GROUP BY 1", "hive": "SELECT a AS b FROM x GROUP BY 1",
@ -1162,6 +1209,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"SELECT y x FROM my_table t", "SELECT y x FROM my_table t",
write={ write={
"drill": "SELECT y AS x FROM my_table AS t",
"hive": "SELECT y AS x FROM my_table AS t", "hive": "SELECT y AS x FROM my_table AS t",
"oracle": "SELECT y AS x FROM my_table t", "oracle": "SELECT y AS x FROM my_table t",
"postgres": "SELECT y AS x FROM my_table AS t", "postgres": "SELECT y AS x FROM my_table AS t",
@ -1230,3 +1278,36 @@ SELECT
}, },
pretty=True, pretty=True,
) )
def test_transactions(self):
self.validate_all(
"BEGIN TRANSACTION",
write={
"bigquery": "BEGIN TRANSACTION",
"mysql": "BEGIN",
"postgres": "BEGIN",
"presto": "START TRANSACTION",
"trino": "START TRANSACTION",
"redshift": "BEGIN",
"snowflake": "BEGIN",
"sqlite": "BEGIN TRANSACTION",
},
)
self.validate_all(
"BEGIN",
read={
"presto": "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE",
"trino": "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE",
},
)
self.validate_all(
"BEGIN",
read={
"presto": "START TRANSACTION ISOLATION LEVEL REPEATABLE READ",
"trino": "START TRANSACTION ISOLATION LEVEL REPEATABLE READ",
},
)
self.validate_all(
"BEGIN IMMEDIATE TRANSACTION",
write={"sqlite": "BEGIN IMMEDIATE TRANSACTION"},
)

View file

@ -0,0 +1,53 @@
from tests.dialects.test_dialect import Validator
class TestDrill(Validator):
dialect = "drill"
def test_string_literals(self):
self.validate_all(
"SELECT '2021-01-01' + INTERVAL 1 MONTH",
write={
"mysql": "SELECT '2021-01-01' + INTERVAL 1 MONTH",
},
)
def test_quotes(self):
self.validate_all(
"'\\''",
write={
"duckdb": "''''",
"presto": "''''",
"hive": "'\\''",
"spark": "'\\''",
},
)
self.validate_all(
"'\"x\"'",
write={
"duckdb": "'\"x\"'",
"presto": "'\"x\"'",
"hive": "'\"x\"'",
"spark": "'\"x\"'",
},
)
self.validate_all(
"'\\\\a'",
read={
"presto": "'\\a'",
},
write={
"duckdb": "'\\a'",
"presto": "'\\a'",
"hive": "'\\\\a'",
"spark": "'\\\\a'",
},
)
def test_table_function(self):
self.validate_all(
"SELECT * FROM table( dfs.`test_data.xlsx` (type => 'excel', sheetName => 'secondSheet'))",
write={
"drill": "SELECT * FROM table(dfs.`test_data.xlsx`(type => 'excel', sheetName => 'secondSheet'))",
},
)

View file

@ -58,6 +58,16 @@ class TestMySQL(Validator):
self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'") self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'")
self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci") self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci")
self.validate_identity("SET autocommit = ON") self.validate_identity("SET autocommit = ON")
self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL SERIALIZABLE")
self.validate_identity("SET TRANSACTION READ ONLY")
self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE")
self.validate_identity("SELECT SCHEMA()")
def test_canonical_functions(self):
self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)")
self.validate_identity("SELECT INSTR('str', 'substr')", "SELECT LOCATE('substr', 'str')")
self.validate_identity("SELECT UCASE('foo')", "SELECT UPPER('foo')")
self.validate_identity("SELECT LCASE('foo')", "SELECT LOWER('foo')")
def test_escape(self): def test_escape(self):
self.validate_all( self.validate_all(

View file

@ -177,6 +177,15 @@ class TestPresto(Validator):
"spark": "CREATE TABLE test USING PARQUET AS SELECT 1", "spark": "CREATE TABLE test USING PARQUET AS SELECT 1",
}, },
) )
self.validate_all(
"CREATE TABLE test STORED = 'PARQUET' AS SELECT 1",
write={
"duckdb": "CREATE TABLE test AS SELECT 1",
"presto": "CREATE TABLE test WITH (FORMAT='PARQUET') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
"spark": "CREATE TABLE test USING PARQUET AS SELECT 1",
},
)
self.validate_all( self.validate_all(
"CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1", "CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1",
write={ write={
@ -427,3 +436,69 @@ class TestPresto(Validator):
"spark": UnsupportedError, "spark": UnsupportedError,
}, },
) )
self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE")
self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ")
def test_encode_decode(self):
self.validate_all(
"TO_UTF8(x)",
write={
"spark": "ENCODE(x, 'utf-8')",
},
)
self.validate_all(
"FROM_UTF8(x)",
write={
"spark": "DECODE(x, 'utf-8')",
},
)
self.validate_all(
"ENCODE(x, 'utf-8')",
write={
"presto": "TO_UTF8(x)",
},
)
self.validate_all(
"DECODE(x, 'utf-8')",
write={
"presto": "FROM_UTF8(x)",
},
)
self.validate_all(
"ENCODE(x, 'invalid')",
write={
"presto": UnsupportedError,
},
)
self.validate_all(
"DECODE(x, 'invalid')",
write={
"presto": UnsupportedError,
},
)
def test_hex_unhex(self):
self.validate_all(
"TO_HEX(x)",
write={
"spark": "HEX(x)",
},
)
self.validate_all(
"FROM_HEX(x)",
write={
"spark": "UNHEX(x)",
},
)
self.validate_all(
"HEX(x)",
write={
"presto": "TO_HEX(x)",
},
)
self.validate_all(
"UNHEX(x)",
write={
"presto": "FROM_HEX(x)",
},
)

View file

@ -169,6 +169,17 @@ class TestSnowflake(Validator):
"snowflake": "SELECT a FROM test AS unpivot", "snowflake": "SELECT a FROM test AS unpivot",
}, },
) )
self.validate_all(
"trim(date_column, 'UTC')",
write={
"snowflake": "TRIM(date_column, 'UTC')",
"postgres": "TRIM('UTC' FROM date_column)",
},
)
self.validate_all(
"trim(date_column)",
write={"snowflake": "TRIM(date_column)"},
)
def test_null_treatment(self): def test_null_treatment(self):
self.validate_all( self.validate_all(

View file

@ -122,13 +122,6 @@ x AT TIME ZONE 'UTC'
CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo' CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo'
SET x = 1 SET x = 1
SET -v SET -v
ADD JAR s3://bucket
ADD JARS s3://bucket, c
ADD FILE s3://file
ADD FILES s3://file, s3://a
ADD ARCHIVE s3://file
ADD ARCHIVES s3://file, s3://a
BEGIN IMMEDIATE TRANSACTION
COMMIT COMMIT
USE db USE db
NOT 1 NOT 1
@ -278,6 +271,7 @@ SELECT CEIL(a, b) FROM test
SELECT COUNT(a) FROM test SELECT COUNT(a) FROM test
SELECT COUNT(1) FROM test SELECT COUNT(1) FROM test
SELECT COUNT(*) FROM test SELECT COUNT(*) FROM test
SELECT COUNT() FROM test
SELECT COUNT(DISTINCT a) FROM test SELECT COUNT(DISTINCT a) FROM test
SELECT EXP(a) FROM test SELECT EXP(a) FROM test
SELECT FLOOR(a) FROM test SELECT FLOOR(a) FROM test
@ -372,6 +366,8 @@ WITH a AS (SELECT 1) SELECT 1 UNION SELECT 2
WITH a AS (SELECT 1) SELECT 1 INTERSECT SELECT 2 WITH a AS (SELECT 1) SELECT 1 INTERSECT SELECT 2
WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2 WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2
WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2 WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2
WITH sub_query AS (SELECT a FROM table) (SELECT a FROM sub_query)
WITH sub_query AS (SELECT a FROM table) ((((SELECT a FROM sub_query))))
(SELECT 1) UNION (SELECT 2) (SELECT 1) UNION (SELECT 2)
(SELECT 1) UNION SELECT 2 (SELECT 1) UNION SELECT 2
SELECT 1 UNION (SELECT 2) SELECT 1 UNION (SELECT 2)
@ -463,6 +459,7 @@ CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECI
CREATE TABLE z (a INT(11) DEFAULT UUID()) CREATE TABLE z (a INT(11) DEFAULT UUID())
CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id') CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id')
CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1) CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1)
CREATE TABLE z (a INT(11) NOT NULL DEFAULT -1)
CREATE TABLE z (a INT(11) NOT NULL COLLATE utf8_bin AUTO_INCREMENT) CREATE TABLE z (a INT(11) NOT NULL COLLATE utf8_bin AUTO_INCREMENT)
CREATE TABLE z (a INT, PRIMARY KEY(a)) CREATE TABLE z (a INT, PRIMARY KEY(a))
CREATE TABLE z WITH (FORMAT='parquet') AS SELECT 1 CREATE TABLE z WITH (FORMAT='parquet') AS SELECT 1
@ -476,6 +473,9 @@ CREATE TABLE z AS ((WITH cte AS (SELECT 1) SELECT * FROM cte))
CREATE TABLE z (a INT UNIQUE) CREATE TABLE z (a INT UNIQUE)
CREATE TABLE z (a INT AUTO_INCREMENT) CREATE TABLE z (a INT AUTO_INCREMENT)
CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT) CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT)
CREATE TABLE z (a INT REFERENCES parent(b, c))
CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id))
CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c))
CREATE TEMPORARY FUNCTION f CREATE TEMPORARY FUNCTION f
CREATE TEMPORARY FUNCTION f AS 'g' CREATE TEMPORARY FUNCTION f AS 'g'
CREATE FUNCTION f CREATE FUNCTION f
@ -514,17 +514,23 @@ DELETE FROM x WHERE y > 1
DELETE FROM y DELETE FROM y
DELETE FROM event USING sales WHERE event.eventid = sales.eventid DELETE FROM event USING sales WHERE event.eventid = sales.eventid
DELETE FROM event USING sales, USING bla WHERE event.eventid = sales.eventid DELETE FROM event USING sales, USING bla WHERE event.eventid = sales.eventid
DELETE FROM event USING sales AS s WHERE event.eventid = s.eventid
PREPARE statement
EXECUTE statement
DROP TABLE a DROP TABLE a
DROP TABLE a.b DROP TABLE a.b
DROP TABLE IF EXISTS a DROP TABLE IF EXISTS a
DROP TABLE IF EXISTS a.b DROP TABLE IF EXISTS a.b
DROP TABLE a CASCADE
DROP VIEW a DROP VIEW a
DROP VIEW a.b DROP VIEW a.b
DROP VIEW IF EXISTS a DROP VIEW IF EXISTS a
DROP VIEW IF EXISTS a.b DROP VIEW IF EXISTS a.b
SHOW TABLES SHOW TABLES
USE db USE db
BEGIN
ROLLBACK ROLLBACK
ROLLBACK TO b
EXPLAIN SELECT * FROM x EXPLAIN SELECT * FROM x
INSERT INTO x SELECT * FROM y INSERT INTO x SELECT * FROM y
INSERT INTO x (SELECT * FROM y) INSERT INTO x (SELECT * FROM y)
@ -581,3 +587,4 @@ SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */
SELECT x FROM a.b.c /* x */, e.f.g /* x */ SELECT x FROM a.b.c /* x */, e.f.g /* x */
SELECT FOO(x /* c */) /* FOO */, b /* b */ SELECT FOO(x /* c */) /* FOO */, b /* b */
SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */ SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */
SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b'

View file

@ -0,0 +1,5 @@
SELECT w.d + w.e AS c FROM w AS w;
SELECT CONCAT(w.d, w.e) AS c FROM w AS w;
SELECT CAST(w.d AS DATE) > w.e AS a FROM w AS w;
SELECT CAST(w.d AS DATE) > CAST(w.e AS DATE) AS a FROM w AS w;

View file

@ -119,7 +119,7 @@ GROUP BY
LIMIT 1; LIMIT 1;
# title: Root subquery is union # title: Root subquery is union
(SELECT b FROM x UNION SELECT b FROM y) LIMIT 1; (SELECT b FROM x UNION SELECT b FROM y ORDER BY b) LIMIT 1;
( (
SELECT SELECT
"x"."b" AS "b" "x"."b" AS "b"
@ -128,6 +128,8 @@ LIMIT 1;
SELECT SELECT
"y"."b" AS "b" "y"."b" AS "b"
FROM "y" AS "y" FROM "y" AS "y"
ORDER BY
"b"
) )
LIMIT 1; LIMIT 1;

View file

@ -15,7 +15,7 @@ select
from from
lineitem lineitem
where where
CAST(l_shipdate AS DATE) <= date '1998-12-01' - interval '90' day l_shipdate <= date '1998-12-01' - interval '90' day
group by group by
l_returnflag, l_returnflag,
l_linestatus l_linestatus
@ -250,8 +250,8 @@ FROM "orders" AS "orders"
LEFT JOIN "_u_0" AS "_u_0" LEFT JOIN "_u_0" AS "_u_0"
ON "_u_0"."l_orderkey" = "orders"."o_orderkey" ON "_u_0"."l_orderkey" = "orders"."o_orderkey"
WHERE WHERE
"orders"."o_orderdate" < CAST('1993-10-01' AS DATE) CAST("orders"."o_orderdate" AS DATE) < CAST('1993-10-01' AS DATE)
AND "orders"."o_orderdate" >= CAST('1993-07-01' AS DATE) AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1993-07-01' AS DATE)
AND NOT "_u_0"."l_orderkey" IS NULL AND NOT "_u_0"."l_orderkey" IS NULL
GROUP BY GROUP BY
"orders"."o_orderpriority" "orders"."o_orderpriority"
@ -293,8 +293,8 @@ SELECT
FROM "customer" AS "customer" FROM "customer" AS "customer"
JOIN "orders" AS "orders" JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey" ON "customer"."c_custkey" = "orders"."o_custkey"
AND "orders"."o_orderdate" < CAST('1995-01-01' AS DATE) AND CAST("orders"."o_orderdate" AS DATE) < CAST('1995-01-01' AS DATE)
AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE) AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1994-01-01' AS DATE)
JOIN "region" AS "region" JOIN "region" AS "region"
ON "region"."r_name" = 'ASIA' ON "region"."r_name" = 'ASIA'
JOIN "nation" AS "nation" JOIN "nation" AS "nation"
@ -328,8 +328,8 @@ FROM "lineitem" AS "lineitem"
WHERE WHERE
"lineitem"."l_discount" BETWEEN 0.05 AND 0.07 "lineitem"."l_discount" BETWEEN 0.05 AND 0.07
AND "lineitem"."l_quantity" < 24 AND "lineitem"."l_quantity" < 24
AND "lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE) AND CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE)
AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE); AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE);
-------------------------------------- --------------------------------------
-- TPC-H 7 -- TPC-H 7
@ -384,13 +384,13 @@ WITH "n1" AS (
SELECT SELECT
"n1"."n_name" AS "supp_nation", "n1"."n_name" AS "supp_nation",
"n2"."n_name" AS "cust_nation", "n2"."n_name" AS "cust_nation",
EXTRACT(year FROM "lineitem"."l_shipdate") AS "l_year", EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME)) AS "l_year",
SUM("lineitem"."l_extendedprice" * ( SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount" 1 - "lineitem"."l_discount"
)) AS "revenue" )) AS "revenue"
FROM "supplier" AS "supplier" FROM "supplier" AS "supplier"
JOIN "lineitem" AS "lineitem" JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) ON CAST("lineitem"."l_shipdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
AND "supplier"."s_suppkey" = "lineitem"."l_suppkey" AND "supplier"."s_suppkey" = "lineitem"."l_suppkey"
JOIN "orders" AS "orders" JOIN "orders" AS "orders"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey" ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
@ -409,7 +409,7 @@ JOIN "n1" AS "n2"
GROUP BY GROUP BY
"n1"."n_name", "n1"."n_name",
"n2"."n_name", "n2"."n_name",
EXTRACT(year FROM "lineitem"."l_shipdate") EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME))
ORDER BY ORDER BY
"supp_nation", "supp_nation",
"cust_nation", "cust_nation",
@ -456,7 +456,7 @@ group by
order by order by
o_year; o_year;
SELECT SELECT
EXTRACT(year FROM "orders"."o_orderdate") AS "o_year", EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year",
SUM( SUM(
CASE CASE
WHEN "nation_2"."n_name" = 'BRAZIL' WHEN "nation_2"."n_name" = 'BRAZIL'
@ -477,7 +477,7 @@ JOIN "customer" AS "customer"
ON "customer"."c_nationkey" = "nation"."n_nationkey" ON "customer"."c_nationkey" = "nation"."n_nationkey"
JOIN "orders" AS "orders" JOIN "orders" AS "orders"
ON "orders"."o_custkey" = "customer"."c_custkey" ON "orders"."o_custkey" = "customer"."c_custkey"
AND "orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) AND CAST("orders"."o_orderdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
JOIN "lineitem" AS "lineitem" JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey" ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
AND "part"."p_partkey" = "lineitem"."l_partkey" AND "part"."p_partkey" = "lineitem"."l_partkey"
@ -488,7 +488,7 @@ JOIN "nation" AS "nation_2"
WHERE WHERE
"part"."p_type" = 'ECONOMY ANODIZED STEEL' "part"."p_type" = 'ECONOMY ANODIZED STEEL'
GROUP BY GROUP BY
EXTRACT(year FROM "orders"."o_orderdate") EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME))
ORDER BY ORDER BY
"o_year"; "o_year";
@ -529,7 +529,7 @@ order by
o_year desc; o_year desc;
SELECT SELECT
"nation"."n_name" AS "nation", "nation"."n_name" AS "nation",
EXTRACT(year FROM "orders"."o_orderdate") AS "o_year", EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year",
SUM( SUM(
"lineitem"."l_extendedprice" * ( "lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount" 1 - "lineitem"."l_discount"
@ -551,7 +551,7 @@ WHERE
"part"."p_name" LIKE '%green%' "part"."p_name" LIKE '%green%'
GROUP BY GROUP BY
"nation"."n_name", "nation"."n_name",
EXTRACT(year FROM "orders"."o_orderdate") EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME))
ORDER BY ORDER BY
"nation", "nation",
"o_year" DESC; "o_year" DESC;
@ -606,8 +606,8 @@ SELECT
FROM "customer" AS "customer" FROM "customer" AS "customer"
JOIN "orders" AS "orders" JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey" ON "customer"."c_custkey" = "orders"."o_custkey"
AND "orders"."o_orderdate" < CAST('1994-01-01' AS DATE) AND CAST("orders"."o_orderdate" AS DATE) < CAST('1994-01-01' AS DATE)
AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE) AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1993-10-01' AS DATE)
JOIN "lineitem" AS "lineitem" JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey" AND "lineitem"."l_returnflag" = 'R' ON "lineitem"."l_orderkey" = "orders"."o_orderkey" AND "lineitem"."l_returnflag" = 'R'
JOIN "nation" AS "nation" JOIN "nation" AS "nation"
@ -740,8 +740,8 @@ SELECT
FROM "orders" AS "orders" FROM "orders" AS "orders"
JOIN "lineitem" AS "lineitem" JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_commitdate" < "lineitem"."l_receiptdate" ON "lineitem"."l_commitdate" < "lineitem"."l_receiptdate"
AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE) AND CAST("lineitem"."l_receiptdate" AS DATE) < CAST('1995-01-01' AS DATE)
AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE) AND CAST("lineitem"."l_receiptdate" AS DATE) >= CAST('1994-01-01' AS DATE)
AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate" AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate"
AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP') AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP')
AND "orders"."o_orderkey" = "lineitem"."l_orderkey" AND "orders"."o_orderkey" = "lineitem"."l_orderkey"
@ -832,8 +832,8 @@ FROM "lineitem" AS "lineitem"
JOIN "part" AS "part" JOIN "part" AS "part"
ON "lineitem"."l_partkey" = "part"."p_partkey" ON "lineitem"."l_partkey" = "part"."p_partkey"
WHERE WHERE
"lineitem"."l_shipdate" < CAST('1995-10-01' AS DATE) CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-10-01' AS DATE)
AND "lineitem"."l_shipdate" >= CAST('1995-09-01' AS DATE); AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1995-09-01' AS DATE);
-------------------------------------- --------------------------------------
-- TPC-H 15 -- TPC-H 15
@ -876,8 +876,8 @@ WITH "revenue" AS (
)) AS "total_revenue" )) AS "total_revenue"
FROM "lineitem" AS "lineitem" FROM "lineitem" AS "lineitem"
WHERE WHERE
"lineitem"."l_shipdate" < CAST('1996-04-01' AS DATE) CAST("lineitem"."l_shipdate" AS DATE) < CAST('1996-04-01' AS DATE)
AND "lineitem"."l_shipdate" >= CAST('1996-01-01' AS DATE) AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1996-01-01' AS DATE)
GROUP BY GROUP BY
"lineitem"."l_suppkey" "lineitem"."l_suppkey"
) )
@ -1220,8 +1220,8 @@ WITH "_u_0" AS (
"lineitem"."l_suppkey" AS "_u_2" "lineitem"."l_suppkey" AS "_u_2"
FROM "lineitem" AS "lineitem" FROM "lineitem" AS "lineitem"
WHERE WHERE
"lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE) CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE)
AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE) AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE)
GROUP BY GROUP BY
"lineitem"."l_partkey", "lineitem"."l_partkey",
"lineitem"."l_suppkey" "lineitem"."l_suppkey"

View file

@ -315,3 +315,10 @@ FROM (
WHERE WHERE
id = 1 id = 1
) /* x */; ) /* x */;
SELECT * /* multi
line
comment */;
SELECT
* /* multi
line
comment */;

View file

@ -57,79 +57,79 @@ SKIP_INTEGRATION = string_to_bool(os.environ.get("SKIP_INTEGRATION", "0").lower(
TPCH_SCHEMA = { TPCH_SCHEMA = {
"lineitem": { "lineitem": {
"l_orderkey": "uint64", "l_orderkey": "bigint",
"l_partkey": "uint64", "l_partkey": "bigint",
"l_suppkey": "uint64", "l_suppkey": "bigint",
"l_linenumber": "uint64", "l_linenumber": "bigint",
"l_quantity": "float64", "l_quantity": "double",
"l_extendedprice": "float64", "l_extendedprice": "double",
"l_discount": "float64", "l_discount": "double",
"l_tax": "float64", "l_tax": "double",
"l_returnflag": "string", "l_returnflag": "string",
"l_linestatus": "string", "l_linestatus": "string",
"l_shipdate": "date32", "l_shipdate": "string",
"l_commitdate": "date32", "l_commitdate": "string",
"l_receiptdate": "date32", "l_receiptdate": "string",
"l_shipinstruct": "string", "l_shipinstruct": "string",
"l_shipmode": "string", "l_shipmode": "string",
"l_comment": "string", "l_comment": "string",
}, },
"orders": { "orders": {
"o_orderkey": "uint64", "o_orderkey": "bigint",
"o_custkey": "uint64", "o_custkey": "bigint",
"o_orderstatus": "string", "o_orderstatus": "string",
"o_totalprice": "float64", "o_totalprice": "double",
"o_orderdate": "date32", "o_orderdate": "string",
"o_orderpriority": "string", "o_orderpriority": "string",
"o_clerk": "string", "o_clerk": "string",
"o_shippriority": "int32", "o_shippriority": "int",
"o_comment": "string", "o_comment": "string",
}, },
"customer": { "customer": {
"c_custkey": "uint64", "c_custkey": "bigint",
"c_name": "string", "c_name": "string",
"c_address": "string", "c_address": "string",
"c_nationkey": "uint64", "c_nationkey": "bigint",
"c_phone": "string", "c_phone": "string",
"c_acctbal": "float64", "c_acctbal": "double",
"c_mktsegment": "string", "c_mktsegment": "string",
"c_comment": "string", "c_comment": "string",
}, },
"part": { "part": {
"p_partkey": "uint64", "p_partkey": "bigint",
"p_name": "string", "p_name": "string",
"p_mfgr": "string", "p_mfgr": "string",
"p_brand": "string", "p_brand": "string",
"p_type": "string", "p_type": "string",
"p_size": "int32", "p_size": "int",
"p_container": "string", "p_container": "string",
"p_retailprice": "float64", "p_retailprice": "double",
"p_comment": "string", "p_comment": "string",
}, },
"supplier": { "supplier": {
"s_suppkey": "uint64", "s_suppkey": "bigint",
"s_name": "string", "s_name": "string",
"s_address": "string", "s_address": "string",
"s_nationkey": "uint64", "s_nationkey": "bigint",
"s_phone": "string", "s_phone": "string",
"s_acctbal": "float64", "s_acctbal": "double",
"s_comment": "string", "s_comment": "string",
}, },
"partsupp": { "partsupp": {
"ps_partkey": "uint64", "ps_partkey": "bigint",
"ps_suppkey": "uint64", "ps_suppkey": "bigint",
"ps_availqty": "int32", "ps_availqty": "int",
"ps_supplycost": "float64", "ps_supplycost": "double",
"ps_comment": "string", "ps_comment": "string",
}, },
"nation": { "nation": {
"n_nationkey": "uint64", "n_nationkey": "bigint",
"n_name": "string", "n_name": "string",
"n_regionkey": "uint64", "n_regionkey": "bigint",
"n_comment": "string", "n_comment": "string",
}, },
"region": { "region": {
"r_regionkey": "uint64", "r_regionkey": "bigint",
"r_name": "string", "r_name": "string",
"r_comment": "string", "r_comment": "string",
}, },

View file

@ -1,12 +1,15 @@
import unittest import unittest
from datetime import date
import duckdb import duckdb
import pandas as pd import pandas as pd
from pandas.testing import assert_frame_equal from pandas.testing import assert_frame_equal
from sqlglot import exp, parse_one from sqlglot import exp, parse_one
from sqlglot.errors import ExecuteError
from sqlglot.executor import execute from sqlglot.executor import execute
from sqlglot.executor.python import Python from sqlglot.executor.python import Python
from sqlglot.executor.table import Table, ensure_tables
from tests.helpers import ( from tests.helpers import (
FIXTURES_DIR, FIXTURES_DIR,
SKIP_INTEGRATION, SKIP_INTEGRATION,
@ -67,13 +70,399 @@ class TestExecutor(unittest.TestCase):
def to_csv(expression): def to_csv(expression):
if isinstance(expression, exp.Table): if isinstance(expression, exp.Table):
return parse_one( return parse_one(
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}" f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}"
) )
return expression return expression
for sql, _ in self.sqls[0:3]: for i, (sql, _) in enumerate(self.sqls[0:7]):
with self.subTest(f"tpch-h {i + 1}"):
a = self.cached_execute(sql) a = self.cached_execute(sql)
sql = parse_one(sql).transform(to_csv).sql(pretty=True) sql = parse_one(sql).transform(to_csv).sql(pretty=True)
table = execute(sql, TPCH_SCHEMA) table = execute(sql, TPCH_SCHEMA)
b = pd.DataFrame(table.rows, columns=table.columns) b = pd.DataFrame(table.rows, columns=table.columns)
assert_frame_equal(a, b, check_dtype=False) assert_frame_equal(a, b, check_dtype=False)
def test_execute_callable(self):
tables = {
"x": [
{"a": "a", "b": "d"},
{"a": "b", "b": "e"},
{"a": "c", "b": "f"},
],
"y": [
{"b": "d", "c": "g"},
{"b": "e", "c": "h"},
{"b": "f", "c": "i"},
],
"z": [],
}
schema = {
"x": {
"a": "VARCHAR",
"b": "VARCHAR",
},
"y": {
"b": "VARCHAR",
"c": "VARCHAR",
},
"z": {"d": "VARCHAR"},
}
for sql, cols, rows in [
("SELECT * FROM x", ["a", "b"], [("a", "d"), ("b", "e"), ("c", "f")]),
(
"SELECT * FROM x JOIN y ON x.b = y.b",
["a", "b", "b", "c"],
[("a", "d", "d", "g"), ("b", "e", "e", "h"), ("c", "f", "f", "i")],
),
(
"SELECT j.c AS d FROM x AS i JOIN y AS j ON i.b = j.b",
["d"],
[("g",), ("h",), ("i",)],
),
(
"SELECT CONCAT(x.a, y.c) FROM x JOIN y ON x.b = y.b WHERE y.b = 'e'",
["_col_0"],
[("bh",)],
),
(
"SELECT * FROM x JOIN y ON x.b = y.b WHERE y.b = 'e'",
["a", "b", "b", "c"],
[("b", "e", "e", "h")],
),
(
"SELECT * FROM z",
["d"],
[],
),
(
"SELECT d FROM z ORDER BY d",
["d"],
[],
),
(
"SELECT a FROM x WHERE x.a <> 'b'",
["a"],
[("a",), ("c",)],
),
(
"SELECT a AS i FROM x ORDER BY a",
["i"],
[("a",), ("b",), ("c",)],
),
(
"SELECT a AS i FROM x ORDER BY i",
["i"],
[("a",), ("b",), ("c",)],
),
(
"SELECT 100 - ORD(a) AS a, a AS i FROM x ORDER BY a",
["a", "i"],
[(1, "c"), (2, "b"), (3, "a")],
),
(
"SELECT a /* test */ FROM x LIMIT 1",
["a"],
[("a",)],
),
]:
with self.subTest(sql):
result = execute(sql, schema=schema, tables=tables)
self.assertEqual(result.columns, tuple(cols))
self.assertEqual(result.rows, rows)
def test_set_operations(self):
tables = {
"x": [
{"a": "a"},
{"a": "b"},
{"a": "c"},
],
"y": [
{"a": "b"},
{"a": "c"},
{"a": "d"},
],
}
schema = {
"x": {
"a": "VARCHAR",
},
"y": {
"a": "VARCHAR",
},
}
for sql, cols, rows in [
(
"SELECT a FROM x UNION ALL SELECT a FROM y",
["a"],
[("a",), ("b",), ("c",), ("b",), ("c",), ("d",)],
),
(
"SELECT a FROM x UNION SELECT a FROM y",
["a"],
[("a",), ("b",), ("c",), ("d",)],
),
(
"SELECT a FROM x EXCEPT SELECT a FROM y",
["a"],
[("a",)],
),
(
"SELECT a FROM x INTERSECT SELECT a FROM y",
["a"],
[("b",), ("c",)],
),
(
"""SELECT i.a
FROM (
SELECT a FROM x UNION SELECT a FROM y
) AS i
JOIN (
SELECT a FROM x UNION SELECT a FROM y
) AS j
ON i.a = j.a""",
["a"],
[("a",), ("b",), ("c",), ("d",)],
),
(
"SELECT 1 AS a UNION SELECT 2 AS a UNION SELECT 3 AS a",
["a"],
[(1,), (2,), (3,)],
),
]:
with self.subTest(sql):
result = execute(sql, schema=schema, tables=tables)
self.assertEqual(result.columns, tuple(cols))
self.assertEqual(set(result.rows), set(rows))
def test_execute_catalog_db_table(self):
tables = {
"catalog": {
"db": {
"x": [
{"a": "a"},
{"a": "b"},
{"a": "c"},
],
}
}
}
schema = {
"catalog": {
"db": {
"x": {
"a": "VARCHAR",
}
}
}
}
result1 = execute("SELECT * FROM x", schema=schema, tables=tables)
result2 = execute("SELECT * FROM catalog.db.x", schema=schema, tables=tables)
assert result1.columns == result2.columns
assert result1.rows == result2.rows
def test_execute_tables(self):
tables = {
"sushi": [
{"id": 1, "price": 1.0},
{"id": 2, "price": 2.0},
{"id": 3, "price": 3.0},
],
"order_items": [
{"sushi_id": 1, "order_id": 1},
{"sushi_id": 1, "order_id": 1},
{"sushi_id": 2, "order_id": 1},
{"sushi_id": 3, "order_id": 2},
],
"orders": [
{"id": 1, "user_id": 1},
{"id": 2, "user_id": 2},
],
}
self.assertEqual(
execute(
"""
SELECT
o.user_id,
SUM(s.price) AS price
FROM orders o
JOIN order_items i
ON o.id = i.order_id
JOIN sushi s
ON i.sushi_id = s.id
GROUP BY o.user_id
""",
tables=tables,
).rows,
[
(1, 4.0),
(2, 3.0),
],
)
self.assertEqual(
execute(
"""
SELECT
o.id, x.*
FROM orders o
LEFT JOIN (
SELECT
1 AS id, 'b' AS x
UNION ALL
SELECT
3 AS id, 'c' AS x
) x
ON o.id = x.id
""",
tables=tables,
).rows,
[(1, 1, "b"), (2, None, None)],
)
self.assertEqual(
execute(
"""
SELECT
o.id, x.*
FROM orders o
RIGHT JOIN (
SELECT
1 AS id,
'b' AS x
UNION ALL
SELECT
3 AS id, 'c' AS x
) x
ON o.id = x.id
""",
tables=tables,
).rows,
[
(1, 1, "b"),
(None, 3, "c"),
],
)
def test_table_depth_mismatch(self):
tables = {"table": []}
schema = {"db": {"table": {"col": "VARCHAR"}}}
with self.assertRaises(ExecuteError):
execute("SELECT * FROM table", schema=schema, tables=tables)
def test_tables(self):
tables = ensure_tables(
{
"catalog1": {
"db1": {
"t1": [
{"a": 1},
],
"t2": [
{"a": 1},
],
},
"db2": {
"t3": [
{"a": 1},
],
"t4": [
{"a": 1},
],
},
},
"catalog2": {
"db3": {
"t5": Table(columns=("a",), rows=[(1,)]),
"t6": Table(columns=("a",), rows=[(1,)]),
},
"db4": {
"t7": Table(columns=("a",), rows=[(1,)]),
"t8": Table(columns=("a",), rows=[(1,)]),
},
},
}
)
t1 = tables.find(exp.table_(table="t1", db="db1", catalog="catalog1"))
self.assertEqual(t1.columns, ("a",))
self.assertEqual(t1.rows, [(1,)])
t8 = tables.find(exp.table_(table="t8"))
self.assertEqual(t1.columns, t8.columns)
self.assertEqual(t1.rows, t8.rows)
def test_static_queries(self):
for sql, cols, rows in [
("SELECT 1", ["_col_0"], [(1,)]),
("SELECT 1 + 2 AS x", ["x"], [(3,)]),
("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]),
]:
result = execute(sql)
self.assertEqual(result.columns, tuple(cols))
self.assertEqual(result.rows, rows)
def test_aggregate_without_group_by(self):
result = execute("SELECT SUM(x) FROM t", tables={"t": [{"x": 1}, {"x": 2}]})
self.assertEqual(result.columns, ("_col_0",))
self.assertEqual(result.rows, [(3,)])
def test_scalar_functions(self):
for sql, expected in [
("CONCAT('a', 'b')", "ab"),
("CONCAT('a', NULL)", None),
("CONCAT_WS('_', 'a', 'b')", "a_b"),
("STR_POSITION('bar', 'foobarbar')", 4),
("STR_POSITION('bar', 'foobarbar', 5)", 7),
("STR_POSITION(NULL, 'foobarbar')", None),
("STR_POSITION('bar', NULL)", None),
("UPPER('foo')", "FOO"),
("UPPER(NULL)", None),
("LOWER('FOO')", "foo"),
("LOWER(NULL)", None),
("IFNULL('a', 'b')", "a"),
("IFNULL(NULL, 'b')", "b"),
("IFNULL(NULL, NULL)", None),
("SUBSTRING('12345')", "12345"),
("SUBSTRING('12345', 3)", "345"),
("SUBSTRING('12345', 3, 0)", ""),
("SUBSTRING('12345', 3, 1)", "3"),
("SUBSTRING('12345', 3, 2)", "34"),
("SUBSTRING('12345', 3, 3)", "345"),
("SUBSTRING('12345', 3, 4)", "345"),
("SUBSTRING('12345', -3)", "345"),
("SUBSTRING('12345', -3, 0)", ""),
("SUBSTRING('12345', -3, 1)", "3"),
("SUBSTRING('12345', -3, 2)", "34"),
("SUBSTRING('12345', 0)", ""),
("SUBSTRING('12345', 0, 1)", ""),
("SUBSTRING(NULL)", None),
("SUBSTRING(NULL, 1)", None),
("CAST(1 AS TEXT)", "1"),
("CAST('1' AS LONG)", 1),
("CAST('1.1' AS FLOAT)", 1.1),
("COALESCE(NULL)", None),
("COALESCE(NULL, NULL)", None),
("COALESCE(NULL, 'b')", "b"),
("COALESCE('a', 'b')", "a"),
("1 << 1", 2),
("1 >> 1", 0),
("1 & 1", 1),
("1 | 1", 1),
("1 < 1", False),
("1 <= 1", True),
("1 > 1", False),
("1 >= 1", True),
("1 + NULL", None),
("IF(true, 1, 0)", 1),
("IF(false, 1, 0)", 0),
("CASE WHEN 0 = 1 THEN 'foo' ELSE 'bar' END", "bar"),
("CAST('2022-01-01' AS DATE) + INTERVAL '1' DAY", date(2022, 1, 2)),
]:
with self.subTest(sql):
result = execute(f"SELECT {sql}")
self.assertEqual(result.rows, [(expected,)])

View file

@ -441,6 +441,9 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(parse_one("VARIANCE(a)"), exp.Variance) self.assertIsInstance(parse_one("VARIANCE(a)"), exp.Variance)
self.assertIsInstance(parse_one("VARIANCE_POP(a)"), exp.VariancePop) self.assertIsInstance(parse_one("VARIANCE_POP(a)"), exp.VariancePop)
self.assertIsInstance(parse_one("YEAR(a)"), exp.Year) self.assertIsInstance(parse_one("YEAR(a)"), exp.Year)
self.assertIsInstance(parse_one("BEGIN DEFERRED TRANSACTION"), exp.Transaction)
self.assertIsInstance(parse_one("COMMIT"), exp.Commit)
self.assertIsInstance(parse_one("ROLLBACK"), exp.Rollback)
def test_column(self): def test_column(self):
dot = parse_one("a.b.c") dot = parse_one("a.b.c")
@ -479,9 +482,9 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(column.text("expression"), "c") self.assertEqual(column.text("expression"), "c")
self.assertEqual(column.text("y"), "") self.assertEqual(column.text("y"), "")
self.assertEqual(parse_one("select * from x.y").find(exp.Table).text("db"), "x") self.assertEqual(parse_one("select * from x.y").find(exp.Table).text("db"), "x")
self.assertEqual(parse_one("select *").text("this"), "") self.assertEqual(parse_one("select *").name, "")
self.assertEqual(parse_one("1 + 1").text("this"), "1") self.assertEqual(parse_one("1 + 1").name, "1")
self.assertEqual(parse_one("'a'").text("this"), "a") self.assertEqual(parse_one("'a'").name, "a")
def test_alias(self): def test_alias(self):
self.assertEqual(alias("foo", "bar").sql(), "foo AS bar") self.assertEqual(alias("foo", "bar").sql(), "foo AS bar")
@ -538,8 +541,8 @@ class TestExpressions(unittest.TestCase):
this=exp.Literal.string("TABLE_FORMAT"), this=exp.Literal.string("TABLE_FORMAT"),
value=exp.to_identifier("test_format"), value=exp.to_identifier("test_format"),
), ),
exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.NULL), exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.null()),
exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.TRUE), exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.true()),
] ]
), ),
) )

View file

@ -29,6 +29,7 @@ class TestOptimizer(unittest.TestCase):
CREATE TABLE x (a INT, b INT); CREATE TABLE x (a INT, b INT);
CREATE TABLE y (b INT, c INT); CREATE TABLE y (b INT, c INT);
CREATE TABLE z (b INT, c INT); CREATE TABLE z (b INT, c INT);
CREATE TABLE w (d TEXT, e TEXT);
INSERT INTO x VALUES (1, 1); INSERT INTO x VALUES (1, 1);
INSERT INTO x VALUES (2, 2); INSERT INTO x VALUES (2, 2);
@ -47,6 +48,8 @@ class TestOptimizer(unittest.TestCase):
INSERT INTO y VALUES (4, 4); INSERT INTO y VALUES (4, 4);
INSERT INTO y VALUES (5, 5); INSERT INTO y VALUES (5, 5);
INSERT INTO y VALUES (null, null); INSERT INTO y VALUES (null, null);
INSERT INTO w VALUES ('a', 'b');
""" """
) )
@ -64,6 +67,10 @@ class TestOptimizer(unittest.TestCase):
"b": "INT", "b": "INT",
"c": "INT", "c": "INT",
}, },
"w": {
"d": "TEXT",
"e": "TEXT",
},
} }
def check_file(self, file, func, pretty=False, execute=False, **kwargs): def check_file(self, file, func, pretty=False, execute=False, **kwargs):
@ -224,6 +231,18 @@ class TestOptimizer(unittest.TestCase):
def test_eliminate_subqueries(self): def test_eliminate_subqueries(self):
self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries) self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries)
def test_canonicalize(self):
optimize = partial(
optimizer.optimize,
rules=[
optimizer.qualify_tables.qualify_tables,
optimizer.qualify_columns.qualify_columns,
annotate_types,
optimizer.canonicalize.canonicalize,
],
)
self.check_file("canonicalize", optimize, schema=self.schema)
def test_tpch(self): def test_tpch(self):
self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True) self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True)

View file

@ -41,12 +41,41 @@ class TestParser(unittest.TestCase):
) )
def test_command(self): def test_command(self):
expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1") expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1", read="hive")
self.assertEqual(len(expressions), 3) self.assertEqual(len(expressions), 3)
self.assertEqual(expressions[0].sql(), "SET x = 1") self.assertEqual(expressions[0].sql(), "SET x = 1")
self.assertEqual(expressions[1].sql(), "ADD JAR s3://a") self.assertEqual(expressions[1].sql(), "ADD JAR s3://a")
self.assertEqual(expressions[2].sql(), "SELECT 1") self.assertEqual(expressions[2].sql(), "SELECT 1")
def test_transactions(self):
expression = parse_one("BEGIN TRANSACTION")
self.assertIsNone(expression.this)
self.assertEqual(expression.args["modes"], [])
self.assertEqual(expression.sql(), "BEGIN")
expression = parse_one("START TRANSACTION", read="mysql")
self.assertIsNone(expression.this)
self.assertEqual(expression.args["modes"], [])
self.assertEqual(expression.sql(), "BEGIN")
expression = parse_one("BEGIN DEFERRED TRANSACTION")
self.assertEqual(expression.this, "DEFERRED")
self.assertEqual(expression.args["modes"], [])
self.assertEqual(expression.sql(), "BEGIN")
expression = parse_one(
"START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE", read="presto"
)
self.assertIsNone(expression.this)
self.assertEqual(expression.args["modes"][0], "READ WRITE")
self.assertEqual(expression.args["modes"][1], "ISOLATION LEVEL SERIALIZABLE")
self.assertEqual(expression.sql(), "BEGIN")
expression = parse_one("BEGIN", read="bigquery")
self.assertNotIsInstance(expression, exp.Transaction)
self.assertIsNone(expression.expression)
self.assertEqual(expression.sql(), "BEGIN")
def test_identify(self): def test_identify(self):
expression = parse_one( expression = parse_one(
""" """
@ -55,14 +84,14 @@ class TestParser(unittest.TestCase):
""" """
) )
assert expression.expressions[0].text("this") == "a" assert expression.expressions[0].name == "a"
assert expression.expressions[1].text("this") == "b" assert expression.expressions[1].name == "b"
assert expression.expressions[2].text("alias") == "c" assert expression.expressions[2].alias == "c"
assert expression.expressions[3].text("alias") == "D" assert expression.expressions[3].alias == "D"
assert expression.expressions[4].text("alias") == "y|z'" assert expression.expressions[4].alias == "y|z'"
table = expression.args["from"].expressions[0] table = expression.args["from"].expressions[0]
assert table.args["this"].args["this"] == "z" assert table.this.name == "z"
assert table.args["db"].args["this"] == "y" assert table.args["db"].name == "y"
def test_multi(self): def test_multi(self):
expressions = parse( expressions = parse(
@ -72,8 +101,8 @@ class TestParser(unittest.TestCase):
) )
assert len(expressions) == 2 assert len(expressions) == 2
assert expressions[0].args["from"].expressions[0].args["this"].args["this"] == "a" assert expressions[0].args["from"].expressions[0].this.name == "a"
assert expressions[1].args["from"].expressions[0].args["this"].args["this"] == "b" assert expressions[1].args["from"].expressions[0].this.name == "b"
def test_expression(self): def test_expression(self):
ignore = Parser(error_level=ErrorLevel.IGNORE) ignore = Parser(error_level=ErrorLevel.IGNORE)
@ -200,7 +229,7 @@ class TestParser(unittest.TestCase):
@patch("sqlglot.parser.logger") @patch("sqlglot.parser.logger")
def test_comment_error_n(self, logger): def test_comment_error_n(self, logger):
parse_one( parse_one(
"""CREATE TABLE x """SUM
( (
-- test -- test
)""", )""",
@ -208,19 +237,19 @@ class TestParser(unittest.TestCase):
) )
assert_logger_contains( assert_logger_contains(
"Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Schema'>. Line 4, Col: 1.", "Required keyword: 'this' missing for <class 'sqlglot.expressions.Sum'>. Line 4, Col: 1.",
logger, logger,
) )
@patch("sqlglot.parser.logger") @patch("sqlglot.parser.logger")
def test_comment_error_r(self, logger): def test_comment_error_r(self, logger):
parse_one( parse_one(
"""CREATE TABLE x (-- test\r)""", """SUM(-- test\r)""",
error_level=ErrorLevel.WARN, error_level=ErrorLevel.WARN,
) )
assert_logger_contains( assert_logger_contains(
"Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Schema'>. Line 2, Col: 1.", "Required keyword: 'this' missing for <class 'sqlglot.expressions.Sum'>. Line 2, Col: 1.",
logger, logger,
) )

View file

@ -12,6 +12,7 @@ class TestTokens(unittest.TestCase):
("--comment\nfoo --test", "comment"), ("--comment\nfoo --test", "comment"),
("foo --comment", "comment"), ("foo --comment", "comment"),
("foo", None), ("foo", None),
("foo /*comment 1*/ /*comment 2*/", "comment 1"),
] ]
for sql, comment in sql_comment: for sql, comment in sql_comment:

View file

@ -20,6 +20,13 @@ class TestTranspile(unittest.TestCase):
self.assertEqual(transpile(sql, **kwargs)[0], target) self.assertEqual(transpile(sql, **kwargs)[0], target)
def test_alias(self): def test_alias(self):
self.assertEqual(transpile("SELECT 1 current_time")[0], "SELECT 1 AS current_time")
self.assertEqual(
transpile("SELECT 1 current_timestamp")[0], "SELECT 1 AS current_timestamp"
)
self.assertEqual(transpile("SELECT 1 current_date")[0], "SELECT 1 AS current_date")
self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime")
for key in ("union", "filter", "over", "from", "join"): for key in ("union", "filter", "over", "from", "join"):
with self.subTest(f"alias {key}"): with self.subTest(f"alias {key}"):
self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}") self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}")
@ -69,6 +76,10 @@ class TestTranspile(unittest.TestCase):
self.validate("SELECT 3>=3", "SELECT 3 >= 3") self.validate("SELECT 3>=3", "SELECT 3 >= 3")
def test_comments(self): def test_comments(self):
self.validate("SELECT */*comment*/", "SELECT * /* comment */")
self.validate(
"SELECT * FROM table /*comment 1*/ /*comment 2*/", "SELECT * FROM table /* comment 1 */"
)
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */") self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */")
self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo") self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo")
self.validate("SELECT --!5\nx FROM foo", "/* !5 */ SELECT x FROM foo") self.validate("SELECT --!5\nx FROM foo", "/* !5 */ SELECT x FROM foo")