1
0
Fork 0

Merging upstream version 10.0.8.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:54:32 +01:00
parent 407314e8d2
commit efc1e37108
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
67 changed files with 2461 additions and 840 deletions

View file

@ -7,21 +7,37 @@ v10.0.0
Changes:
- Breaking: replaced SQLGlot annotations with comments. Now comments can be preserved after transpilation, and they can appear in other places besides SELECT's expressions.
- Breaking: renamed list_get to seq_get.
- Breaking: activated mypy type checking for SQLGlot.
- New: Azure Databricks support.
- New: placeholders can now be replaced in an expression.
- New: null safe equal operator (<=>).
- New: [SET statements](https://github.com/tobymao/sqlglot/pull/673) for MySQL.
- New: [SHOW commands](https://dev.mysql.com/doc/refman/8.0/en/show.html) for MySQL.
- New: [FORMAT function](https://www.w3schools.com/sql/func_sqlserver_format.asp) for TSQL.
- New: CROSS APPLY / OUTER APPLY [support](https://github.com/tobymao/sqlglot/pull/641) for TSQL.
- New: added formats for TSQL's [DATENAME/DATEPART functions](https://learn.microsoft.com/en-us/sql/t-sql/functions/datename-transact-sql?view=sql-server-ver16)
- New: added formats for TSQL's [DATENAME/DATEPART functions](https://learn.microsoft.com/en-us/sql/t-sql/functions/datename-transact-sql?view=sql-server-ver16).
- New: added styles for TSQL's [CONVERT function](https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16).
- Improvement: [refactored the schema](https://github.com/tobymao/sqlglot/pull/668) to be more lenient; before it needed to do an exact match of db.table, now it finds table if there are no ambiguities.
- Improvement: allow functions to [inherit](https://github.com/tobymao/sqlglot/pull/674) their arguments' types, so that annotating CASE, IF etc. is possible.
- Improvement: allow [joining with same names](https://github.com/tobymao/sqlglot/pull/660) in the python executor.
- Improvement: the "using" field can now be set for the [join expression builders](https://github.com/tobymao/sqlglot/pull/636).
- Improvement: qualify_columns [now qualifies](https://github.com/tobymao/sqlglot/pull/635) only non-alias columns in the having clause.
v9.0.0
@ -37,6 +53,7 @@ v8.0.0
Changes:
- Breaking : New add\_table method in Schema ABC.
- New: SQLGlot now supports the [PySpark](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dataframe) dataframe API. This is still relatively experimental.
v7.1.0
@ -45,8 +62,11 @@ v7.1.0
Changes:
- Improvement: Pretty generator now takes max\_text\_width which breaks segments into new lines
- New: exp.to\_table helper to turn table names into table expression objects
- New: int[] type parsers
- New: annotations are now generated in sql
v7.0.0

View file

@ -21,7 +21,7 @@ Pull requests are the best way to propose changes to the codebase. We actively w
5. Issue that pull request and wait for it to be reviewed by a maintainer or contributor!
## Report bugs using Github's [issues](https://github.com/tobymao/sqlglot/issues)
We use GitHub issues to track public bugs. Report a bug by [opening a new issue]().
We use GitHub issues to track public bugs. Report a bug by opening a new issue.
**Great Bug Reports** tend to have:

View file

@ -90,7 +90,7 @@ sqlglot.transpile("SELECT STRFTIME(x, '%y-%-m-%S')", read="duckdb", write="hive"
"SELECT DATE_FORMAT(x, 'yy-M-ss')"
```
As another example, let's suppose that we want to read in a SQL query that contains a CTE and a cast to `REAL`, and then transpile it to Spark, which uses backticks as identifiers and `FLOAT` instead of `REAL`:
As another example, let's suppose that we want to read in a SQL query that contains a CTE and a cast to `REAL`, and then transpile it to Spark, which uses backticks for identifiers and `FLOAT` instead of `REAL`:
```python
import sqlglot
@ -376,12 +376,12 @@ print(Dialect["custom"])
[Benchmarks](benchmarks) run on Python 3.10.5 in seconds.
| Query | sqlglot | sqltree | sqlparse | moz_sql_parser | sqloxide |
| --------------- | --------------- | --------------- | --------------- | --------------- | --------------- |
| tpch | 0.01178 (1.0) | 0.01173 (0.995) | 0.04676 (3.966) | 0.06800 (5.768) | 0.00094 (0.080) |
| short | 0.00084 (1.0) | 0.00079 (0.948) | 0.00296 (3.524) | 0.00443 (5.266) | 0.00006 (0.072) |
| long | 0.01102 (1.0) | 0.01044 (0.947) | 0.04349 (3.945) | 0.05998 (5.440) | 0.00084 (0.077) |
| crazy | 0.03751 (1.0) | 0.03471 (0.925) | 11.0796 (295.3) | 1.03355 (27.55) | 0.00529 (0.141) |
| Query | sqlglot | sqlfluff | sqltree | sqlparse | moz_sql_parser | sqloxide |
| --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- |
| tpch | 0.01308 (1.0) | 1.60626 (122.7) | 0.01168 (0.893) | 0.04958 (3.791) | 0.08543 (6.531) | 0.00136 (0.104) |
| short | 0.00109 (1.0) | 0.14134 (129.2) | 0.00099 (0.906) | 0.00342 (3.131) | 0.00652 (5.970) | 8.76621 (0.080) |
| long | 0.01399 (1.0) | 2.12632 (151.9) | 0.01126 (0.805) | 0.04410 (3.151) | 0.06671 (4.767) | 0.00107 (0.076) |
| crazy | 0.03969 (1.0) | 24.3777 (614.1) | 0.03917 (0.987) | 11.7043 (294.8) | 1.03280 (26.02) | 0.00625 (0.157) |
## Optional Dependencies

View file

@ -5,8 +5,10 @@ collections.Iterable = collections.abc.Iterable
import gc
import timeit
import moz_sql_parser
import numpy as np
import sqlfluff
import moz_sql_parser
import sqloxide
import sqlparse
import sqltree
@ -177,6 +179,10 @@ def sqloxide_parse(sql):
sqloxide.parse_sql(sql, dialect="ansi")
def sqlfluff_parse(sql):
sqlfluff.parse(sql)
def border(columns):
columns = " | ".join(columns)
return f"| {columns} |"
@ -193,6 +199,7 @@ def diff(row, column):
libs = [
"sqlglot",
"sqlfluff",
"sqltree",
"sqlparse",
"moz_sql_parser",
@ -206,7 +213,8 @@ for name, sql in {"tpch": tpch, "short": short, "long": long, "crazy": crazy}.it
for lib in libs:
try:
row[lib] = np.mean(timeit.repeat(lambda: globals()[lib + "_parse"](sql), number=3))
except:
except Exception as e:
print(e)
row[lib] = "error"
columns = ["Query"] + libs

View file

@ -30,7 +30,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "10.0.1"
__version__ = "10.0.8"
pretty = False

View file

@ -260,7 +260,10 @@ class Column:
"""
if isinstance(dataType, DataType):
dataType = dataType.simpleString()
new_expression = exp.Cast(this=self.column_expression, to=dataType)
new_expression = exp.Cast(
this=self.column_expression,
to=sqlglot.parse_one(dataType, into=exp.DataType, read="spark"), # type: ignore
)
return Column(new_expression)
def startswith(self, value: t.Union[str, Column]) -> Column:

View file

@ -314,7 +314,13 @@ class DataFrame:
replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore
cache_table_name
)
sqlglot.schema.add_table(cache_table_name, select_expression.named_selects)
sqlglot.schema.add_table(
cache_table_name,
{
expression.alias_or_name: expression.type.name
for expression in select_expression.expressions
},
)
cache_storage_level = select_expression.args["cache_storage_level"]
options = [
exp.Literal.string("storageLevel"),

View file

@ -757,11 +757,15 @@ def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
def decode(col: ColumnOrName, charset: str) -> Column:
return Column.invoke_anonymous_function(col, "DECODE", lit(charset))
return Column.invoke_expression_over_column(
col, glotexp.Decode, charset=glotexp.Literal.string(charset)
)
def encode(col: ColumnOrName, charset: str) -> Column:
return Column.invoke_anonymous_function(col, "ENCODE", lit(charset))
return Column.invoke_expression_over_column(
col, glotexp.Encode, charset=glotexp.Literal.string(charset)
)
def format_number(col: ColumnOrName, d: int) -> Column:
@ -867,11 +871,11 @@ def bin(col: ColumnOrName) -> Column:
def hex(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "HEX")
return Column.invoke_expression_over_column(col, glotexp.Hex)
def unhex(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "UNHEX")
return Column.invoke_expression_over_column(col, glotexp.Unhex)
def length(col: ColumnOrName) -> Column:
@ -939,11 +943,7 @@ def array_join(
def concat(*cols: ColumnOrName) -> Column:
if len(cols) == 1:
return Column.invoke_anonymous_function(cols[0], "CONCAT")
return Column.invoke_anonymous_function(
cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]]
)
return Column.invoke_expression_over_column(None, glotexp.Concat, expressions=cols)
def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:

View file

@ -88,14 +88,14 @@ class SparkSession:
"expressions": sel_columns,
"from": exp.From(
expressions=[
exp.Subquery(
this=exp.Values(expressions=data_expressions),
exp.Values(
expressions=data_expressions,
alias=exp.TableAlias(
this=exp.to_identifier(self._auto_incrementing_name),
columns=[exp.to_identifier(col_name) for col_name in column_mapping],
),
)
]
),
],
),
}

View file

@ -2,6 +2,7 @@ from sqlglot.dialects.bigquery import BigQuery
from sqlglot.dialects.clickhouse import ClickHouse
from sqlglot.dialects.databricks import Databricks
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.dialects.drill import Drill
from sqlglot.dialects.duckdb import DuckDB
from sqlglot.dialects.hive import Hive
from sqlglot.dialects.mysql import MySQL

View file

@ -119,6 +119,8 @@ class BigQuery(Dialect):
"UNKNOWN": TokenType.NULL,
"WINDOW": TokenType.WINDOW,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
}
KEYWORDS.pop("DIV")
@ -204,6 +206,15 @@ class BigQuery(Dialect):
EXPLICIT_UNION = True
def transaction_sql(self, *_):
return "BEGIN TRANSACTION"
def commit_sql(self, *_):
return "COMMIT TRANSACTION"
def rollback_sql(self, *_):
return "ROLLBACK TRANSACTION"
def in_unnest_op(self, unnest):
return self.sql(unnest)

View file

@ -32,6 +32,7 @@ class Dialects(str, Enum):
TRINO = "trino"
TSQL = "tsql"
DATABRICKS = "databricks"
DRILL = "drill"
class _Dialect(type):
@ -362,3 +363,18 @@ def parse_date_delta(exp_class, unit_mapping=None):
return exp_class(this=this, expression=expression, unit=unit)
return inner_func
def locate_to_strposition(args):
return exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
)
def strposition_to_local_sql(self, expression):
args = self.format_args(
expression.args.get("substr"), expression.this, expression.args.get("position")
)
return f"LOCATE({args})"

174
sqlglot/dialects/drill.py Normal file
View file

@ -0,0 +1,174 @@
from __future__ import annotations
import re
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
create_with_partitions_sql,
format_time_lambda,
no_pivot_sql,
no_trycast_sql,
rename_func,
str_position_sql,
)
from sqlglot.dialects.postgres import _lateral_sql
def _to_timestamp(args):
# TO_TIMESTAMP accepts either a single double argument or (text, text)
if len(args) == 1 and args[0].is_number:
return exp.UnixToTime.from_arg_list(args)
return format_time_lambda(exp.StrToTime, "drill")(args)
def _str_to_time_sql(self, expression):
return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
def _ts_or_ds_to_date_sql(self, expression):
time_format = self.format_time(expression)
if time_format and time_format not in (Drill.time_format, Drill.date_format):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
return f"CAST({self.sql(expression, 'this')} AS DATE)"
def _date_add_sql(kind):
def func(self, expression):
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
expression = self.sql(expression, "expression")
return f"DATE_{kind}({this}, INTERVAL '{expression}' {unit})"
return func
def if_sql(self, expression):
"""
Drill requires backticks around certain SQL reserved words, IF being one of them, This function
adds the backticks around the keyword IF.
Args:
self: The Drill dialect
expression: The input IF expression
Returns: The expression with IF in backticks.
"""
expressions = self.format_args(
expression.this, expression.args.get("true"), expression.args.get("false")
)
return f"`IF`({expressions})"
def _str_to_date(self, expression):
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Drill.date_format:
return f"CAST({this} AS DATE)"
return f"TO_DATE({this}, {time_format})"
class Drill(Dialect):
normalize_functions = None
null_ordering = "nulls_are_last"
date_format = "'yyyy-MM-dd'"
dateint_format = "'yyyyMMdd'"
time_format = "'yyyy-MM-dd HH:mm:ss'"
time_mapping = {
"y": "%Y",
"Y": "%Y",
"YYYY": "%Y",
"yyyy": "%Y",
"YY": "%y",
"yy": "%y",
"MMMM": "%B",
"MMM": "%b",
"MM": "%m",
"M": "%-m",
"dd": "%d",
"d": "%-d",
"HH": "%H",
"H": "%-H",
"hh": "%I",
"h": "%-I",
"mm": "%M",
"m": "%-M",
"ss": "%S",
"s": "%-S",
"SSSSSS": "%f",
"a": "%p",
"DD": "%j",
"D": "%-j",
"E": "%a",
"EE": "%a",
"EEE": "%a",
"EEEE": "%A",
"''T''": "T",
}
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'"]
IDENTIFIERS = ["`"]
ESCAPES = ["\\"]
ENCODE = "utf-8"
class Parser(parser.Parser):
STRICT_CAST = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
}
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER",
exp.DataType.Type.BINARY: "VARBINARY",
exp.DataType.Type.TEXT: "VARCHAR",
exp.DataType.Type.NCHAR: "VARCHAR",
exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.DATETIME: "TIMESTAMP",
}
ROOT_PROPERTIES = {exp.PartitionedByProperty}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.Lateral: _lateral_sql,
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
exp.ArraySize: rename_func("REPEATED_COUNT"),
exp.Create: create_with_partitions_sql,
exp.DateAdd: _date_add_sql("ADD"),
exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.DateSub: _date_add_sql("SUB"),
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})",
exp.If: if_sql,
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
exp.Pivot: no_pivot_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
exp.StrPosition: str_position_sql,
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), INTERVAL '{self.sql(e, 'expression')}' DAY)",
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
}
def normalize_func(self, name):
return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`"

View file

@ -55,13 +55,13 @@ def _array_sort_sql(self, expression):
def _sort_array_sql(self, expression):
this = self.sql(expression, "this")
if expression.args.get("asc") == exp.FALSE:
if expression.args.get("asc") == exp.false():
return f"ARRAY_REVERSE_SORT({this})"
return f"ARRAY_SORT({this})"
def _sort_array_reverse(args):
return exp.SortArray(this=seq_get(args, 0), asc=exp.FALSE)
return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
def _struct_pack_sql(self, expression):

View file

@ -7,16 +7,19 @@ from sqlglot.dialects.dialect import (
create_with_partitions_sql,
format_time_lambda,
if_sql,
locate_to_strposition,
no_ilike_sql,
no_recursive_cte_sql,
no_safe_divide_sql,
no_trycast_sql,
rename_func,
strposition_to_local_sql,
struct_extract_sql,
var_map_sql,
)
from sqlglot.helper import seq_get
from sqlglot.parser import parse_var_map
from sqlglot.tokens import TokenType
# (FuncType, Multiplier)
DATE_DELTA_INTERVAL = {
@ -181,6 +184,15 @@ class Hive(Dialect):
"F": "FLOAT",
"BD": "DECIMAL",
}
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ADD ARCHIVE": TokenType.COMMAND,
"ADD ARCHIVES": TokenType.COMMAND,
"ADD FILE": TokenType.COMMAND,
"ADD FILES": TokenType.COMMAND,
"ADD JAR": TokenType.COMMAND,
"ADD JARS": TokenType.COMMAND,
}
class Parser(parser.Parser):
STRICT_CAST = False
@ -210,11 +222,7 @@ class Hive(Dialect):
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
"LOCATE": lambda args: exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
),
"LOCATE": locate_to_strposition,
"LOG": (
lambda args: exp.Log.from_arg_list(args)
if len(args) > 1
@ -272,7 +280,7 @@ class Hive(Dialect):
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.SetAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
exp.StrPosition: lambda self, e: f"LOCATE({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})",
exp.StrPosition: strposition_to_local_sql,
exp.StrToDate: _str_to_date,
exp.StrToTime: _str_to_time,
exp.StrToUnix: _str_to_unix,

View file

@ -5,10 +5,12 @@ import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
locate_to_strposition,
no_ilike_sql,
no_paren_current_date_sql,
no_tablesample_sql,
no_trycast_sql,
strposition_to_local_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@ -120,6 +122,7 @@ class MySQL(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"START": TokenType.BEGIN,
"SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
@ -172,13 +175,18 @@ class MySQL(Dialect):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
class Parser(parser.Parser):
STRICT_CAST = False
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA}
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE_ADD": _date_add(exp.DateAdd),
"DATE_SUB": _date_add(exp.DateSub),
"STR_TO_DATE": _str_to_date,
"LOCATE": locate_to_strposition,
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
"LEFT": lambda args: exp.Substring(
this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1)
),
}
FUNCTION_PARSERS = {
@ -264,6 +272,7 @@ class MySQL(Dialect):
"CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"NAMES": lambda self: self._parse_set_item_names(),
"TRANSACTION": lambda self: self._parse_set_transaction(),
}
PROFILE_TYPES = {
@ -278,39 +287,48 @@ class MySQL(Dialect):
"SWAPS",
}
TRANSACTION_CHARACTERISTICS = {
"ISOLATION LEVEL REPEATABLE READ",
"ISOLATION LEVEL READ COMMITTED",
"ISOLATION LEVEL READ UNCOMMITTED",
"ISOLATION LEVEL SERIALIZABLE",
"READ WRITE",
"READ ONLY",
}
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
if target:
if isinstance(target, str):
self._match_text(target)
self._match_text_seq(target)
target_id = self._parse_id_var()
else:
target_id = None
log = self._parse_string() if self._match_text("IN") else None
log = self._parse_string() if self._match_text_seq("IN") else None
if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}:
position = self._parse_number() if self._match_text("FROM") else None
position = self._parse_number() if self._match_text_seq("FROM") else None
db = None
else:
position = None
db = self._parse_id_var() if self._match_text("FROM") else None
db = self._parse_id_var() if self._match_text_seq("FROM") else None
channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None
channel = self._parse_id_var() if self._match_text_seq("FOR", "CHANNEL") else None
like = self._parse_string() if self._match_text("LIKE") else None
like = self._parse_string() if self._match_text_seq("LIKE") else None
where = self._parse_where()
if this == "PROFILE":
types = self._parse_csv(self._parse_show_profile_type)
query = self._parse_number() if self._match_text("FOR", "QUERY") else None
offset = self._parse_number() if self._match_text("OFFSET") else None
limit = self._parse_number() if self._match_text("LIMIT") else None
types = self._parse_csv(lambda: self._parse_var_from_options(self.PROFILE_TYPES))
query = self._parse_number() if self._match_text_seq("FOR", "QUERY") else None
offset = self._parse_number() if self._match_text_seq("OFFSET") else None
limit = self._parse_number() if self._match_text_seq("LIMIT") else None
else:
types, query = None, None
offset, limit = self._parse_oldstyle_limit()
mutex = True if self._match_text("MUTEX") else None
mutex = False if self._match_text("STATUS") else mutex
mutex = True if self._match_text_seq("MUTEX") else None
mutex = False if self._match_text_seq("STATUS") else mutex
return self.expression(
exp.Show,
@ -331,16 +349,16 @@ class MySQL(Dialect):
**{"global": global_},
)
def _parse_show_profile_type(self):
for type_ in self.PROFILE_TYPES:
if self._match_text(*type_.split(" ")):
return exp.Var(this=type_)
def _parse_var_from_options(self, options):
for option in options:
if self._match_text_seq(*option.split(" ")):
return exp.Var(this=option)
return None
def _parse_oldstyle_limit(self):
limit = None
offset = None
if self._match_text("LIMIT"):
if self._match_text_seq("LIMIT"):
parts = self._parse_csv(self._parse_number)
if len(parts) == 1:
limit = parts[0]
@ -353,6 +371,9 @@ class MySQL(Dialect):
return self._parse_set_item_assignment(kind=None)
def _parse_set_item_assignment(self, kind):
if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"):
return self._parse_set_transaction(global_=kind == "GLOBAL")
left = self._parse_primary() or self._parse_id_var()
if not self._match(TokenType.EQ):
self.raise_error("Expected =")
@ -381,7 +402,7 @@ class MySQL(Dialect):
def _parse_set_item_names(self):
charset = self._parse_string() or self._parse_id_var()
if self._match_text("COLLATE"):
if self._match_text_seq("COLLATE"):
collate = self._parse_string() or self._parse_id_var()
else:
collate = None
@ -392,6 +413,18 @@ class MySQL(Dialect):
kind="NAMES",
)
def _parse_set_transaction(self, global_=False):
self._match_text_seq("TRANSACTION")
characteristics = self._parse_csv(
lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS)
)
return self.expression(
exp.SetItem,
expressions=characteristics,
kind="TRANSACTION",
**{"global": global_},
)
class Generator(generator.Generator):
NULL_ORDERING_SUPPORTED = False
@ -411,6 +444,7 @@ class MySQL(Dialect):
exp.Trim: _trim_sql,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
exp.StrPosition: strposition_to_local_sql,
}
ROOT_PROPERTIES = {
@ -481,9 +515,11 @@ class MySQL(Dialect):
kind = self.sql(expression, "kind")
kind = f"{kind} " if kind else ""
this = self.sql(expression, "this")
expressions = self.expressions(expression)
collate = self.sql(expression, "collate")
collate = f" COLLATE {collate}" if collate else ""
return f"{kind}{this}{collate}"
global_ = "GLOBAL " if expression.args.get("global") else ""
return f"{global_}{kind}{this}{expressions}{collate}"
def set_sql(self, expression):
return f"SET {self.expressions(expression)}"

View file

@ -91,6 +91,7 @@ class Oracle(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"START": TokenType.BEGIN,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
"NVARCHAR2": TokenType.NVARCHAR,

View file

@ -164,11 +164,34 @@ class Postgres(Dialect):
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
CREATABLES = (
"AGGREGATE",
"CAST",
"CONVERSION",
"COLLATION",
"DEFAULT CONVERSION",
"CONSTRAINT",
"DOMAIN",
"EXTENSION",
"FOREIGN",
"FUNCTION",
"OPERATOR",
"POLICY",
"ROLE",
"RULE",
"SEQUENCE",
"TEXT",
"TRIGGER",
"TYPE",
"UNLOGGED",
"USER",
)
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ALWAYS": TokenType.ALWAYS,
"BY DEFAULT": TokenType.BY_DEFAULT,
"COMMENT ON": TokenType.COMMENT_ON,
"IDENTITY": TokenType.IDENTITY,
"GENERATED": TokenType.GENERATED,
"DOUBLE PRECISION": TokenType.DOUBLE,
@ -176,6 +199,19 @@ class Postgres(Dialect):
"SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
"UUID": TokenType.UUID,
"TEMP": TokenType.TEMPORARY,
"BEGIN TRANSACTION": TokenType.BEGIN,
"BEGIN": TokenType.COMMAND,
"COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
"REFRESH": TokenType.COMMAND,
"REINDEX": TokenType.COMMAND,
"RESET": TokenType.COMMAND,
"REVOKE": TokenType.COMMAND,
"GRANT": TokenType.COMMAND,
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
}
QUOTES = ["'", "$$"]
SINGLE_TOKENS = {

View file

@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
struct_extract_sql,
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.errors import UnsupportedError
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@ -61,8 +62,18 @@ def _initcap_sql(self, expression):
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
def _decode_sql(self, expression):
_ensure_utf8(expression.args.get("charset"))
return f"FROM_UTF8({self.sql(expression, 'this')})"
def _encode_sql(self, expression):
_ensure_utf8(expression.args.get("charset"))
return f"TO_UTF8({self.sql(expression, 'this')})"
def _no_sort_array(self, expression):
if expression.args.get("asc") == exp.FALSE:
if expression.args.get("asc") == exp.false():
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
else:
comparator = None
@ -72,7 +83,7 @@ def _no_sort_array(self, expression):
def _schema_sql(self, expression):
if isinstance(expression.parent, exp.Property):
columns = ", ".join(f"'{c.text('this')}'" for c in expression.expressions)
columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
return f"ARRAY[{columns}]"
for schema in expression.parent.find_all(exp.Schema):
@ -106,6 +117,11 @@ def _ts_or_ds_add_sql(self, expression):
return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
def _ensure_utf8(charset):
if charset.name.lower() != "utf-8":
raise UnsupportedError(f"Unsupported charset {charset}")
class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
@ -115,6 +131,7 @@ class Presto(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"START": TokenType.BEGIN,
"ROW": TokenType.STRUCT,
}
@ -140,6 +157,14 @@ class Presto(Dialect):
"STRPOS": exp.StrPosition.from_arg_list,
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"FROM_HEX": exp.Unhex.from_arg_list,
"TO_HEX": exp.Hex.from_arg_list,
"TO_UTF8": lambda args: exp.Encode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
),
"FROM_UTF8": lambda args: exp.Decode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
),
}
class Generator(generator.Generator):
@ -187,7 +212,10 @@ class Presto(Dialect):
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
exp.Decode: _decode_sql,
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.Encode: _encode_sql,
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
@ -212,7 +240,13 @@ class Presto(Dialect):
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.Unhex: rename_func("FROM_HEX"),
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
}
def transaction_sql(self, expression):
modes = expression.args.get("modes")
modes = f" {', '.join(modes)}" if modes else ""
return f"START TRANSACTION{modes}"

View file

@ -148,6 +148,7 @@ class Snowflake(Dialect):
**parser.Parser.FUNCTION_PARSERS,
"DATE_PART": _parse_date_part,
}
FUNCTION_PARSERS.pop("TRIM")
FUNC_TOKENS = {
*parser.Parser.FUNC_TOKENS,
@ -203,6 +204,7 @@ class Snowflake(Dialect):
exp.StrPosition: rename_func("POSITION"),
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
}
TYPE_MAPPING = {

View file

@ -63,3 +63,8 @@ class SQLite(Dialect):
exp.TableSample: no_tablesample_sql,
exp.TryCast: no_trycast_sql,
}
def transaction_sql(self, expression):
this = expression.this
this = f" {this}" if this else ""
return f"BEGIN{this} TRANSACTION"

View file

@ -248,7 +248,7 @@ class TSQL(Dialect):
def _parse_convert(self, strict):
to = self._parse_types()
self._match(TokenType.COMMA)
this = self._parse_column()
this = self._parse_conjunction()
# Retrieve length of datatype and override to default if not specified
if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:

View file

@ -1,3 +1,6 @@
from __future__ import annotations
import typing as t
from collections import defaultdict
from dataclasses import dataclass
from heapq import heappop, heappush
@ -6,6 +9,10 @@ from sqlglot import Dialect
from sqlglot import expressions as exp
from sqlglot.helper import ensure_collection
if t.TYPE_CHECKING:
T = t.TypeVar("T")
Edit = t.Union[Insert, Remove, Move, Update, Keep]
@dataclass(frozen=True)
class Insert:
@ -44,7 +51,7 @@ class Keep:
target: exp.Expression
def diff(source, target):
def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
"""
Returns the list of changes between the source and the target expressions.
@ -89,25 +96,25 @@ class ChangeDistiller:
Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf.
"""
def __init__(self, f=0.6, t=0.6):
def __init__(self, f: float = 0.6, t: float = 0.6) -> None:
self.f = f
self.t = t
self._sql_generator = Dialect().generator()
def diff(self, source, target):
def diff(self, source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
self._source = source
self._target = target
self._source_index = {id(n[0]): n[0] for n in source.bfs()}
self._target_index = {id(n[0]): n[0] for n in target.bfs()}
self._unmatched_source_nodes = set(self._source_index)
self._unmatched_target_nodes = set(self._target_index)
self._bigram_histo_cache = {}
self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
matching_set = self._compute_matching_set()
return self._generate_edit_script(matching_set)
def _generate_edit_script(self, matching_set):
edit_script = []
def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.List[Edit]:
edit_script: t.List[Edit] = []
for removed_node_id in self._unmatched_source_nodes:
edit_script.append(Remove(self._source_index[removed_node_id]))
for inserted_node_id in self._unmatched_target_nodes:
@ -125,7 +132,9 @@ class ChangeDistiller:
return edit_script
def _generate_move_edits(self, source, target, matching_set):
def _generate_move_edits(
self, source: exp.Expression, target: exp.Expression, matching_set: t.Set[t.Tuple[int, int]]
) -> t.List[Move]:
source_args = [id(e) for e in _expression_only_args(source)]
target_args = [id(e) for e in _expression_only_args(target)]
@ -138,7 +147,7 @@ class ChangeDistiller:
return move_edits
def _compute_matching_set(self):
def _compute_matching_set(self) -> t.Set[t.Tuple[int, int]]:
leaves_matching_set = self._compute_leaf_matching_set()
matching_set = leaves_matching_set.copy()
@ -183,8 +192,8 @@ class ChangeDistiller:
return matching_set
def _compute_leaf_matching_set(self):
candidate_matchings = []
def _compute_leaf_matching_set(self) -> t.Set[t.Tuple[int, int]]:
candidate_matchings: t.List[t.Tuple[float, int, exp.Expression, exp.Expression]] = []
source_leaves = list(_get_leaves(self._source))
target_leaves = list(_get_leaves(self._target))
for source_leaf in source_leaves:
@ -216,7 +225,7 @@ class ChangeDistiller:
return matching_set
def _dice_coefficient(self, source, target):
def _dice_coefficient(self, source: exp.Expression, target: exp.Expression) -> float:
source_histo = self._bigram_histo(source)
target_histo = self._bigram_histo(target)
@ -231,13 +240,13 @@ class ChangeDistiller:
return 2 * overlap_len / total_grams
def _bigram_histo(self, expression):
def _bigram_histo(self, expression: exp.Expression) -> t.DefaultDict[str, int]:
if id(expression) in self._bigram_histo_cache:
return self._bigram_histo_cache[id(expression)]
expression_str = self._sql_generator.generate(expression)
count = max(0, len(expression_str) - 1)
bigram_histo = defaultdict(int)
bigram_histo: t.DefaultDict[str, int] = defaultdict(int)
for i in range(count):
bigram_histo[expression_str[i : i + 2]] += 1
@ -245,7 +254,7 @@ class ChangeDistiller:
return bigram_histo
def _get_leaves(expression):
def _get_leaves(expression: exp.Expression) -> t.Generator[exp.Expression, None, None]:
has_child_exprs = False
for a in expression.args.values():
@ -258,7 +267,7 @@ def _get_leaves(expression):
yield expression
def _is_same_type(source, target):
def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool:
if type(source) is type(target):
if isinstance(source, exp.Join):
return source.args.get("side") == target.args.get("side")
@ -271,15 +280,17 @@ def _is_same_type(source, target):
return False
def _expression_only_args(expression):
args = []
def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]:
args: t.List[t.Union[exp.Expression, t.List]] = []
if expression:
for a in expression.args.values():
args.extend(ensure_collection(a))
return [a for a in args if isinstance(a, exp.Expression)]
def _lcs(seq_a, seq_b, equal):
def _lcs(
seq_a: t.Sequence[T], seq_b: t.Sequence[T], equal: t.Callable[[T, T], bool]
) -> t.Sequence[t.Optional[T]]:
"""Calculates the longest common subsequence"""
len_a = len(seq_a)
@ -289,14 +300,14 @@ def _lcs(seq_a, seq_b, equal):
for i in range(len_a + 1):
for j in range(len_b + 1):
if i == 0 or j == 0:
lcs_result[i][j] = []
lcs_result[i][j] = [] # type: ignore
elif equal(seq_a[i - 1], seq_b[j - 1]):
lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]]
lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]] # type: ignore
else:
lcs_result[i][j] = (
lcs_result[i - 1][j]
if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1])
if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1]) # type: ignore
else lcs_result[i][j - 1]
)
return lcs_result[len_a][len_b]
return lcs_result[len_a][len_b] # type: ignore

View file

@ -37,6 +37,10 @@ class SchemaError(SqlglotError):
pass
class ExecuteError(SqlglotError):
pass
def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str:
msg = [str(e) for e in errors[:maximum]]
remaining = len(errors) - maximum

View file

@ -1,20 +1,23 @@
import logging
import time
from sqlglot import parse_one
from sqlglot import maybe_parse
from sqlglot.errors import ExecuteError
from sqlglot.executor.python import PythonExecutor
from sqlglot.executor.table import Table, ensure_tables
from sqlglot.optimizer import optimize
from sqlglot.planner import Plan
from sqlglot.schema import ensure_schema
logger = logging.getLogger("sqlglot")
def execute(sql, schema, read=None):
def execute(sql, schema=None, read=None, tables=None):
"""
Run a sql query against data.
Args:
sql (str): a sql statement
sql (str|sqlglot.Expression): a sql statement
schema (dict|sqlglot.optimizer.Schema): database schema.
This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
the following forms:
@ -23,10 +26,20 @@ def execute(sql, schema, read=None):
3. {catalog: {db: {table: {col: type}}}}
read (str): the SQL dialect to apply during parsing
(eg. "spark", "hive", "presto", "mysql").
tables (dict): additional tables to register.
Returns:
sqlglot.executor.Table: Simple columnar data structure.
"""
expression = parse_one(sql, read=read)
tables = ensure_tables(tables)
if not schema:
schema = {
name: {column: type(table[0][column]).__name__ for column in table.columns}
for name, table in tables.mapping.items()
}
schema = ensure_schema(schema)
if tables.supported_table_args and tables.supported_table_args != schema.supported_table_args:
raise ExecuteError("Tables must support the same table args as schema")
expression = maybe_parse(sql, dialect=read)
now = time.time()
expression = optimize(expression, schema, leave_tables_isolated=True)
logger.debug("Optimization finished: %f", time.time() - now)
@ -34,6 +47,6 @@ def execute(sql, schema, read=None):
plan = Plan(expression)
logger.debug("Logical Plan: %s", plan)
now = time.time()
result = PythonExecutor().execute(plan)
result = PythonExecutor(tables=tables).execute(plan)
logger.debug("Query finished: %f", time.time() - now)
return result

View file

@ -1,5 +1,12 @@
from __future__ import annotations
import typing as t
from sqlglot.executor.env import ENV
if t.TYPE_CHECKING:
from sqlglot.executor.table import Table, TableIter
class Context:
"""
@ -12,14 +19,14 @@ class Context:
evaluation of aggregation functions.
"""
def __init__(self, tables, env=None):
def __init__(self, tables: t.Dict[str, Table], env: t.Optional[t.Dict] = None) -> None:
"""
Args
tables (dict): table_name -> Table, representing the scope of the current execution context
env (Optional[dict]): dictionary of functions within the execution context
tables: representing the scope of the current execution context.
env: dictionary of functions within the execution context.
"""
self.tables = tables
self._table = None
self._table: t.Optional[Table] = None
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
self.row_readers = {name: table.reader for name, table in tables.items()}
self.env = {**(env or {}), "scope": self.row_readers}
@ -31,7 +38,7 @@ class Context:
return tuple(self.eval(code) for code in codes)
@property
def table(self):
def table(self) -> Table:
if self._table is None:
self._table = list(self.tables.values())[0]
for other in self.tables.values():
@ -41,8 +48,12 @@ class Context:
raise Exception(f"Rows are different.")
return self._table
def add_columns(self, *columns: str) -> None:
for table in self.tables.values():
table.add_columns(*columns)
@property
def columns(self):
def columns(self) -> t.Tuple:
return self.table.columns
def __iter__(self):
@ -52,35 +63,39 @@ class Context:
reader = table[i]
yield reader, self
def table_iter(self, table):
def table_iter(self, table: str) -> t.Generator[t.Tuple[TableIter, Context], None, None]:
self.env["scope"] = self.row_readers
for reader in self.tables[table]:
yield reader, self
def sort(self, key):
table = self.table
def filter(self, condition) -> None:
rows = [reader.row for reader, _ in self if self.eval(condition)]
def sort_key(row):
table.reader.row = row
for table in self.tables.values():
table.rows = rows
def sort(self, key) -> None:
def sort_key(row: t.Tuple) -> t.Tuple:
self.set_row(row)
return self.eval_tuple(key)
table.rows.sort(key=sort_key)
self.table.rows.sort(key=sort_key)
def set_row(self, row):
def set_row(self, row: t.Tuple) -> None:
for table in self.tables.values():
table.reader.row = row
self.env["scope"] = self.row_readers
def set_index(self, index):
def set_index(self, index: int) -> None:
for table in self.tables.values():
table[index]
self.env["scope"] = self.row_readers
def set_range(self, start, end):
def set_range(self, start: int, end: int) -> None:
for name in self.tables:
self.range_readers[name].range = range(start, end)
self.env["scope"] = self.range_readers
def __contains__(self, table):
def __contains__(self, table: str) -> bool:
return table in self.tables

View file

@ -1,7 +1,10 @@
import datetime
import inspect
import re
import statistics
from functools import wraps
from sqlglot import exp
from sqlglot.helper import PYTHON_VERSION
@ -16,20 +19,153 @@ class reverse_key:
return other.obj < self.obj
def filter_nulls(func):
@wraps(func)
def _func(values):
return func(v for v in values if v is not None)
return _func
def null_if_any(*required):
"""
Decorator that makes a function return `None` if any of the `required` arguments are `None`.
This also supports decoration with no arguments, e.g.:
@null_if_any
def foo(a, b): ...
In which case all arguments are required.
"""
f = None
if len(required) == 1 and callable(required[0]):
f = required[0]
required = ()
def decorator(func):
if required:
required_indices = [
i for i, param in enumerate(inspect.signature(func).parameters) if param in required
]
def predicate(*args):
return any(args[i] is None for i in required_indices)
else:
def predicate(*args):
return any(a is None for a in args)
@wraps(func)
def _func(*args):
if predicate(*args):
return None
return func(*args)
return _func
if f:
return decorator(f)
return decorator
@null_if_any("substr", "this")
def str_position(substr, this, position=None):
position = position - 1 if position is not None else position
return this.find(substr, position) + 1
@null_if_any("this")
def substring(this, start=None, length=None):
if start is None:
return this
elif start == 0:
return ""
elif start < 0:
start = len(this) + start
else:
start -= 1
end = None if length is None else start + length
return this[start:end]
@null_if_any
def cast(this, to):
if to == exp.DataType.Type.DATE:
return datetime.date.fromisoformat(this)
if to == exp.DataType.Type.DATETIME:
return datetime.datetime.fromisoformat(this)
if to in exp.DataType.TEXT_TYPES:
return str(this)
if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
return float(this)
if to in exp.DataType.NUMERIC_TYPES:
return int(this)
raise NotImplementedError(f"Casting to '{to}' not implemented.")
def ordered(this, desc, nulls_first):
if desc:
return reverse_key(this)
return this
@null_if_any
def interval(this, unit):
if unit == "DAY":
return datetime.timedelta(days=float(this))
raise NotImplementedError
ENV = {
"__builtins__": {},
"datetime": datetime,
"locals": locals,
"re": re,
"bool": bool,
"float": float,
"int": int,
"str": str,
"desc": reverse_key,
"SUM": sum,
"AVG": statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean, # type: ignore
"COUNT": lambda acc: sum(1 for e in acc if e is not None),
"MAX": max,
"MIN": min,
"exp": exp,
# aggs
"SUM": filter_nulls(sum),
"AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
"COUNT": filter_nulls(lambda acc: sum(1 for _ in acc)),
"MAX": filter_nulls(max),
"MIN": filter_nulls(min),
# scalar functions
"ABS": null_if_any(lambda this: abs(this)),
"ADD": null_if_any(lambda e, this: e + this),
"BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
"BITWISEAND": null_if_any(lambda this, e: this & e),
"BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
"BITWISEOR": null_if_any(lambda this, e: this | e),
"BITWISERIGHTSHIFT": null_if_any(lambda this, e: this >> e),
"BITWISEXOR": null_if_any(lambda this, e: this ^ e),
"CAST": cast,
"COALESCE": lambda *args: next((a for a in args if a is not None), None),
"CONCAT": null_if_any(lambda *args: "".join(args)),
"CONCATWS": null_if_any(lambda this, *args: this.join(args)),
"DIV": null_if_any(lambda e, this: e / this),
"EQ": null_if_any(lambda this, e: this == e),
"EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
"GT": null_if_any(lambda this, e: this > e),
"GTE": null_if_any(lambda this, e: this >= e),
"IFNULL": lambda e, alt: alt if e is None else e,
"IF": lambda predicate, true, false: true if predicate else false,
"INTDIV": null_if_any(lambda e, this: e // this),
"INTERVAL": interval,
"LIKE": null_if_any(
lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this))
),
"LOWER": null_if_any(lambda arg: arg.lower()),
"LT": null_if_any(lambda this, e: this < e),
"LTE": null_if_any(lambda this, e: this <= e),
"MOD": null_if_any(lambda e, this: e % this),
"MUL": null_if_any(lambda e, this: e * this),
"NEQ": null_if_any(lambda this, e: this != e),
"ORD": null_if_any(ord),
"ORDERED": ordered,
"POW": pow,
"STRPOSITION": str_position,
"SUB": null_if_any(lambda e, this: e - this),
"SUBSTRING": substring,
"UPPER": null_if_any(lambda arg: arg.upper()),
}

View file

@ -5,16 +5,18 @@ import math
from sqlglot import exp, generator, planner, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql
from sqlglot.errors import ExecuteError
from sqlglot.executor.context import Context
from sqlglot.executor.env import ENV
from sqlglot.executor.table import Table
from sqlglot.helper import csv_reader
from sqlglot.executor.table import RowReader, Table
from sqlglot.helper import csv_reader, subclasses
class PythonExecutor:
def __init__(self, env=None):
self.generator = Python().generator(identify=True)
def __init__(self, env=None, tables=None):
self.generator = Python().generator(identify=True, comments=False)
self.env = {**ENV, **(env or {})}
self.tables = tables or {}
def execute(self, plan):
running = set()
@ -24,6 +26,7 @@ class PythonExecutor:
while queue:
node = queue.pop()
try:
context = self.context(
{
name: table
@ -41,6 +44,8 @@ class PythonExecutor:
contexts[node] = self.join(node, context)
elif isinstance(node, planner.Sort):
contexts[node] = self.sort(node, context)
elif isinstance(node, planner.SetOperation):
contexts[node] = self.set_operation(node, context)
else:
raise NotImplementedError
@ -54,6 +59,8 @@ class PythonExecutor:
for dep in node.dependencies:
if all(d in finished for d in dep.dependents):
contexts.pop(dep)
except Exception as e:
raise ExecuteError(f"Step '{node.id}' failed: {e}") from e
root = plan.root
return contexts[root].tables[root.name]
@ -76,38 +83,43 @@ class PythonExecutor:
return Context(tables, env=self.env)
def table(self, expressions):
return Table(expression.alias_or_name for expression in expressions)
return Table(
expression.alias_or_name if isinstance(expression, exp.Expression) else expression
for expression in expressions
)
def scan(self, step, context):
source = step.source
if isinstance(source, exp.Expression):
if source and isinstance(source, exp.Expression):
source = source.name or source.alias
condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)
if source in context:
if source is None:
context, table_iter = self.static()
elif source in context:
if not projections and not condition:
return self.context({step.name: context.tables[source]})
table_iter = context.table_iter(source)
else:
elif isinstance(step.source, exp.Table) and isinstance(step.source.this, exp.ReadCSV):
table_iter = self.scan_csv(step)
context = next(table_iter)
else:
context, table_iter = self.scan_table(step)
if projections:
sink = self.table(step.projections)
else:
sink = None
sink = self.table(context.columns)
for reader, ctx in table_iter:
if sink is None:
sink = Table(reader.columns)
if condition and not ctx.eval(condition):
for reader in table_iter:
if condition and not context.eval(condition):
continue
if projections:
sink.append(ctx.eval_tuple(projections))
sink.append(context.eval_tuple(projections))
else:
sink.append(reader.row)
@ -116,14 +128,23 @@ class PythonExecutor:
return self.context({step.name: sink})
def static(self):
return self.context({}), [RowReader(())]
def scan_table(self, step):
table = self.tables.find(step.source)
context = self.context({step.source.alias_or_name: table})
return context, iter(table)
def scan_csv(self, step):
source = step.source
alias = source.alias
alias = step.source.alias
source = step.source.this
with csv_reader(source) as reader:
columns = next(reader)
table = Table(columns)
context = self.context({alias: table})
yield context
types = []
for row in reader:
@ -134,7 +155,7 @@ class PythonExecutor:
except (ValueError, SyntaxError):
types.append(str)
context.set_row(tuple(t(v) for t, v in zip(types, row)))
yield context.table.reader, context
yield context.table.reader
def join(self, step, context):
source = step.name
@ -160,16 +181,19 @@ class PythonExecutor:
for name, column_range in column_ranges.items()
}
)
condition = self.generate(join["condition"])
if condition:
source_context.filter(condition)
condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)
if not condition or not projections:
if not condition and not projections:
return source_context
sink = self.table(step.projections if projections else source_context.columns)
for reader, ctx in join_context:
for reader, ctx in source_context:
if condition and not ctx.eval(condition):
continue
@ -181,7 +205,15 @@ class PythonExecutor:
if len(sink) >= step.limit:
break
if projections:
return self.context({step.name: sink})
else:
return self.context(
{
name: Table(table.columns, sink.rows, table.column_range)
for name, table in source_context.tables.items()
}
)
def nested_loop_join(self, _join, source_context, join_context):
table = Table(source_context.columns + join_context.columns)
@ -195,6 +227,8 @@ class PythonExecutor:
def hash_join(self, join, source_context, join_context):
source_key = self.generate_tuple(join["source_key"])
join_key = self.generate_tuple(join["join_key"])
left = join.get("side") == "LEFT"
right = join.get("side") == "RIGHT"
results = collections.defaultdict(lambda: ([], []))
@ -204,28 +238,47 @@ class PythonExecutor:
results[ctx.eval_tuple(join_key)][1].append(reader.row)
table = Table(source_context.columns + join_context.columns)
nulls = [(None,) * len(join_context.columns if left else source_context.columns)]
for a_group, b_group in results.values():
if left:
b_group = b_group or nulls
elif right:
a_group = a_group or nulls
for a_row, b_row in itertools.product(a_group, b_group):
table.append(a_row + b_row)
return table
def aggregate(self, step, context):
source = step.source
group_by = self.generate_tuple(step.group)
group_by = self.generate_tuple(step.group.values())
aggregations = self.generate_tuple(step.aggregations)
operands = self.generate_tuple(step.operands)
if operands:
source_table = context.tables[source]
operand_table = Table(source_table.columns + self.table(step.operands).columns)
operand_table = Table(self.table(step.operands).columns)
for reader, ctx in context:
operand_table.append(reader.row + ctx.eval_tuple(operands))
operand_table.append(ctx.eval_tuple(operands))
for i, (a, b) in enumerate(zip(context.table.rows, operand_table.rows)):
context.table.rows[i] = a + b
width = len(context.columns)
context.add_columns(*operand_table.columns)
operand_table = Table(
context.columns,
context.table.rows,
range(width, width + len(operand_table.columns)),
)
context = self.context(
{None: operand_table, **{table: operand_table for table in context.tables}}
{
None: operand_table,
**context.tables,
}
)
context.sort(group_by)
@ -233,25 +286,22 @@ class PythonExecutor:
group = None
start = 0
end = 1
length = len(context.tables[source])
table = self.table(step.group + step.aggregations)
length = len(context.table)
table = self.table(list(step.group) + step.aggregations)
for i in range(length):
context.set_index(i)
key = context.eval_tuple(group_by)
group = key if group is None else group
end += 1
if i == length - 1:
context.set_range(start, end - 1)
elif key != group:
if key != group:
context.set_range(start, end - 2)
else:
continue
table.append(group + context.eval_tuple(aggregations))
group = key
start = end - 2
if i == length - 1:
context.set_range(start, end - 1)
table.append(group + context.eval_tuple(aggregations))
context = self.context({step.name: table, **{name: table for name in context.tables}})
@ -262,87 +312,67 @@ class PythonExecutor:
def sort(self, step, context):
projections = self.generate_tuple(step.projections)
sink = self.table(step.projections)
projection_columns = [p.alias_or_name for p in step.projections]
all_columns = list(context.columns) + projection_columns
sink = self.table(all_columns)
for reader, ctx in context:
sink.append(ctx.eval_tuple(projections))
sink.append(reader.row + ctx.eval_tuple(projections))
context = self.context(
sort_ctx = self.context(
{
None: sink,
**{table: sink for table in context.tables},
}
)
context.sort(self.generate_tuple(step.key))
sort_ctx.sort(self.generate_tuple(step.key))
if not math.isinf(step.limit):
context.table.rows = context.table.rows[0 : step.limit]
sort_ctx.table.rows = sort_ctx.table.rows[0 : step.limit]
return self.context({step.name: context.table})
output = Table(
projection_columns,
rows=[r[len(context.columns) : len(all_columns)] for r in sort_ctx.table.rows],
)
return self.context({step.name: output})
def set_operation(self, step, context):
left = context.tables[step.left]
right = context.tables[step.right]
def _cast_py(self, expression):
to = expression.args["to"].this
this = self.sql(expression, "this")
sink = self.table(left.columns)
if to == exp.DataType.Type.DATE:
return f"datetime.date.fromisoformat({this})"
if to == exp.DataType.Type.TEXT:
return f"str({this})"
raise NotImplementedError
if issubclass(step.op, exp.Intersect):
sink.rows = list(set(left.rows).intersection(set(right.rows)))
elif issubclass(step.op, exp.Except):
sink.rows = list(set(left.rows).difference(set(right.rows)))
elif issubclass(step.op, exp.Union) and step.distinct:
sink.rows = list(set(left.rows).union(set(right.rows)))
else:
sink.rows = left.rows + right.rows
def _column_py(self, expression):
table = self.sql(expression, "table") or None
this = self.sql(expression, "this")
return f"scope[{table}][{this}]"
def _interval_py(self, expression):
this = self.sql(expression, "this")
unit = expression.text("unit").upper()
if unit == "DAY":
return f"datetime.timedelta(days=float({this}))"
raise NotImplementedError
def _like_py(self, expression):
this = self.sql(expression, "this")
expression = self.sql(expression, "expression")
return f"""bool(re.match({expression}.replace("_", ".").replace("%", ".*"), {this}))"""
return self.context({step.name: sink})
def _ordered_py(self, expression):
this = self.sql(expression, "this")
desc = expression.args.get("desc")
return f"desc({this})" if desc else this
desc = "True" if expression.args.get("desc") else "False"
nulls_first = "True" if expression.args.get("nulls_first") else "False"
return f"ORDERED({this}, {desc}, {nulls_first})"
class Python(Dialect):
class Tokenizer(tokens.Tokenizer):
ESCAPES = ["\\"]
def _rename(self, e):
try:
if "expressions" in e.args:
this = self.sql(e, "this")
this = f"{this}, " if this else ""
return f"{e.key.upper()}({this}{self.expressions(e)})"
return f"{e.key.upper()}({self.format_args(*e.args.values())})"
except Exception as ex:
raise Exception(f"Could not rename {repr(e)}") from ex
class Generator(generator.Generator):
TRANSFORMS = {
exp.Alias: lambda self, e: self.sql(e.this),
exp.Array: inline_array_sql,
exp.And: lambda self, e: self.binary(e, "and"),
exp.Boolean: lambda self, e: "True" if e.this else "False",
exp.Cast: _cast_py,
exp.Column: _column_py,
exp.EQ: lambda self, e: self.binary(e, "=="),
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}",
exp.Interval: _interval_py,
exp.Is: lambda self, e: self.binary(e, "is"),
exp.Like: _like_py,
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",
exp.Or: lambda self, e: self.binary(e, "or"),
exp.Ordered: _ordered_py,
exp.Star: lambda *_: "1",
}
def case_sql(self, expression):
def _case_sql(self, expression):
this = self.sql(expression, "this")
chain = self.sql(expression, "default") or "None"
@ -353,3 +383,30 @@ class Python(Dialect):
chain = f"{true} if {condition} else ({chain})"
return chain
class Python(Dialect):
class Tokenizer(tokens.Tokenizer):
ESCAPES = ["\\"]
class Generator(generator.Generator):
TRANSFORMS = {
**{klass: _rename for klass in subclasses(exp.__name__, exp.Binary)},
**{klass: _rename for klass in exp.ALL_FUNCTIONS},
exp.Case: _case_sql,
exp.Alias: lambda self, e: self.sql(e.this),
exp.Array: inline_array_sql,
exp.And: lambda self, e: self.binary(e, "and"),
exp.Between: _rename,
exp.Boolean: lambda self, e: "True" if e.this else "False",
exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}",
exp.Is: lambda self, e: self.binary(e, "is"),
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",
exp.Or: lambda self, e: self.binary(e, "or"),
exp.Ordered: _ordered_py,
exp.Star: lambda *_: "1",
}

View file

@ -1,14 +1,27 @@
from __future__ import annotations
from sqlglot.helper import dict_depth
from sqlglot.schema import AbstractMappingSchema
class Table:
def __init__(self, columns, rows=None, column_range=None):
self.columns = tuple(columns)
self.column_range = column_range
self.reader = RowReader(self.columns, self.column_range)
self.rows = rows or []
if rows:
assert len(rows[0]) == len(self.columns)
self.range_reader = RangeReader(self)
def add_columns(self, *columns: str) -> None:
self.columns += columns
if self.column_range:
self.column_range = range(
self.column_range.start, self.column_range.stop + len(columns)
)
self.reader = RowReader(self.columns, self.column_range)
def append(self, row):
assert len(row) == len(self.columns)
self.rows.append(row)
@ -87,3 +100,31 @@ class RowReader:
def __getitem__(self, column):
return self.row[self.columns[column]]
class Tables(AbstractMappingSchema[Table]):
pass
def ensure_tables(d: dict | None) -> Tables:
return Tables(_ensure_tables(d))
def _ensure_tables(d: dict | None) -> dict:
if not d:
return {}
depth = dict_depth(d)
if depth > 1:
return {k: _ensure_tables(v) for k, v in d.items()}
result = {}
for name, table in d.items():
if isinstance(table, Table):
result[name] = table
else:
columns = tuple(table[0]) if table else ()
rows = [tuple(row[c] for c in columns) for row in table]
result[name] = Table(columns=columns, rows=rows)
return result

View file

@ -641,9 +641,11 @@ class Set(Expression):
class SetItem(Expression):
arg_types = {
"this": True,
"this": False,
"expressions": False,
"kind": False,
"collate": False, # MySQL SET NAMES statement
"global": False,
}
@ -787,6 +789,7 @@ class Drop(Expression):
"exists": False,
"temporary": False,
"materialized": False,
"cascade": False,
}
@ -1073,6 +1076,18 @@ class FileFormatProperty(Property):
pass
class DistKeyProperty(Property):
pass
class SortKeyProperty(Property):
pass
class DistStyleProperty(Property):
pass
class LocationProperty(Property):
pass
@ -1130,6 +1145,9 @@ class Properties(Expression):
"LOCATION": LocationProperty,
"PARTITIONED_BY": PartitionedByProperty,
"TABLE_FORMAT": TableFormatProperty,
"DISTKEY": DistKeyProperty,
"DISTSTYLE": DistStyleProperty,
"SORTKEY": SortKeyProperty,
}
@classmethod
@ -1356,7 +1374,7 @@ class Var(Expression):
class Schema(Expression):
arg_types = {"this": False, "expressions": True}
arg_types = {"this": False, "expressions": False}
class Select(Subqueryable):
@ -1741,7 +1759,7 @@ class Select(Subqueryable):
)
if join_alias:
join.set("this", alias_(join.args["this"], join_alias, table=True))
join.set("this", alias_(join.this, join_alias, table=True))
return _apply_list_builder(
join,
instance=self,
@ -1884,6 +1902,7 @@ class Subquery(DerivedTable, Unionable):
arg_types = {
"this": True,
"alias": False,
"with": False,
**QUERY_MODIFIERS,
}
@ -2025,6 +2044,31 @@ class DataType(Expression):
NULL = auto()
UNKNOWN = auto() # Sentinel value, useful for type annotation
TEXT_TYPES = {
Type.CHAR,
Type.NCHAR,
Type.VARCHAR,
Type.NVARCHAR,
Type.TEXT,
}
NUMERIC_TYPES = {
Type.INT,
Type.TINYINT,
Type.SMALLINT,
Type.BIGINT,
Type.FLOAT,
Type.DOUBLE,
}
TEMPORAL_TYPES = {
Type.TIMESTAMP,
Type.TIMESTAMPTZ,
Type.TIMESTAMPLTZ,
Type.DATE,
Type.DATETIME,
}
@classmethod
def build(cls, dtype, **kwargs) -> DataType:
return DataType(
@ -2054,16 +2098,25 @@ class Exists(SubqueryPredicate):
pass
# Commands to interact with the databases or engines
# These expressions don't truly parse the expression and consume
# whatever exists as a string until the end or a semicolon
# Commands to interact with the databases or engines. For most of the command
# expressions we parse whatever comes after the command's name as a string.
class Command(Expression):
arg_types = {"this": True, "expression": False}
# Binary Expressions
# (ADD a b)
# (FROM table selects)
class Transaction(Command):
arg_types = {"this": False, "modes": False}
class Commit(Command):
arg_types = {} # type: ignore
class Rollback(Command):
arg_types = {"savepoint": False}
# Binary expressions like (ADD a b)
class Binary(Expression):
arg_types = {"this": True, "expression": True}
@ -2215,7 +2268,7 @@ class Not(Unary, Condition):
class Paren(Unary, Condition):
pass
arg_types = {"this": True, "with": False}
class Neg(Unary):
@ -2428,6 +2481,10 @@ class Cast(Func):
return self.args["to"]
class Collate(Binary):
pass
class TryCast(Cast):
pass
@ -2442,13 +2499,17 @@ class Coalesce(Func):
is_var_len_args = True
class ConcatWs(Func):
arg_types = {"expressions": False}
class Concat(Func):
arg_types = {"expressions": True}
is_var_len_args = True
class ConcatWs(Concat):
_sql_names = ["CONCAT_WS"]
class Count(AggFunc):
pass
arg_types = {"this": False}
class CurrentDate(Func):
@ -2556,10 +2617,18 @@ class Day(Func):
pass
class Decode(Func):
arg_types = {"this": True, "charset": True}
class DiToDate(Func):
pass
class Encode(Func):
arg_types = {"this": True, "charset": True}
class Exp(Func):
pass
@ -2581,6 +2650,10 @@ class GroupConcat(Func):
arg_types = {"this": True, "separator": False}
class Hex(Func):
pass
class If(Func):
arg_types = {"this": True, "true": True, "false": False}
@ -2641,7 +2714,7 @@ class Log10(Func):
class Lower(Func):
pass
_sql_names = ["LOWER", "LCASE"]
class Map(Func):
@ -2686,6 +2759,12 @@ class ApproxQuantile(Quantile):
arg_types = {"this": True, "quantile": True, "accuracy": False}
class ReadCSV(Func):
_sql_names = ["READ_CSV"]
is_var_len_args = True
arg_types = {"this": True, "expressions": False}
class Reduce(Func):
arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
@ -2804,8 +2883,8 @@ class TimeStrToUnix(Func):
class Trim(Func):
arg_types = {
"this": True,
"position": False,
"expression": False,
"position": False,
"collation": False,
}
@ -2826,6 +2905,10 @@ class TsOrDiToDi(Func):
pass
class Unhex(Func):
pass
class UnixToStr(Func):
arg_types = {"this": True, "format": False}
@ -2843,7 +2926,7 @@ class UnixToTimeStr(Func):
class Upper(Func):
pass
_sql_names = ["UPPER", "UCASE"]
class Variance(AggFunc):
@ -3701,6 +3784,19 @@ def replace_placeholders(expression, *args, **kwargs):
return expression.transform(_replace_placeholders, iter(args), **kwargs)
def true():
return Boolean(this=True)
def false():
return Boolean(this=False)
def null():
return Null()
# TODO: deprecate this
TRUE = Boolean(this=True)
FALSE = Boolean(this=False)
NULL = Null()

View file

@ -67,7 +67,7 @@ class Generator:
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
exp.VolatilityProperty: lambda self, e: self.sql(e.name),
exp.VolatilityProperty: lambda self, e: e.name,
}
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
@ -94,6 +94,9 @@ class Generator:
ROOT_PROPERTIES = {
exp.ReturnsProperty,
exp.LanguageProperty,
exp.DistStyleProperty,
exp.DistKeyProperty,
exp.SortKeyProperty,
}
WITH_PROPERTIES = {
@ -241,7 +244,7 @@ class Generator:
if not NEWLINE_RE.search(comment):
return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
return f"/*{comment}*/\n{sql}"
return f"/*{comment}*/\n{sql}" if sql else f" /*{comment}*/"
def wrap(self, expression):
this_sql = self.indent(
@ -475,7 +478,8 @@ class Generator:
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}"
cascade = " CASCADE" if expression.args.get("cascade") else ""
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}"
def except_sql(self, expression):
return self.prepend_ctes(
@ -915,13 +919,15 @@ class Generator:
def subquery_sql(self, expression):
alias = self.sql(expression, "alias")
return self.query_modifiers(
sql = self.query_modifiers(
expression,
self.wrap(expression),
self.expressions(expression, key="pivots", sep=" "),
f" AS {alias}" if alias else "",
)
return self.prepend_ctes(expression, sql)
def qualify_sql(self, expression):
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('QUALIFY')}{self.sep()}{this}"
@ -1111,9 +1117,12 @@ class Generator:
def paren_sql(self, expression):
if isinstance(expression.unnest(), exp.Select):
return self.wrap(expression)
sql = self.wrap(expression)
else:
sql = self.seg(self.indent(self.sql(expression, "this")), sep="")
return f"({sql}{self.seg(')', sep='')}"
sql = f"({sql}{self.seg(')', sep='')}"
return self.prepend_ctes(expression, sql)
def neg_sql(self, expression):
return f"-{self.sql(expression, 'this')}"
@ -1173,9 +1182,23 @@ class Generator:
zone = self.sql(expression, "this")
return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE"
def collate_sql(self, expression):
return self.binary(expression, "COLLATE")
def command_sql(self, expression):
return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}"
def transaction_sql(self, *_):
return "BEGIN"
def commit_sql(self, *_):
return "COMMIT"
def rollback_sql(self, expression):
savepoint = expression.args.get("savepoint")
savepoint = f" TO {savepoint}" if savepoint else ""
return f"ROLLBACK{savepoint}"
def distinct_sql(self, expression):
this = self.expressions(expression, flat=True)
this = f" {this}" if this else ""
@ -1193,10 +1216,7 @@ class Generator:
def intdiv_sql(self, expression):
return self.sql(
exp.Cast(
this=exp.Div(
this=expression.args["this"],
expression=expression.args["expression"],
),
this=exp.Div(this=expression.this, expression=expression.expression),
to=exp.DataType(this=exp.DataType.Type.INT),
)
)

View file

@ -11,7 +11,8 @@ from copy import copy
from enum import Enum
if t.TYPE_CHECKING:
from sqlglot.expressions import Expression, Table
from sqlglot import exp
from sqlglot.expressions import Expression
T = t.TypeVar("T")
E = t.TypeVar("E", bound=Expression)
@ -150,7 +151,7 @@ def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
if expression.is_int:
expression = expression.copy()
logger.warning("Applying array index offset (%s)", offset)
expression.args["this"] = str(int(expression.args["this"]) + offset)
expression.args["this"] = str(int(expression.this) + offset)
return [expression]
return expressions
@ -228,19 +229,18 @@ def open_file(file_name: str) -> t.TextIO:
@contextmanager
def csv_reader(table: Table) -> t.Any:
def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
"""
Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
Args:
table: a `Table` expression with an anonymous function `READ_CSV` in it.
read_csv: a `ReadCSV` function call
Yields:
A python csv reader.
"""
file, *args = table.this.expressions
file = file.name
file = open_file(file)
args = read_csv.expressions
file = open_file(read_csv.name)
delimiter = ","
args = iter(arg.name for arg in args)
@ -354,3 +354,34 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any,
yield from flatten(value)
else:
yield value
def dict_depth(d: t.Dict) -> int:
"""
Get the nesting depth of a dictionary.
For example:
>>> dict_depth(None)
0
>>> dict_depth({})
1
>>> dict_depth({"a": "b"})
1
>>> dict_depth({"a": {}})
2
>>> dict_depth({"a": {"b": {}}})
3
Args:
d (dict): dictionary
Returns:
int: depth
"""
try:
return 1 + dict_depth(next(iter(d.values())))
except AttributeError:
# d doesn't have attribute "values"
return 0
except StopIteration:
# d.values() returns an empty sequence
return 1

View file

@ -245,23 +245,31 @@ class TypeAnnotator:
def annotate(self, expression):
if isinstance(expression, self.TRAVERSABLES):
for scope in traverse_scope(expression):
subscope_selects = {
name: {select.alias_or_name: select for select in source.selects}
for name, source in scope.sources.items()
if isinstance(source, Scope)
selects = {}
for name, source in scope.sources.items():
if not isinstance(source, Scope):
continue
if isinstance(source.expression, exp.Values):
selects[name] = {
alias: column
for alias, column in zip(
source.expression.alias_column_names,
source.expression.expressions[0].expressions,
)
}
else:
selects[name] = {
select.alias_or_name: select for select in source.expression.selects
}
# First annotate the current scope's column references
for col in scope.columns:
source = scope.sources[col.table]
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
else:
col.type = subscope_selects[col.table][col.name].type
col.type = selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
def _maybe_annotate(self, expression):

View file

@ -0,0 +1,48 @@
import itertools
from sqlglot import exp
def canonicalize(expression: exp.Expression) -> exp.Expression:
"""Converts a sql expression into a standard form.
This method relies on annotate_types because many of the
conversions rely on type inference.
Args:
expression: The expression to canonicalize.
"""
exp.replace_children(expression, canonicalize)
expression = add_text_to_concat(expression)
expression = coerce_type(expression)
return expression
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES:
node = exp.Concat(this=node.this, expression=node.expression)
return node
def coerce_type(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Binary):
_coerce_date(node.left, node.right)
elif isinstance(node, exp.Between):
_coerce_date(node.this, node.args["low"])
elif isinstance(node, exp.Extract):
if node.expression.type not in exp.DataType.TEMPORAL_TYPES:
_replace_cast(node.expression, "datetime")
return node
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
for a, b in itertools.permutations([a, b]):
if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE:
_replace_cast(b, "date")
def _replace_cast(node: exp.Expression, to: str) -> None:
data_type = exp.DataType.build(to)
cast = exp.Cast(this=node.copy(), to=data_type)
cast.type = data_type
node.replace(cast)

View file

@ -128,8 +128,8 @@ def join_condition(join):
Tuple of (source key, join key, remaining predicate)
"""
name = join.this.alias_or_name
on = join.args.get("on") or exp.TRUE
on = on.copy()
on = (join.args.get("on") or exp.true()).copy()
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
source_key = []
join_key = []
@ -141,7 +141,7 @@ def join_condition(join):
#
# should pull y.b as the join key and x.a as the source key
if normalized(on):
for condition in on.flatten() if isinstance(on, exp.And) else [on]:
for condition in on.flatten():
if isinstance(condition, exp.EQ):
left, right = condition.unnest_operands()
left_tables = exp.column_table_names(left)
@ -150,13 +150,12 @@ def join_condition(join):
if name in left_tables and name not in right_tables:
join_key.append(left)
source_key.append(right)
condition.replace(exp.TRUE)
condition.replace(exp.true())
elif name in right_tables and name not in left_tables:
join_key.append(right)
source_key.append(left)
condition.replace(exp.TRUE)
condition.replace(exp.true())
on = simplify(on)
remaining_condition = None if on == exp.TRUE else on
remaining_condition = None if on == exp.true() else on
return source_key, join_key, remaining_condition

View file

@ -29,7 +29,7 @@ def optimize_joins(expression):
if isinstance(on, exp.Connector):
for predicate in on.flatten():
if name in exp.column_table_names(predicate):
predicate.replace(exp.TRUE)
predicate.replace(exp.true())
join.on(predicate, copy=False)
expression = reorder_joins(expression)
@ -70,6 +70,6 @@ def normalize(expression):
def other_table_names(join, exclude):
return [
name
for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
for name in (exp.column_table_names(join.args.get("on") or exp.true()))
if name != exclude
]

View file

@ -1,4 +1,6 @@
import sqlglot
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
@ -28,6 +30,8 @@ RULES = (
merge_subqueries,
eliminate_joins,
eliminate_ctes,
annotate_types,
canonicalize,
quote_identities,
)

View file

@ -64,11 +64,11 @@ def pushdown_cnf(predicates, scope, scope_ref_count):
for predicate in predicates:
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
if isinstance(node, exp.Join):
predicate.replace(exp.TRUE)
predicate.replace(exp.true())
node.on(predicate, copy=False)
break
if isinstance(node, exp.Select):
predicate.replace(exp.TRUE)
predicate.replace(exp.true())
node.where(replace_aliases(node, predicate), copy=False)

View file

@ -382,9 +382,7 @@ class _Resolver:
raise OptimizeError(str(e)) from e
if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
values_alias = source.expression.parent
if hasattr(values_alias, "alias_column_names"):
return values_alias.alias_column_names
return source.expression.alias_column_names
# Otherwise, if referencing another scope, return that scope's named selects
return source.expression.named_selects

View file

@ -1,10 +1,11 @@
import itertools
from sqlglot import alias, exp
from sqlglot.helper import csv_reader
from sqlglot.optimizer.scope import traverse_scope
def qualify_tables(expression, db=None, catalog=None):
def qualify_tables(expression, db=None, catalog=None, schema=None):
"""
Rewrite sqlglot AST to have fully qualified tables.
@ -18,6 +19,7 @@ def qualify_tables(expression, db=None, catalog=None):
expression (sqlglot.Expression): expression to qualify
db (str): Database name
catalog (str): Catalog name
schema: A schema to populate
Returns:
sqlglot.Expression: qualified expression
"""
@ -41,7 +43,7 @@ def qualify_tables(expression, db=None, catalog=None):
source.set("catalog", exp.to_identifier(catalog))
if not source.alias:
source.replace(
source = source.replace(
alias(
source.copy(),
source.this if identifier else f"_q_{next(sequence)}",
@ -49,4 +51,12 @@ def qualify_tables(expression, db=None, catalog=None):
)
)
if schema and isinstance(source.this, exp.ReadCSV):
with csv_reader(source.this) as reader:
header = next(reader)
columns = next(reader)
schema.add_table(
source, {k: type(v).__name__ for k, v in zip(header, columns)}
)
return expression

View file

@ -189,11 +189,11 @@ def absorb_and_eliminate(expression):
# absorb
if is_complement(b, aa):
aa.replace(exp.TRUE if kind == exp.And else exp.FALSE)
aa.replace(exp.true() if kind == exp.And else exp.false())
elif is_complement(b, ab):
ab.replace(exp.TRUE if kind == exp.And else exp.FALSE)
ab.replace(exp.true() if kind == exp.And else exp.false())
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
a.replace(exp.FALSE if kind == exp.And else exp.TRUE)
a.replace(exp.false() if kind == exp.And else exp.true())
elif isinstance(b, kind):
# eliminate
rhs = b.unnest_operands()

View file

@ -169,7 +169,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
select.parent.replace(alias)
for key, column, predicate in keys:
predicate.replace(exp.TRUE)
predicate.replace(exp.true())
nested = exp.column(key_aliases[key], table_alias)
if key in group_by:

View file

@ -141,26 +141,29 @@ class Parser(metaclass=_Parser):
ID_VAR_TOKENS = {
TokenType.VAR,
TokenType.ALTER,
TokenType.ALWAYS,
TokenType.ANTI,
TokenType.APPLY,
TokenType.AUTO_INCREMENT,
TokenType.BEGIN,
TokenType.BOTH,
TokenType.BUCKET,
TokenType.CACHE,
TokenType.CALL,
TokenType.CASCADE,
TokenType.COLLATE,
TokenType.COMMAND,
TokenType.COMMIT,
TokenType.CONSTRAINT,
TokenType.CURRENT_TIME,
TokenType.DEFAULT,
TokenType.DELETE,
TokenType.DESCRIBE,
TokenType.DETERMINISTIC,
TokenType.DISTKEY,
TokenType.DISTSTYLE,
TokenType.EXECUTE,
TokenType.ENGINE,
TokenType.ESCAPE,
TokenType.EXPLAIN,
TokenType.FALSE,
TokenType.FIRST,
TokenType.FOLLOWING,
@ -182,7 +185,6 @@ class Parser(metaclass=_Parser):
TokenType.NATURAL,
TokenType.NEXT,
TokenType.ONLY,
TokenType.OPTIMIZE,
TokenType.OPTIONS,
TokenType.ORDINALITY,
TokenType.PARTITIONED_BY,
@ -199,6 +201,7 @@ class Parser(metaclass=_Parser):
TokenType.SEMI,
TokenType.SET,
TokenType.SHOW,
TokenType.SORTKEY,
TokenType.STABLE,
TokenType.STORED,
TokenType.TABLE,
@ -207,7 +210,6 @@ class Parser(metaclass=_Parser):
TokenType.TRANSIENT,
TokenType.TOP,
TokenType.TRAILING,
TokenType.TRUNCATE,
TokenType.TRUE,
TokenType.UNBOUNDED,
TokenType.UNIQUE,
@ -217,6 +219,7 @@ class Parser(metaclass=_Parser):
TokenType.VOLATILE,
*SUBQUERY_PREDICATES,
*TYPE_TOKENS,
*NO_PAREN_FUNCTIONS,
}
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY}
@ -231,6 +234,7 @@ class Parser(metaclass=_Parser):
TokenType.FILTER,
TokenType.FIRST,
TokenType.FORMAT,
TokenType.IDENTIFIER,
TokenType.ISNULL,
TokenType.OFFSET,
TokenType.PRIMARY_KEY,
@ -242,6 +246,7 @@ class Parser(metaclass=_Parser):
TokenType.RIGHT,
TokenType.DATE,
TokenType.DATETIME,
TokenType.TABLE,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
*TYPE_TOKENS,
@ -277,6 +282,7 @@ class Parser(metaclass=_Parser):
TokenType.DASH: exp.Sub,
TokenType.PLUS: exp.Add,
TokenType.MOD: exp.Mod,
TokenType.COLLATE: exp.Collate,
}
FACTOR = {
@ -391,7 +397,10 @@ class Parser(metaclass=_Parser):
TokenType.DELETE: lambda self: self._parse_delete(),
TokenType.CACHE: lambda self: self._parse_cache(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
TokenType.USE: lambda self: self._parse_use(),
TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
TokenType.BEGIN: lambda self: self._parse_transaction(),
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
}
PRIMARY_PARSERS = {
@ -402,7 +411,8 @@ class Parser(metaclass=_Parser):
exp.Literal, this=token.text, is_string=False
),
TokenType.STAR: lambda self, _: self.expression(
exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()}
exp.Star,
**{"except": self._parse_except(), "replace": self._parse_replace()},
),
TokenType.NULL: lambda self, _: self.expression(exp.Null),
TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
@ -446,6 +456,9 @@ class Parser(metaclass=_Parser):
TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(),
TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(),
TokenType.STORED: lambda self: self._parse_stored(),
TokenType.DISTKEY: lambda self: self._parse_distkey(),
TokenType.DISTSTYLE: lambda self: self._parse_diststyle(),
TokenType.SORTKEY: lambda self: self._parse_sortkey(),
TokenType.RETURNS: lambda self: self._parse_returns(),
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
@ -471,7 +484,9 @@ class Parser(metaclass=_Parser):
}
CONSTRAINT_PARSERS = {
TokenType.CHECK: lambda self: self._parse_check(),
TokenType.CHECK: lambda self: self.expression(
exp.Check, this=self._parse_wrapped(self._parse_conjunction)
),
TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(),
TokenType.UNIQUE: lambda self: self._parse_unique(),
}
@ -521,6 +536,8 @@ class Parser(metaclass=_Parser):
TokenType.SCHEMA,
}
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
STRICT_CAST = True
__slots__ = (
@ -740,6 +757,7 @@ class Parser(metaclass=_Parser):
kind=kind,
temporary=temporary,
materialized=materialized,
cascade=self._match(TokenType.CASCADE),
)
def _parse_exists(self, not_=False):
@ -777,7 +795,11 @@ class Parser(metaclass=_Parser):
expression = self._parse_select_or_expression()
elif create_token.token_type == TokenType.INDEX:
this = self._parse_index()
elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW, TokenType.SCHEMA):
elif create_token.token_type in (
TokenType.TABLE,
TokenType.VIEW,
TokenType.SCHEMA,
):
this = self._parse_table(schema=True)
properties = self._parse_properties()
if self._match(TokenType.ALIAS):
@ -834,7 +856,38 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.FileFormatProperty,
this=exp.Literal.string("FORMAT"),
value=exp.Literal.string(self._parse_var().name),
value=exp.Literal.string(self._parse_var_or_string().name),
)
def _parse_distkey(self):
self._match_l_paren()
this = exp.Literal.string("DISTKEY")
value = exp.Literal.string(self._parse_var().name)
self._match_r_paren()
return self.expression(
exp.DistKeyProperty,
this=this,
value=value,
)
def _parse_sortkey(self):
self._match_l_paren()
this = exp.Literal.string("SORTKEY")
value = exp.Literal.string(self._parse_var().name)
self._match_r_paren()
return self.expression(
exp.SortKeyProperty,
this=this,
value=value,
)
def _parse_diststyle(self):
this = exp.Literal.string("DISTSTYLE")
value = exp.Literal.string(self._parse_var().name)
return self.expression(
exp.DistStyleProperty,
this=this,
value=value,
)
def _parse_auto_increment(self):
@ -842,7 +895,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.AutoIncrementProperty,
this=exp.Literal.string("AUTO_INCREMENT"),
value=self._parse_var() or self._parse_number(),
value=self._parse_number(),
)
def _parse_schema_comment(self):
@ -898,13 +951,10 @@ class Parser(metaclass=_Parser):
while True:
if self._match(TokenType.WITH):
self._match_l_paren()
properties.extend(self._parse_csv(lambda: self._parse_property()))
self._match_r_paren()
properties.extend(self._parse_wrapped_csv(self._parse_property))
elif self._match(TokenType.PROPERTIES):
self._match_l_paren()
properties.extend(
self._parse_csv(
self._parse_wrapped_csv(
lambda: self.expression(
exp.AnonymousProperty,
this=self._parse_string(),
@ -912,25 +962,24 @@ class Parser(metaclass=_Parser):
)
)
)
self._match_r_paren()
else:
identified_property = self._parse_property()
if not identified_property:
break
properties.append(identified_property)
if properties:
return self.expression(exp.Properties, expressions=properties)
return None
def _parse_describe(self):
self._match(TokenType.TABLE)
return self.expression(exp.Describe, this=self._parse_id_var())
def _parse_insert(self):
overwrite = self._match(TokenType.OVERWRITE)
local = self._match(TokenType.LOCAL)
if self._match_text("DIRECTORY"):
if self._match_text_seq("DIRECTORY"):
this = self.expression(
exp.Directory,
this=self._parse_var_or_string(),
@ -954,27 +1003,27 @@ class Parser(metaclass=_Parser):
if not self._match_pair(TokenType.ROW, TokenType.FORMAT):
return None
self._match_text("DELIMITED")
self._match_text_seq("DELIMITED")
kwargs = {}
if self._match_text("FIELDS", "TERMINATED", "BY"):
if self._match_text_seq("FIELDS", "TERMINATED", "BY"):
kwargs["fields"] = self._parse_string()
if self._match_text("ESCAPED", "BY"):
if self._match_text_seq("ESCAPED", "BY"):
kwargs["escaped"] = self._parse_string()
if self._match_text("COLLECTION", "ITEMS", "TERMINATED", "BY"):
if self._match_text_seq("COLLECTION", "ITEMS", "TERMINATED", "BY"):
kwargs["collection_items"] = self._parse_string()
if self._match_text("MAP", "KEYS", "TERMINATED", "BY"):
if self._match_text_seq("MAP", "KEYS", "TERMINATED", "BY"):
kwargs["map_keys"] = self._parse_string()
if self._match_text("LINES", "TERMINATED", "BY"):
if self._match_text_seq("LINES", "TERMINATED", "BY"):
kwargs["lines"] = self._parse_string()
if self._match_text("NULL", "DEFINED", "AS"):
if self._match_text_seq("NULL", "DEFINED", "AS"):
kwargs["null"] = self._parse_string()
return self.expression(exp.RowFormat, **kwargs)
def _parse_load_data(self):
local = self._match(TokenType.LOCAL)
self._match_text("INPATH")
self._match_text_seq("INPATH")
inpath = self._parse_string()
overwrite = self._match(TokenType.OVERWRITE)
self._match_pair(TokenType.INTO, TokenType.TABLE)
@ -986,8 +1035,8 @@ class Parser(metaclass=_Parser):
overwrite=overwrite,
inpath=inpath,
partition=self._parse_partition(),
input_format=self._match_text("INPUTFORMAT") and self._parse_string(),
serde=self._match_text("SERDE") and self._parse_string(),
input_format=self._match_text_seq("INPUTFORMAT") and self._parse_string(),
serde=self._match_text_seq("SERDE") and self._parse_string(),
)
def _parse_delete(self):
@ -996,9 +1045,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.Delete,
this=self._parse_table(schema=True),
using=self._parse_csv(
lambda: self._match(TokenType.USING) and self._parse_table(schema=True)
),
using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table()),
where=self._parse_where(),
)
@ -1029,12 +1076,7 @@ class Parser(metaclass=_Parser):
options = []
if self._match(TokenType.OPTIONS):
self._match_l_paren()
k = self._parse_string()
self._match(TokenType.EQ)
v = self._parse_string()
options = [k, v]
self._match_r_paren()
options = self._parse_wrapped_csv(self._parse_string, sep=TokenType.EQ)
self._match(TokenType.ALIAS)
return self.expression(
@ -1050,27 +1092,13 @@ class Parser(metaclass=_Parser):
return None
def parse_values():
key = self._parse_var()
value = None
props = self._parse_csv(self._parse_var_or_string, sep=TokenType.EQ)
return exp.Property(this=seq_get(props, 0), value=seq_get(props, 1))
if self._match(TokenType.EQ):
value = self._parse_string()
return exp.Property(this=key, value=value)
self._match_l_paren()
values = self._parse_csv(parse_values)
self._match_r_paren()
return self.expression(
exp.Partition,
this=values,
)
return self.expression(exp.Partition, this=self._parse_wrapped_csv(parse_values))
def _parse_value(self):
self._match_l_paren()
expressions = self._parse_csv(self._parse_conjunction)
self._match_r_paren()
expressions = self._parse_wrapped_csv(self._parse_conjunction)
return self.expression(exp.Tuple, expressions=expressions)
def _parse_select(self, nested=False, table=False):
@ -1124,10 +1152,11 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
this = self._parse_subquery(this)
elif self._match(TokenType.VALUES):
this = self.expression(exp.Values, expressions=self._parse_csv(self._parse_value))
alias = self._parse_table_alias()
if alias:
this = self.expression(exp.Subquery, this=this, alias=alias)
this = self.expression(
exp.Values,
expressions=self._parse_csv(self._parse_value),
alias=self._parse_table_alias(),
)
else:
this = None
@ -1140,7 +1169,6 @@ class Parser(metaclass=_Parser):
recursive = self._match(TokenType.RECURSIVE)
expressions = []
while True:
expressions.append(self._parse_cte())
@ -1149,11 +1177,7 @@ class Parser(metaclass=_Parser):
else:
self._match(TokenType.WITH)
return self.expression(
exp.With,
expressions=expressions,
recursive=recursive,
)
return self.expression(exp.With, expressions=expressions, recursive=recursive)
def _parse_cte(self):
alias = self._parse_table_alias()
@ -1163,13 +1187,9 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.ALIAS):
self.raise_error("Expected AS in CTE")
self._match_l_paren()
expression = self._parse_statement()
self._match_r_paren()
return self.expression(
exp.CTE,
this=expression,
this=self._parse_wrapped(self._parse_statement),
alias=alias,
)
@ -1223,7 +1243,7 @@ class Parser(metaclass=_Parser):
def _parse_hint(self):
if self._match(TokenType.HINT):
hints = self._parse_csv(self._parse_function)
if not self._match(TokenType.HINT):
if not self._match_pair(TokenType.STAR, TokenType.SLASH):
self.raise_error("Expected */ after HINT")
return self.expression(exp.Hint, expressions=hints)
return None
@ -1259,26 +1279,18 @@ class Parser(metaclass=_Parser):
columns = self._parse_csv(self._parse_id_var)
elif self._match(TokenType.L_PAREN):
columns = self._parse_csv(self._parse_id_var)
self._match(TokenType.R_PAREN)
self._match_r_paren()
expression = self.expression(
exp.Lateral,
this=this,
view=view,
outer=outer,
alias=self.expression(
exp.TableAlias,
this=table_alias,
columns=columns,
),
alias=self.expression(exp.TableAlias, this=table_alias, columns=columns),
)
if outer_apply or cross_apply:
return self.expression(
exp.Join,
this=expression,
side=None if cross_apply else "LEFT",
)
return self.expression(exp.Join, this=expression, side=None if cross_apply else "LEFT")
return expression
@ -1387,12 +1399,8 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.UNNEST):
return None
self._match_l_paren()
expressions = self._parse_csv(self._parse_column)
self._match_r_paren()
expressions = self._parse_wrapped_csv(self._parse_column)
ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY))
alias = self._parse_table_alias()
if alias and self.unnest_column_only:
@ -1402,10 +1410,7 @@ class Parser(metaclass=_Parser):
alias.set("this", None)
return self.expression(
exp.Unnest,
expressions=expressions,
ordinality=ordinality,
alias=alias,
exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias
)
def _parse_derived_table_values(self):
@ -1418,13 +1423,7 @@ class Parser(metaclass=_Parser):
if is_derived:
self._match_r_paren()
alias = self._parse_table_alias()
return self.expression(
exp.Values,
expressions=expressions,
alias=alias,
)
return self.expression(exp.Values, expressions=expressions, alias=self._parse_table_alias())
def _parse_table_sample(self):
if not self._match(TokenType.TABLE_SAMPLE):
@ -1460,9 +1459,7 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
if self._match(TokenType.SEED):
self._match_l_paren()
seed = self._parse_number()
self._match_r_paren()
seed = self._parse_wrapped(self._parse_number)
return self.expression(
exp.TableSample,
@ -1513,12 +1510,7 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
return self.expression(
exp.Pivot,
expressions=expressions,
field=field,
unpivot=unpivot,
)
return self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot)
def _parse_where(self, skip_where_token=False):
if not skip_where_token and not self._match(TokenType.WHERE):
@ -1539,11 +1531,7 @@ class Parser(metaclass=_Parser):
def _parse_grouping_sets(self):
if not self._match(TokenType.GROUPING_SETS):
return None
self._match_l_paren()
grouping_sets = self._parse_csv(self._parse_grouping_set)
self._match_r_paren()
return grouping_sets
return self._parse_wrapped_csv(self._parse_grouping_set)
def _parse_grouping_set(self):
if self._match(TokenType.L_PAREN):
@ -1573,7 +1561,6 @@ class Parser(metaclass=_Parser):
def _parse_sort(self, token_type, exp_class):
if not self._match(token_type):
return None
return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
def _parse_ordered(self):
@ -1602,9 +1589,12 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.TOP if top else TokenType.LIMIT):
limit_paren = self._match(TokenType.L_PAREN)
limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number())
if limit_paren:
self._match(TokenType.R_PAREN)
self._match_r_paren()
return limit_exp
if self._match(TokenType.FETCH):
direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
direction = self._prev.text if direction else "FIRST"
@ -1612,11 +1602,13 @@ class Parser(metaclass=_Parser):
self._match_set((TokenType.ROW, TokenType.ROWS))
self._match(TokenType.ONLY)
return self.expression(exp.Fetch, direction=direction, count=count)
return this
def _parse_offset(self, this=None):
if not self._match_set((TokenType.OFFSET, TokenType.COMMA)):
return this
count = self._parse_number()
self._match_set((TokenType.ROW, TokenType.ROWS))
return self.expression(exp.Offset, this=this, expression=count)
@ -1678,6 +1670,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.DISTINCT_FROM):
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
return self.expression(klass, this=this, expression=self._parse_expression())
this = self.expression(
exp.Is,
this=this,
@ -1754,11 +1747,7 @@ class Parser(metaclass=_Parser):
def _parse_type(self):
if self._match(TokenType.INTERVAL):
return self.expression(
exp.Interval,
this=self._parse_term(),
unit=self._parse_var(),
)
return self.expression(exp.Interval, this=self._parse_term(), unit=self._parse_var())
index = self._index
type_token = self._parse_types(check_func=True)
@ -1824,30 +1813,18 @@ class Parser(metaclass=_Parser):
value = None
if type_token in self.TIMESTAMPS:
if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ:
value = exp.DataType(
this=exp.DataType.Type.TIMESTAMPTZ,
expressions=expressions,
)
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions)
elif (
self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
):
value = exp.DataType(
this=exp.DataType.Type.TIMESTAMPLTZ,
expressions=expressions,
)
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions)
elif self._match(TokenType.WITHOUT_TIME_ZONE):
value = exp.DataType(
this=exp.DataType.Type.TIMESTAMP,
expressions=expressions,
)
value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
maybe_func = maybe_func and value is None
if value is None:
value = exp.DataType(
this=exp.DataType.Type.TIMESTAMP,
expressions=expressions,
)
value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
if maybe_func and check_func:
index2 = self._index
@ -1872,6 +1849,7 @@ class Parser(metaclass=_Parser):
this = self._parse_id_var()
self._match(TokenType.COLON)
data_type = self._parse_types()
if not data_type:
return None
return self.expression(exp.StructKwarg, this=this, expression=data_type)
@ -1879,7 +1857,6 @@ class Parser(metaclass=_Parser):
def _parse_at_time_zone(self, this):
if not self._match(TokenType.AT_TIME_ZONE):
return this
return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary())
def _parse_column(self):
@ -1984,16 +1961,14 @@ class Parser(metaclass=_Parser):
else:
subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type)
if subquery_predicate and self._curr.token_type in (
TokenType.SELECT,
TokenType.WITH,
):
if subquery_predicate and self._curr.token_type in (TokenType.SELECT, TokenType.WITH):
this = self.expression(subquery_predicate, this=self._parse_select())
self._match_r_paren()
return this
if functions is None:
functions = self.FUNCTIONS
function = functions.get(upper)
args = self._parse_csv(self._parse_lambda)
@ -2014,6 +1989,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.L_PAREN):
return this
expressions = self._parse_csv(self._parse_udf_kwarg)
self._match_r_paren()
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
@ -2021,25 +1997,19 @@ class Parser(metaclass=_Parser):
def _parse_introducer(self, token):
literal = self._parse_primary()
if literal:
return self.expression(
exp.Introducer,
this=token.text,
expression=literal,
)
return self.expression(exp.Introducer, this=token.text, expression=literal)
return self.expression(exp.Identifier, this=token.text)
def _parse_session_parameter(self):
kind = None
this = self._parse_id_var() or self._parse_primary()
if self._match(TokenType.DOT):
kind = this.name
this = self._parse_var() or self._parse_primary()
return self.expression(
exp.SessionParameter,
this=this,
kind=kind,
)
return self.expression(exp.SessionParameter, this=this, kind=kind)
def _parse_udf_kwarg(self):
this = self._parse_id_var()
@ -2106,7 +2076,10 @@ class Parser(metaclass=_Parser):
return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints)
def _parse_column_constraint(self):
this = None
this = self._parse_references()
if this:
return this
if self._match(TokenType.CONSTRAINT):
this = self._parse_id_var()
@ -2114,13 +2087,12 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.AUTO_INCREMENT):
kind = exp.AutoIncrementColumnConstraint()
elif self._match(TokenType.CHECK):
self._match_l_paren()
kind = self.expression(exp.CheckColumnConstraint, this=self._parse_conjunction())
self._match_r_paren()
constraint = self._parse_wrapped(self._parse_conjunction)
kind = self.expression(exp.CheckColumnConstraint, this=constraint)
elif self._match(TokenType.COLLATE):
kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var())
elif self._match(TokenType.DEFAULT):
kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_field())
kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction())
elif self._match_pair(TokenType.NOT, TokenType.NULL):
kind = exp.NotNullColumnConstraint()
elif self._match(TokenType.SCHEMA_COMMENT):
@ -2137,7 +2109,7 @@ class Parser(metaclass=_Parser):
kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
self._match_pair(TokenType.ALIAS, TokenType.IDENTITY)
else:
return None
return this
return self.expression(exp.ColumnConstraint, this=this, kind=kind)
@ -2159,37 +2131,29 @@ class Parser(metaclass=_Parser):
def _parse_unnamed_constraint(self):
if not self._match_set(self.CONSTRAINT_PARSERS):
return None
return self.CONSTRAINT_PARSERS[self._prev.token_type](self)
def _parse_check(self):
self._match(TokenType.CHECK)
self._match_l_paren()
expression = self._parse_conjunction()
self._match_r_paren()
return self.expression(exp.Check, this=expression)
def _parse_unique(self):
self._match(TokenType.UNIQUE)
columns = self._parse_wrapped_id_vars()
return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars())
return self.expression(exp.Unique, expressions=columns)
def _parse_foreign_key(self):
self._match(TokenType.FOREIGN_KEY)
expressions = self._parse_wrapped_id_vars()
reference = self._match(TokenType.REFERENCES) and self.expression(
def _parse_references(self):
if not self._match(TokenType.REFERENCES):
return None
return self.expression(
exp.Reference,
this=self._parse_id_var(),
expressions=self._parse_wrapped_id_vars(),
)
def _parse_foreign_key(self):
expressions = self._parse_wrapped_id_vars()
reference = self._parse_references()
options = {}
while self._match(TokenType.ON):
if not self._match_set((TokenType.DELETE, TokenType.UPDATE)):
self.raise_error("Expected DELETE or UPDATE")
kind = self._prev.text.lower()
if self._match(TokenType.NO_ACTION):
@ -2200,6 +2164,7 @@ class Parser(metaclass=_Parser):
else:
self._advance()
action = self._prev.text.upper()
options[kind] = action
return self.expression(
@ -2363,20 +2328,14 @@ class Parser(metaclass=_Parser):
def _parse_window(self, this, alias=False):
if self._match(TokenType.FILTER):
self._match_l_paren()
this = self.expression(exp.Filter, this=this, expression=self._parse_where())
self._match_r_paren()
where = self._parse_wrapped(self._parse_where)
this = self.expression(exp.Filter, this=this, expression=where)
# T-SQL allows the OVER (...) syntax after WITHIN GROUP.
# https://learn.microsoft.com/en-us/sql/t-sql/functions/percentile-disc-transact-sql?view=sql-server-ver16
if self._match(TokenType.WITHIN_GROUP):
self._match_l_paren()
this = self.expression(
exp.WithinGroup,
this=this,
expression=self._parse_order(),
)
self._match_r_paren()
order = self._parse_wrapped(self._parse_order)
this = self.expression(exp.WithinGroup, this=this, expression=order)
# SQL spec defines an optional [ { IGNORE | RESPECT } NULLS ] OVER
# Some dialects choose to implement and some do not.
@ -2404,18 +2363,11 @@ class Parser(metaclass=_Parser):
return this
if not self._match(TokenType.L_PAREN):
alias = self._parse_id_var(False)
return self.expression(exp.Window, this=this, alias=self._parse_id_var(False))
return self.expression(
exp.Window,
this=this,
alias=alias,
)
alias = self._parse_id_var(False)
partition = None
alias = self._parse_id_var(False)
if self._match(TokenType.PARTITION_BY):
partition = self._parse_csv(self._parse_conjunction)
@ -2552,17 +2504,13 @@ class Parser(metaclass=_Parser):
def _parse_replace(self):
if not self._match(TokenType.REPLACE):
return None
return self._parse_wrapped_csv(lambda: self._parse_alias(self._parse_expression()))
self._match_l_paren()
columns = self._parse_csv(lambda: self._parse_alias(self._parse_expression()))
self._match_r_paren()
return columns
def _parse_csv(self, parse_method):
def _parse_csv(self, parse_method, sep=TokenType.COMMA):
parse_result = parse_method()
items = [parse_result] if parse_result is not None else []
while self._match(TokenType.COMMA):
while self._match(sep):
if parse_result and self._prev_comment is not None:
parse_result.comment = self._prev_comment
@ -2583,16 +2531,53 @@ class Parser(metaclass=_Parser):
return this
def _parse_wrapped_id_vars(self):
return self._parse_wrapped_csv(self._parse_id_var)
def _parse_wrapped_csv(self, parse_method, sep=TokenType.COMMA):
return self._parse_wrapped(lambda: self._parse_csv(parse_method, sep=sep))
def _parse_wrapped(self, parse_method):
self._match_l_paren()
expressions = self._parse_csv(self._parse_id_var)
parse_result = parse_method()
self._match_r_paren()
return expressions
return parse_result
def _parse_select_or_expression(self):
return self._parse_select() or self._parse_expression()
def _parse_use(self):
return self.expression(exp.Use, this=self._parse_id_var())
def _parse_transaction(self):
this = None
if self._match_texts(self.TRANSACTION_KIND):
this = self._prev.text
self._match_texts({"TRANSACTION", "WORK"})
modes = []
while True:
mode = []
while self._match(TokenType.VAR):
mode.append(self._prev.text)
if mode:
modes.append(" ".join(mode))
if not self._match(TokenType.COMMA):
break
return self.expression(exp.Transaction, this=this, modes=modes)
def _parse_commit_or_rollback(self):
savepoint = None
is_rollback = self._prev.token_type == TokenType.ROLLBACK
self._match_texts({"TRANSACTION", "WORK"})
if self._match_text_seq("TO"):
self._match_text_seq("SAVEPOINT")
savepoint = self._parse_id_var()
if is_rollback:
return self.expression(exp.Rollback, savepoint=savepoint)
return self.expression(exp.Commit)
def _parse_show(self):
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
@ -2675,7 +2660,13 @@ class Parser(metaclass=_Parser):
if expression and self._prev_comment:
expression.comment = self._prev_comment
def _match_text(self, *texts):
def _match_texts(self, texts):
if self._curr and self._curr.text.upper() in texts:
self._advance()
return True
return False
def _match_text_seq(self, *texts):
index = self._index
for text in texts:
if self._curr and self._curr.text.upper() == text:

View file

@ -1,5 +1,8 @@
from __future__ import annotations
import itertools
import math
import typing as t
from sqlglot import alias, exp
from sqlglot.errors import UnsupportedError
@ -7,15 +10,15 @@ from sqlglot.optimizer.eliminate_joins import join_condition
class Plan:
def __init__(self, expression):
self.expression = expression
def __init__(self, expression: exp.Expression) -> None:
self.expression = expression.copy()
self.root = Step.from_expression(self.expression)
self._dag = {}
self._dag: t.Dict[Step, t.Set[Step]] = {}
@property
def dag(self):
def dag(self) -> t.Dict[Step, t.Set[Step]]:
if not self._dag:
dag = {}
dag: t.Dict[Step, t.Set[Step]] = {}
nodes = {self.root}
while nodes:
@ -29,32 +32,64 @@ class Plan:
return self._dag
@property
def leaves(self):
def leaves(self) -> t.Generator[Step, None, None]:
return (node for node, deps in self.dag.items() if not deps)
def __repr__(self) -> str:
return f"Plan\n----\n{repr(self.root)}"
class Step:
@classmethod
def from_expression(cls, expression, ctes=None):
def from_expression(
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
) -> Step:
"""
Build a DAG of Steps from a SQL expression.
Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine.
Note: the expression's tables and subqueries must be aliased for this method to work. For
example, given the following expression:
Giving an expression like:
SELECT x.a, SUM(x.b)
FROM x
JOIN y
SELECT
x.a,
SUM(x.b)
FROM x AS x
JOIN y AS y
ON x.a = y.a
GROUP BY x.a
Transform it into a DAG of the form:
the following DAG is produced (the expression IDs might differ per execution):
Aggregate(x.a, SUM(x.b))
Join(y)
Scan(x)
Scan(y)
- Aggregate: x (4347984624)
Context:
Aggregations:
- SUM(x.b)
Group:
- x.a
Projections:
- x.a
- "x".""
Dependencies:
- Join: x (4347985296)
Context:
y:
On: x.a = y.a
Projections:
Dependencies:
- Scan: x (4347983136)
Context:
Source: x AS x
Projections:
- Scan: y (4343416624)
Context:
Source: y AS y
Projections:
This can then more easily be executed on by an engine.
Args:
expression: the expression to build the DAG from.
ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
Returns:
A Step DAG corresponding to `expression`.
"""
ctes = ctes or {}
with_ = expression.args.get("with")
@ -65,11 +100,11 @@ class Step:
for cte in with_.expressions:
step = Step.from_expression(cte.this, ctes)
step.name = cte.alias
ctes[step.name] = step
ctes[step.name] = step # type: ignore
from_ = expression.args.get("from")
if from_:
if isinstance(expression, exp.Select) and from_:
from_ = from_.expressions
if len(from_) > 1:
raise UnsupportedError(
@ -77,8 +112,10 @@ class Step:
)
step = Scan.from_expression(from_[0], ctes)
elif isinstance(expression, exp.Union):
step = SetOperation.from_expression(expression, ctes)
else:
raise UnsupportedError("Static selects are unsupported.")
step = Scan()
joins = expression.args.get("joins")
@ -115,7 +152,7 @@ class Step:
group = expression.args.get("group")
if group:
if group or aggregations:
aggregate = Aggregate()
aggregate.source = step.name
aggregate.name = step.name
@ -123,7 +160,15 @@ class Step:
alias(operand, alias_) for operand, alias_ in operands.items()
)
aggregate.aggregations = aggregations
aggregate.group = group.expressions
# give aggregates names and replace projections with references to them
aggregate.group = {
f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
}
for projection in projections:
for i, e in aggregate.group.items():
for child, _, _ in projection.walk():
if child == e:
child.replace(exp.column(i, step.name))
aggregate.add_dependency(step)
step = aggregate
@ -150,22 +195,22 @@ class Step:
return step
def __init__(self):
self.name = None
self.dependencies = set()
self.dependents = set()
self.projections = []
self.limit = math.inf
self.condition = None
def __init__(self) -> None:
self.name: t.Optional[str] = None
self.dependencies: t.Set[Step] = set()
self.dependents: t.Set[Step] = set()
self.projections: t.Sequence[exp.Expression] = []
self.limit: float = math.inf
self.condition: t.Optional[exp.Expression] = None
def add_dependency(self, dependency):
def add_dependency(self, dependency: Step) -> None:
self.dependencies.add(dependency)
dependency.dependents.add(self)
def __repr__(self):
def __repr__(self) -> str:
return self.to_s()
def to_s(self, level=0):
def to_s(self, level: int = 0) -> str:
indent = " " * level
nested = f"{indent} "
@ -175,7 +220,7 @@ class Step:
context = [f"{nested}Context:"] + context
lines = [
f"{indent}- {self.__class__.__name__}: {self.name}",
f"{indent}- {self.id}",
*context,
f"{nested}Projections:",
]
@ -193,13 +238,25 @@ class Step:
return "\n".join(lines)
def _to_s(self, _indent):
@property
def type_name(self) -> str:
return self.__class__.__name__
@property
def id(self) -> str:
name = self.name
name = f" {name}" if name else ""
return f"{self.type_name}:{name} ({id(self)})"
def _to_s(self, _indent: str) -> t.List[str]:
return []
class Scan(Step):
@classmethod
def from_expression(cls, expression, ctes=None):
def from_expression(
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
) -> Step:
table = expression
alias_ = expression.alias
@ -217,26 +274,24 @@ class Scan(Step):
step = Scan()
step.name = alias_
step.source = expression
if table.name in ctes:
if ctes and table.name in ctes:
step.add_dependency(ctes[table.name])
return step
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.source = None
self.source: t.Optional[exp.Expression] = None
def _to_s(self, indent):
return [f"{indent}Source: {self.source.sql()}"]
class Write(Step):
pass
def _to_s(self, indent: str) -> t.List[str]:
return [f"{indent}Source: {self.source.sql() if self.source else '-static-'}"] # type: ignore
class Join(Step):
@classmethod
def from_joins(cls, joins, ctes=None):
def from_joins(
cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None
) -> Step:
step = Join()
for join in joins:
@ -252,28 +307,28 @@ class Join(Step):
return step
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.joins = {}
self.joins: t.Dict[str, t.Dict[str, t.List[str] | exp.Expression]] = {}
def _to_s(self, indent):
def _to_s(self, indent: str) -> t.List[str]:
lines = []
for name, join in self.joins.items():
lines.append(f"{indent}{name}: {join['side']}")
if join.get("condition"):
lines.append(f"{indent}On: {join['condition'].sql()}")
lines.append(f"{indent}On: {join['condition'].sql()}") # type: ignore
return lines
class Aggregate(Step):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.aggregations = []
self.operands = []
self.group = []
self.source = None
self.aggregations: t.List[exp.Expression] = []
self.operands: t.Tuple[exp.Expression, ...] = ()
self.group: t.Dict[str, exp.Expression] = {}
self.source: t.Optional[str] = None
def _to_s(self, indent):
def _to_s(self, indent: str) -> t.List[str]:
lines = [f"{indent}Aggregations:"]
for expression in self.aggregations:
@ -281,7 +336,7 @@ class Aggregate(Step):
if self.group:
lines.append(f"{indent}Group:")
for expression in self.group:
for expression in self.group.values():
lines.append(f"{indent} - {expression.sql()}")
if self.operands:
lines.append(f"{indent}Operands:")
@ -292,14 +347,56 @@ class Aggregate(Step):
class Sort(Step):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.key = None
def _to_s(self, indent):
def _to_s(self, indent: str) -> t.List[str]:
lines = [f"{indent}Key:"]
for expression in self.key:
for expression in self.key: # type: ignore
lines.append(f"{indent} - {expression.sql()}")
return lines
class SetOperation(Step):
def __init__(
self,
op: t.Type[exp.Expression],
left: str | None,
right: str | None,
distinct: bool = False,
) -> None:
super().__init__()
self.op = op
self.left = left
self.right = right
self.distinct = distinct
@classmethod
def from_expression(
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
) -> Step:
assert isinstance(expression, exp.Union)
left = Step.from_expression(expression.left, ctes)
right = Step.from_expression(expression.right, ctes)
step = cls(
op=expression.__class__,
left=left.name,
right=right.name,
distinct=expression.args.get("distinct"),
)
step.add_dependency(left)
step.add_dependency(right)
return step
def _to_s(self, indent: str) -> t.List[str]:
lines = []
if self.distinct:
lines.append(f"{indent}Distinct: {self.distinct}")
return lines
@property
def type_name(self) -> str:
return self.op.__name__

View file

@ -5,7 +5,7 @@ import typing as t
from sqlglot import expressions as exp
from sqlglot.errors import SchemaError
from sqlglot.helper import csv_reader
from sqlglot.helper import dict_depth
from sqlglot.trie import in_trie, new_trie
if t.TYPE_CHECKING:
@ -15,6 +15,8 @@ if t.TYPE_CHECKING:
TABLE_ARGS = ("this", "db", "catalog")
T = t.TypeVar("T")
class Schema(abc.ABC):
"""Abstract base class for database schemas"""
@ -57,8 +59,81 @@ class Schema(abc.ABC):
The resulting column type.
"""
@property
def supported_table_args(self) -> t.Tuple[str, ...]:
"""
Table arguments this schema support, e.g. `("this", "db", "catalog")`
"""
raise NotImplementedError
class MappingSchema(Schema):
class AbstractMappingSchema(t.Generic[T]):
def __init__(
self,
mapping: dict | None = None,
) -> None:
self.mapping = mapping or {}
self.mapping_trie = self._build_trie(self.mapping)
self._supported_table_args: t.Tuple[str, ...] = tuple()
def _build_trie(self, schema: t.Dict) -> t.Dict:
return new_trie(tuple(reversed(t)) for t in flatten_schema(schema, depth=self._depth()))
def _depth(self) -> int:
return dict_depth(self.mapping)
@property
def supported_table_args(self) -> t.Tuple[str, ...]:
if not self._supported_table_args and self.mapping:
depth = self._depth()
if not depth: # None
self._supported_table_args = tuple()
elif 1 <= depth <= 3:
self._supported_table_args = TABLE_ARGS[:depth]
else:
raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
return self._supported_table_args
def table_parts(self, table: exp.Table) -> t.List[str]:
if isinstance(table.this, exp.ReadCSV):
return [table.this.name]
return [table.text(part) for part in TABLE_ARGS if table.text(part)]
def find(
self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
) -> t.Optional[T]:
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
if value == 0:
if raise_on_missing:
raise SchemaError(f"Cannot find mapping for {table}.")
else:
return None
elif value == 1:
possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
if len(possibilities) == 1:
parts.extend(possibilities[0])
else:
message = ", ".join(".".join(parts) for parts in possibilities)
if raise_on_missing:
raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
return None
return self._nested_get(parts, raise_on_missing=raise_on_missing)
def _nested_get(
self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
) -> t.Optional[t.Any]:
return _nested_get(
d or self.mapping,
*zip(self.supported_table_args, reversed(parts)),
raise_on_missing=raise_on_missing,
)
class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
"""
Schema based on a nested mapping.
@ -82,17 +157,17 @@ class MappingSchema(Schema):
visible: t.Optional[t.Dict] = None,
dialect: t.Optional[str] = None,
) -> None:
self.schema = schema or {}
super().__init__(schema)
self.visible = visible or {}
self.schema_trie = self._build_trie(self.schema)
self.dialect = dialect
self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {}
self._supported_table_args: t.Tuple[str, ...] = tuple()
self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {
"STR": exp.DataType.Type.TEXT,
}
@classmethod
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
return MappingSchema(
schema=mapping_schema.schema,
schema=mapping_schema.mapping,
visible=mapping_schema.visible,
dialect=mapping_schema.dialect,
)
@ -100,27 +175,13 @@ class MappingSchema(Schema):
def copy(self, **kwargs) -> MappingSchema:
return MappingSchema(
**{ # type: ignore
"schema": self.schema.copy(),
"schema": self.mapping.copy(),
"visible": self.visible.copy(),
"dialect": self.dialect,
**kwargs,
}
)
@property
def supported_table_args(self):
if not self._supported_table_args and self.schema:
depth = _dict_depth(self.schema)
if not depth or depth == 1: # {}
self._supported_table_args = tuple()
elif 2 <= depth <= 4:
self._supported_table_args = TABLE_ARGS[: depth - 1]
else:
raise SchemaError(f"Invalid schema shape. Depth: {depth}")
return self._supported_table_args
def add_table(
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
) -> None:
@ -133,17 +194,21 @@ class MappingSchema(Schema):
"""
table_ = self._ensure_table(table)
column_mapping = ensure_column_mapping(column_mapping)
schema = self.find_schema(table_, raise_on_missing=False)
schema = self.find(table_, raise_on_missing=False)
if schema and not column_mapping:
return
_nested_set(
self.schema,
self.mapping,
list(reversed(self.table_parts(table_))),
column_mapping,
)
self.schema_trie = self._build_trie(self.schema)
self.mapping_trie = self._build_trie(self.mapping)
def _depth(self) -> int:
# The columns themselves are a mapping, but we don't want to include those
return super()._depth() - 1
def _ensure_table(self, table: exp.Table | str) -> exp.Table:
table_ = exp.to_table(table)
@ -153,16 +218,9 @@ class MappingSchema(Schema):
return table_
def table_parts(self, table: exp.Table) -> t.List[str]:
return [table.text(part) for part in TABLE_ARGS if table.text(part)]
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
table_ = self._ensure_table(table)
if not isinstance(table_.this, exp.Identifier):
return fs_get(table) # type: ignore
schema = self.find_schema(table_)
schema = self.find(table_)
if schema is None:
raise SchemaError(f"Could not find table schema {table}")
@ -173,36 +231,13 @@ class MappingSchema(Schema):
visible = self._nested_get(self.table_parts(table_), self.visible)
return [col for col in schema if col in visible] # type: ignore
def find_schema(
self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
) -> t.Optional[t.Dict[str, str]]:
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
value, trie = in_trie(self.schema_trie if trie is None else trie, parts)
if value == 0:
if raise_on_missing:
raise SchemaError(f"Cannot find schema for {table}.")
else:
return None
elif value == 1:
possibilities = flatten_schema(trie)
if len(possibilities) == 1:
parts.extend(possibilities[0])
else:
message = ", ".join(".".join(parts) for parts in possibilities)
if raise_on_missing:
raise SchemaError(f"Ambiguous schema for {table}: {message}.")
return None
return self._nested_get(parts, raise_on_missing=raise_on_missing)
def get_column_type(
self, table: exp.Table | str, column: exp.Column | str
) -> exp.DataType.Type:
column_name = column if isinstance(column, str) else column.name
table_ = exp.to_table(table)
if table_:
table_schema = self.find_schema(table_)
table_schema = self.find(table_)
schema_type = table_schema.get(column_name).upper() # type: ignore
return self._convert_type(schema_type)
raise SchemaError(f"Could not convert table '{table}'")
@ -228,18 +263,6 @@ class MappingSchema(Schema):
return self._type_mapping_cache[schema_type]
def _build_trie(self, schema: t.Dict):
return new_trie(tuple(reversed(t)) for t in flatten_schema(schema))
def _nested_get(
self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
) -> t.Optional[t.Any]:
return _nested_get(
d or self.schema,
*zip(self.supported_table_args, reversed(parts)),
raise_on_missing=raise_on_missing,
)
def ensure_schema(schema: t.Any) -> Schema:
if isinstance(schema, Schema):
@ -267,29 +290,20 @@ def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema(schema: t.Dict, keys: t.Optional[t.List[str]] = None) -> t.List[t.List[str]]:
def flatten_schema(
schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
) -> t.List[t.List[str]]:
tables = []
keys = keys or []
depth = _dict_depth(schema)
for k, v in schema.items():
if depth >= 3:
tables.extend(flatten_schema(v, keys + [k]))
elif depth == 2:
if depth >= 2:
tables.extend(flatten_schema(v, depth - 1, keys + [k]))
elif depth == 1:
tables.append(keys + [k])
return tables
def fs_get(table: exp.Table) -> t.List[str]:
name = table.this.name
if name.upper() == "READ_CSV":
with csv_reader(table) as reader:
return next(reader)
raise ValueError(f"Cannot read schema for {table}")
def _nested_get(
d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
) -> t.Optional[t.Any]:
@ -310,7 +324,7 @@ def _nested_get(
if d is None:
if raise_on_missing:
name = "table" if name == "this" else name
raise ValueError(f"Unknown {name}")
raise ValueError(f"Unknown {name}: {key}")
return None
return d
@ -350,34 +364,3 @@ def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
subd[keys[-1]] = value
return d
def _dict_depth(d: t.Dict) -> int:
"""
Get the nesting depth of a dictionary.
For example:
>>> _dict_depth(None)
0
>>> _dict_depth({})
1
>>> _dict_depth({"a": "b"})
1
>>> _dict_depth({"a": {}})
2
>>> _dict_depth({"a": {"b": {}}})
3
Args:
d (dict): dictionary
Returns:
int: depth
"""
try:
return 1 + _dict_depth(next(iter(d.values())))
except AttributeError:
# d doesn't have attribute "values"
return 0
except StopIteration:
# d.values() returns an empty sequence
return 1

View file

@ -105,12 +105,9 @@ class TokenType(AutoName):
OBJECT = auto()
# keywords
ADD_FILE = auto()
ALIAS = auto()
ALWAYS = auto()
ALL = auto()
ALTER = auto()
ANALYZE = auto()
ANTI = auto()
ANY = auto()
APPLY = auto()
@ -124,14 +121,14 @@ class TokenType(AutoName):
BUCKET = auto()
BY_DEFAULT = auto()
CACHE = auto()
CALL = auto()
CASCADE = auto()
CASE = auto()
CHARACTER_SET = auto()
CHECK = auto()
CLUSTER_BY = auto()
COLLATE = auto()
COMMAND = auto()
COMMENT = auto()
COMMENT_ON = auto()
COMMIT = auto()
CONSTRAINT = auto()
CREATE = auto()
@ -149,7 +146,9 @@ class TokenType(AutoName):
DETERMINISTIC = auto()
DISTINCT = auto()
DISTINCT_FROM = auto()
DISTKEY = auto()
DISTRIBUTE_BY = auto()
DISTSTYLE = auto()
DIV = auto()
DROP = auto()
ELSE = auto()
@ -159,7 +158,6 @@ class TokenType(AutoName):
EXCEPT = auto()
EXECUTE = auto()
EXISTS = auto()
EXPLAIN = auto()
FALSE = auto()
FETCH = auto()
FILTER = auto()
@ -216,7 +214,6 @@ class TokenType(AutoName):
OFFSET = auto()
ON = auto()
ONLY = auto()
OPTIMIZE = auto()
OPTIONS = auto()
ORDER_BY = auto()
ORDERED = auto()
@ -258,6 +255,7 @@ class TokenType(AutoName):
SHOW = auto()
SIMILAR_TO = auto()
SOME = auto()
SORTKEY = auto()
SORT_BY = auto()
STABLE = auto()
STORED = auto()
@ -268,9 +266,8 @@ class TokenType(AutoName):
TRANSIENT = auto()
TOP = auto()
THEN = auto()
TRUE = auto()
TRAILING = auto()
TRUNCATE = auto()
TRUE = auto()
UNBOUNDED = auto()
UNCACHE = auto()
UNION = auto()
@ -280,7 +277,6 @@ class TokenType(AutoName):
USE = auto()
USING = auto()
VALUES = auto()
VACUUM = auto()
VIEW = auto()
VOLATILE = auto()
WHEN = auto()
@ -420,7 +416,6 @@ class Tokenizer(metaclass=_Tokenizer):
KEYWORDS = {
"/*+": TokenType.HINT,
"*/": TokenType.HINT,
"==": TokenType.EQ,
"::": TokenType.DCOLON,
"||": TokenType.DPIPE,
@ -435,15 +430,7 @@ class Tokenizer(metaclass=_Tokenizer):
"#>": TokenType.HASH_ARROW,
"#>>": TokenType.DHASH_ARROW,
"<->": TokenType.LR_ARROW,
"ADD ARCHIVE": TokenType.ADD_FILE,
"ADD ARCHIVES": TokenType.ADD_FILE,
"ADD FILE": TokenType.ADD_FILE,
"ADD FILES": TokenType.ADD_FILE,
"ADD JAR": TokenType.ADD_FILE,
"ADD JARS": TokenType.ADD_FILE,
"ALL": TokenType.ALL,
"ALTER": TokenType.ALTER,
"ANALYZE": TokenType.ANALYZE,
"AND": TokenType.AND,
"ANTI": TokenType.ANTI,
"ANY": TokenType.ANY,
@ -455,10 +442,10 @@ class Tokenizer(metaclass=_Tokenizer):
"BETWEEN": TokenType.BETWEEN,
"BOTH": TokenType.BOTH,
"BUCKET": TokenType.BUCKET,
"CALL": TokenType.CALL,
"CACHE": TokenType.CACHE,
"UNCACHE": TokenType.UNCACHE,
"CASE": TokenType.CASE,
"CASCADE": TokenType.CASCADE,
"CHARACTER SET": TokenType.CHARACTER_SET,
"CHECK": TokenType.CHECK,
"CLUSTER BY": TokenType.CLUSTER_BY,
@ -479,7 +466,9 @@ class Tokenizer(metaclass=_Tokenizer):
"DETERMINISTIC": TokenType.DETERMINISTIC,
"DISTINCT": TokenType.DISTINCT,
"DISTINCT FROM": TokenType.DISTINCT_FROM,
"DISTKEY": TokenType.DISTKEY,
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
"DISTSTYLE": TokenType.DISTSTYLE,
"DIV": TokenType.DIV,
"DROP": TokenType.DROP,
"ELSE": TokenType.ELSE,
@ -489,7 +478,6 @@ class Tokenizer(metaclass=_Tokenizer):
"EXCEPT": TokenType.EXCEPT,
"EXECUTE": TokenType.EXECUTE,
"EXISTS": TokenType.EXISTS,
"EXPLAIN": TokenType.EXPLAIN,
"FALSE": TokenType.FALSE,
"FETCH": TokenType.FETCH,
"FILTER": TokenType.FILTER,
@ -541,7 +529,6 @@ class Tokenizer(metaclass=_Tokenizer):
"OFFSET": TokenType.OFFSET,
"ON": TokenType.ON,
"ONLY": TokenType.ONLY,
"OPTIMIZE": TokenType.OPTIMIZE,
"OPTIONS": TokenType.OPTIONS,
"OR": TokenType.OR,
"ORDER BY": TokenType.ORDER_BY,
@ -579,6 +566,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SET": TokenType.SET,
"SHOW": TokenType.SHOW,
"SOME": TokenType.SOME,
"SORTKEY": TokenType.SORTKEY,
"SORT BY": TokenType.SORT_BY,
"STABLE": TokenType.STABLE,
"STORED": TokenType.STORED,
@ -592,7 +580,6 @@ class Tokenizer(metaclass=_Tokenizer):
"THEN": TokenType.THEN,
"TRUE": TokenType.TRUE,
"TRAILING": TokenType.TRAILING,
"TRUNCATE": TokenType.TRUNCATE,
"UNBOUNDED": TokenType.UNBOUNDED,
"UNION": TokenType.UNION,
"UNPIVOT": TokenType.UNPIVOT,
@ -600,7 +587,6 @@ class Tokenizer(metaclass=_Tokenizer):
"UPDATE": TokenType.UPDATE,
"USE": TokenType.USE,
"USING": TokenType.USING,
"VACUUM": TokenType.VACUUM,
"VALUES": TokenType.VALUES,
"VIEW": TokenType.VIEW,
"VOLATILE": TokenType.VOLATILE,
@ -659,6 +645,14 @@ class Tokenizer(metaclass=_Tokenizer):
"UNIQUE": TokenType.UNIQUE,
"STRUCT": TokenType.STRUCT,
"VARIANT": TokenType.VARIANT,
"ALTER": TokenType.COMMAND,
"ANALYZE": TokenType.COMMAND,
"CALL": TokenType.COMMAND,
"EXPLAIN": TokenType.COMMAND,
"OPTIMIZE": TokenType.COMMAND,
"PREPARE": TokenType.COMMAND,
"TRUNCATE": TokenType.COMMAND,
"VACUUM": TokenType.COMMAND,
}
WHITE_SPACE = {
@ -670,20 +664,11 @@ class Tokenizer(metaclass=_Tokenizer):
}
COMMANDS = {
TokenType.ALTER,
TokenType.ADD_FILE,
TokenType.ANALYZE,
TokenType.BEGIN,
TokenType.CALL,
TokenType.COMMENT_ON,
TokenType.COMMIT,
TokenType.EXPLAIN,
TokenType.OPTIMIZE,
TokenType.COMMAND,
TokenType.EXECUTE,
TokenType.FETCH,
TokenType.SET,
TokenType.SHOW,
TokenType.TRUNCATE,
TokenType.VACUUM,
TokenType.ROLLBACK,
}
# handle numeric literals like in hive (3L = BIGINT)
@ -885,6 +870,7 @@ class Tokenizer(metaclass=_Tokenizer):
if comment_start_line == self._prev_token_line:
if self._prev_token_comment is None:
self.tokens[-1].comment = self._comment
self._prev_token_comment = self._comment
self._comment = None

View file

@ -4,6 +4,8 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataframe(DataFrameSQLValidator):
maxDiff = None
def test_hash_select_expression(self):
expression = exp.select("cola").from_("table")
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
@ -16,26 +18,26 @@ class TestDataframe(DataFrameSQLValidator):
def test_cache(self):
df = self.df_employee.select("fname").cache()
expected_statements = [
"DROP VIEW IF EXISTS t11623",
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
"DROP VIEW IF EXISTS t31563",
"CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
]
self.compare_sql(df, expected_statements)
def test_persist_default(self):
df = self.df_employee.select("fname").persist()
expected_statements = [
"DROP VIEW IF EXISTS t11623",
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
"DROP VIEW IF EXISTS t31563",
"CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
]
self.compare_sql(df, expected_statements)
def test_persist_storagelevel(self):
df = self.df_employee.select("fname").persist("DISK_ONLY_2")
expected_statements = [
"DROP VIEW IF EXISTS t11623",
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
"DROP VIEW IF EXISTS t31563",
"CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
]
self.compare_sql(df, expected_statements)

View file

@ -6,39 +6,41 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataFrameWriter(DataFrameSQLValidator):
maxDiff = None
def test_insertInto_full_path(self):
df = self.df_employee.write.insertInto("catalog.db.table_name")
expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_db_table(self):
df = self.df_employee.write.insertInto("db.table_name")
expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_table(self):
df = self.df_employee.write.insertInto("table_name")
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_overwrite(self):
df = self.df_employee.write.insertInto("table_name", overwrite=True)
expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema())
def test_insertInto_byName(self):
sqlglot.schema.add_table("table_name", {"employee_id": "INT"})
df = self.df_employee.write.byName.insertInto("table_name")
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_cache(self):
df = self.df_employee.cache().write.insertInto("table_name")
expected_statements = [
"DROP VIEW IF EXISTS t35612",
"CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"INSERT INTO table_name SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`",
"DROP VIEW IF EXISTS t37164",
"CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"INSERT INTO table_name SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`",
]
self.compare_sql(df, expected_statements)
@ -48,39 +50,39 @@ class TestDataFrameWriter(DataFrameSQLValidator):
def test_saveAsTable_append(self):
df = self.df_employee.write.saveAsTable("table_name", mode="append")
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_overwrite(self):
df = self.df_employee.write.saveAsTable("table_name", mode="overwrite")
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_error(self):
df = self.df_employee.write.saveAsTable("table_name", mode="error")
expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_ignore(self):
df = self.df_employee.write.saveAsTable("table_name", mode="ignore")
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_mode_standalone(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name")
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_mode_override(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite")
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_cache(self):
df = self.df_employee.cache().write.saveAsTable("table_name")
expected_statements = [
"DROP VIEW IF EXISTS t35612",
"CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"CREATE TABLE table_name AS SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`",
"DROP VIEW IF EXISTS t37164",
"CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"CREATE TABLE table_name AS SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`",
]
self.compare_sql(df, expected_statements)

View file

@ -11,32 +11,32 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataframeSession(DataFrameSQLValidator):
def test_cdf_one_row(self):
df = self.spark.createDataFrame([[1, 2]], ["cola", "colb"])
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2)) AS `a2`(`cola`, `colb`)"
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_multiple_rows(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]], ["cola", "colb"])
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`cola`, `colb`)"
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_no_schema(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`_1`, `_2`)"
self.compare_sql(df, expected)
def test_cdf_row_mixed_primitives(self):
df = self.spark.createDataFrame([[1, 10.1, "test", False, None]])
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM (VALUES (1, 10.1, 'test', FALSE, NULL)) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)"
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM VALUES (1, 10.1, 'test', FALSE, NULL) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)"
self.compare_sql(df, expected)
def test_cdf_dict_rows(self):
df = self.spark.createDataFrame([{"cola": 1, "colb": "test"}, {"cola": 2, "colb": "test2"}])
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 'test'), (2, 'test2')) AS `a2`(`cola`, `colb`)"
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 'test'), (2, 'test2') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_str_schema(self):
df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING")
expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)"
expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_typed_schema_basic(self):
@ -47,7 +47,7 @@ class TestDataframeSession(DataFrameSQLValidator):
]
)
df = self.spark.createDataFrame([[1, "test"]], schema)
expected = "SELECT CAST(`a2`.`cola` AS int) AS `cola`, CAST(`a2`.`colb` AS string) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)"
expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_typed_schema_nested(self):
@ -65,7 +65,8 @@ class TestDataframeSession(DataFrameSQLValidator):
]
)
df = self.spark.createDataFrame([[{"sub_cola": 1, "sub_colb": "test"}]], schema)
expected = "SELECT CAST(`a2`.`cola` AS struct<sub_cola:int, sub_colb:string>) AS `cola` FROM (VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`))) AS `a2`(`cola`)"
expected = "SELECT CAST(`a2`.`cola` AS STRUCT<`sub_cola`: INT, `sub_colb`: STRING>) AS `cola` FROM VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`)) AS `a2`(`cola`)"
self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema())

View file

@ -286,6 +286,10 @@ class TestBigQuery(Validator):
"bigquery": "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))",
},
)
self.validate_identity("BEGIN A B C D E F")
self.validate_identity("BEGIN TRANSACTION")
self.validate_identity("COMMIT TRANSACTION")
self.validate_identity("ROLLBACK TRANSACTION")
def test_user_defined_functions(self):
self.validate_identity(

View file

@ -69,6 +69,7 @@ class TestDialect(Validator):
write={
"bigquery": "CAST(a AS STRING)",
"clickhouse": "CAST(a AS TEXT)",
"drill": "CAST(a AS VARCHAR)",
"duckdb": "CAST(a AS TEXT)",
"mysql": "CAST(a AS TEXT)",
"hive": "CAST(a AS STRING)",
@ -86,6 +87,7 @@ class TestDialect(Validator):
write={
"bigquery": "CAST(a AS BINARY(4))",
"clickhouse": "CAST(a AS BINARY(4))",
"drill": "CAST(a AS VARBINARY(4))",
"duckdb": "CAST(a AS BINARY(4))",
"mysql": "CAST(a AS BINARY(4))",
"hive": "CAST(a AS BINARY(4))",
@ -146,6 +148,7 @@ class TestDialect(Validator):
"CAST(a AS STRING)",
write={
"bigquery": "CAST(a AS STRING)",
"drill": "CAST(a AS VARCHAR)",
"duckdb": "CAST(a AS TEXT)",
"mysql": "CAST(a AS TEXT)",
"hive": "CAST(a AS STRING)",
@ -162,6 +165,7 @@ class TestDialect(Validator):
"CAST(a AS VARCHAR)",
write={
"bigquery": "CAST(a AS STRING)",
"drill": "CAST(a AS VARCHAR)",
"duckdb": "CAST(a AS TEXT)",
"mysql": "CAST(a AS VARCHAR)",
"hive": "CAST(a AS STRING)",
@ -178,6 +182,7 @@ class TestDialect(Validator):
"CAST(a AS VARCHAR(3))",
write={
"bigquery": "CAST(a AS STRING(3))",
"drill": "CAST(a AS VARCHAR(3))",
"duckdb": "CAST(a AS TEXT(3))",
"mysql": "CAST(a AS VARCHAR(3))",
"hive": "CAST(a AS VARCHAR(3))",
@ -194,6 +199,7 @@ class TestDialect(Validator):
"CAST(a AS SMALLINT)",
write={
"bigquery": "CAST(a AS INT64)",
"drill": "CAST(a AS INTEGER)",
"duckdb": "CAST(a AS SMALLINT)",
"mysql": "CAST(a AS SMALLINT)",
"hive": "CAST(a AS SMALLINT)",
@ -215,6 +221,7 @@ class TestDialect(Validator):
},
write={
"duckdb": "TRY_CAST(a AS DOUBLE)",
"drill": "CAST(a AS DOUBLE)",
"postgres": "CAST(a AS DOUBLE PRECISION)",
"redshift": "CAST(a AS DOUBLE PRECISION)",
},
@ -225,6 +232,7 @@ class TestDialect(Validator):
write={
"bigquery": "CAST(a AS FLOAT64)",
"clickhouse": "CAST(a AS Float64)",
"drill": "CAST(a AS DOUBLE)",
"duckdb": "CAST(a AS DOUBLE)",
"mysql": "CAST(a AS DOUBLE)",
"hive": "CAST(a AS DOUBLE)",
@ -279,6 +287,7 @@ class TestDialect(Validator):
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)",
"presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')",
"drill": "TO_TIMESTAMP(x, 'yyyy-MM-dd''T''HH:mm:ss')",
"redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH:MI:SS')",
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')",
},
@ -286,6 +295,7 @@ class TestDialect(Validator):
self.validate_all(
"STR_TO_TIME('2020-01-01', '%Y-%m-%d')",
write={
"drill": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')",
"duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')",
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
"oracle": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
@ -298,6 +308,7 @@ class TestDialect(Validator):
self.validate_all(
"STR_TO_TIME(x, '%y')",
write={
"drill": "TO_TIMESTAMP(x, 'yy')",
"duckdb": "STRPTIME(x, '%y')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)",
"presto": "DATE_PARSE(x, '%y')",
@ -319,6 +330,7 @@ class TestDialect(Validator):
self.validate_all(
"TIME_STR_TO_DATE('2020-01-01')",
write={
"drill": "CAST('2020-01-01' AS DATE)",
"duckdb": "CAST('2020-01-01' AS DATE)",
"hive": "TO_DATE('2020-01-01')",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')",
@ -328,6 +340,7 @@ class TestDialect(Validator):
self.validate_all(
"TIME_STR_TO_TIME('2020-01-01')",
write={
"drill": "CAST('2020-01-01' AS TIMESTAMP)",
"duckdb": "CAST('2020-01-01' AS TIMESTAMP)",
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')",
@ -344,6 +357,7 @@ class TestDialect(Validator):
self.validate_all(
"TIME_TO_STR(x, '%Y-%m-%d')",
write={
"drill": "TO_CHAR(x, 'yyyy-MM-dd')",
"duckdb": "STRFTIME(x, '%Y-%m-%d')",
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd')",
"oracle": "TO_CHAR(x, 'YYYY-MM-DD')",
@ -355,6 +369,7 @@ class TestDialect(Validator):
self.validate_all(
"TIME_TO_TIME_STR(x)",
write={
"drill": "CAST(x AS VARCHAR)",
"duckdb": "CAST(x AS TEXT)",
"hive": "CAST(x AS STRING)",
"presto": "CAST(x AS VARCHAR)",
@ -364,6 +379,7 @@ class TestDialect(Validator):
self.validate_all(
"TIME_TO_UNIX(x)",
write={
"drill": "UNIX_TIMESTAMP(x)",
"duckdb": "EPOCH(x)",
"hive": "UNIX_TIMESTAMP(x)",
"presto": "TO_UNIXTIME(x)",
@ -425,6 +441,7 @@ class TestDialect(Validator):
self.validate_all(
"DATE_TO_DATE_STR(x)",
write={
"drill": "CAST(x AS VARCHAR)",
"duckdb": "CAST(x AS TEXT)",
"hive": "CAST(x AS STRING)",
"presto": "CAST(x AS VARCHAR)",
@ -433,6 +450,7 @@ class TestDialect(Validator):
self.validate_all(
"DATE_TO_DI(x)",
write={
"drill": "CAST(TO_DATE(x, 'yyyyMMdd') AS INT)",
"duckdb": "CAST(STRFTIME(x, '%Y%m%d') AS INT)",
"hive": "CAST(DATE_FORMAT(x, 'yyyyMMdd') AS INT)",
"presto": "CAST(DATE_FORMAT(x, '%Y%m%d') AS INT)",
@ -441,6 +459,7 @@ class TestDialect(Validator):
self.validate_all(
"DI_TO_DATE(x)",
write={
"drill": "TO_DATE(CAST(x AS VARCHAR), 'yyyyMMdd')",
"duckdb": "CAST(STRPTIME(CAST(x AS TEXT), '%Y%m%d') AS DATE)",
"hive": "TO_DATE(CAST(x AS STRING), 'yyyyMMdd')",
"presto": "CAST(DATE_PARSE(CAST(x AS VARCHAR), '%Y%m%d') AS DATE)",
@ -463,6 +482,7 @@ class TestDialect(Validator):
},
write={
"bigquery": "DATE_ADD(x, INTERVAL 1 'day')",
"drill": "DATE_ADD(x, INTERVAL '1' DAY)",
"duckdb": "x + INTERVAL 1 day",
"hive": "DATE_ADD(x, 1)",
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
@ -477,6 +497,7 @@ class TestDialect(Validator):
"DATE_ADD(x, 1)",
write={
"bigquery": "DATE_ADD(x, INTERVAL 1 'day')",
"drill": "DATE_ADD(x, INTERVAL '1' DAY)",
"duckdb": "x + INTERVAL 1 DAY",
"hive": "DATE_ADD(x, 1)",
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
@ -546,6 +567,7 @@ class TestDialect(Validator):
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
},
write={
"drill": "TO_DATE(x, 'yyyy-MM-dd''T''HH:mm:ss')",
"mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)",
@ -556,6 +578,7 @@ class TestDialect(Validator):
self.validate_all(
"STR_TO_DATE(x, '%Y-%m-%d')",
write={
"drill": "CAST(x AS DATE)",
"mysql": "STR_TO_DATE(x, '%Y-%m-%d')",
"starrocks": "STR_TO_DATE(x, '%Y-%m-%d')",
"hive": "CAST(x AS DATE)",
@ -566,6 +589,7 @@ class TestDialect(Validator):
self.validate_all(
"DATE_STR_TO_DATE(x)",
write={
"drill": "CAST(x AS DATE)",
"duckdb": "CAST(x AS DATE)",
"hive": "TO_DATE(x)",
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)",
@ -575,6 +599,7 @@ class TestDialect(Validator):
self.validate_all(
"TS_OR_DS_ADD('2021-02-01', 1, 'DAY')",
write={
"drill": "DATE_ADD(CAST('2021-02-01' AS DATE), INTERVAL '1' DAY)",
"duckdb": "CAST('2021-02-01' AS DATE) + INTERVAL 1 DAY",
"hive": "DATE_ADD('2021-02-01', 1)",
"presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR('2021-02-01', 1, 10), '%Y-%m-%d'))",
@ -584,6 +609,7 @@ class TestDialect(Validator):
self.validate_all(
"DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
write={
"drill": "DATE_ADD(CAST('2020-01-01' AS DATE), INTERVAL '1' DAY)",
"duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY",
"hive": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
"presto": "DATE_ADD('day', 1, CAST('2020-01-01' AS DATE))",
@ -593,6 +619,7 @@ class TestDialect(Validator):
self.validate_all(
"TIMESTAMP '2022-01-01'",
write={
"drill": "CAST('2022-01-01' AS TIMESTAMP)",
"mysql": "CAST('2022-01-01' AS TIMESTAMP)",
"starrocks": "CAST('2022-01-01' AS DATETIME)",
"hive": "CAST('2022-01-01' AS TIMESTAMP)",
@ -614,6 +641,7 @@ class TestDialect(Validator):
dialect: f"{unit}(x)"
for dialect in (
"bigquery",
"drill",
"duckdb",
"mysql",
"presto",
@ -624,6 +652,7 @@ class TestDialect(Validator):
dialect: f"{unit}(x)"
for dialect in (
"bigquery",
"drill",
"duckdb",
"mysql",
"presto",
@ -649,6 +678,7 @@ class TestDialect(Validator):
write={
"bigquery": "ARRAY_LENGTH(x)",
"duckdb": "ARRAY_LENGTH(x)",
"drill": "REPEATED_COUNT(x)",
"presto": "CARDINALITY(x)",
"spark": "SIZE(x)",
},
@ -736,6 +766,7 @@ class TestDialect(Validator):
self.validate_all(
"SELECT a FROM x CROSS JOIN UNNEST(y) AS t (a)",
write={
"drill": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)",
"presto": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)",
"spark": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a",
},
@ -743,6 +774,7 @@ class TestDialect(Validator):
self.validate_all(
"SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t (a, b)",
write={
"drill": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
"presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
"spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b",
},
@ -775,6 +807,7 @@ class TestDialect(Validator):
},
write={
"bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b",
"drill": "SELECT * FROM a UNION SELECT * FROM b",
"duckdb": "SELECT * FROM a UNION SELECT * FROM b",
"presto": "SELECT * FROM a UNION SELECT * FROM b",
"spark": "SELECT * FROM a UNION SELECT * FROM b",
@ -887,6 +920,7 @@ class TestDialect(Validator):
write={
"bigquery": "LOWER(x) LIKE '%y'",
"clickhouse": "x ILIKE '%y'",
"drill": "x `ILIKE` '%y'",
"duckdb": "x ILIKE '%y'",
"hive": "LOWER(x) LIKE '%y'",
"mysql": "LOWER(x) LIKE '%y'",
@ -910,32 +944,38 @@ class TestDialect(Validator):
self.validate_all(
"POSITION(' ' in x)",
write={
"drill": "STRPOS(x, ' ')",
"duckdb": "STRPOS(x, ' ')",
"postgres": "STRPOS(x, ' ')",
"presto": "STRPOS(x, ' ')",
"spark": "LOCATE(' ', x)",
"clickhouse": "position(x, ' ')",
"snowflake": "POSITION(' ', x)",
"mysql": "LOCATE(' ', x)",
},
)
self.validate_all(
"STR_POSITION('a', x)",
write={
"drill": "STRPOS(x, 'a')",
"duckdb": "STRPOS(x, 'a')",
"postgres": "STRPOS(x, 'a')",
"presto": "STRPOS(x, 'a')",
"spark": "LOCATE('a', x)",
"clickhouse": "position(x, 'a')",
"snowflake": "POSITION('a', x)",
"mysql": "LOCATE('a', x)",
},
)
self.validate_all(
"POSITION('a', x, 3)",
write={
"drill": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
"presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
"spark": "LOCATE('a', x, 3)",
"clickhouse": "position(x, 'a', 3)",
"snowflake": "POSITION('a', x, 3)",
"mysql": "LOCATE('a', x, 3)",
},
)
self.validate_all(
@ -960,6 +1000,7 @@ class TestDialect(Validator):
self.validate_all(
"IF(x > 1, 1, 0)",
write={
"drill": "`IF`(x > 1, 1, 0)",
"duckdb": "CASE WHEN x > 1 THEN 1 ELSE 0 END",
"presto": "IF(x > 1, 1, 0)",
"hive": "IF(x > 1, 1, 0)",
@ -970,6 +1011,7 @@ class TestDialect(Validator):
self.validate_all(
"CASE WHEN 1 THEN x ELSE 0 END",
write={
"drill": "CASE WHEN 1 THEN x ELSE 0 END",
"duckdb": "CASE WHEN 1 THEN x ELSE 0 END",
"presto": "CASE WHEN 1 THEN x ELSE 0 END",
"hive": "CASE WHEN 1 THEN x ELSE 0 END",
@ -980,6 +1022,7 @@ class TestDialect(Validator):
self.validate_all(
"x[y]",
write={
"drill": "x[y]",
"duckdb": "x[y]",
"presto": "x[y]",
"hive": "x[y]",
@ -1000,6 +1043,7 @@ class TestDialect(Validator):
'true or null as "foo"',
write={
"bigquery": "TRUE OR NULL AS `foo`",
"drill": "TRUE OR NULL AS `foo`",
"duckdb": 'TRUE OR NULL AS "foo"',
"presto": 'TRUE OR NULL AS "foo"',
"hive": "TRUE OR NULL AS `foo`",
@ -1020,6 +1064,7 @@ class TestDialect(Validator):
"LEVENSHTEIN(col1, col2)",
write={
"duckdb": "LEVENSHTEIN(col1, col2)",
"drill": "LEVENSHTEIN_DISTANCE(col1, col2)",
"presto": "LEVENSHTEIN_DISTANCE(col1, col2)",
"hive": "LEVENSHTEIN(col1, col2)",
"spark": "LEVENSHTEIN(col1, col2)",
@ -1029,6 +1074,7 @@ class TestDialect(Validator):
"LEVENSHTEIN(coalesce(col1, col2), coalesce(col2, col1))",
write={
"duckdb": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
"drill": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
"presto": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
"hive": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
"spark": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
@ -1152,6 +1198,7 @@ class TestDialect(Validator):
self.validate_all(
"SELECT a AS b FROM x GROUP BY b",
write={
"drill": "SELECT a AS b FROM x GROUP BY b",
"duckdb": "SELECT a AS b FROM x GROUP BY b",
"presto": "SELECT a AS b FROM x GROUP BY 1",
"hive": "SELECT a AS b FROM x GROUP BY 1",
@ -1162,6 +1209,7 @@ class TestDialect(Validator):
self.validate_all(
"SELECT y x FROM my_table t",
write={
"drill": "SELECT y AS x FROM my_table AS t",
"hive": "SELECT y AS x FROM my_table AS t",
"oracle": "SELECT y AS x FROM my_table t",
"postgres": "SELECT y AS x FROM my_table AS t",
@ -1230,3 +1278,36 @@ SELECT
},
pretty=True,
)
def test_transactions(self):
self.validate_all(
"BEGIN TRANSACTION",
write={
"bigquery": "BEGIN TRANSACTION",
"mysql": "BEGIN",
"postgres": "BEGIN",
"presto": "START TRANSACTION",
"trino": "START TRANSACTION",
"redshift": "BEGIN",
"snowflake": "BEGIN",
"sqlite": "BEGIN TRANSACTION",
},
)
self.validate_all(
"BEGIN",
read={
"presto": "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE",
"trino": "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE",
},
)
self.validate_all(
"BEGIN",
read={
"presto": "START TRANSACTION ISOLATION LEVEL REPEATABLE READ",
"trino": "START TRANSACTION ISOLATION LEVEL REPEATABLE READ",
},
)
self.validate_all(
"BEGIN IMMEDIATE TRANSACTION",
write={"sqlite": "BEGIN IMMEDIATE TRANSACTION"},
)

View file

@ -0,0 +1,53 @@
from tests.dialects.test_dialect import Validator
class TestDrill(Validator):
dialect = "drill"
def test_string_literals(self):
self.validate_all(
"SELECT '2021-01-01' + INTERVAL 1 MONTH",
write={
"mysql": "SELECT '2021-01-01' + INTERVAL 1 MONTH",
},
)
def test_quotes(self):
self.validate_all(
"'\\''",
write={
"duckdb": "''''",
"presto": "''''",
"hive": "'\\''",
"spark": "'\\''",
},
)
self.validate_all(
"'\"x\"'",
write={
"duckdb": "'\"x\"'",
"presto": "'\"x\"'",
"hive": "'\"x\"'",
"spark": "'\"x\"'",
},
)
self.validate_all(
"'\\\\a'",
read={
"presto": "'\\a'",
},
write={
"duckdb": "'\\a'",
"presto": "'\\a'",
"hive": "'\\\\a'",
"spark": "'\\\\a'",
},
)
def test_table_function(self):
self.validate_all(
"SELECT * FROM table( dfs.`test_data.xlsx` (type => 'excel', sheetName => 'secondSheet'))",
write={
"drill": "SELECT * FROM table(dfs.`test_data.xlsx`(type => 'excel', sheetName => 'secondSheet'))",
},
)

View file

@ -58,6 +58,16 @@ class TestMySQL(Validator):
self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'")
self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci")
self.validate_identity("SET autocommit = ON")
self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL SERIALIZABLE")
self.validate_identity("SET TRANSACTION READ ONLY")
self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE")
self.validate_identity("SELECT SCHEMA()")
def test_canonical_functions(self):
self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)")
self.validate_identity("SELECT INSTR('str', 'substr')", "SELECT LOCATE('substr', 'str')")
self.validate_identity("SELECT UCASE('foo')", "SELECT UPPER('foo')")
self.validate_identity("SELECT LCASE('foo')", "SELECT LOWER('foo')")
def test_escape(self):
self.validate_all(

View file

@ -177,6 +177,15 @@ class TestPresto(Validator):
"spark": "CREATE TABLE test USING PARQUET AS SELECT 1",
},
)
self.validate_all(
"CREATE TABLE test STORED = 'PARQUET' AS SELECT 1",
write={
"duckdb": "CREATE TABLE test AS SELECT 1",
"presto": "CREATE TABLE test WITH (FORMAT='PARQUET') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
"spark": "CREATE TABLE test USING PARQUET AS SELECT 1",
},
)
self.validate_all(
"CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1",
write={
@ -427,3 +436,69 @@ class TestPresto(Validator):
"spark": UnsupportedError,
},
)
self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE")
self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ")
def test_encode_decode(self):
self.validate_all(
"TO_UTF8(x)",
write={
"spark": "ENCODE(x, 'utf-8')",
},
)
self.validate_all(
"FROM_UTF8(x)",
write={
"spark": "DECODE(x, 'utf-8')",
},
)
self.validate_all(
"ENCODE(x, 'utf-8')",
write={
"presto": "TO_UTF8(x)",
},
)
self.validate_all(
"DECODE(x, 'utf-8')",
write={
"presto": "FROM_UTF8(x)",
},
)
self.validate_all(
"ENCODE(x, 'invalid')",
write={
"presto": UnsupportedError,
},
)
self.validate_all(
"DECODE(x, 'invalid')",
write={
"presto": UnsupportedError,
},
)
def test_hex_unhex(self):
self.validate_all(
"TO_HEX(x)",
write={
"spark": "HEX(x)",
},
)
self.validate_all(
"FROM_HEX(x)",
write={
"spark": "UNHEX(x)",
},
)
self.validate_all(
"HEX(x)",
write={
"presto": "TO_HEX(x)",
},
)
self.validate_all(
"UNHEX(x)",
write={
"presto": "FROM_HEX(x)",
},
)

View file

@ -169,6 +169,17 @@ class TestSnowflake(Validator):
"snowflake": "SELECT a FROM test AS unpivot",
},
)
self.validate_all(
"trim(date_column, 'UTC')",
write={
"snowflake": "TRIM(date_column, 'UTC')",
"postgres": "TRIM('UTC' FROM date_column)",
},
)
self.validate_all(
"trim(date_column)",
write={"snowflake": "TRIM(date_column)"},
)
def test_null_treatment(self):
self.validate_all(

View file

@ -122,13 +122,6 @@ x AT TIME ZONE 'UTC'
CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo'
SET x = 1
SET -v
ADD JAR s3://bucket
ADD JARS s3://bucket, c
ADD FILE s3://file
ADD FILES s3://file, s3://a
ADD ARCHIVE s3://file
ADD ARCHIVES s3://file, s3://a
BEGIN IMMEDIATE TRANSACTION
COMMIT
USE db
NOT 1
@ -278,6 +271,7 @@ SELECT CEIL(a, b) FROM test
SELECT COUNT(a) FROM test
SELECT COUNT(1) FROM test
SELECT COUNT(*) FROM test
SELECT COUNT() FROM test
SELECT COUNT(DISTINCT a) FROM test
SELECT EXP(a) FROM test
SELECT FLOOR(a) FROM test
@ -372,6 +366,8 @@ WITH a AS (SELECT 1) SELECT 1 UNION SELECT 2
WITH a AS (SELECT 1) SELECT 1 INTERSECT SELECT 2
WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2
WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2
WITH sub_query AS (SELECT a FROM table) (SELECT a FROM sub_query)
WITH sub_query AS (SELECT a FROM table) ((((SELECT a FROM sub_query))))
(SELECT 1) UNION (SELECT 2)
(SELECT 1) UNION SELECT 2
SELECT 1 UNION (SELECT 2)
@ -463,6 +459,7 @@ CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECI
CREATE TABLE z (a INT(11) DEFAULT UUID())
CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id')
CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1)
CREATE TABLE z (a INT(11) NOT NULL DEFAULT -1)
CREATE TABLE z (a INT(11) NOT NULL COLLATE utf8_bin AUTO_INCREMENT)
CREATE TABLE z (a INT, PRIMARY KEY(a))
CREATE TABLE z WITH (FORMAT='parquet') AS SELECT 1
@ -476,6 +473,9 @@ CREATE TABLE z AS ((WITH cte AS (SELECT 1) SELECT * FROM cte))
CREATE TABLE z (a INT UNIQUE)
CREATE TABLE z (a INT AUTO_INCREMENT)
CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT)
CREATE TABLE z (a INT REFERENCES parent(b, c))
CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id))
CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c))
CREATE TEMPORARY FUNCTION f
CREATE TEMPORARY FUNCTION f AS 'g'
CREATE FUNCTION f
@ -514,17 +514,23 @@ DELETE FROM x WHERE y > 1
DELETE FROM y
DELETE FROM event USING sales WHERE event.eventid = sales.eventid
DELETE FROM event USING sales, USING bla WHERE event.eventid = sales.eventid
DELETE FROM event USING sales AS s WHERE event.eventid = s.eventid
PREPARE statement
EXECUTE statement
DROP TABLE a
DROP TABLE a.b
DROP TABLE IF EXISTS a
DROP TABLE IF EXISTS a.b
DROP TABLE a CASCADE
DROP VIEW a
DROP VIEW a.b
DROP VIEW IF EXISTS a
DROP VIEW IF EXISTS a.b
SHOW TABLES
USE db
BEGIN
ROLLBACK
ROLLBACK TO b
EXPLAIN SELECT * FROM x
INSERT INTO x SELECT * FROM y
INSERT INTO x (SELECT * FROM y)
@ -581,3 +587,4 @@ SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */
SELECT x FROM a.b.c /* x */, e.f.g /* x */
SELECT FOO(x /* c */) /* FOO */, b /* b */
SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */
SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b'

View file

@ -0,0 +1,5 @@
SELECT w.d + w.e AS c FROM w AS w;
SELECT CONCAT(w.d, w.e) AS c FROM w AS w;
SELECT CAST(w.d AS DATE) > w.e AS a FROM w AS w;
SELECT CAST(w.d AS DATE) > CAST(w.e AS DATE) AS a FROM w AS w;

View file

@ -119,7 +119,7 @@ GROUP BY
LIMIT 1;
# title: Root subquery is union
(SELECT b FROM x UNION SELECT b FROM y) LIMIT 1;
(SELECT b FROM x UNION SELECT b FROM y ORDER BY b) LIMIT 1;
(
SELECT
"x"."b" AS "b"
@ -128,6 +128,8 @@ LIMIT 1;
SELECT
"y"."b" AS "b"
FROM "y" AS "y"
ORDER BY
"b"
)
LIMIT 1;

View file

@ -15,7 +15,7 @@ select
from
lineitem
where
CAST(l_shipdate AS DATE) <= date '1998-12-01' - interval '90' day
l_shipdate <= date '1998-12-01' - interval '90' day
group by
l_returnflag,
l_linestatus
@ -250,8 +250,8 @@ FROM "orders" AS "orders"
LEFT JOIN "_u_0" AS "_u_0"
ON "_u_0"."l_orderkey" = "orders"."o_orderkey"
WHERE
"orders"."o_orderdate" < CAST('1993-10-01' AS DATE)
AND "orders"."o_orderdate" >= CAST('1993-07-01' AS DATE)
CAST("orders"."o_orderdate" AS DATE) < CAST('1993-10-01' AS DATE)
AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1993-07-01' AS DATE)
AND NOT "_u_0"."l_orderkey" IS NULL
GROUP BY
"orders"."o_orderpriority"
@ -293,8 +293,8 @@ SELECT
FROM "customer" AS "customer"
JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey"
AND "orders"."o_orderdate" < CAST('1995-01-01' AS DATE)
AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE)
AND CAST("orders"."o_orderdate" AS DATE) < CAST('1995-01-01' AS DATE)
AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1994-01-01' AS DATE)
JOIN "region" AS "region"
ON "region"."r_name" = 'ASIA'
JOIN "nation" AS "nation"
@ -328,8 +328,8 @@ FROM "lineitem" AS "lineitem"
WHERE
"lineitem"."l_discount" BETWEEN 0.05 AND 0.07
AND "lineitem"."l_quantity" < 24
AND "lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE)
AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE);
AND CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE)
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE);
--------------------------------------
-- TPC-H 7
@ -384,13 +384,13 @@ WITH "n1" AS (
SELECT
"n1"."n_name" AS "supp_nation",
"n2"."n_name" AS "cust_nation",
EXTRACT(year FROM "lineitem"."l_shipdate") AS "l_year",
EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME)) AS "l_year",
SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount"
)) AS "revenue"
FROM "supplier" AS "supplier"
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
ON CAST("lineitem"."l_shipdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
AND "supplier"."s_suppkey" = "lineitem"."l_suppkey"
JOIN "orders" AS "orders"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
@ -409,7 +409,7 @@ JOIN "n1" AS "n2"
GROUP BY
"n1"."n_name",
"n2"."n_name",
EXTRACT(year FROM "lineitem"."l_shipdate")
EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME))
ORDER BY
"supp_nation",
"cust_nation",
@ -456,7 +456,7 @@ group by
order by
o_year;
SELECT
EXTRACT(year FROM "orders"."o_orderdate") AS "o_year",
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year",
SUM(
CASE
WHEN "nation_2"."n_name" = 'BRAZIL'
@ -477,7 +477,7 @@ JOIN "customer" AS "customer"
ON "customer"."c_nationkey" = "nation"."n_nationkey"
JOIN "orders" AS "orders"
ON "orders"."o_custkey" = "customer"."c_custkey"
AND "orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
AND CAST("orders"."o_orderdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
AND "part"."p_partkey" = "lineitem"."l_partkey"
@ -488,7 +488,7 @@ JOIN "nation" AS "nation_2"
WHERE
"part"."p_type" = 'ECONOMY ANODIZED STEEL'
GROUP BY
EXTRACT(year FROM "orders"."o_orderdate")
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME))
ORDER BY
"o_year";
@ -529,7 +529,7 @@ order by
o_year desc;
SELECT
"nation"."n_name" AS "nation",
EXTRACT(year FROM "orders"."o_orderdate") AS "o_year",
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year",
SUM(
"lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount"
@ -551,7 +551,7 @@ WHERE
"part"."p_name" LIKE '%green%'
GROUP BY
"nation"."n_name",
EXTRACT(year FROM "orders"."o_orderdate")
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME))
ORDER BY
"nation",
"o_year" DESC;
@ -606,8 +606,8 @@ SELECT
FROM "customer" AS "customer"
JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey"
AND "orders"."o_orderdate" < CAST('1994-01-01' AS DATE)
AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE)
AND CAST("orders"."o_orderdate" AS DATE) < CAST('1994-01-01' AS DATE)
AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1993-10-01' AS DATE)
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey" AND "lineitem"."l_returnflag" = 'R'
JOIN "nation" AS "nation"
@ -740,8 +740,8 @@ SELECT
FROM "orders" AS "orders"
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_commitdate" < "lineitem"."l_receiptdate"
AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE)
AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE)
AND CAST("lineitem"."l_receiptdate" AS DATE) < CAST('1995-01-01' AS DATE)
AND CAST("lineitem"."l_receiptdate" AS DATE) >= CAST('1994-01-01' AS DATE)
AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate"
AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP')
AND "orders"."o_orderkey" = "lineitem"."l_orderkey"
@ -832,8 +832,8 @@ FROM "lineitem" AS "lineitem"
JOIN "part" AS "part"
ON "lineitem"."l_partkey" = "part"."p_partkey"
WHERE
"lineitem"."l_shipdate" < CAST('1995-10-01' AS DATE)
AND "lineitem"."l_shipdate" >= CAST('1995-09-01' AS DATE);
CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-10-01' AS DATE)
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1995-09-01' AS DATE);
--------------------------------------
-- TPC-H 15
@ -876,8 +876,8 @@ WITH "revenue" AS (
)) AS "total_revenue"
FROM "lineitem" AS "lineitem"
WHERE
"lineitem"."l_shipdate" < CAST('1996-04-01' AS DATE)
AND "lineitem"."l_shipdate" >= CAST('1996-01-01' AS DATE)
CAST("lineitem"."l_shipdate" AS DATE) < CAST('1996-04-01' AS DATE)
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1996-01-01' AS DATE)
GROUP BY
"lineitem"."l_suppkey"
)
@ -1220,8 +1220,8 @@ WITH "_u_0" AS (
"lineitem"."l_suppkey" AS "_u_2"
FROM "lineitem" AS "lineitem"
WHERE
"lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE)
AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE)
CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE)
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE)
GROUP BY
"lineitem"."l_partkey",
"lineitem"."l_suppkey"

View file

@ -315,3 +315,10 @@ FROM (
WHERE
id = 1
) /* x */;
SELECT * /* multi
line
comment */;
SELECT
* /* multi
line
comment */;

View file

@ -57,79 +57,79 @@ SKIP_INTEGRATION = string_to_bool(os.environ.get("SKIP_INTEGRATION", "0").lower(
TPCH_SCHEMA = {
"lineitem": {
"l_orderkey": "uint64",
"l_partkey": "uint64",
"l_suppkey": "uint64",
"l_linenumber": "uint64",
"l_quantity": "float64",
"l_extendedprice": "float64",
"l_discount": "float64",
"l_tax": "float64",
"l_orderkey": "bigint",
"l_partkey": "bigint",
"l_suppkey": "bigint",
"l_linenumber": "bigint",
"l_quantity": "double",
"l_extendedprice": "double",
"l_discount": "double",
"l_tax": "double",
"l_returnflag": "string",
"l_linestatus": "string",
"l_shipdate": "date32",
"l_commitdate": "date32",
"l_receiptdate": "date32",
"l_shipdate": "string",
"l_commitdate": "string",
"l_receiptdate": "string",
"l_shipinstruct": "string",
"l_shipmode": "string",
"l_comment": "string",
},
"orders": {
"o_orderkey": "uint64",
"o_custkey": "uint64",
"o_orderkey": "bigint",
"o_custkey": "bigint",
"o_orderstatus": "string",
"o_totalprice": "float64",
"o_orderdate": "date32",
"o_totalprice": "double",
"o_orderdate": "string",
"o_orderpriority": "string",
"o_clerk": "string",
"o_shippriority": "int32",
"o_shippriority": "int",
"o_comment": "string",
},
"customer": {
"c_custkey": "uint64",
"c_custkey": "bigint",
"c_name": "string",
"c_address": "string",
"c_nationkey": "uint64",
"c_nationkey": "bigint",
"c_phone": "string",
"c_acctbal": "float64",
"c_acctbal": "double",
"c_mktsegment": "string",
"c_comment": "string",
},
"part": {
"p_partkey": "uint64",
"p_partkey": "bigint",
"p_name": "string",
"p_mfgr": "string",
"p_brand": "string",
"p_type": "string",
"p_size": "int32",
"p_size": "int",
"p_container": "string",
"p_retailprice": "float64",
"p_retailprice": "double",
"p_comment": "string",
},
"supplier": {
"s_suppkey": "uint64",
"s_suppkey": "bigint",
"s_name": "string",
"s_address": "string",
"s_nationkey": "uint64",
"s_nationkey": "bigint",
"s_phone": "string",
"s_acctbal": "float64",
"s_acctbal": "double",
"s_comment": "string",
},
"partsupp": {
"ps_partkey": "uint64",
"ps_suppkey": "uint64",
"ps_availqty": "int32",
"ps_supplycost": "float64",
"ps_partkey": "bigint",
"ps_suppkey": "bigint",
"ps_availqty": "int",
"ps_supplycost": "double",
"ps_comment": "string",
},
"nation": {
"n_nationkey": "uint64",
"n_nationkey": "bigint",
"n_name": "string",
"n_regionkey": "uint64",
"n_regionkey": "bigint",
"n_comment": "string",
},
"region": {
"r_regionkey": "uint64",
"r_regionkey": "bigint",
"r_name": "string",
"r_comment": "string",
},

View file

@ -1,12 +1,15 @@
import unittest
from datetime import date
import duckdb
import pandas as pd
from pandas.testing import assert_frame_equal
from sqlglot import exp, parse_one
from sqlglot.errors import ExecuteError
from sqlglot.executor import execute
from sqlglot.executor.python import Python
from sqlglot.executor.table import Table, ensure_tables
from tests.helpers import (
FIXTURES_DIR,
SKIP_INTEGRATION,
@ -67,13 +70,399 @@ class TestExecutor(unittest.TestCase):
def to_csv(expression):
if isinstance(expression, exp.Table):
return parse_one(
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}"
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}"
)
return expression
for sql, _ in self.sqls[0:3]:
for i, (sql, _) in enumerate(self.sqls[0:7]):
with self.subTest(f"tpch-h {i + 1}"):
a = self.cached_execute(sql)
sql = parse_one(sql).transform(to_csv).sql(pretty=True)
table = execute(sql, TPCH_SCHEMA)
b = pd.DataFrame(table.rows, columns=table.columns)
assert_frame_equal(a, b, check_dtype=False)
def test_execute_callable(self):
tables = {
"x": [
{"a": "a", "b": "d"},
{"a": "b", "b": "e"},
{"a": "c", "b": "f"},
],
"y": [
{"b": "d", "c": "g"},
{"b": "e", "c": "h"},
{"b": "f", "c": "i"},
],
"z": [],
}
schema = {
"x": {
"a": "VARCHAR",
"b": "VARCHAR",
},
"y": {
"b": "VARCHAR",
"c": "VARCHAR",
},
"z": {"d": "VARCHAR"},
}
for sql, cols, rows in [
("SELECT * FROM x", ["a", "b"], [("a", "d"), ("b", "e"), ("c", "f")]),
(
"SELECT * FROM x JOIN y ON x.b = y.b",
["a", "b", "b", "c"],
[("a", "d", "d", "g"), ("b", "e", "e", "h"), ("c", "f", "f", "i")],
),
(
"SELECT j.c AS d FROM x AS i JOIN y AS j ON i.b = j.b",
["d"],
[("g",), ("h",), ("i",)],
),
(
"SELECT CONCAT(x.a, y.c) FROM x JOIN y ON x.b = y.b WHERE y.b = 'e'",
["_col_0"],
[("bh",)],
),
(
"SELECT * FROM x JOIN y ON x.b = y.b WHERE y.b = 'e'",
["a", "b", "b", "c"],
[("b", "e", "e", "h")],
),
(
"SELECT * FROM z",
["d"],
[],
),
(
"SELECT d FROM z ORDER BY d",
["d"],
[],
),
(
"SELECT a FROM x WHERE x.a <> 'b'",
["a"],
[("a",), ("c",)],
),
(
"SELECT a AS i FROM x ORDER BY a",
["i"],
[("a",), ("b",), ("c",)],
),
(
"SELECT a AS i FROM x ORDER BY i",
["i"],
[("a",), ("b",), ("c",)],
),
(
"SELECT 100 - ORD(a) AS a, a AS i FROM x ORDER BY a",
["a", "i"],
[(1, "c"), (2, "b"), (3, "a")],
),
(
"SELECT a /* test */ FROM x LIMIT 1",
["a"],
[("a",)],
),
]:
with self.subTest(sql):
result = execute(sql, schema=schema, tables=tables)
self.assertEqual(result.columns, tuple(cols))
self.assertEqual(result.rows, rows)
def test_set_operations(self):
tables = {
"x": [
{"a": "a"},
{"a": "b"},
{"a": "c"},
],
"y": [
{"a": "b"},
{"a": "c"},
{"a": "d"},
],
}
schema = {
"x": {
"a": "VARCHAR",
},
"y": {
"a": "VARCHAR",
},
}
for sql, cols, rows in [
(
"SELECT a FROM x UNION ALL SELECT a FROM y",
["a"],
[("a",), ("b",), ("c",), ("b",), ("c",), ("d",)],
),
(
"SELECT a FROM x UNION SELECT a FROM y",
["a"],
[("a",), ("b",), ("c",), ("d",)],
),
(
"SELECT a FROM x EXCEPT SELECT a FROM y",
["a"],
[("a",)],
),
(
"SELECT a FROM x INTERSECT SELECT a FROM y",
["a"],
[("b",), ("c",)],
),
(
"""SELECT i.a
FROM (
SELECT a FROM x UNION SELECT a FROM y
) AS i
JOIN (
SELECT a FROM x UNION SELECT a FROM y
) AS j
ON i.a = j.a""",
["a"],
[("a",), ("b",), ("c",), ("d",)],
),
(
"SELECT 1 AS a UNION SELECT 2 AS a UNION SELECT 3 AS a",
["a"],
[(1,), (2,), (3,)],
),
]:
with self.subTest(sql):
result = execute(sql, schema=schema, tables=tables)
self.assertEqual(result.columns, tuple(cols))
self.assertEqual(set(result.rows), set(rows))
def test_execute_catalog_db_table(self):
tables = {
"catalog": {
"db": {
"x": [
{"a": "a"},
{"a": "b"},
{"a": "c"},
],
}
}
}
schema = {
"catalog": {
"db": {
"x": {
"a": "VARCHAR",
}
}
}
}
result1 = execute("SELECT * FROM x", schema=schema, tables=tables)
result2 = execute("SELECT * FROM catalog.db.x", schema=schema, tables=tables)
assert result1.columns == result2.columns
assert result1.rows == result2.rows
def test_execute_tables(self):
tables = {
"sushi": [
{"id": 1, "price": 1.0},
{"id": 2, "price": 2.0},
{"id": 3, "price": 3.0},
],
"order_items": [
{"sushi_id": 1, "order_id": 1},
{"sushi_id": 1, "order_id": 1},
{"sushi_id": 2, "order_id": 1},
{"sushi_id": 3, "order_id": 2},
],
"orders": [
{"id": 1, "user_id": 1},
{"id": 2, "user_id": 2},
],
}
self.assertEqual(
execute(
"""
SELECT
o.user_id,
SUM(s.price) AS price
FROM orders o
JOIN order_items i
ON o.id = i.order_id
JOIN sushi s
ON i.sushi_id = s.id
GROUP BY o.user_id
""",
tables=tables,
).rows,
[
(1, 4.0),
(2, 3.0),
],
)
self.assertEqual(
execute(
"""
SELECT
o.id, x.*
FROM orders o
LEFT JOIN (
SELECT
1 AS id, 'b' AS x
UNION ALL
SELECT
3 AS id, 'c' AS x
) x
ON o.id = x.id
""",
tables=tables,
).rows,
[(1, 1, "b"), (2, None, None)],
)
self.assertEqual(
execute(
"""
SELECT
o.id, x.*
FROM orders o
RIGHT JOIN (
SELECT
1 AS id,
'b' AS x
UNION ALL
SELECT
3 AS id, 'c' AS x
) x
ON o.id = x.id
""",
tables=tables,
).rows,
[
(1, 1, "b"),
(None, 3, "c"),
],
)
def test_table_depth_mismatch(self):
tables = {"table": []}
schema = {"db": {"table": {"col": "VARCHAR"}}}
with self.assertRaises(ExecuteError):
execute("SELECT * FROM table", schema=schema, tables=tables)
def test_tables(self):
tables = ensure_tables(
{
"catalog1": {
"db1": {
"t1": [
{"a": 1},
],
"t2": [
{"a": 1},
],
},
"db2": {
"t3": [
{"a": 1},
],
"t4": [
{"a": 1},
],
},
},
"catalog2": {
"db3": {
"t5": Table(columns=("a",), rows=[(1,)]),
"t6": Table(columns=("a",), rows=[(1,)]),
},
"db4": {
"t7": Table(columns=("a",), rows=[(1,)]),
"t8": Table(columns=("a",), rows=[(1,)]),
},
},
}
)
t1 = tables.find(exp.table_(table="t1", db="db1", catalog="catalog1"))
self.assertEqual(t1.columns, ("a",))
self.assertEqual(t1.rows, [(1,)])
t8 = tables.find(exp.table_(table="t8"))
self.assertEqual(t1.columns, t8.columns)
self.assertEqual(t1.rows, t8.rows)
def test_static_queries(self):
for sql, cols, rows in [
("SELECT 1", ["_col_0"], [(1,)]),
("SELECT 1 + 2 AS x", ["x"], [(3,)]),
("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]),
]:
result = execute(sql)
self.assertEqual(result.columns, tuple(cols))
self.assertEqual(result.rows, rows)
def test_aggregate_without_group_by(self):
result = execute("SELECT SUM(x) FROM t", tables={"t": [{"x": 1}, {"x": 2}]})
self.assertEqual(result.columns, ("_col_0",))
self.assertEqual(result.rows, [(3,)])
def test_scalar_functions(self):
for sql, expected in [
("CONCAT('a', 'b')", "ab"),
("CONCAT('a', NULL)", None),
("CONCAT_WS('_', 'a', 'b')", "a_b"),
("STR_POSITION('bar', 'foobarbar')", 4),
("STR_POSITION('bar', 'foobarbar', 5)", 7),
("STR_POSITION(NULL, 'foobarbar')", None),
("STR_POSITION('bar', NULL)", None),
("UPPER('foo')", "FOO"),
("UPPER(NULL)", None),
("LOWER('FOO')", "foo"),
("LOWER(NULL)", None),
("IFNULL('a', 'b')", "a"),
("IFNULL(NULL, 'b')", "b"),
("IFNULL(NULL, NULL)", None),
("SUBSTRING('12345')", "12345"),
("SUBSTRING('12345', 3)", "345"),
("SUBSTRING('12345', 3, 0)", ""),
("SUBSTRING('12345', 3, 1)", "3"),
("SUBSTRING('12345', 3, 2)", "34"),
("SUBSTRING('12345', 3, 3)", "345"),
("SUBSTRING('12345', 3, 4)", "345"),
("SUBSTRING('12345', -3)", "345"),
("SUBSTRING('12345', -3, 0)", ""),
("SUBSTRING('12345', -3, 1)", "3"),
("SUBSTRING('12345', -3, 2)", "34"),
("SUBSTRING('12345', 0)", ""),
("SUBSTRING('12345', 0, 1)", ""),
("SUBSTRING(NULL)", None),
("SUBSTRING(NULL, 1)", None),
("CAST(1 AS TEXT)", "1"),
("CAST('1' AS LONG)", 1),
("CAST('1.1' AS FLOAT)", 1.1),
("COALESCE(NULL)", None),
("COALESCE(NULL, NULL)", None),
("COALESCE(NULL, 'b')", "b"),
("COALESCE('a', 'b')", "a"),
("1 << 1", 2),
("1 >> 1", 0),
("1 & 1", 1),
("1 | 1", 1),
("1 < 1", False),
("1 <= 1", True),
("1 > 1", False),
("1 >= 1", True),
("1 + NULL", None),
("IF(true, 1, 0)", 1),
("IF(false, 1, 0)", 0),
("CASE WHEN 0 = 1 THEN 'foo' ELSE 'bar' END", "bar"),
("CAST('2022-01-01' AS DATE) + INTERVAL '1' DAY", date(2022, 1, 2)),
]:
with self.subTest(sql):
result = execute(f"SELECT {sql}")
self.assertEqual(result.rows, [(expected,)])

View file

@ -441,6 +441,9 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(parse_one("VARIANCE(a)"), exp.Variance)
self.assertIsInstance(parse_one("VARIANCE_POP(a)"), exp.VariancePop)
self.assertIsInstance(parse_one("YEAR(a)"), exp.Year)
self.assertIsInstance(parse_one("BEGIN DEFERRED TRANSACTION"), exp.Transaction)
self.assertIsInstance(parse_one("COMMIT"), exp.Commit)
self.assertIsInstance(parse_one("ROLLBACK"), exp.Rollback)
def test_column(self):
dot = parse_one("a.b.c")
@ -479,9 +482,9 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(column.text("expression"), "c")
self.assertEqual(column.text("y"), "")
self.assertEqual(parse_one("select * from x.y").find(exp.Table).text("db"), "x")
self.assertEqual(parse_one("select *").text("this"), "")
self.assertEqual(parse_one("1 + 1").text("this"), "1")
self.assertEqual(parse_one("'a'").text("this"), "a")
self.assertEqual(parse_one("select *").name, "")
self.assertEqual(parse_one("1 + 1").name, "1")
self.assertEqual(parse_one("'a'").name, "a")
def test_alias(self):
self.assertEqual(alias("foo", "bar").sql(), "foo AS bar")
@ -538,8 +541,8 @@ class TestExpressions(unittest.TestCase):
this=exp.Literal.string("TABLE_FORMAT"),
value=exp.to_identifier("test_format"),
),
exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.NULL),
exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.TRUE),
exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.null()),
exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.true()),
]
),
)

View file

@ -29,6 +29,7 @@ class TestOptimizer(unittest.TestCase):
CREATE TABLE x (a INT, b INT);
CREATE TABLE y (b INT, c INT);
CREATE TABLE z (b INT, c INT);
CREATE TABLE w (d TEXT, e TEXT);
INSERT INTO x VALUES (1, 1);
INSERT INTO x VALUES (2, 2);
@ -47,6 +48,8 @@ class TestOptimizer(unittest.TestCase):
INSERT INTO y VALUES (4, 4);
INSERT INTO y VALUES (5, 5);
INSERT INTO y VALUES (null, null);
INSERT INTO w VALUES ('a', 'b');
"""
)
@ -64,6 +67,10 @@ class TestOptimizer(unittest.TestCase):
"b": "INT",
"c": "INT",
},
"w": {
"d": "TEXT",
"e": "TEXT",
},
}
def check_file(self, file, func, pretty=False, execute=False, **kwargs):
@ -224,6 +231,18 @@ class TestOptimizer(unittest.TestCase):
def test_eliminate_subqueries(self):
self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries)
def test_canonicalize(self):
optimize = partial(
optimizer.optimize,
rules=[
optimizer.qualify_tables.qualify_tables,
optimizer.qualify_columns.qualify_columns,
annotate_types,
optimizer.canonicalize.canonicalize,
],
)
self.check_file("canonicalize", optimize, schema=self.schema)
def test_tpch(self):
self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True)

