1
0
Fork 0

Merging upstream version 10.2.6.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:58:37 +01:00
parent 40155883c5
commit 17f6b2c749
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
36 changed files with 1281 additions and 493 deletions

View file

@ -1,6 +1,43 @@
Changelog Changelog
========= =========
v10.2.0
------
Changes:
- Breaking: types inferred from annotate_types are now DataType objects, instead of DataType.Type.
- New: the optimizer can now simplify [BETWEEN expressions expressed as explicit comparisons](https://github.com/tobymao/sqlglot/commit/e24d0317dfa644104ff21d009b790224bf84d698).
- New: the optimizer now removes redundant casts.
- New: added support for Redshift's ENCODE/DECODE.
- New: the optimizer now [treats identifiers as case-insensitive](https://github.com/tobymao/sqlglot/commit/638ed265f195219d7226f4fbae128f1805ae8988).
- New: the optimizer now [handles nested CTEs](https://github.com/tobymao/sqlglot/commit/1bdd652792889a8aaffb1c6d2c8aa1fe4a066281).
- New: the executor can now execute SELECT DISTINCT expressions.
- New: added support for Redshift's COPY and UNLOAD commands.
- New: added ability to parse LIKE in CREATE TABLE statement.
- New: the optimizer now [unnests scalar subqueries as cross joins](https://github.com/tobymao/sqlglot/commit/4373ad8518ede4ef1fda8b247b648c680a93d12d).
- Improvement: fixed Bigquery's ARRAY function parsing, so that it can now handle a SELECT expression as an argument.
- Improvement: improved Snowflake's [ARRAY and MAP constructs](https://github.com/tobymao/sqlglot/commit/0506657dba55fe71d004c81c907e23cdd2b37d82).
- Improvement: fixed transpilation between STRING_AGG and GROUP_CONCAT.
- Improvement: the INTO clause can now be parsed in SELECT expressions.
- Improvement: improve executor; it currently executes all TPC-H queries up to TPC-H 17 (inclusive).
- Improvement: DISTINCT ON is now transpiled to a SELECT expression from a subquery for Redshift.
v10.1.0 v10.1.0
------ ------

View file

@ -30,7 +30,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType from sqlglot.tokens import Tokenizer, TokenType
__version__ = "10.1.3" __version__ = "10.2.6"
pretty = False pretty = False

View file

@ -317,7 +317,7 @@ class DataFrame:
sqlglot.schema.add_table( sqlglot.schema.add_table(
cache_table_name, cache_table_name,
{ {
expression.alias_or_name: expression.type.name expression.alias_or_name: expression.type.sql("spark")
for expression in select_expression.expressions for expression in select_expression.expressions
}, },
) )

View file

@ -110,17 +110,17 @@ class BigQuery(Dialect):
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"CURRENT_TIME": TokenType.CURRENT_TIME, "CURRENT_TIME": TokenType.CURRENT_TIME,
"GEOGRAPHY": TokenType.GEOGRAPHY, "GEOGRAPHY": TokenType.GEOGRAPHY,
"INT64": TokenType.BIGINT,
"FLOAT64": TokenType.DOUBLE, "FLOAT64": TokenType.DOUBLE,
"INT64": TokenType.BIGINT,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"QUALIFY": TokenType.QUALIFY, "QUALIFY": TokenType.QUALIFY,
"UNKNOWN": TokenType.NULL, "UNKNOWN": TokenType.NULL,
"WINDOW": TokenType.WINDOW, "WINDOW": TokenType.WINDOW,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
} }
KEYWORDS.pop("DIV") KEYWORDS.pop("DIV")
@ -131,6 +131,7 @@ class BigQuery(Dialect):
"DATE_ADD": _date_add(exp.DateAdd), "DATE_ADD": _date_add(exp.DateAdd),
"DATETIME_ADD": _date_add(exp.DatetimeAdd), "DATETIME_ADD": _date_add(exp.DatetimeAdd),
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)), "DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
"TIME_ADD": _date_add(exp.TimeAdd), "TIME_ADD": _date_add(exp.TimeAdd),
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd), "TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
"DATE_SUB": _date_add(exp.DateSub), "DATE_SUB": _date_add(exp.DateSub),
@ -144,6 +145,7 @@ class BigQuery(Dialect):
FUNCTION_PARSERS = { FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS, **parser.Parser.FUNCTION_PARSERS,
"ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]),
} }
FUNCTION_PARSERS.pop("TRIM") FUNCTION_PARSERS.pop("TRIM")
@ -161,7 +163,6 @@ class BigQuery(Dialect):
class Generator(generator.Generator): class Generator(generator.Generator):
TRANSFORMS = { TRANSFORMS = {
**generator.Generator.TRANSFORMS, **generator.Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.DateAdd: _date_add_sql("DATE", "ADD"), exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"), exp.DateSub: _date_add_sql("DATE", "SUB"),
@ -183,6 +184,7 @@ class BigQuery(Dialect):
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE" if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC", else "NOT DETERMINISTIC",
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
} }
TYPE_MAPPING = { TYPE_MAPPING = {
@ -210,24 +212,31 @@ class BigQuery(Dialect):
EXPLICIT_UNION = True EXPLICIT_UNION = True
def transaction_sql(self, *_): def array_sql(self, expression: exp.Array) -> str:
first_arg = seq_get(expression.expressions, 0)
if isinstance(first_arg, exp.Subqueryable):
return f"ARRAY{self.wrap(self.sql(first_arg))}"
return inline_array_sql(self, expression)
def transaction_sql(self, *_) -> str:
return "BEGIN TRANSACTION" return "BEGIN TRANSACTION"
def commit_sql(self, *_): def commit_sql(self, *_) -> str:
return "COMMIT TRANSACTION" return "COMMIT TRANSACTION"
def rollback_sql(self, *_): def rollback_sql(self, *_) -> str:
return "ROLLBACK TRANSACTION" return "ROLLBACK TRANSACTION"
def in_unnest_op(self, unnest): def in_unnest_op(self, expression: exp.Unnest) -> str:
return self.sql(unnest) return self.sql(expression)
def except_op(self, expression): def except_op(self, expression: exp.Except) -> str:
if not expression.args.get("distinct", False): if not expression.args.get("distinct", False):
self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery") self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery")
return f"EXCEPT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" return f"EXCEPT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
def intersect_op(self, expression): def intersect_op(self, expression: exp.Intersect) -> str:
if not expression.args.get("distinct", False): if not expression.args.get("distinct", False):
self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery") self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery")
return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"

View file

@ -190,6 +190,7 @@ class Hive(Dialect):
"ADD FILES": TokenType.COMMAND, "ADD FILES": TokenType.COMMAND,
"ADD JAR": TokenType.COMMAND, "ADD JAR": TokenType.COMMAND,
"ADD JARS": TokenType.COMMAND, "ADD JARS": TokenType.COMMAND,
"WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
} }
class Parser(parser.Parser): class Parser(parser.Parser):
@ -238,6 +239,13 @@ class Hive(Dialect):
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
} }
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS,
TokenType.SERDE_PROPERTIES: lambda self: exp.SerdeProperties(
expressions=self._parse_wrapped_csv(self._parse_property)
),
}
class Generator(generator.Generator): class Generator(generator.Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING,
@ -297,6 +305,8 @@ class Hive(Dialect):
exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}", exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}",
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"), exp.NumberToStr: rename_func("FORMAT_NUMBER"),
} }
@ -308,12 +318,15 @@ class Hive(Dialect):
exp.SchemaCommentProperty, exp.SchemaCommentProperty,
exp.LocationProperty, exp.LocationProperty,
exp.TableFormatProperty, exp.TableFormatProperty,
exp.RowFormatDelimitedProperty,
exp.RowFormatSerdeProperty,
exp.SerdeProperties,
} }
def with_properties(self, properties): def with_properties(self, properties):
return self.properties( return self.properties(
properties, properties,
prefix="TBLPROPERTIES", prefix=self.seg("TBLPROPERTIES"),
) )
def datatype_sql(self, expression): def datatype_sql(self, expression):

View file

@ -98,6 +98,7 @@ class Oracle(Dialect):
class Tokenizer(tokens.Tokenizer): class Tokenizer(tokens.Tokenizer):
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"MINUS": TokenType.EXCEPT,
"START": TokenType.BEGIN, "START": TokenType.BEGIN,
"TOP": TokenType.TOP, "TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR, "VARCHAR2": TokenType.VARCHAR,

View file

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from sqlglot import exp, transforms from sqlglot import exp, transforms
from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.postgres import Postgres from sqlglot.dialects.postgres import Postgres
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
@ -13,12 +14,20 @@ class Redshift(Postgres):
"HH": "%H", "HH": "%H",
} }
class Parser(Postgres.Parser):
FUNCTIONS = {
**Postgres.Parser.FUNCTIONS, # type: ignore
"DECODE": exp.Matches.from_arg_list,
"NVL": exp.Coalesce.from_arg_list,
}
class Tokenizer(Postgres.Tokenizer): class Tokenizer(Postgres.Tokenizer):
ESCAPES = ["\\"] ESCAPES = ["\\"]
KEYWORDS = { KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS, # type: ignore **Postgres.Tokenizer.KEYWORDS, # type: ignore
"COPY": TokenType.COMMAND, "COPY": TokenType.COMMAND,
"ENCODE": TokenType.ENCODE,
"GEOMETRY": TokenType.GEOMETRY, "GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY, "GEOGRAPHY": TokenType.GEOGRAPHY,
"HLLSKETCH": TokenType.HLLSKETCH, "HLLSKETCH": TokenType.HLLSKETCH,
@ -50,4 +59,5 @@ class Redshift(Postgres):
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})", exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e), exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.Matches: rename_func("DECODE"),
} }

View file

@ -198,6 +198,7 @@ class Snowflake(Dialect):
"TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPNTZ": TokenType.TIMESTAMP, "TIMESTAMPNTZ": TokenType.TIMESTAMP,
"MINUS": TokenType.EXCEPT,
"SAMPLE": TokenType.TABLE_SAMPLE, "SAMPLE": TokenType.TABLE_SAMPLE,
} }

View file