View file

@ -41,12 +41,41 @@ class TestParser(unittest.TestCase):
)
def test_command(self):
expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1")
expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1", read="hive")
self.assertEqual(len(expressions), 3)
self.assertEqual(expressions[0].sql(), "SET x = 1")
self.assertEqual(expressions[1].sql(), "ADD JAR s3://a")
self.assertEqual(expressions[2].sql(), "SELECT 1")
def test_transactions(self):
expression = parse_one("BEGIN TRANSACTION")
self.assertIsNone(expression.this)
self.assertEqual(expression.args["modes"], [])
self.assertEqual(expression.sql(), "BEGIN")
expression = parse_one("START TRANSACTION", read="mysql")
self.assertIsNone(expression.this)
self.assertEqual(expression.args["modes"], [])
self.assertEqual(expression.sql(), "BEGIN")
expression = parse_one("BEGIN DEFERRED TRANSACTION")
self.assertEqual(expression.this, "DEFERRED")
self.assertEqual(expression.args["modes"], [])
self.assertEqual(expression.sql(), "BEGIN")
expression = parse_one(
"START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE", read="presto"
)
self.assertIsNone(expression.this)
self.assertEqual(expression.args["modes"][0], "READ WRITE")
self.assertEqual(expression.args["modes"][1], "ISOLATION LEVEL SERIALIZABLE")
self.assertEqual(expression.sql(), "BEGIN")
expression = parse_one("BEGIN", read="bigquery")
self.assertNotIsInstance(expression, exp.Transaction)
self.assertIsNone(expression.expression)
self.assertEqual(expression.sql(), "BEGIN")
def test_identify(self):
expression = parse_one(
"""
@ -55,14 +84,14 @@ class TestParser(unittest.TestCase):
"""
)
assert expression.expressions[0].text("this") == "a"
assert expression.expressions[1].text("this") == "b"
assert expression.expressions[2].text("alias") == "c"
assert expression.expressions[3].text("alias") == "D"
assert expression.expressions[4].text("alias") == "y|z'"
assert expression.expressions[0].name == "a"
assert expression.expressions[1].name == "b"
assert expression.expressions[2].alias == "c"
assert expression.expressions[3].alias == "D"
assert expression.expressions[4].alias == "y|z'"
table = expression.args["from"].expressions[0]
assert table.args["this"].args["this"] == "z"
assert table.args["db"].args["this"] == "y"
assert table.this.name == "z"
assert table.args["db"].name == "y"
def test_multi(self):
expressions = parse(
@ -72,8 +101,8 @@ class TestParser(unittest.TestCase):
)
assert len(expressions) == 2
assert expressions[0].args["from"].expressions[0].args["this"].args["this"] == "a"
assert expressions[1].args["from"].expressions[0].args["this"].args["this"] == "b"
assert expressions[0].args["from"].expressions[0].this.name == "a"
assert expressions[1].args["from"].expressions[0].this.name == "b"
def test_expression(self):
ignore = Parser(error_level=ErrorLevel.IGNORE)
@ -200,7 +229,7 @@ class TestParser(unittest.TestCase):
@patch("sqlglot.parser.logger")
def test_comment_error_n(self, logger):
parse_one(
"""CREATE TABLE x
"""SUM
(
-- test
)""",
@ -208,19 +237,19 @@ class TestParser(unittest.TestCase):
)
assert_logger_contains(
"Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Schema'>. Line 4, Col: 1.",
"Required keyword: 'this' missing for <class 'sqlglot.expressions.Sum'>. Line 4, Col: 1.",
logger,
)
@patch("sqlglot.parser.logger")
def test_comment_error_r(self, logger):
parse_one(
"""CREATE TABLE x (-- test\r)""",
"""SUM(-- test\r)""",
error_level=ErrorLevel.WARN,
)
assert_logger_contains(
"Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Schema'>. Line 2, Col: 1.",
"Required keyword: 'this' missing for <class 'sqlglot.expressions.Sum'>. Line 2, Col: 1.",
logger,
)

View file

@ -12,6 +12,7 @@ class TestTokens(unittest.TestCase):
("--comment\nfoo --test", "comment"),
("foo --comment", "comment"),
("foo", None),
("foo /*comment 1*/ /*comment 2*/", "comment 1"),
]
for sql, comment in sql_comment:

View file

@ -20,6 +20,13 @@ class TestTranspile(unittest.TestCase):
self.assertEqual(transpile(sql, **kwargs)[0], target)
def test_alias(self):
self.assertEqual(transpile("SELECT 1 current_time")[0], "SELECT 1 AS current_time")
self.assertEqual(
transpile("SELECT 1 current_timestamp")[0], "SELECT 1 AS current_timestamp"
)
self.assertEqual(transpile("SELECT 1 current_date")[0], "SELECT 1 AS current_date")
self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime")
for key in ("union", "filter", "over", "from", "join"):
with self.subTest(f"alias {key}"):
self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}")
@ -69,6 +76,10 @@ class TestTranspile(unittest.TestCase):
self.validate("SELECT 3>=3", "SELECT 3 >= 3")
def test_comments(self):
self.validate("SELECT */*comment*/", "SELECT * /* comment */")
self.validate(
"SELECT * FROM table /*comment 1*/ /*comment 2*/", "SELECT * FROM table /* comment 1 */"
)
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */")
self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo")
self.validate("SELECT --!5\nx FROM foo", "/* !5 */ SELECT x FROM foo")