@ -19,10 +19,13 @@ class reverse_key:
return other.obj < self.obj return other.obj < self.obj
def filter_nulls(func): def filter_nulls(func, empty_null=True):
@wraps(func) @wraps(func)
def _func(values): def _func(values):
return func(v for v in values if v is not None) filtered = tuple(v for v in values if v is not None)
if not filtered and empty_null:
return None
return func(filtered)
return _func return _func
@ -126,7 +129,7 @@ ENV = {
# aggs # aggs
"SUM": filter_nulls(sum), "SUM": filter_nulls(sum),
"AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
"COUNT": filter_nulls(lambda acc: sum(1 for _ in acc)), "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False),
"MAX": filter_nulls(max), "MAX": filter_nulls(max),
"MIN": filter_nulls(min), "MIN": filter_nulls(min),
# scalar functions # scalar functions

View file

@ -310,9 +310,9 @@ class PythonExecutor:
if i == length - 1: if i == length - 1:
context.set_range(start, end - 1) context.set_range(start, end - 1)
add_row() add_row()
elif step.limit > 0: elif step.limit > 0 and not group_by:
context.set_range(0, 0) context.set_range(0, 0)
table.append(context.eval_tuple(group_by) + context.eval_tuple(aggregations)) table.append(context.eval_tuple(aggregations))
context = self.context({step.name: table, **{name: table for name in context.tables}}) context = self.context({step.name: table, **{name: table for name in context.tables}})

View file

@ -43,14 +43,14 @@ class Expression(metaclass=_Expression):
key = "Expression" key = "Expression"
arg_types = {"this": True} arg_types = {"this": True}
__slots__ = ("args", "parent", "arg_key", "type", "comments") __slots__ = ("args", "parent", "arg_key", "comments", "_type")
def __init__(self, **args): def __init__(self, **args):
self.args = args self.args = args
self.parent = None self.parent = None
self.arg_key = None self.arg_key = None
self.type = None
self.comments = None self.comments = None
self._type: t.Optional[DataType] = None
for arg_key, value in self.args.items(): for arg_key, value in self.args.items():
self._set_parent(arg_key, value) self._set_parent(arg_key, value)
@ -122,6 +122,16 @@ class Expression(metaclass=_Expression):
return "NULL" return "NULL"
return self.alias or self.name return self.alias or self.name
@property
def type(self) -> t.Optional[DataType]:
return self._type
@type.setter
def type(self, dtype: t.Optional[DataType | DataType.Type | str]) -> None:
if dtype and not isinstance(dtype, DataType):
dtype = DataType.build(dtype)
self._type = dtype # type: ignore
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
copy = self.__class__(**deepcopy(self.args)) copy = self.__class__(**deepcopy(self.args))
copy.comments = self.comments copy.comments = self.comments
@ -348,7 +358,7 @@ class Expression(metaclass=_Expression):
indent += "".join([" "] * level) indent += "".join([" "] * level)
left = f"({self.key.upper()} " left = f"({self.key.upper()} "
args = { args: t.Dict[str, t.Any] = {
k: ", ".join( k: ", ".join(
v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v) v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
for v in ensure_collection(vs) for v in ensure_collection(vs)
@ -612,6 +622,7 @@ class Create(Expression):
"properties": False, "properties": False,
"temporary": False, "temporary": False,
"transient": False, "transient": False,
"external": False,
"replace": False, "replace": False,
"unique": False, "unique": False,
"materialized": False, "materialized": False,
@ -744,13 +755,17 @@ class DefaultColumnConstraint(ColumnConstraintKind):
pass pass
class EncodeColumnConstraint(ColumnConstraintKind):
pass
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
# this: True -> ALWAYS, this: False -> BY DEFAULT # this: True -> ALWAYS, this: False -> BY DEFAULT
arg_types = {"this": True, "expression": False} arg_types = {"this": True, "expression": False}
class NotNullColumnConstraint(ColumnConstraintKind): class NotNullColumnConstraint(ColumnConstraintKind):
pass arg_types = {"allow_null": False}
class PrimaryKeyColumnConstraint(ColumnConstraintKind): class PrimaryKeyColumnConstraint(ColumnConstraintKind):
@ -766,7 +781,7 @@ class Constraint(Expression):
class Delete(Expression): class Delete(Expression):
arg_types = {"with": False, "this": True, "using": False, "where": False} arg_types = {"with": False, "this": False, "using": False, "where": False}
class Drop(Expression): class Drop(Expression):
@ -850,7 +865,7 @@ class Insert(Expression):
arg_types = { arg_types = {
"with": False, "with": False,
"this": True, "this": True,
"expression": True, "expression": False,
"overwrite": False, "overwrite": False,
"exists": False, "exists": False,
"partition": False, "partition": False,
@ -1125,6 +1140,27 @@ class VolatilityProperty(Property):
arg_types = {"this": True} arg_types = {"this": True}
class RowFormatDelimitedProperty(Property):
# https://cwiki.apache.org/confluence/display/hive/languagemanual+dml
arg_types = {
"fields": False,
"escaped": False,
"collection_items": False,
"map_keys": False,
"lines": False,
"null": False,
"serde": False,
}
class RowFormatSerdeProperty(Property):
arg_types = {"this": True}
class SerdeProperties(Property):
arg_types = {"expressions": True}
class Properties(Expression): class Properties(Expression):
arg_types = {"expressions": True} arg_types = {"expressions": True}
@ -1169,18 +1205,6 @@ class Reference(Expression):
arg_types = {"this": True, "expressions": True} arg_types = {"this": True, "expressions": True}
class RowFormat(Expression):
# https://cwiki.apache.org/confluence/display/hive/languagemanual+dml
arg_types = {
"fields": False,
"escaped": False,
"collection_items": False,
"map_keys": False,
"lines": False,
"null": False,
}
class Tuple(Expression): class Tuple(Expression):
arg_types = {"expressions": False} arg_types = {"expressions": False}
@ -1208,6 +1232,9 @@ class Subqueryable(Unionable):
alias=TableAlias(this=to_identifier(alias)), alias=TableAlias(this=to_identifier(alias)),
) )
def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
raise NotImplementedError
@property @property
def ctes(self): def ctes(self):
with_ = self.args.get("with") with_ = self.args.get("with")
@ -1320,6 +1347,32 @@ class Union(Subqueryable):
**QUERY_MODIFIERS, **QUERY_MODIFIERS,
} }
def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
"""
Set the LIMIT expression.
Example:
>>> select("1").union(select("1")).limit(1).sql()
'SELECT * FROM (SELECT 1 UNION SELECT 1) AS "_l_0" LIMIT 1'
Args:
expression (str | int | Expression): the SQL code string to parse.
This can also be an integer.
If a `Limit` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Limit`.
dialect (str): the dialect used to parse the input expression.
copy (bool): if `False`, modify this expression instance in-place.
opts (kwargs): other options to use to parse the input expressions.
Returns:
Select: The limited subqueryable.
"""
return (
select("*")
.from_(self.subquery(alias="_l_0", copy=copy))
.limit(expression, dialect=dialect, copy=False, **opts)
)
@property @property
def named_selects(self): def named_selects(self):
return self.this.unnest().named_selects return self.this.unnest().named_selects
@ -1356,7 +1409,7 @@ class Unnest(UDTF):
class Update(Expression): class Update(Expression):
arg_types = { arg_types = {
"with": False, "with": False,
"this": True, "this": False,
"expressions": True, "expressions": True,
"from": False, "from": False,
"where": False, "where": False,
@ -2057,15 +2110,20 @@ class DataType(Expression):
Type.TEXT, Type.TEXT,
} }
NUMERIC_TYPES = { INTEGER_TYPES = {
Type.INT, Type.INT,
Type.TINYINT, Type.TINYINT,
Type.SMALLINT, Type.SMALLINT,
Type.BIGINT, Type.BIGINT,
}
FLOAT_TYPES = {
Type.FLOAT, Type.FLOAT,
Type.DOUBLE, Type.DOUBLE,
} }
NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES}
TEMPORAL_TYPES = { TEMPORAL_TYPES = {
Type.TIMESTAMP, Type.TIMESTAMP,
Type.TIMESTAMPTZ, Type.TIMESTAMPTZ,
@ -2968,6 +3026,14 @@ class Use(Expression):
pass pass
class Merge(Expression):
arg_types = {"this": True, "using": True, "on": True, "expressions": True}
class When(Func):
arg_types = {"this": True, "then": True}
def _norm_args(expression): def _norm_args(expression):
args = {} args = {}

File diff suppressed because it is too large Load diff

View file

@ -385,3 +385,11 @@ def dict_depth(d: t.Dict) -> int:
except StopIteration: except StopIteration:
# d.values() returns an empty sequence # d.values() returns an empty sequence
return 1 return 1
def first(it: t.Iterable[T]) -> T:
"""Returns the first element from an iterable.
Useful for sets.
"""
return next(i for i in it)

View file

@ -14,7 +14,7 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
>>> schema = {"y": {"cola": "SMALLINT"}} >>> schema = {"y": {"cola": "SMALLINT"}}
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
>>> annotated_expr.expressions[0].type # Get the type of "x.cola + 2.5 AS cola" >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'> <Type.DOUBLE: 'DOUBLE'>
Args: Args:
@ -41,9 +41,12 @@ class TypeAnnotator:
expr_type: lambda self, expr: self._annotate_binary(expr) expr_type: lambda self, expr: self._annotate_binary(expr)
for expr_type in subclasses(exp.__name__, exp.Binary) for expr_type in subclasses(exp.__name__, exp.Binary)
}, },
exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"].this), exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.this), exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr),
exp.Alias: lambda self, expr: self._annotate_unary(expr), exp.Alias: lambda self, expr: self._annotate_unary(expr),
exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Literal: lambda self, expr: self._annotate_literal(expr), exp.Literal: lambda self, expr: self._annotate_literal(expr),
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
@ -52,6 +55,9 @@ class TypeAnnotator:
expr, exp.DataType.Type.BIGINT expr, exp.DataType.Type.BIGINT
), ),
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Min: lambda self, expr: self._annotate_by_args(expr, "this"),
exp.Max: lambda self, expr: self._annotate_by_args(expr, "this"),
exp.Sum: lambda self, expr: self._annotate_by_args(expr, "this", promote=True),
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
@ -263,10 +269,10 @@ class TypeAnnotator:
} }
# First annotate the current scope's column references # First annotate the current scope's column references
for col in scope.columns: for col in scope.columns:
source = scope.sources[col.table] source = scope.sources.get(col.table)
if isinstance(source, exp.Table): if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col) col.type = self.schema.get_column_type(source, col)
else: elif source:
col.type = selects[col.table][col.name].type col.type = selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope # Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression) self._maybe_annotate(scope.expression)
@ -280,6 +286,7 @@ class TypeAnnotator:
return expression # We've already inferred the expression's type return expression # We've already inferred the expression's type
annotator = self.annotators.get(expression.__class__) annotator = self.annotators.get(expression.__class__)
return ( return (
annotator(self, expression) annotator(self, expression)
if annotator if annotator
@ -295,18 +302,23 @@ class TypeAnnotator:
def _maybe_coerce(self, type1, type2): def _maybe_coerce(self, type1, type2):
# We propagate the NULL / UNKNOWN types upwards if found # We propagate the NULL / UNKNOWN types upwards if found
if isinstance(type1, exp.DataType):
type1 = type1.this
if isinstance(type2, exp.DataType):
type2 = type2.this
if exp.DataType.Type.NULL in (type1, type2): if exp.DataType.Type.NULL in (type1, type2):
return exp.DataType.Type.NULL return exp.DataType.Type.NULL
if exp.DataType.Type.UNKNOWN in (type1, type2): if exp.DataType.Type.UNKNOWN in (type1, type2):
return exp.DataType.Type.UNKNOWN return exp.DataType.Type.UNKNOWN
return type2 if type2 in self.coerces_to[type1] else type1 return type2 if type2 in self.coerces_to.get(type1, {}) else type1
def _annotate_binary(self, expression): def _annotate_binary(self, expression):
self._annotate_args(expression) self._annotate_args(expression)
left_type = expression.left.type left_type = expression.left.type.this
right_type = expression.right.type right_type = expression.right.type.this
if isinstance(expression, (exp.And, exp.Or)): if isinstance(expression, (exp.And, exp.Or)):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
@ -348,7 +360,7 @@ class TypeAnnotator:
expression.type = target_type expression.type = target_type
return self._annotate_args(expression) return self._annotate_args(expression)
def _annotate_by_args(self, expression, *args): def _annotate_by_args(self, expression, *args, promote=False):
self._annotate_args(expression) self._annotate_args(expression)
expressions = [] expressions = []
for arg in args: for arg in args:
@ -360,4 +372,11 @@ class TypeAnnotator:
last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
expression.type = last_datatype or exp.DataType.Type.UNKNOWN expression.type = last_datatype or exp.DataType.Type.UNKNOWN
if promote:
if expression.type.this in exp.DataType.INTEGER_TYPES:
expression.type = exp.DataType.Type.BIGINT
elif expression.type.this in exp.DataType.FLOAT_TYPES:
expression.type = exp.DataType.Type.DOUBLE
return expression return expression

View file

@ -13,13 +13,16 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
expression: The expression to canonicalize. expression: The expression to canonicalize.
""" """
exp.replace_children(expression, canonicalize) exp.replace_children(expression, canonicalize)
expression = add_text_to_concat(expression) expression = add_text_to_concat(expression)
expression = coerce_type(expression) expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
return expression return expression
def add_text_to_concat(node: exp.Expression) -> exp.Expression: def add_text_to_concat(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES: if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
node = exp.Concat(this=node.this, expression=node.expression) node = exp.Concat(this=node.this, expression=node.expression)
return node return node
@ -30,14 +33,30 @@ def coerce_type(node: exp.Expression) -> exp.Expression:
elif isinstance(node, exp.Between): elif isinstance(node, exp.Between):
_coerce_date(node.this, node.args["low"]) _coerce_date(node.this, node.args["low"])
elif isinstance(node, exp.Extract): elif isinstance(node, exp.Extract):
if node.expression.type not in exp.DataType.TEMPORAL_TYPES: if node.expression.type.this not in exp.DataType.TEMPORAL_TYPES:
_replace_cast(node.expression, "datetime") _replace_cast(node.expression, "datetime")
return node return node
def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
if (
isinstance(expression, exp.Cast)
and expression.to.type
and expression.this.type
and expression.to.type.this == expression.this.type.this
):
return expression.this
return expression
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
for a, b in itertools.permutations([a, b]): for a, b in itertools.permutations([a, b]):
if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE: if (
a.type
and a.type.this == exp.DataType.Type.DATE
and b.type
and b.type.this != exp.DataType.Type.DATE
):
_replace_cast(b, "date") _replace_cast(b, "date")

View file

@ -7,7 +7,7 @@ from decimal import Decimal
from sqlglot import exp from sqlglot import exp
from sqlglot.expressions import FALSE, NULL, TRUE from sqlglot.expressions import FALSE, NULL, TRUE
from sqlglot.generator import Generator from sqlglot.generator import Generator
from sqlglot.helper import while_changing from sqlglot.helper import first, while_changing
GENERATOR = Generator(normalize=True, identify=True) GENERATOR = Generator(normalize=True, identify=True)
@ -30,6 +30,7 @@ def simplify(expression):
def _simplify(expression, root=True): def _simplify(expression, root=True):
node = expression node = expression
node = rewrite_between(node)
node = uniq_sort(node) node = uniq_sort(node)
node = absorb_and_eliminate(node) node = absorb_and_eliminate(node)
exp.replace_children(node, lambda e: _simplify(e, False)) exp.replace_children(node, lambda e: _simplify(e, False))
@ -49,6 +50,19 @@ def simplify(expression):
return expression return expression
def rewrite_between(expression: exp.Expression) -> exp.Expression:
"""Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
"""
if isinstance(expression, exp.Between):
return exp.and_(
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
)
return expression
def simplify_not(expression): def simplify_not(expression):
""" """
Demorgan's Law Demorgan's Law
@ -57,7 +71,7 @@ def simplify_not(expression):
""" """
if isinstance(expression, exp.Not): if isinstance(expression, exp.Not):
if isinstance(expression.this, exp.Null): if isinstance(expression.this, exp.Null):
return NULL return exp.null()
if isinstance(expression.this, exp.Paren): if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest() condition = expression.this.unnest()
if isinstance(condition, exp.And): if isinstance(condition, exp.And):
@ -65,11 +79,11 @@ def simplify_not(expression):
if isinstance(condition, exp.Or): if isinstance(condition, exp.Or):
return exp.and_(exp.not_(condition.left), exp.not_(condition.right)) return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Null): if isinstance(condition, exp.Null):
return NULL return exp.null()
if always_true(expression.this): if always_true(expression.this):
return FALSE return exp.false()
if expression.this == FALSE: if expression.this == FALSE:
return TRUE return exp.true()
if isinstance(expression.this, exp.Not): if isinstance(expression.this, exp.Not):
# double negation # double negation
# NOT NOT x -> x # NOT NOT x -> x
@ -91,41 +105,120 @@ def flatten(expression):
def simplify_connectors(expression): def simplify_connectors(expression):
def _simplify_connectors(expression, left, right):
if isinstance(expression, exp.Connector): if isinstance(expression, exp.Connector):
left = expression.left
right = expression.right
if left == right: if left == right:
return left return left
if isinstance(expression, exp.And): if isinstance(expression, exp.And):
if FALSE in (left, right): if FALSE in (left, right):
return FALSE return exp.false()
if NULL in (left, right): if NULL in (left, right):
return NULL return exp.null()
if always_true(left) and always_true(right): if always_true(left) and always_true(right):
return TRUE return exp.true()
if always_true(left): if always_true(left):
return right return right
if always_true(right): if always_true(right):
return left return left
return _simplify_comparison(expression, left, right)
elif isinstance(expression, exp.Or): elif isinstance(expression, exp.Or):
if always_true(left) or always_true(right): if always_true(left) or always_true(right):
return TRUE return exp.true()
if left == FALSE and right == FALSE: if left == FALSE and right == FALSE:
return FALSE return exp.false()
if ( if (
(left == NULL and right == NULL) (left == NULL and right == NULL)
or (left == NULL and right == FALSE) or (left == NULL and right == FALSE)
or (left == FALSE and right == NULL) or (left == FALSE and right == NULL)
): ):
return NULL return exp.null()
if left == FALSE: if left == FALSE:
return right return right
if right == FALSE: if right == FALSE:
return left return left
return _simplify_comparison(expression, left, right, or_=True)
return None
return _flat_simplify(expression, _simplify_connectors)
LT_LTE = (exp.LT, exp.LTE)
GT_GTE = (exp.GT, exp.GTE)
COMPARISONS = (
*LT_LTE,
*GT_GTE,
exp.EQ,
exp.NEQ,
)
INVERSE_COMPARISONS = {
exp.LT: exp.GT,
exp.GT: exp.LT,
exp.LTE: exp.GTE,
exp.GTE: exp.LTE,
}
def _simplify_comparison(expression, left, right, or_=False):
if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
ll, lr = left.args.values()
rl, rr = right.args.values()
largs = {ll, lr}
rargs = {rl, rr}
matching = largs & rargs
columns = {m for m in matching if isinstance(m, exp.Column)}
if matching and columns:
try:
l = first(largs - columns)
r = first(rargs - columns)
except StopIteration:
return expression return expression
# make sure the comparison is always of the form x > 1 instead of 1 < x
if left.__class__ in INVERSE_COMPARISONS and l == ll:
left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
if right.__class__ in INVERSE_COMPARISONS and r == rl:
right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
if l.is_number and r.is_number:
l = float(l.name)
r = float(r.name)
elif l.is_string and r.is_string:
l = l.name
r = r.name
else:
return None
for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
return left if (av > bv if or_ else av <= bv) else right
if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
return left if (av < bv if or_ else av >= bv) else right
# we can't ever shortcut to true because the column could be null
if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
if not or_ and av <= bv:
return exp.false()
elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
if not or_ and av >= bv:
return exp.false()
elif isinstance(a, exp.EQ):
if isinstance(b, exp.LT):
return exp.false() if av >= bv else a
if isinstance(b, exp.LTE):
return exp.false() if av > bv else a
if isinstance(b, exp.GT):
return exp.false() if av <= bv else a
if isinstance(b, exp.GTE):
return exp.false() if av < bv else a
if isinstance(b, exp.NEQ):
return exp.false() if av == bv else a
return None
def remove_compliments(expression): def remove_compliments(expression):
""" """
@ -135,7 +228,7 @@ def remove_compliments(expression):
A OR NOT A -> TRUE A OR NOT A -> TRUE
""" """
if isinstance(expression, exp.Connector): if isinstance(expression, exp.Connector):
compliment = FALSE if isinstance(expression, exp.And) else TRUE compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
for a, b in itertools.permutations(expression.flatten(), 2): for a, b in itertools.permutations(expression.flatten(), 2):
if is_complement(a, b): if is_complement(a, b):
@ -211,27 +304,7 @@ def absorb_and_eliminate(expression):
def simplify_literals(expression): def simplify_literals(expression):
if isinstance(expression, exp.Binary): if isinstance(expression, exp.Binary):
operands = [] return _flat_simplify(expression, _simplify_binary)
queue = deque(expression.flatten(unnest=False))
size = len(queue)
while queue:
a = queue.popleft()
for b in queue:
result = _simplify_binary(expression, a, b)
if result:
queue.remove(b)
queue.append(result)
break
else:
operands.append(a)
if len(operands) < size:
return functools.reduce(
lambda a, b: expression.__class__(this=a, expression=b), operands
)
elif isinstance(expression, exp.Neg): elif isinstance(expression, exp.Neg):
this = expression.this this = expression.this
if this.is_number: if this.is_number:
@ -254,20 +327,13 @@ def _simplify_binary(expression, a, b):
if c == NULL: if c == NULL:
if isinstance(a, exp.Literal): if isinstance(a, exp.Literal):
return TRUE if not_ else FALSE return exp.true() if not_ else exp.false()
if a == NULL: if a == NULL:
return FALSE if not_ else TRUE return exp.false() if not_ else exp.true()
elif isinstance(expression, exp.NullSafeEQ): elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
if a == b: return None
return TRUE
elif isinstance(expression, exp.NullSafeNEQ):
if a == b:
return FALSE
elif NULL in (a, b): elif NULL in (a, b):
return NULL return exp.null()
if isinstance(expression, exp.EQ) and a == b:
return TRUE
if a.is_number and b.is_number: if a.is_number and b.is_number:
a = int(a.name) if a.is_int else Decimal(a.name) a = int(a.name) if a.is_int else Decimal(a.name)
@ -388,4 +454,27 @@ def date_literal(date):
def boolean_literal(condition): def boolean_literal(condition):
return TRUE if condition else FALSE return exp.true() if condition else exp.false()
def _flat_simplify(expression, simplifier):
operands = []
queue = deque(expression.flatten(unnest=False))
size = len(queue)
while queue:
a = queue.popleft()
for b in queue:
result = simplifier(expression, a, b)
if result:
queue.remove(b)
queue.append(result)
break
else:
operands.append(a)
if len(operands) < size:
return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
return expression

View file

@ -185,6 +185,7 @@ class Parser(metaclass=_Parser):
TokenType.LOCAL, TokenType.LOCAL,
TokenType.LOCATION, TokenType.LOCATION,
TokenType.MATERIALIZED, TokenType.MATERIALIZED,
TokenType.MERGE,
TokenType.NATURAL, TokenType.NATURAL,
TokenType.NEXT, TokenType.NEXT,
TokenType.ONLY, TokenType.ONLY,
@ -211,7 +212,6 @@ class Parser(metaclass=_Parser):
TokenType.TABLE, TokenType.TABLE,
TokenType.TABLE_FORMAT, TokenType.TABLE_FORMAT,
TokenType.TEMPORARY, TokenType.TEMPORARY,
TokenType.TRANSIENT,
TokenType.TOP, TokenType.TOP,
TokenType.TRAILING, TokenType.TRAILING,
TokenType.TRUE, TokenType.TRUE,
@ -229,6 +229,8 @@ class Parser(metaclass=_Parser):
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY} TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY}
UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET}
TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH} TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
FUNC_TOKENS = { FUNC_TOKENS = {
@ -241,6 +243,7 @@ class Parser(metaclass=_Parser):
TokenType.FORMAT, TokenType.FORMAT,
TokenType.IDENTIFIER, TokenType.IDENTIFIER,
TokenType.ISNULL, TokenType.ISNULL,
TokenType.MERGE,
TokenType.OFFSET, TokenType.OFFSET,
TokenType.PRIMARY_KEY, TokenType.PRIMARY_KEY,
TokenType.REPLACE, TokenType.REPLACE,
@ -407,6 +410,7 @@ class Parser(metaclass=_Parser):
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.END: lambda self: self._parse_commit_or_rollback(), TokenType.END: lambda self: self._parse_commit_or_rollback(),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
TokenType.MERGE: lambda self: self._parse_merge(),
} }
UNARY_PARSERS = { UNARY_PARSERS = {
@ -474,6 +478,7 @@ class Parser(metaclass=_Parser):
TokenType.SORTKEY: lambda self: self._parse_sortkey(), TokenType.SORTKEY: lambda self: self._parse_sortkey(),
TokenType.LIKE: lambda self: self._parse_create_like(), TokenType.LIKE: lambda self: self._parse_create_like(),
TokenType.RETURNS: lambda self: self._parse_returns(), TokenType.RETURNS: lambda self: self._parse_returns(),
TokenType.ROW: lambda self: self._parse_row(),
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty), TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty), TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
@ -495,6 +500,8 @@ class Parser(metaclass=_Parser):
TokenType.VOLATILE: lambda self: self.expression( TokenType.VOLATILE: lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE") exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
), ),
TokenType.WITH: lambda self: self._parse_wrapped_csv(self._parse_property),
TokenType.PROPERTIES: lambda self: self._parse_wrapped_csv(self._parse_property),
} }
CONSTRAINT_PARSERS = { CONSTRAINT_PARSERS = {
@ -802,7 +809,8 @@ class Parser(metaclass=_Parser):
def _parse_create(self): def _parse_create(self):
replace = self._match_pair(TokenType.OR, TokenType.REPLACE) replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
temporary = self._match(TokenType.TEMPORARY) temporary = self._match(TokenType.TEMPORARY)
transient = self._match(TokenType.TRANSIENT) transient = self._match_text_seq("TRANSIENT")
external = self._match_text_seq("EXTERNAL")
unique = self._match(TokenType.UNIQUE) unique = self._match(TokenType.UNIQUE)
materialized = self._match(TokenType.MATERIALIZED) materialized = self._match(TokenType.MATERIALIZED)
@ -846,6 +854,7 @@ class Parser(metaclass=_Parser):
properties=properties, properties=properties,
temporary=temporary, temporary=temporary,
transient=transient, transient=transient,
external=external,
replace=replace, replace=replace,
unique=unique, unique=unique,
materialized=materialized, materialized=materialized,
@ -861,8 +870,12 @@ class Parser(metaclass=_Parser):
if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY): if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
return self._parse_sortkey(compound=True) return self._parse_sortkey(compound=True)
if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False): assignment = self._match_pair(
key = self._parse_var() TokenType.VAR, TokenType.EQ, advance=False
) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False)
if assignment:
key = self._parse_var() or self._parse_string()
self._match(TokenType.EQ) self._match(TokenType.EQ)
return self.expression(exp.Property, this=key, value=self._parse_column()) return self.expression(exp.Property, this=key, value=self._parse_column())
@ -871,7 +884,10 @@ class Parser(metaclass=_Parser):
def _parse_property_assignment(self, exp_class): def _parse_property_assignment(self, exp_class):
self._match(TokenType.EQ) self._match(TokenType.EQ)
self._match(TokenType.ALIAS) self._match(TokenType.ALIAS)
return self.expression(exp_class, this=self._parse_var_or_string() or self._parse_number()) return self.expression(
exp_class,
this=self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
)
def _parse_partitioned_by(self): def _parse_partitioned_by(self):
self._match(TokenType.EQ) self._match(TokenType.EQ)
@ -881,7 +897,7 @@ class Parser(metaclass=_Parser):
) )
def _parse_distkey(self): def _parse_distkey(self):
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_var)) return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var))
def _parse_create_like(self): def _parse_create_like(self):
table = self._parse_table(schema=True) table = self._parse_table(schema=True)
@ -898,7 +914,7 @@ class Parser(metaclass=_Parser):
def _parse_sortkey(self, compound=False): def _parse_sortkey(self, compound=False):
return self.expression( return self.expression(
exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_var), compound=compound exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_id_var), compound=compound
) )
def _parse_character_set(self, default=False): def _parse_character_set(self, default=False):
@ -929,23 +945,11 @@ class Parser(metaclass=_Parser):
properties = [] properties = []
while True: while True:
if self._match(TokenType.WITH):
properties.extend(self._parse_wrapped_csv(self._parse_property))
elif self._match(TokenType.PROPERTIES):
properties.extend(
self._parse_wrapped_csv(
lambda: self.expression(
exp.Property,
this=self._parse_string(),
value=self._match(TokenType.EQ) and self._parse_string(),
)
)
)
else:
identified_property = self._parse_property() identified_property = self._parse_property()
if not identified_property: if not identified_property:
break break
properties.append(identified_property) for p in ensure_collection(identified_property):
properties.append(p)
if properties: if properties:
return self.expression(exp.Properties, expressions=properties) return self.expression(exp.Properties, expressions=properties)
@ -963,7 +967,7 @@ class Parser(metaclass=_Parser):
exp.Directory, exp.Directory,
this=self._parse_var_or_string(), this=self._parse_var_or_string(),
local=local, local=local,
row_format=self._parse_row_format(), row_format=self._parse_row_format(match_row=True),
) )
else: else:
self._match(TokenType.INTO) self._match(TokenType.INTO)
@ -978,9 +982,17 @@ class Parser(metaclass=_Parser):
overwrite=overwrite, overwrite=overwrite,
) )
def _parse_row_format(self): def _parse_row(self):
if not self._match_pair(TokenType.ROW, TokenType.FORMAT): if not self._match(TokenType.FORMAT):
return None return None
return self._parse_row_format()
def _parse_row_format(self, match_row=False):
if match_row and not self._match_pair(TokenType.ROW, TokenType.FORMAT):
return None
if self._match_text_seq("SERDE"):
return self.expression(exp.RowFormatSerdeProperty, this=self._parse_string())
self._match_text_seq("DELIMITED") self._match_text_seq("DELIMITED")
@ -998,7 +1010,7 @@ class Parser(metaclass=_Parser):
kwargs["lines"] = self._parse_string() kwargs["lines"] = self._parse_string()
if self._match_text_seq("NULL", "DEFINED", "AS"): if self._match_text_seq("NULL", "DEFINED", "AS"):
kwargs["null"] = self._parse_string() kwargs["null"] = self._parse_string()
return self.expression(exp.RowFormat, **kwargs) return self.expression(exp.RowFormatDelimitedProperty, **kwargs)
def _parse_load_data(self): def _parse_load_data(self):
local = self._match(TokenType.LOCAL) local = self._match(TokenType.LOCAL)
@ -1032,7 +1044,7 @@ class Parser(metaclass=_Parser):
return self.expression( return self.expression(
exp.Update, exp.Update,
**{ **{
"this": self._parse_table(schema=True), "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
"expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality), "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality),
"from": self._parse_from(), "from": self._parse_from(),
"where": self._parse_where(), "where": self._parse_where(),
@ -1183,9 +1195,11 @@ class Parser(metaclass=_Parser):
alias=alias, alias=alias,
) )
def _parse_table_alias(self): def _parse_table_alias(self, alias_tokens=None):
any_token = self._match(TokenType.ALIAS) any_token = self._match(TokenType.ALIAS)
alias = self._parse_id_var(any_token=any_token, tokens=self.TABLE_ALIAS_TOKENS) alias = self._parse_id_var(
any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS
)
columns = None columns = None
if self._match(TokenType.L_PAREN): if self._match(TokenType.L_PAREN):
@ -1337,7 +1351,7 @@ class Parser(metaclass=_Parser):
columns=self._parse_expression(), columns=self._parse_expression(),
) )
def _parse_table(self, schema=False): def _parse_table(self, schema=False, alias_tokens=None):
lateral = self._parse_lateral() lateral = self._parse_lateral()
if lateral: if lateral:
@ -1372,7 +1386,7 @@ class Parser(metaclass=_Parser):
table = self._parse_id_var() table = self._parse_id_var()
if not table: if not table:
self.raise_error("Expected table name") self.raise_error(f"Expected table name but got {self._curr}")
this = self.expression( this = self.expression(
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots() exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
@ -1384,7 +1398,7 @@ class Parser(metaclass=_Parser):
if self.alias_post_tablesample: if self.alias_post_tablesample:
table_sample = self._parse_table_sample() table_sample = self._parse_table_sample()
alias = self._parse_table_alias() alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS)
if alias: if alias:
this.set("alias", alias) this.set("alias", alias)
@ -2092,10 +2106,14 @@ class Parser(metaclass=_Parser):
kind = self.expression(exp.CheckColumnConstraint, this=constraint) kind = self.expression(exp.CheckColumnConstraint, this=constraint)
elif self._match(TokenType.COLLATE): elif self._match(TokenType.COLLATE):
kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var()) kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var())
elif self._match(TokenType.ENCODE):
kind = self.expression(exp.EncodeColumnConstraint, this=self._parse_var())
elif self._match(TokenType.DEFAULT): elif self._match(TokenType.DEFAULT):
kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction()) kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction())
elif self._match_pair(TokenType.NOT, TokenType.NULL): elif self._match_pair(TokenType.NOT, TokenType.NULL):
kind = exp.NotNullColumnConstraint() kind = exp.NotNullColumnConstraint()
elif self._match(TokenType.NULL):
kind = exp.NotNullColumnConstraint(allow_null=True)
elif self._match(TokenType.SCHEMA_COMMENT): elif self._match(TokenType.SCHEMA_COMMENT):
kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string()) kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string())
elif self._match(TokenType.PRIMARY_KEY): elif self._match(TokenType.PRIMARY_KEY):
@ -2234,7 +2252,7 @@ class Parser(metaclass=_Parser):
return self._parse_window(this) return self._parse_window(this)
def _parse_extract(self): def _parse_extract(self):
this = self._parse_var() or self._parse_type() this = self._parse_function() or self._parse_var() or self._parse_type()
if self._match(TokenType.FROM): if self._match(TokenType.FROM):
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise()) return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
@ -2635,6 +2653,54 @@ class Parser(metaclass=_Parser):
parser = self._find_parser(self.SET_PARSERS, self._set_trie) parser = self._find_parser(self.SET_PARSERS, self._set_trie)
return parser(self) if parser else self._default_parse_set_item() return parser(self) if parser else self._default_parse_set_item()
def _parse_merge(self):
self._match(TokenType.INTO)
target = self._parse_table(schema=True)
self._match(TokenType.USING)
using = self._parse_table()
self._match(TokenType.ON)
on = self._parse_conjunction()
whens = []
while self._match(TokenType.WHEN):
this = self._parse_conjunction()
self._match(TokenType.THEN)
if self._match(TokenType.INSERT):
_this = self._parse_star()
if _this:
then = self.expression(exp.Insert, this=_this)
else:
then = self.expression(
exp.Insert,
this=self._parse_value(),
expression=self._match(TokenType.VALUES) and self._parse_value(),
)
elif self._match(TokenType.UPDATE):
expressions = self._parse_star()
if expressions:
then = self.expression(exp.Update, expressions=expressions)
else:
then = self.expression(
exp.Update,
expressions=self._match(TokenType.SET)
and self._parse_csv(self._parse_equality),
)
elif self._match(TokenType.DELETE):
then = self.expression(exp.Var, this=self._prev.text)
whens.append(self.expression(exp.When, this=this, then=then))
return self.expression(
exp.Merge,
this=target,
using=using,
on=on,
expressions=whens,
)
def _parse_set(self): def _parse_set(self):
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item)) return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))

View file

@ -47,7 +47,7 @@ class Schema(abc.ABC):
""" """
@abc.abstractmethod @abc.abstractmethod
def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType.Type: def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType:
""" """
Get the :class:`sqlglot.exp.DataType` type of a column in the schema. Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
@ -160,8 +160,8 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
super().__init__(schema) super().__init__(schema)
self.visible = visible or {} self.visible = visible or {}
self.dialect = dialect self.dialect = dialect
self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = { self._type_mapping_cache: t.Dict[str, exp.DataType] = {
"STR": exp.DataType.Type.TEXT, "STR": exp.DataType.build("text"),
} }
@classmethod @classmethod
@ -231,18 +231,18 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
visible = self._nested_get(self.table_parts(table_), self.visible) visible = self._nested_get(self.table_parts(table_), self.visible)
return [col for col in schema if col in visible] # type: ignore return [col for col in schema if col in visible] # type: ignore
def get_column_type( def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
self, table: exp.Table | str, column: exp.Column | str
) -> exp.DataType.Type:
column_name = column if isinstance(column, str) else column.name column_name = column if isinstance(column, str) else column.name
table_ = exp.to_table(table) table_ = exp.to_table(table)
if table_: if table_:
table_schema = self.find(table_) table_schema = self.find(table_, raise_on_missing=False)
if table_schema:
schema_type = table_schema.get(column_name).upper() # type: ignore schema_type = table_schema.get(column_name).upper() # type: ignore
return self._convert_type(schema_type) return self._convert_type(schema_type)
return exp.DataType(this=exp.DataType.Type.UNKNOWN)
raise SchemaError(f"Could not convert table '{table}'") raise SchemaError(f"Could not convert table '{table}'")
def _convert_type(self, schema_type: str) -> exp.DataType.Type: def _convert_type(self, schema_type: str) -> exp.DataType:
""" """
Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object. Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
@ -257,7 +257,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect) expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
if expression is None: if expression is None:
raise ValueError(f"Could not parse {schema_type}") raise ValueError(f"Could not parse {schema_type}")
self._type_mapping_cache[schema_type] = expression.this self._type_mapping_cache[schema_type] = expression # type: ignore
except AttributeError: except AttributeError:
raise SchemaError(f"Failed to convert type {schema_type}") raise SchemaError(f"Failed to convert type {schema_type}")

View file

@ -49,6 +49,9 @@ class TokenType(AutoName):
PARAMETER = auto() PARAMETER = auto()
SESSION_PARAMETER = auto() SESSION_PARAMETER = auto()
BLOCK_START = auto()
BLOCK_END = auto()
SPACE = auto() SPACE = auto()
BREAK = auto() BREAK = auto()
@ -156,6 +159,7 @@ class TokenType(AutoName):
DIV = auto() DIV = auto()
DROP = auto() DROP = auto()
ELSE = auto() ELSE = auto()
ENCODE = auto()
END = auto() END = auto()
ENGINE = auto() ENGINE = auto()
ESCAPE = auto() ESCAPE = auto()
@ -207,6 +211,7 @@ class TokenType(AutoName):
LOCATION = auto() LOCATION = auto()
MAP = auto() MAP = auto()
MATERIALIZED = auto() MATERIALIZED = auto()
MERGE = auto()
MOD = auto() MOD = auto()
NATURAL = auto() NATURAL = auto()
NEXT = auto() NEXT = auto()
@ -255,6 +260,7 @@ class TokenType(AutoName):
SELECT = auto() SELECT = auto()
SEMI = auto() SEMI = auto()
SEPARATOR = auto() SEPARATOR = auto()
SERDE_PROPERTIES = auto()
SET = auto() SET = auto()
SHOW = auto() SHOW = auto()
SIMILAR_TO = auto() SIMILAR_TO = auto()
@ -267,7 +273,6 @@ class TokenType(AutoName):
TABLE_FORMAT = auto() TABLE_FORMAT = auto()
TABLE_SAMPLE = auto() TABLE_SAMPLE = auto()
TEMPORARY = auto() TEMPORARY = auto()
TRANSIENT = auto()
TOP = auto() TOP = auto()
THEN = auto() THEN = auto()
TRAILING = auto() TRAILING = auto()
@ -420,6 +425,16 @@ class Tokenizer(metaclass=_Tokenizer):
ESCAPES = ["'"] ESCAPES = ["'"]
KEYWORDS = { KEYWORDS = {
**{
f"{key}{postfix}": TokenType.BLOCK_START
for key in ("{{", "{%", "{#")
for postfix in ("", "+", "-")
},
**{
f"{prefix}{key}": TokenType.BLOCK_END
for key in ("}}", "%}", "#}")
for prefix in ("", "+", "-")
},
"/*+": TokenType.HINT, "/*+": TokenType.HINT,
"==": TokenType.EQ, "==": TokenType.EQ,
"::": TokenType.DCOLON, "::": TokenType.DCOLON,
@ -523,6 +538,7 @@ class Tokenizer(metaclass=_Tokenizer):
"LOCAL": TokenType.LOCAL, "LOCAL": TokenType.LOCAL,
"LOCATION": TokenType.LOCATION, "LOCATION": TokenType.LOCATION,
"MATERIALIZED": TokenType.MATERIALIZED, "MATERIALIZED": TokenType.MATERIALIZED,
"MERGE": TokenType.MERGE,
"NATURAL": TokenType.NATURAL, "NATURAL": TokenType.NATURAL,
"NEXT": TokenType.NEXT, "NEXT": TokenType.NEXT,
"NO ACTION": TokenType.NO_ACTION, "NO ACTION": TokenType.NO_ACTION,
@ -582,7 +598,6 @@ class Tokenizer(metaclass=_Tokenizer):
"TABLESAMPLE": TokenType.TABLE_SAMPLE, "TABLESAMPLE": TokenType.TABLE_SAMPLE,
"TEMP": TokenType.TEMPORARY, "TEMP": TokenType.TEMPORARY,
"TEMPORARY": TokenType.TEMPORARY, "TEMPORARY": TokenType.TEMPORARY,
"TRANSIENT": TokenType.TRANSIENT,
"THEN": TokenType.THEN, "THEN": TokenType.THEN,
"TRUE": TokenType.TRUE, "TRUE": TokenType.TRUE,
"TRAILING": TokenType.TRAILING, "TRAILING": TokenType.TRAILING,

View file

@ -4,6 +4,7 @@ import unittest
from sqlglot.dataframe.sql import types from sqlglot.dataframe.sql import types
from sqlglot.dataframe.sql.dataframe import DataFrame from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.session import SparkSession from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.helper import ensure_list
class DataFrameSQLValidator(unittest.TestCase): class DataFrameSQLValidator(unittest.TestCase):
@ -33,9 +34,7 @@ class DataFrameSQLValidator(unittest.TestCase):
self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False
): ):
actual_sqls = df.sql(pretty=pretty) actual_sqls = df.sql(pretty=pretty)
expected_statements = ( expected_statements = ensure_list(expected_statements)
[expected_statements] if isinstance(expected_statements, str) else expected_statements
)
self.assertEqual(len(expected_statements), len(actual_sqls)) self.assertEqual(len(expected_statements), len(actual_sqls))
for expected, actual in zip(expected_statements, actual_sqls): for expected, actual in zip(expected_statements, actual_sqls):
self.assertEqual(expected, actual) self.assertEqual(expected, actual)

View file

@ -10,37 +10,37 @@ class TestDataFrameWriter(DataFrameSQLValidator):
def test_insertInto_full_path(self): def test_insertInto_full_path(self):
df = self.df_employee.write.insertInto("catalog.db.table_name") df = self.df_employee.write.insertInto("catalog.db.table_name")
expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "INSERT INTO catalog.db.table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_insertInto_db_table(self): def test_insertInto_db_table(self):
df = self.df_employee.write.insertInto("db.table_name") df = self.df_employee.write.insertInto("db.table_name")
expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "INSERT INTO db.table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_insertInto_table(self): def test_insertInto_table(self):
df = self.df_employee.write.insertInto("table_name") df = self.df_employee.write.insertInto("table_name")
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_insertInto_overwrite(self): def test_insertInto_overwrite(self):
df = self.df_employee.write.insertInto("table_name", overwrite=True) df = self.df_employee.write.insertInto("table_name", overwrite=True)
expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "INSERT OVERWRITE TABLE table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema()) @mock.patch("sqlglot.schema", MappingSchema())
def test_insertInto_byName(self): def test_insertInto_byName(self):
sqlglot.schema.add_table("table_name", {"employee_id": "INT"}) sqlglot.schema.add_table("table_name", {"employee_id": "INT"})
df = self.df_employee.write.byName.insertInto("table_name") df = self.df_employee.write.byName.insertInto("table_name")
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_insertInto_cache(self): def test_insertInto_cache(self):
df = self.df_employee.cache().write.insertInto("table_name") df = self.df_employee.cache().write.insertInto("table_name")
expected_statements = [ expected_statements = [
"DROP VIEW IF EXISTS t37164", "DROP VIEW IF EXISTS t12441",
"CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", "CACHE LAZY TABLE t12441 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"INSERT INTO table_name SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`", "INSERT INTO table_name SELECT `t12441`.`employee_id` AS `employee_id`, `t12441`.`fname` AS `fname`, `t12441`.`lname` AS `lname`, `t12441`.`age` AS `age`, `t12441`.`store_id` AS `store_id` FROM `t12441` AS `t12441`",
] ]
self.compare_sql(df, expected_statements) self.compare_sql(df, expected_statements)
@ -50,39 +50,39 @@ class TestDataFrameWriter(DataFrameSQLValidator):
def test_saveAsTable_append(self): def test_saveAsTable_append(self):
df = self.df_employee.write.saveAsTable("table_name", mode="append") df = self.df_employee.write.saveAsTable("table_name", mode="append")
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_saveAsTable_overwrite(self): def test_saveAsTable_overwrite(self):
df = self.df_employee.write.saveAsTable("table_name", mode="overwrite") df = self.df_employee.write.saveAsTable("table_name", mode="overwrite")
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "CREATE OR REPLACE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_saveAsTable_error(self): def test_saveAsTable_error(self):
df = self.df_employee.write.saveAsTable("table_name", mode="error") df = self.df_employee.write.saveAsTable("table_name", mode="error")
expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "CREATE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_saveAsTable_ignore(self): def test_saveAsTable_ignore(self):
df = self.df_employee.write.saveAsTable("table_name", mode="ignore") df = self.df_employee.write.saveAsTable("table_name", mode="ignore")
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_mode_standalone(self): def test_mode_standalone(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name") df = self.df_employee.write.mode("ignore").saveAsTable("table_name")
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_mode_override(self): def test_mode_override(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite") df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite")
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" expected = "CREATE OR REPLACE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_saveAsTable_cache(self): def test_saveAsTable_cache(self):
df = self.df_employee.cache().write.saveAsTable("table_name") df = self.df_employee.cache().write.saveAsTable("table_name")
expected_statements = [ expected_statements = [
"DROP VIEW IF EXISTS t37164", "DROP VIEW IF EXISTS t12441",
"CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", "CACHE LAZY TABLE t12441 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"CREATE TABLE table_name AS SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`", "CREATE TABLE table_name AS SELECT `t12441`.`employee_id` AS `employee_id`, `t12441`.`fname` AS `fname`, `t12441`.`lname` AS `lname`, `t12441`.`age` AS `age`, `t12441`.`store_id` AS `store_id` FROM `t12441` AS `t12441`",
] ]
self.compare_sql(df, expected_statements) self.compare_sql(df, expected_statements)

View file

@ -36,7 +36,7 @@ class TestDataframeSession(DataFrameSQLValidator):
def test_cdf_str_schema(self): def test_cdf_str_schema(self):
df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING") df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING")
expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)" expected = "SELECT `a2`.`cola` AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_typed_schema_basic(self): def test_typed_schema_basic(self):
@ -47,7 +47,7 @@ class TestDataframeSession(DataFrameSQLValidator):
] ]
) )
df = self.spark.createDataFrame([[1, "test"]], schema) df = self.spark.createDataFrame([[1, "test"]], schema)
expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)" expected = "SELECT `a2`.`cola` AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_typed_schema_nested(self): def test_typed_schema_nested(self):

View file

@ -6,6 +6,11 @@ class TestBigQuery(Validator):
dialect = "bigquery" dialect = "bigquery"
def test_bigquery(self): def test_bigquery(self):
self.validate_all(
"REGEXP_CONTAINS('foo', '.*')",
read={"bigquery": "REGEXP_CONTAINS('foo', '.*')"},
write={"mysql": "REGEXP_LIKE('foo', '.*')"},
),
self.validate_all( self.validate_all(
'"""x"""', '"""x"""',
write={ write={
@ -94,6 +99,20 @@ class TestBigQuery(Validator):
"spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)", "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)",
}, },
) )
self.validate_all(
"SELECT ARRAY(SELECT x FROM UNNEST([0, 1]) AS x)",
write={"bigquery": "SELECT ARRAY(SELECT x FROM UNNEST([0, 1]) AS x)"},
)
self.validate_all(
"SELECT ARRAY(SELECT DISTINCT x FROM UNNEST(some_numbers) AS x) AS unique_numbers",
write={
"bigquery": "SELECT ARRAY(SELECT DISTINCT x FROM UNNEST(some_numbers) AS x) AS unique_numbers"
},
)
self.validate_all(
"SELECT ARRAY(SELECT * FROM foo JOIN bla ON x = y)",
write={"bigquery": "SELECT ARRAY(SELECT * FROM foo JOIN bla ON x = y)"},
)
self.validate_all( self.validate_all(
"x IS unknown", "x IS unknown",

View file

@ -1318,3 +1318,39 @@ SELECT
"BEGIN IMMEDIATE TRANSACTION", "BEGIN IMMEDIATE TRANSACTION",
write={"sqlite": "BEGIN IMMEDIATE TRANSACTION"}, write={"sqlite": "BEGIN IMMEDIATE TRANSACTION"},
) )
def test_merge(self):
self.validate_all(
"""
MERGE INTO target USING source ON target.id = source.id
WHEN NOT MATCHED THEN INSERT (id) values (source.id)
""",
write={
"bigquery": "MERGE INTO target USING source ON target.id = source.id WHEN NOT MATCHED THEN INSERT (id) VALUES (source.id)",
"snowflake": "MERGE INTO target USING source ON target.id = source.id WHEN NOT MATCHED THEN INSERT (id) VALUES (source.id)",
"spark": "MERGE INTO target USING source ON target.id = source.id WHEN NOT MATCHED THEN INSERT (id) VALUES (source.id)",
},
)
self.validate_all(
"""
MERGE INTO target USING source ON target.id = source.id
WHEN MATCHED AND source.is_deleted = 1 THEN DELETE
WHEN MATCHED THEN UPDATE SET val = source.val
WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)
""",
write={
"bigquery": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED AND source.is_deleted = 1 THEN DELETE WHEN MATCHED THEN UPDATE SET val = source.val WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)",
"snowflake": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED AND source.is_deleted = 1 THEN DELETE WHEN MATCHED THEN UPDATE SET val = source.val WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)",
"spark": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED AND source.is_deleted = 1 THEN DELETE WHEN MATCHED THEN UPDATE SET val = source.val WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)",
},
)
self.validate_all(
"""
MERGE INTO target USING source ON target.id = source.id
WHEN MATCHED THEN UPDATE *
WHEN NOT MATCHED THEN INSERT *
""",
write={
"spark": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED THEN UPDATE * WHEN NOT MATCHED THEN INSERT *",
},
)

View file

@ -145,6 +145,10 @@ class TestHive(Validator):
}, },
) )
self.validate_identity(
"""CREATE EXTERNAL TABLE x (y INT) ROW FORMAT SERDE 'serde' ROW FORMAT DELIMITED FIELDS TERMINATED BY '1' WITH SERDEPROPERTIES ('input.regex'='')""",
)
def test_lateral_view(self): def test_lateral_view(self):
self.validate_all( self.validate_all(
"SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b", "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b",

View file

@ -256,3 +256,7 @@ class TestPostgres(Validator):
"SELECT $$Dianne's horse$$", "SELECT $$Dianne's horse$$",
write={"postgres": "SELECT 'Dianne''s horse'"}, write={"postgres": "SELECT 'Dianne''s horse'"},
) )
self.validate_all(
"UPDATE MYTABLE T1 SET T1.COL = 13",
write={"postgres": "UPDATE MYTABLE AS T1 SET T1.COL = 13"},
)

View file

@ -56,8 +56,27 @@ class TestRedshift(Validator):
"redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1', "redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1',
}, },
) )
self.validate_all(
"DECODE(x, a, b, c, d)",
write={
"": "MATCHES(x, a, b, c, d)",
"oracle": "DECODE(x, a, b, c, d)",
"snowflake": "DECODE(x, a, b, c, d)",
},
)
self.validate_all(
"NVL(a, b, c, d)",
write={
"redshift": "COALESCE(a, b, c, d)",
"mysql": "COALESCE(a, b, c, d)",
"postgres": "COALESCE(a, b, c, d)",
},
)
def test_identity(self): def test_identity(self):
self.validate_identity(
"SELECT DECODE(COL1, 'replace_this', 'with_this', 'replace_that', 'with_that')"
)
self.validate_identity("CAST('bla' AS SUPER)") self.validate_identity("CAST('bla' AS SUPER)")
self.validate_identity("CREATE TABLE real1 (realcol REAL)") self.validate_identity("CREATE TABLE real1 (realcol REAL)")
self.validate_identity("CAST('foo' AS HLLSKETCH)") self.validate_identity("CAST('foo' AS HLLSKETCH)")
@ -70,9 +89,9 @@ class TestRedshift(Validator):
self.validate_identity( self.validate_identity(
"SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'" "SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'"
) )
self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE AUTO") self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL")
self.validate_identity( self.validate_identity(
"CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid)" "CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid) DISTSTYLE AUTO"
) )
self.validate_identity( self.validate_identity(
"COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'" "COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'"
@ -80,3 +99,6 @@ class TestRedshift(Validator):
self.validate_identity( self.validate_identity(
"UNLOAD ('select * from venue') TO 's3://mybucket/unload/' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'" "UNLOAD ('select * from venue') TO 's3://mybucket/unload/' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'"
) )
self.validate_identity(
"CREATE TABLE SOUP (SOUP1 VARCHAR(50) NOT NULL ENCODE ZSTD, SOUP2 VARCHAR(70) NULL ENCODE DELTA)"
)

View file

@ -500,3 +500,12 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') f, LATERAL F
}, },
pretty=True, pretty=True,
) )
def test_minus(self):
self.validate_all(
"SELECT 1 EXCEPT SELECT 1",
read={
"oracle": "SELECT 1 MINUS SELECT 1",
"snowflake": "SELECT 1 MINUS SELECT 1",
},
)

View file

@ -75,6 +75,7 @@ ARRAY(1, 2)
ARRAY_CONTAINS(x, 1) ARRAY_CONTAINS(x, 1)
EXTRACT(x FROM y) EXTRACT(x FROM y)
EXTRACT(DATE FROM y) EXTRACT(DATE FROM y)
EXTRACT(WEEK(monday) FROM created_at)
CONCAT_WS('-', 'a', 'b') CONCAT_WS('-', 'a', 'b')
CONCAT_WS('-', 'a', 'b', 'c') CONCAT_WS('-', 'a', 'b', 'c')
POSEXPLODE("x") AS ("a", "b") POSEXPLODE("x") AS ("a", "b")

View file

@ -3,3 +3,9 @@ SELECT CONCAT(w.d, w.e) AS c FROM w AS w;
SELECT CAST(w.d AS DATE) > w.e AS a FROM w AS w; SELECT CAST(w.d AS DATE) > w.e AS a FROM w AS w;
SELECT CAST(w.d AS DATE) > CAST(w.e AS DATE) AS a FROM w AS w; SELECT CAST(w.d AS DATE) > CAST(w.e AS DATE) AS a FROM w AS w;
SELECT CAST(1 AS VARCHAR) AS a FROM w AS w;
SELECT CAST(1 AS VARCHAR) AS a FROM w AS w;
SELECT CAST(1 + 3.2 AS DOUBLE) AS a FROM w AS w;
SELECT 1 + 3.2 AS a FROM w AS w;

View file

@ -79,14 +79,16 @@ NULL;
NULL = NULL; NULL = NULL;
NULL; NULL;
-- Can't optimize this because different engines do different things
-- mysql converts to 0 and 1 but tsql does true and false
NULL <=> NULL; NULL <=> NULL;
TRUE; NULL IS NOT DISTINCT FROM NULL;
a IS NOT DISTINCT FROM a; a IS NOT DISTINCT FROM a;
TRUE; a IS NOT DISTINCT FROM a;
NULL IS DISTINCT FROM NULL; NULL IS DISTINCT FROM NULL;
FALSE; NULL IS DISTINCT FROM NULL;
NOT (NOT TRUE); NOT (NOT TRUE);
TRUE; TRUE;
@ -239,10 +241,10 @@ TRUE;
FALSE; FALSE;
((NOT FALSE) AND (x = x)) AND (TRUE OR 1 <> 3); ((NOT FALSE) AND (x = x)) AND (TRUE OR 1 <> 3);
TRUE; x = x;
((NOT FALSE) AND (x = x)) AND (FALSE OR 1 <> 2); ((NOT FALSE) AND (x = x)) AND (FALSE OR 1 <> 2);
TRUE; x = x;
(('a' = 'a') AND TRUE and NOT FALSE); (('a' = 'a') AND TRUE and NOT FALSE);
TRUE; TRUE;
@ -372,3 +374,171 @@ CAST('1998-12-01' AS DATE) - INTERVAL '90' foo;
date '1998-12-01' + interval '90' foo; date '1998-12-01' + interval '90' foo;
CAST('1998-12-01' AS DATE) + INTERVAL '90' foo; CAST('1998-12-01' AS DATE) + INTERVAL '90' foo;
--------------------------------------
-- Comparisons
--------------------------------------
x < 0 OR x > 1;
x < 0 OR x > 1;
x < 0 OR x > 0;
x < 0 OR x > 0;
x < 1 OR x > 0;
x < 1 OR x > 0;
x < 1 OR x >= 0;
x < 1 OR x >= 0;
x <= 1 OR x > 0;
x <= 1 OR x > 0;
x <= 1 OR x >= 0;
x <= 1 OR x >= 0;
x <= 1 AND x <= 0;
x <= 0;
x <= 1 AND x > 0;
x <= 1 AND x > 0;
x <= 1 OR x > 0;
x <= 1 OR x > 0;
x <= 0 OR x < 0;
x <= 0;
x >= 0 OR x > 0;
x >= 0;
x >= 0 OR x > 1;
x >= 0;
x <= 0 OR x >= 0;
x <= 0 OR x >= 0;
x <= 0 AND x >= 0;
x <= 0 AND x >= 0;
x < 1 AND x < 2;
x < 1;
x < 1 OR x < 2;
x < 2;
x < 2 AND x < 1;
x < 1;
x < 2 OR x < 1;
x < 2;
x < 1 AND x < 1;
x < 1;
x < 1 OR x < 1;
x < 1;
x <= 1 AND x < 1;
x < 1;
x <= 1 OR x < 1;
x <= 1;
x < 1 AND x <= 1;
x < 1;
x < 1 OR x <= 1;
x <= 1;
x > 1 AND x > 2;
x > 2;
x > 1 OR x > 2;
x > 1;
x > 2 AND x > 1;
x > 2;
x > 2 OR x > 1;
x > 1;
x > 1 AND x > 1;
x > 1;
x > 1 OR x > 1;
x > 1;
x >= 1 AND x > 1;
x > 1;
x >= 1 OR x > 1;
x >= 1;
x > 1 AND x >= 1;
x > 1;
x > 1 OR x >= 1;
x >= 1;
x > 1 AND x >= 2;
x >= 2;
x > 1 OR x >= 2;
x > 1;
x > 1 AND x >= 2 AND x > 3 AND x > 0;
x > 3;
(x > 1 AND x >= 2 AND x > 3 AND x > 0) OR x > 0;
x > 0;
x > 1 AND x < 2 AND x > 3;
FALSE;
x > 1 AND x < 1;
FALSE;
x < 2 AND x > 1;
x < 2 AND x > 1;
x = 1 AND x < 1;
FALSE;
x = 1 AND x < 1.1;
x = 1;
x = 1 AND x <= 1;
x = 1;
x = 1 AND x <= 0.9;
FALSE;
x = 1 AND x > 0.9;
x = 1;
x = 1 AND x > 1;
FALSE;
x = 1 AND x >= 1;
x = 1;
x = 1 AND x >= 2;
FALSE;
x = 1 AND x <> 2;
x = 1;
x <> 1 AND x = 1;
FALSE;
x BETWEEN 0 AND 5 AND x > 3;
x <= 5 AND x > 3;
x > 3 AND 5 > x AND x BETWEEN 0 AND 10;
x < 5 AND x > 3;
x > 3 AND 5 < x AND x BETWEEN 9 AND 10;
x <= 10 AND x >= 9;
1 < x AND 3 < x;
x > 3;

View file

@ -190,7 +190,7 @@ SELECT
SUM("lineitem"."l_extendedprice" * ( SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount" 1 - "lineitem"."l_discount"
)) AS "revenue", )) AS "revenue",
CAST("orders"."o_orderdate" AS TEXT) AS "o_orderdate", "orders"."o_orderdate" AS "o_orderdate",
"orders"."o_shippriority" AS "o_shippriority" "orders"."o_shippriority" AS "o_shippriority"
FROM "customer" AS "customer" FROM "customer" AS "customer"
JOIN "orders" AS "orders" JOIN "orders" AS "orders"
@ -326,7 +326,8 @@ SELECT
SUM("lineitem"."l_extendedprice" * "lineitem"."l_discount") AS "revenue" SUM("lineitem"."l_extendedprice" * "lineitem"."l_discount") AS "revenue"
FROM "lineitem" AS "lineitem" FROM "lineitem" AS "lineitem"
WHERE WHERE
"lineitem"."l_discount" BETWEEN 0.05 AND 0.07 "lineitem"."l_discount" <= 0.07
AND "lineitem"."l_discount" >= 0.05
AND "lineitem"."l_quantity" < 24 AND "lineitem"."l_quantity" < 24
AND CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE) AND CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE)
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE); AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE);
@ -344,7 +345,7 @@ from
select select
n1.n_name as supp_nation, n1.n_name as supp_nation,
n2.n_name as cust_nation, n2.n_name as cust_nation,
extract(year from l_shipdate) as l_year, extract(year from cast(l_shipdate as date)) as l_year,
l_extendedprice * (1 - l_discount) as volume l_extendedprice * (1 - l_discount) as volume
from from
supplier, supplier,
@ -384,13 +385,14 @@ WITH "n1" AS (
SELECT SELECT
"n1"."n_name" AS "supp_nation", "n1"."n_name" AS "supp_nation",
"n2"."n_name" AS "cust_nation", "n2"."n_name" AS "cust_nation",
EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME)) AS "l_year", EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATE)) AS "l_year",
SUM("lineitem"."l_extendedprice" * ( SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount" 1 - "lineitem"."l_discount"
)) AS "revenue" )) AS "revenue"
FROM "supplier" AS "supplier" FROM "supplier" AS "supplier"
JOIN "lineitem" AS "lineitem" JOIN "lineitem" AS "lineitem"
ON CAST("lineitem"."l_shipdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) ON CAST("lineitem"."l_shipdate" AS DATE) <= CAST('1996-12-31' AS DATE)
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1995-01-01' AS DATE)
AND "supplier"."s_suppkey" = "lineitem"."l_suppkey" AND "supplier"."s_suppkey" = "lineitem"."l_suppkey"
JOIN "orders" AS "orders" JOIN "orders" AS "orders"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey" ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
@ -409,7 +411,7 @@ JOIN "n1" AS "n2"
GROUP BY GROUP BY
"n1"."n_name", "n1"."n_name",
"n2"."n_name", "n2"."n_name",
EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME)) EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATE))
ORDER BY ORDER BY
"supp_nation", "supp_nation",
"cust_nation", "cust_nation",
@ -427,7 +429,7 @@ select
from from
( (
select select
extract(year from o_orderdate) as o_year, extract(year from cast(o_orderdate as date)) as o_year,
l_extendedprice * (1 - l_discount) as volume, l_extendedprice * (1 - l_discount) as volume,
n2.n_name as nation n2.n_name as nation
from from
@ -456,7 +458,7 @@ group by
order by order by
o_year; o_year;
SELECT SELECT
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year", EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE)) AS "o_year",
SUM( SUM(
CASE CASE
WHEN "nation_2"."n_name" = 'BRAZIL' WHEN "nation_2"."n_name" = 'BRAZIL'
@ -477,7 +479,8 @@ JOIN "customer" AS "customer"
ON "customer"."c_nationkey" = "nation"."n_nationkey" ON "customer"."c_nationkey" = "nation"."n_nationkey"
JOIN "orders" AS "orders" JOIN "orders" AS "orders"
ON "orders"."o_custkey" = "customer"."c_custkey" ON "orders"."o_custkey" = "customer"."c_custkey"
AND CAST("orders"."o_orderdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) AND CAST("orders"."o_orderdate" AS DATE) <= CAST('1996-12-31' AS DATE)
AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1995-01-01' AS DATE)
JOIN "lineitem" AS "lineitem" JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey" ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
AND "part"."p_partkey" = "lineitem"."l_partkey" AND "part"."p_partkey" = "lineitem"."l_partkey"
@ -488,7 +491,7 @@ JOIN "nation" AS "nation_2"
WHERE WHERE
"part"."p_type" = 'ECONOMY ANODIZED STEEL' "part"."p_type" = 'ECONOMY ANODIZED STEEL'
GROUP BY GROUP BY
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE))
ORDER BY ORDER BY
"o_year"; "o_year";
@ -503,7 +506,7 @@ from
( (
select select
n_name as nation, n_name as nation,
extract(year from o_orderdate) as o_year, extract(year from cast(o_orderdate as date)) as o_year,
l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount
from from
part, part,
@ -529,7 +532,7 @@ order by
o_year desc; o_year desc;
SELECT SELECT
"nation"."n_name" AS "nation", "nation"."n_name" AS "nation",
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year", EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE)) AS "o_year",
SUM( SUM(
"lineitem"."l_extendedprice" * ( "lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount" 1 - "lineitem"."l_discount"
@ -551,7 +554,7 @@ WHERE
"part"."p_name" LIKE '%green%' "part"."p_name" LIKE '%green%'
GROUP BY GROUP BY
"nation"."n_name", "nation"."n_name",
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE))
ORDER BY ORDER BY
"nation", "nation",
"o_year" DESC; "o_year" DESC;
@ -1016,7 +1019,7 @@ select
o_orderkey, o_orderkey,
o_orderdate, o_orderdate,
o_totalprice, o_totalprice,
sum(l_quantity) sum(l_quantity) total_quantity
from from
customer, customer,
orders, orders,
@ -1060,7 +1063,7 @@ SELECT
"orders"."o_orderkey" AS "o_orderkey", "orders"."o_orderkey" AS "o_orderkey",
"orders"."o_orderdate" AS "o_orderdate", "orders"."o_orderdate" AS "o_orderdate",
"orders"."o_totalprice" AS "o_totalprice", "orders"."o_totalprice" AS "o_totalprice",
SUM("lineitem"."l_quantity") AS "_col_5" SUM("lineitem"."l_quantity") AS "total_quantity"
FROM "customer" AS "customer" FROM "customer" AS "customer"
JOIN "orders" AS "orders" JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey" ON "customer"."c_custkey" = "orders"."o_custkey"
@ -1129,19 +1132,22 @@ JOIN "part" AS "part"
"part"."p_brand" = 'Brand#12' "part"."p_brand" = 'Brand#12'
AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG')
AND "part"."p_partkey" = "lineitem"."l_partkey" AND "part"."p_partkey" = "lineitem"."l_partkey"
AND "part"."p_size" BETWEEN 1 AND 5 AND "part"."p_size" <= 5
AND "part"."p_size" >= 1
) )
OR ( OR (
"part"."p_brand" = 'Brand#23' "part"."p_brand" = 'Brand#23'
AND "part"."p_container" IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') AND "part"."p_container" IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK')
AND "part"."p_partkey" = "lineitem"."l_partkey" AND "part"."p_partkey" = "lineitem"."l_partkey"
AND "part"."p_size" BETWEEN 1 AND 10 AND "part"."p_size" <= 10
AND "part"."p_size" >= 1
) )
OR ( OR (
"part"."p_brand" = 'Brand#34' "part"."p_brand" = 'Brand#34'
AND "part"."p_container" IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') AND "part"."p_container" IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG')
AND "part"."p_partkey" = "lineitem"."l_partkey" AND "part"."p_partkey" = "lineitem"."l_partkey"
AND "part"."p_size" BETWEEN 1 AND 15 AND "part"."p_size" <= 15
AND "part"."p_size" >= 1
) )
WHERE WHERE
( (
@ -1152,7 +1158,8 @@ WHERE
AND "part"."p_brand" = 'Brand#12' AND "part"."p_brand" = 'Brand#12'
AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG')
AND "part"."p_partkey" = "lineitem"."l_partkey" AND "part"."p_partkey" = "lineitem"."l_partkey"
AND "part"."p_size" BETWEEN 1 AND 5 AND "part"."p_size" <= 5
AND "part"."p_size" >= 1
) )
OR ( OR (
"lineitem"."l_quantity" <= 20 "lineitem"."l_quantity" <= 20
@ -1162,7 +1169,8 @@ WHERE
AND "part"."p_brand" = 'Brand#23' AND "part"."p_brand" = 'Brand#23'
AND "part"."p_container" IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') AND "part"."p_container" IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK')
AND "part"."p_partkey" = "lineitem"."l_partkey" AND "part"."p_partkey" = "lineitem"."l_partkey"
AND "part"."p_size" BETWEEN 1 AND 10 AND "part"."p_size" <= 10
AND "part"."p_size" >= 1
) )
OR ( OR (
"lineitem"."l_quantity" <= 30 "lineitem"."l_quantity" <= 30
@ -1172,7 +1180,8 @@ WHERE
AND "part"."p_brand" = 'Brand#34' AND "part"."p_brand" = 'Brand#34'
AND "part"."p_container" IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') AND "part"."p_container" IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG')
AND "part"."p_partkey" = "lineitem"."l_partkey" AND "part"."p_partkey" = "lineitem"."l_partkey"
AND "part"."p_size" BETWEEN 1 AND 15 AND "part"."p_size" <= 15
AND "part"."p_size" >= 1
); );
-------------------------------------- --------------------------------------

View file

@ -26,12 +26,12 @@ class TestExecutor(unittest.TestCase):
def setUpClass(cls): def setUpClass(cls):
cls.conn = duckdb.connect() cls.conn = duckdb.connect()
for table in TPCH_SCHEMA: for table, columns in TPCH_SCHEMA.items():
cls.conn.execute( cls.conn.execute(
f""" f"""
CREATE VIEW {table} AS CREATE VIEW {table} AS
SELECT * SELECT *
FROM READ_CSV_AUTO('{DIR}{table}.csv.gz') FROM READ_CSV('{DIR}{table}.csv.gz', delim='|', header=True, columns={columns})
""" """
) )
@ -74,13 +74,13 @@ class TestExecutor(unittest.TestCase):
) )
return expression return expression
for i, (sql, _) in enumerate(self.sqls[0:16]): for i, (sql, _) in enumerate(self.sqls[0:18]):
with self.subTest(f"tpch-h {i + 1}"): with self.subTest(f"tpch-h {i + 1}"):
a = self.cached_execute(sql) a = self.cached_execute(sql)
sql = parse_one(sql).transform(to_csv).sql(pretty=True) sql = parse_one(sql).transform(to_csv).sql(pretty=True)
table = execute(sql, TPCH_SCHEMA) table = execute(sql, TPCH_SCHEMA)
b = pd.DataFrame(table.rows, columns=table.columns) b = pd.DataFrame(table.rows, columns=table.columns)
assert_frame_equal(a, b, check_dtype=False) assert_frame_equal(a, b, check_dtype=False, check_index_type=False)
def test_execute_callable(self): def test_execute_callable(self):
tables = { tables = {
@ -456,8 +456,13 @@ class TestExecutor(unittest.TestCase):
("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]), ("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]), ("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]), ("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]),
("SELECT SUM(x) FROM (SELECT 1 AS x WHERE FALSE)", ["_col_0"], [(0,)]), (
"SELECT SUM(x), COUNT(x) FROM (SELECT 1 AS x WHERE FALSE)",
["_col_0", "_col_1"],
[(None, 0)],
),
]: ]:
with self.subTest(sql):
result = execute(sql) result = execute(sql)
self.assertEqual(result.columns, tuple(cols)) self.assertEqual(result.columns, tuple(cols))
self.assertEqual(result.rows, rows) self.assertEqual(result.rows, rows)

View file

@ -333,7 +333,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
for sql, target_type in tests.items(): for sql, target_type in tests.items():
expression = annotate_types(parse_one(sql)) expression = annotate_types(parse_one(sql))
self.assertEqual(expression.find(exp.Literal).type, target_type) self.assertEqual(expression.find(exp.Literal).type.this, target_type)
def test_boolean_type_annotation(self): def test_boolean_type_annotation(self):
tests = { tests = {
@ -343,31 +343,33 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
for sql, target_type in tests.items(): for sql, target_type in tests.items():
expression = annotate_types(parse_one(sql)) expression = annotate_types(parse_one(sql))
self.assertEqual(expression.find(exp.Boolean).type, target_type) self.assertEqual(expression.find(exp.Boolean).type.this, target_type)
def test_cast_type_annotation(self): def test_cast_type_annotation(self):
expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))")) expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))"))
self.assertEqual(expression.type.this, exp.DataType.Type.TIMESTAMPTZ)
self.assertEqual(expression.this.type.this, exp.DataType.Type.VARCHAR)
self.assertEqual(expression.args["to"].type.this, exp.DataType.Type.TIMESTAMPTZ)
self.assertEqual(expression.args["to"].expressions[0].type.this, exp.DataType.Type.INT)
self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ) expression = annotate_types(parse_one("ARRAY(1)::ARRAY<INT>"))
self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR) self.assertEqual(expression.type, parse_one("ARRAY<INT>", into=exp.DataType))
self.assertEqual(expression.args["to"].type, exp.DataType.Type.TIMESTAMPTZ)
self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT)
def test_cache_annotation(self): def test_cache_annotation(self):
expression = annotate_types( expression = annotate_types(
parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1") parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")
) )
self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT) self.assertEqual(expression.expression.expressions[0].type.this, exp.DataType.Type.INT)
def test_binary_annotation(self): def test_binary_annotation(self):
expression = annotate_types(parse_one("SELECT 0.0 + (2 + 3)")).expressions[0] expression = annotate_types(parse_one("SELECT 0.0 + (2 + 3)")).expressions[0]
self.assertEqual(expression.type, exp.DataType.Type.DOUBLE) self.assertEqual(expression.type.this, exp.DataType.Type.DOUBLE)
self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE) self.assertEqual(expression.left.type.this, exp.DataType.Type.DOUBLE)
self.assertEqual(expression.right.type, exp.DataType.Type.INT) self.assertEqual(expression.right.type.this, exp.DataType.Type.INT)
self.assertEqual(expression.right.this.type, exp.DataType.Type.INT) self.assertEqual(expression.right.this.type.this, exp.DataType.Type.INT)
self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT) self.assertEqual(expression.right.this.left.type.this, exp.DataType.Type.INT)
self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT) self.assertEqual(expression.right.this.right.type.this, exp.DataType.Type.INT)
def test_derived_tables_column_annotation(self): def test_derived_tables_column_annotation(self):
schema = {"x": {"cola": "INT"}, "y": {"cola": "FLOAT"}} schema = {"x": {"cola": "INT"}, "y": {"cola": "FLOAT"}}
@ -387,128 +389,169 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
""" """
expression = annotate_types(parse_one(sql), schema=schema) expression = annotate_types(parse_one(sql), schema=schema)
self.assertEqual(expression.expressions[0].type, exp.DataType.Type.FLOAT) # a.cola AS cola self.assertEqual(
expression.expressions[0].type.this, exp.DataType.Type.FLOAT
) # a.cola AS cola
addition_alias = expression.args["from"].expressions[0].this.expressions[0] addition_alias = expression.args["from"].expressions[0].this.expressions[0]
self.assertEqual(addition_alias.type, exp.DataType.Type.FLOAT) # x.cola + y.cola AS cola self.assertEqual(
addition_alias.type.this, exp.DataType.Type.FLOAT
) # x.cola + y.cola AS cola
addition = addition_alias.this addition = addition_alias.this
self.assertEqual(addition.type, exp.DataType.Type.FLOAT) self.assertEqual(addition.type.this, exp.DataType.Type.FLOAT)
self.assertEqual(addition.this.type, exp.DataType.Type.INT) self.assertEqual(addition.this.type.this, exp.DataType.Type.INT)
self.assertEqual(addition.expression.type, exp.DataType.Type.FLOAT) self.assertEqual(addition.expression.type.this, exp.DataType.Type.FLOAT)
def test_cte_column_annotation(self): def test_cte_column_annotation(self):
schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT"}} schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT", "colc": "BOOLEAN"}}
sql = """ sql = """
WITH tbl AS ( WITH tbl AS (
SELECT x.cola + 'bla' AS cola, y.colb AS colb SELECT x.cola + 'bla' AS cola, y.colb AS colb, y.colc AS colc
FROM ( FROM (
SELECT x.cola AS cola SELECT x.cola AS cola
FROM x AS x FROM x AS x
) AS x ) AS x
JOIN ( JOIN (
SELECT y.colb AS colb SELECT y.colb AS colb, y.colc AS colc
FROM y AS y FROM y AS y
) AS y ) AS y
) )
SELECT tbl.cola + tbl.colb + 'foo' AS col SELECT tbl.cola + tbl.colb + 'foo' AS col
FROM tbl AS tbl FROM tbl AS tbl
WHERE tbl.colc = True
""" """
expression = annotate_types(parse_one(sql), schema=schema) expression = annotate_types(parse_one(sql), schema=schema)
self.assertEqual( self.assertEqual(
expression.expressions[0].type, exp.DataType.Type.TEXT expression.expressions[0].type.this, exp.DataType.Type.TEXT
) # tbl.cola + tbl.colb + 'foo' AS col ) # tbl.cola + tbl.colb + 'foo' AS col
outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo' outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo'
self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT) self.assertEqual(outer_addition.type.this, exp.DataType.Type.TEXT)
self.assertEqual(outer_addition.left.type, exp.DataType.Type.TEXT) self.assertEqual(outer_addition.left.type.this, exp.DataType.Type.TEXT)
self.assertEqual(outer_addition.right.type, exp.DataType.Type.VARCHAR) self.assertEqual(outer_addition.right.type.this, exp.DataType.Type.VARCHAR)
inner_addition = expression.expressions[0].this.left # tbl.cola + tbl.colb inner_addition = expression.expressions[0].this.left # tbl.cola + tbl.colb
self.assertEqual(inner_addition.left.type, exp.DataType.Type.VARCHAR) self.assertEqual(inner_addition.left.type.this, exp.DataType.Type.VARCHAR)
self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT) self.assertEqual(inner_addition.right.type.this, exp.DataType.Type.TEXT)
# WHERE tbl.colc = True
self.assertEqual(expression.args["where"].this.type.this, exp.DataType.Type.BOOLEAN)
cte_select = expression.args["with"].expressions[0].this cte_select = expression.args["with"].expressions[0].this
self.assertEqual( self.assertEqual(
cte_select.expressions[0].type, exp.DataType.Type.VARCHAR cte_select.expressions[0].type.this, exp.DataType.Type.VARCHAR
) # x.cola + 'bla' AS cola ) # x.cola + 'bla' AS cola
self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb self.assertEqual(
cte_select.expressions[1].type.this, exp.DataType.Type.TEXT
) # y.colb AS colb
self.assertEqual(
cte_select.expressions[2].type.this, exp.DataType.Type.BOOLEAN
) # y.colc AS colc
cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla' cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla'
self.assertEqual(cte_select_addition.type, exp.DataType.Type.VARCHAR) self.assertEqual(cte_select_addition.type.this, exp.DataType.Type.VARCHAR)
self.assertEqual(cte_select_addition.left.type, exp.DataType.Type.CHAR) self.assertEqual(cte_select_addition.left.type.this, exp.DataType.Type.CHAR)
self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR) self.assertEqual(cte_select_addition.right.type.this, exp.DataType.Type.VARCHAR)
# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively # Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
for d, t in zip( for d, t in zip(
cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT] cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]
): ):
self.assertEqual(d.this.expressions[0].this.type, t) self.assertEqual(d.this.expressions[0].this.type.this, t)
def test_function_annotation(self): def test_function_annotation(self):
schema = {"x": {"cola": "VARCHAR", "colb": "CHAR"}} schema = {"x": {"cola": "VARCHAR", "colb": "CHAR"}}
sql = "SELECT x.cola || TRIM(x.colb) AS col FROM x AS x" sql = "SELECT x.cola || TRIM(x.colb) AS col FROM x AS x"
concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
self.assertEqual(concat_expr_alias.type, exp.DataType.Type.VARCHAR) self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.VARCHAR)
concat_expr = concat_expr_alias.this concat_expr = concat_expr_alias.this
self.assertEqual(concat_expr.type, exp.DataType.Type.VARCHAR) self.assertEqual(concat_expr.type.this, exp.DataType.Type.VARCHAR)
self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.VARCHAR) # x.cola
self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb) self.assertEqual(concat_expr.right.type.this, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb self.assertEqual(concat_expr.right.this.type.this, exp.DataType.Type.CHAR) # x.colb
sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x" sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x"
case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
self.assertEqual(case_expr_alias.type, exp.DataType.Type.VARCHAR) self.assertEqual(case_expr_alias.type.this, exp.DataType.Type.VARCHAR)
case_expr = case_expr_alias.this case_expr = case_expr_alias.this
self.assertEqual(case_expr.type, exp.DataType.Type.VARCHAR) self.assertEqual(case_expr.type.this, exp.DataType.Type.VARCHAR)
self.assertEqual(case_expr.args["default"].type, exp.DataType.Type.CHAR) self.assertEqual(case_expr.args["default"].type.this, exp.DataType.Type.CHAR)
case_ifs_expr = case_expr.args["ifs"][0] case_ifs_expr = case_expr.args["ifs"][0]
self.assertEqual(case_ifs_expr.type, exp.DataType.Type.VARCHAR) self.assertEqual(case_ifs_expr.type.this, exp.DataType.Type.VARCHAR)
self.assertEqual(case_ifs_expr.args["true"].type, exp.DataType.Type.VARCHAR) self.assertEqual(case_ifs_expr.args["true"].type.this, exp.DataType.Type.VARCHAR)
def test_unknown_annotation(self): def test_unknown_annotation(self):
schema = {"x": {"cola": "VARCHAR"}} schema = {"x": {"cola": "VARCHAR"}}
sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x" sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x"
concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
self.assertEqual(concat_expr_alias.type, exp.DataType.Type.UNKNOWN) self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.UNKNOWN)
concat_expr = concat_expr_alias.this concat_expr = concat_expr_alias.this
self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN) self.assertEqual(concat_expr.type.this, exp.DataType.Type.UNKNOWN)
self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.VARCHAR) # x.cola
self.assertEqual( self.assertEqual(
concat_expr.right.type, exp.DataType.Type.UNKNOWN concat_expr.right.type.this, exp.DataType.Type.UNKNOWN
) # SOME_ANONYMOUS_FUNC(x.cola) ) # SOME_ANONYMOUS_FUNC(x.cola)
self.assertEqual( self.assertEqual(
concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR concat_expr.right.expressions[0].type.this, exp.DataType.Type.VARCHAR
) # x.cola (arg) ) # x.cola (arg)
def test_null_annotation(self): def test_null_annotation(self):
expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this
self.assertEqual(expression.left.type, exp.DataType.Type.NULL) self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL)
self.assertEqual(expression.right.type, exp.DataType.Type.INT) self.assertEqual(expression.right.type.this, exp.DataType.Type.INT)
# NULL <op> UNKNOWN should yield NULL # NULL <op> UNKNOWN should yield NULL
sql = "SELECT NULL || SOME_ANONYMOUS_FUNC() AS result" sql = "SELECT NULL || SOME_ANONYMOUS_FUNC() AS result"
concat_expr_alias = annotate_types(parse_one(sql)).expressions[0] concat_expr_alias = annotate_types(parse_one(sql)).expressions[0]
self.assertEqual(concat_expr_alias.type, exp.DataType.Type.NULL) self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.NULL)
concat_expr = concat_expr_alias.this concat_expr = concat_expr_alias.this
self.assertEqual(concat_expr.type, exp.DataType.Type.NULL) self.assertEqual(concat_expr.type.this, exp.DataType.Type.NULL)
self.assertEqual(concat_expr.left.type, exp.DataType.Type.NULL) self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.NULL)
self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) self.assertEqual(concat_expr.right.type.this, exp.DataType.Type.UNKNOWN)
def test_nullable_annotation(self): def test_nullable_annotation(self):
nullable = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")) nullable = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
expression = annotate_types(parse_one("NULL AND FALSE")) expression = annotate_types(parse_one("NULL AND FALSE"))
self.assertEqual(expression.type, nullable) self.assertEqual(expression.type, nullable)
self.assertEqual(expression.left.type, exp.DataType.Type.NULL) self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL)
self.assertEqual(expression.right.type, exp.DataType.Type.BOOLEAN) self.assertEqual(expression.right.type.this, exp.DataType.Type.BOOLEAN)
def test_predicate_annotation(self):
expression = annotate_types(parse_one("x BETWEEN a AND b"))
self.assertEqual(expression.type.this, exp.DataType.Type.BOOLEAN)
expression = annotate_types(parse_one("x IN (a, b, c, d)"))
self.assertEqual(expression.type.this, exp.DataType.Type.BOOLEAN)
def test_aggfunc_annotation(self):
schema = {"x": {"cola": "SMALLINT", "colb": "FLOAT", "colc": "TEXT", "cold": "DATE"}}
tests = {
("AVG", "cola"): exp.DataType.Type.DOUBLE,
("SUM", "cola"): exp.DataType.Type.BIGINT,
("SUM", "colb"): exp.DataType.Type.DOUBLE,
("MIN", "cola"): exp.DataType.Type.SMALLINT,
("MIN", "colb"): exp.DataType.Type.FLOAT,
("MAX", "colc"): exp.DataType.Type.TEXT,
("MAX", "cold"): exp.DataType.Type.DATE,
("COUNT", "colb"): exp.DataType.Type.BIGINT,
("STDDEV", "cola"): exp.DataType.Type.DOUBLE,
}
for (func, col), target_type in tests.items():
expression = annotate_types(
parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), schema=schema
)
self.assertEqual(expression.expressions[0].type.this, target_type)

View file

@ -151,31 +151,33 @@ class TestSchema(unittest.TestCase):
def test_schema_get_column_type(self): def test_schema_get_column_type(self):
schema = MappingSchema({"a": {"b": "varchar"}}) schema = MappingSchema({"a": {"b": "varchar"}})
self.assertEqual(schema.get_column_type("a", "b"), exp.DataType.Type.VARCHAR) self.assertEqual(schema.get_column_type("a", "b").this, exp.DataType.Type.VARCHAR)
self.assertEqual( self.assertEqual(
schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")), schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")).this,
exp.DataType.Type.VARCHAR, exp.DataType.Type.VARCHAR,
) )
self.assertEqual( self.assertEqual(
schema.get_column_type("a", exp.Column(this="b")), exp.DataType.Type.VARCHAR schema.get_column_type("a", exp.Column(this="b")).this, exp.DataType.Type.VARCHAR
) )
self.assertEqual( self.assertEqual(
schema.get_column_type(exp.Table(this="a"), "b"), exp.DataType.Type.VARCHAR schema.get_column_type(exp.Table(this="a"), "b").this, exp.DataType.Type.VARCHAR
) )
schema = MappingSchema({"a": {"b": {"c": "varchar"}}}) schema = MappingSchema({"a": {"b": {"c": "varchar"}}})
self.assertEqual( self.assertEqual(
schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")), schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")).this,
exp.DataType.Type.VARCHAR, exp.DataType.Type.VARCHAR,
) )
self.assertEqual( self.assertEqual(
schema.get_column_type(exp.Table(this="b", db="a"), "c"), exp.DataType.Type.VARCHAR schema.get_column_type(exp.Table(this="b", db="a"), "c").this, exp.DataType.Type.VARCHAR
) )
schema = MappingSchema({"a": {"b": {"c": {"d": "varchar"}}}}) schema = MappingSchema({"a": {"b": {"c": {"d": "varchar"}}}})
self.assertEqual( self.assertEqual(
schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d")), schema.get_column_type(
exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d")
).this,
exp.DataType.Type.VARCHAR, exp.DataType.Type.VARCHAR,
) )
self.assertEqual( self.assertEqual(
schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d"), schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d").this,
exp.DataType.Type.VARCHAR, exp.DataType.Type.VARCHAR,
) )

View file

@ -1,6 +1,6 @@
import unittest import unittest
from sqlglot.tokens import Tokenizer from sqlglot.tokens import Tokenizer, TokenType
class TestTokens(unittest.TestCase): class TestTokens(unittest.TestCase):
@ -17,3 +17,48 @@ class TestTokens(unittest.TestCase):
for sql, comment in sql_comment: for sql, comment in sql_comment:
self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment) self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment)
def test_jinja(self):
tokenizer = Tokenizer()
tokens = tokenizer.tokenize(
"""
SELECT
{{ x }},
{{- x -}},
{% for x in y -%}
a {{+ b }}
{% endfor %};
"""
)
tokens = [(token.token_type, token.text) for token in tokens]
self.assertEqual(
tokens,
[
(TokenType.SELECT, "SELECT"),
(TokenType.BLOCK_START, "{{"),
(TokenType.VAR, "x"),
(TokenType.BLOCK_END, "}}"),
(TokenType.COMMA, ","),
(TokenType.BLOCK_START, "{{-"),
(TokenType.VAR, "x"),
(TokenType.BLOCK_END, "-}}"),
(TokenType.COMMA, ","),
(TokenType.BLOCK_START, "{%"),
(TokenType.FOR, "for"),
(TokenType.VAR, "x"),
(TokenType.IN, "in"),
(TokenType.VAR, "y"),
(TokenType.BLOCK_END, "-%}"),
(TokenType.VAR, "a"),
(TokenType.BLOCK_START, "{{+"),
(TokenType.VAR, "b"),
(TokenType.BLOCK_END, "}}"),
(TokenType.BLOCK_START, "{%"),
(TokenType.VAR, "endfor"),
(TokenType.BLOCK_END, "%}"),
(TokenType.SEMICOLON, ";"),
],
)