Merging upstream version 10.2.6.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
40155883c5
commit
17f6b2c749
36 changed files with 1281 additions and 493 deletions
37
CHANGELOG.md
37
CHANGELOG.md
|
@ -1,6 +1,43 @@
|
||||||
Changelog
|
Changelog
|
||||||
=========
|
=========
|
||||||
|
|
||||||
|
v10.2.0
|
||||||
|
------
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
- Breaking: types inferred from annotate_types are now DataType objects, instead of DataType.Type.
|
||||||
|
|
||||||
|
- New: the optimizer can now simplify [BETWEEN expressions expressed as explicit comparisons](https://github.com/tobymao/sqlglot/commit/e24d0317dfa644104ff21d009b790224bf84d698).
|
||||||
|
|
||||||
|
- New: the optimizer now removes redundant casts.
|
||||||
|
|
||||||
|
- New: added support for Redshift's ENCODE/DECODE.
|
||||||
|
|
||||||
|
- New: the optimizer now [treats identifiers as case-insensitive](https://github.com/tobymao/sqlglot/commit/638ed265f195219d7226f4fbae128f1805ae8988).
|
||||||
|
|
||||||
|
- New: the optimizer now [handles nested CTEs](https://github.com/tobymao/sqlglot/commit/1bdd652792889a8aaffb1c6d2c8aa1fe4a066281).
|
||||||
|
|
||||||
|
- New: the executor can now execute SELECT DISTINCT expressions.
|
||||||
|
|
||||||
|
- New: added support for Redshift's COPY and UNLOAD commands.
|
||||||
|
|
||||||
|
- New: added ability to parse LIKE in CREATE TABLE statement.
|
||||||
|
|
||||||
|
- New: the optimizer now [unnests scalar subqueries as cross joins](https://github.com/tobymao/sqlglot/commit/4373ad8518ede4ef1fda8b247b648c680a93d12d).
|
||||||
|
|
||||||
|
- Improvement: fixed Bigquery's ARRAY function parsing, so that it can now handle a SELECT expression as an argument.
|
||||||
|
|
||||||
|
- Improvement: improved Snowflake's [ARRAY and MAP constructs](https://github.com/tobymao/sqlglot/commit/0506657dba55fe71d004c81c907e23cdd2b37d82).
|
||||||
|
|
||||||
|
- Improvement: fixed transpilation between STRING_AGG and GROUP_CONCAT.
|
||||||
|
|
||||||
|
- Improvement: the INTO clause can now be parsed in SELECT expressions.
|
||||||
|
|
||||||
|
- Improvement: improve executor; it currently executes all TPC-H queries up to TPC-H 17 (inclusive).
|
||||||
|
|
||||||
|
- Improvement: DISTINCT ON is now transpiled to a SELECT expression from a subquery for Redshift.
|
||||||
|
|
||||||
v10.1.0
|
v10.1.0
|
||||||
------
|
------
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ from sqlglot.parser import Parser
|
||||||
from sqlglot.schema import MappingSchema
|
from sqlglot.schema import MappingSchema
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
from sqlglot.tokens import Tokenizer, TokenType
|
||||||
|
|
||||||
__version__ = "10.1.3"
|
__version__ = "10.2.6"
|
||||||
|
|
||||||
pretty = False
|
pretty = False
|
||||||
|
|
||||||
|
|
|
@ -317,7 +317,7 @@ class DataFrame:
|
||||||
sqlglot.schema.add_table(
|
sqlglot.schema.add_table(
|
||||||
cache_table_name,
|
cache_table_name,
|
||||||
{
|
{
|
||||||
expression.alias_or_name: expression.type.name
|
expression.alias_or_name: expression.type.sql("spark")
|
||||||
for expression in select_expression.expressions
|
for expression in select_expression.expressions
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
@ -110,17 +110,17 @@ class BigQuery(Dialect):
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**tokens.Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
|
"BEGIN": TokenType.COMMAND,
|
||||||
|
"BEGIN TRANSACTION": TokenType.BEGIN,
|
||||||
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
|
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
|
||||||
"CURRENT_TIME": TokenType.CURRENT_TIME,
|
"CURRENT_TIME": TokenType.CURRENT_TIME,
|
||||||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||||
"INT64": TokenType.BIGINT,
|
|
||||||
"FLOAT64": TokenType.DOUBLE,
|
"FLOAT64": TokenType.DOUBLE,
|
||||||
|
"INT64": TokenType.BIGINT,
|
||||||
|
"NOT DETERMINISTIC": TokenType.VOLATILE,
|
||||||
"QUALIFY": TokenType.QUALIFY,
|
"QUALIFY": TokenType.QUALIFY,
|
||||||
"UNKNOWN": TokenType.NULL,
|
"UNKNOWN": TokenType.NULL,
|
||||||
"WINDOW": TokenType.WINDOW,
|
"WINDOW": TokenType.WINDOW,
|
||||||
"NOT DETERMINISTIC": TokenType.VOLATILE,
|
|
||||||
"BEGIN": TokenType.COMMAND,
|
|
||||||
"BEGIN TRANSACTION": TokenType.BEGIN,
|
|
||||||
}
|
}
|
||||||
KEYWORDS.pop("DIV")
|
KEYWORDS.pop("DIV")
|
||||||
|
|
||||||
|
@ -131,6 +131,7 @@ class BigQuery(Dialect):
|
||||||
"DATE_ADD": _date_add(exp.DateAdd),
|
"DATE_ADD": _date_add(exp.DateAdd),
|
||||||
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
|
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
|
||||||
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
|
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
|
||||||
|
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
|
||||||
"TIME_ADD": _date_add(exp.TimeAdd),
|
"TIME_ADD": _date_add(exp.TimeAdd),
|
||||||
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
|
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
|
||||||
"DATE_SUB": _date_add(exp.DateSub),
|
"DATE_SUB": _date_add(exp.DateSub),
|
||||||
|
@ -144,6 +145,7 @@ class BigQuery(Dialect):
|
||||||
|
|
||||||
FUNCTION_PARSERS = {
|
FUNCTION_PARSERS = {
|
||||||
**parser.Parser.FUNCTION_PARSERS,
|
**parser.Parser.FUNCTION_PARSERS,
|
||||||
|
"ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]),
|
||||||
}
|
}
|
||||||
FUNCTION_PARSERS.pop("TRIM")
|
FUNCTION_PARSERS.pop("TRIM")
|
||||||
|
|
||||||
|
@ -161,7 +163,6 @@ class BigQuery(Dialect):
|
||||||
class Generator(generator.Generator):
|
class Generator(generator.Generator):
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**generator.Generator.TRANSFORMS,
|
**generator.Generator.TRANSFORMS,
|
||||||
exp.Array: inline_array_sql,
|
|
||||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||||
exp.DateAdd: _date_add_sql("DATE", "ADD"),
|
exp.DateAdd: _date_add_sql("DATE", "ADD"),
|
||||||
exp.DateSub: _date_add_sql("DATE", "SUB"),
|
exp.DateSub: _date_add_sql("DATE", "SUB"),
|
||||||
|
@ -183,6 +184,7 @@ class BigQuery(Dialect):
|
||||||
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
|
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
|
||||||
if e.name == "IMMUTABLE"
|
if e.name == "IMMUTABLE"
|
||||||
else "NOT DETERMINISTIC",
|
else "NOT DETERMINISTIC",
|
||||||
|
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
|
@ -210,24 +212,31 @@ class BigQuery(Dialect):
|
||||||
|
|
||||||
EXPLICIT_UNION = True
|
EXPLICIT_UNION = True
|
||||||
|
|
||||||
def transaction_sql(self, *_):
|
def array_sql(self, expression: exp.Array) -> str:
|
||||||
|
first_arg = seq_get(expression.expressions, 0)
|
||||||
|
if isinstance(first_arg, exp.Subqueryable):
|
||||||
|
return f"ARRAY{self.wrap(self.sql(first_arg))}"
|
||||||
|
|
||||||
|
return inline_array_sql(self, expression)
|
||||||
|
|
||||||
|
def transaction_sql(self, *_) -> str:
|
||||||
return "BEGIN TRANSACTION"
|
return "BEGIN TRANSACTION"
|
||||||
|
|
||||||
def commit_sql(self, *_):
|
def commit_sql(self, *_) -> str:
|
||||||
return "COMMIT TRANSACTION"
|
return "COMMIT TRANSACTION"
|
||||||
|
|
||||||
def rollback_sql(self, *_):
|
def rollback_sql(self, *_) -> str:
|
||||||
return "ROLLBACK TRANSACTION"
|
return "ROLLBACK TRANSACTION"
|
||||||
|
|
||||||
def in_unnest_op(self, unnest):
|
def in_unnest_op(self, expression: exp.Unnest) -> str:
|
||||||
return self.sql(unnest)
|
return self.sql(expression)
|
||||||
|
|
||||||
def except_op(self, expression):
|
def except_op(self, expression: exp.Except) -> str:
|
||||||
if not expression.args.get("distinct", False):
|
if not expression.args.get("distinct", False):
|
||||||
self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery")
|
self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery")
|
||||||
return f"EXCEPT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
|
return f"EXCEPT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
|
||||||
|
|
||||||
def intersect_op(self, expression):
|
def intersect_op(self, expression: exp.Intersect) -> str:
|
||||||
if not expression.args.get("distinct", False):
|
if not expression.args.get("distinct", False):
|
||||||
self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery")
|
self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery")
|
||||||
return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
|
return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
|
||||||
|
|
|
@ -190,6 +190,7 @@ class Hive(Dialect):
|
||||||
"ADD FILES": TokenType.COMMAND,
|
"ADD FILES": TokenType.COMMAND,
|
||||||
"ADD JAR": TokenType.COMMAND,
|
"ADD JAR": TokenType.COMMAND,
|
||||||
"ADD JARS": TokenType.COMMAND,
|
"ADD JARS": TokenType.COMMAND,
|
||||||
|
"WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
|
||||||
}
|
}
|
||||||
|
|
||||||
class Parser(parser.Parser):
|
class Parser(parser.Parser):
|
||||||
|
@ -238,6 +239,13 @@ class Hive(Dialect):
|
||||||
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
|
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PROPERTY_PARSERS = {
|
||||||
|
**parser.Parser.PROPERTY_PARSERS,
|
||||||
|
TokenType.SERDE_PROPERTIES: lambda self: exp.SerdeProperties(
|
||||||
|
expressions=self._parse_wrapped_csv(self._parse_property)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
class Generator(generator.Generator):
|
class Generator(generator.Generator):
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**generator.Generator.TYPE_MAPPING,
|
**generator.Generator.TYPE_MAPPING,
|
||||||
|
@ -297,6 +305,8 @@ class Hive(Dialect):
|
||||||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||||
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
|
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
|
||||||
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
|
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
|
||||||
|
exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}",
|
||||||
|
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
|
||||||
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
|
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -308,12 +318,15 @@ class Hive(Dialect):
|
||||||
exp.SchemaCommentProperty,
|
exp.SchemaCommentProperty,
|
||||||
exp.LocationProperty,
|
exp.LocationProperty,
|
||||||
exp.TableFormatProperty,
|
exp.TableFormatProperty,
|
||||||
|
exp.RowFormatDelimitedProperty,
|
||||||
|
exp.RowFormatSerdeProperty,
|
||||||
|
exp.SerdeProperties,
|
||||||
}
|
}
|
||||||
|
|
||||||
def with_properties(self, properties):
|
def with_properties(self, properties):
|
||||||
return self.properties(
|
return self.properties(
|
||||||
properties,
|
properties,
|
||||||
prefix="TBLPROPERTIES",
|
prefix=self.seg("TBLPROPERTIES"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def datatype_sql(self, expression):
|
def datatype_sql(self, expression):
|
||||||
|
|
|
@ -98,6 +98,7 @@ class Oracle(Dialect):
|
||||||
class Tokenizer(tokens.Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**tokens.Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
|
"MINUS": TokenType.EXCEPT,
|
||||||
"START": TokenType.BEGIN,
|
"START": TokenType.BEGIN,
|
||||||
"TOP": TokenType.TOP,
|
"TOP": TokenType.TOP,
|
||||||
"VARCHAR2": TokenType.VARCHAR,
|
"VARCHAR2": TokenType.VARCHAR,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from sqlglot import exp, transforms
|
from sqlglot import exp, transforms
|
||||||
|
from sqlglot.dialects.dialect import rename_func
|
||||||
from sqlglot.dialects.postgres import Postgres
|
from sqlglot.dialects.postgres import Postgres
|
||||||
from sqlglot.tokens import TokenType
|
from sqlglot.tokens import TokenType
|
||||||
|
|
||||||
|
@ -13,12 +14,20 @@ class Redshift(Postgres):
|
||||||
"HH": "%H",
|
"HH": "%H",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class Parser(Postgres.Parser):
|
||||||
|
FUNCTIONS = {
|
||||||
|
**Postgres.Parser.FUNCTIONS, # type: ignore
|
||||||
|
"DECODE": exp.Matches.from_arg_list,
|
||||||
|
"NVL": exp.Coalesce.from_arg_list,
|
||||||
|
}
|
||||||
|
|
||||||
class Tokenizer(Postgres.Tokenizer):
|
class Tokenizer(Postgres.Tokenizer):
|
||||||
ESCAPES = ["\\"]
|
ESCAPES = ["\\"]
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**Postgres.Tokenizer.KEYWORDS, # type: ignore
|
**Postgres.Tokenizer.KEYWORDS, # type: ignore
|
||||||
"COPY": TokenType.COMMAND,
|
"COPY": TokenType.COMMAND,
|
||||||
|
"ENCODE": TokenType.ENCODE,
|
||||||
"GEOMETRY": TokenType.GEOMETRY,
|
"GEOMETRY": TokenType.GEOMETRY,
|
||||||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||||
"HLLSKETCH": TokenType.HLLSKETCH,
|
"HLLSKETCH": TokenType.HLLSKETCH,
|
||||||
|
@ -50,4 +59,5 @@ class Redshift(Postgres):
|
||||||
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
|
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
|
||||||
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||||
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
|
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
|
||||||
|
exp.Matches: rename_func("DECODE"),
|
||||||
}
|
}
|
||||||
|
|
|
@ -198,6 +198,7 @@ class Snowflake(Dialect):
|
||||||
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
|
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
|
||||||
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
|
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
|
||||||
"TIMESTAMPNTZ": TokenType.TIMESTAMP,
|
"TIMESTAMPNTZ": TokenType.TIMESTAMP,
|
||||||
|
"MINUS": TokenType.EXCEPT,
|
||||||
"SAMPLE": TokenType.TABLE_SAMPLE,
|
"SAMPLE": TokenType.TABLE_SAMPLE,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,10 +19,13 @@ class reverse_key:
|
||||||
return other.obj < self.obj
|
return other.obj < self.obj
|
||||||
|
|
||||||
|
|
||||||
def filter_nulls(func):
|
def filter_nulls(func, empty_null=True):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def _func(values):
|
def _func(values):
|
||||||
return func(v for v in values if v is not None)
|
filtered = tuple(v for v in values if v is not None)
|
||||||
|
if not filtered and empty_null:
|
||||||
|
return None
|
||||||
|
return func(filtered)
|
||||||
|
|
||||||
return _func
|
return _func
|
||||||
|
|
||||||
|
@ -126,7 +129,7 @@ ENV = {
|
||||||
# aggs
|
# aggs
|
||||||
"SUM": filter_nulls(sum),
|
"SUM": filter_nulls(sum),
|
||||||
"AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
|
"AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
|
||||||
"COUNT": filter_nulls(lambda acc: sum(1 for _ in acc)),
|
"COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False),
|
||||||
"MAX": filter_nulls(max),
|
"MAX": filter_nulls(max),
|
||||||
"MIN": filter_nulls(min),
|
"MIN": filter_nulls(min),
|
||||||
# scalar functions
|
# scalar functions
|
||||||
|
|
|
@ -310,9 +310,9 @@ class PythonExecutor:
|
||||||
if i == length - 1:
|
if i == length - 1:
|
||||||
context.set_range(start, end - 1)
|
context.set_range(start, end - 1)
|
||||||
add_row()
|
add_row()
|
||||||
elif step.limit > 0:
|
elif step.limit > 0 and not group_by:
|
||||||
context.set_range(0, 0)
|
context.set_range(0, 0)
|
||||||
table.append(context.eval_tuple(group_by) + context.eval_tuple(aggregations))
|
table.append(context.eval_tuple(aggregations))
|
||||||
|
|
||||||
context = self.context({step.name: table, **{name: table for name in context.tables}})
|
context = self.context({step.name: table, **{name: table for name in context.tables}})
|
||||||
|
|
||||||
|
|
|
@ -43,14 +43,14 @@ class Expression(metaclass=_Expression):
|
||||||
|
|
||||||
key = "Expression"
|
key = "Expression"
|
||||||
arg_types = {"this": True}
|
arg_types = {"this": True}
|
||||||
__slots__ = ("args", "parent", "arg_key", "type", "comments")
|
__slots__ = ("args", "parent", "arg_key", "comments", "_type")
|
||||||
|
|
||||||
def __init__(self, **args):
|
def __init__(self, **args):
|
||||||
self.args = args
|
self.args = args
|
||||||
self.parent = None
|
self.parent = None
|
||||||
self.arg_key = None
|
self.arg_key = None
|
||||||
self.type = None
|
|
||||||
self.comments = None
|
self.comments = None
|
||||||
|
self._type: t.Optional[DataType] = None
|
||||||
|
|
||||||
for arg_key, value in self.args.items():
|
for arg_key, value in self.args.items():
|
||||||
self._set_parent(arg_key, value)
|
self._set_parent(arg_key, value)
|
||||||
|
@ -122,6 +122,16 @@ class Expression(metaclass=_Expression):
|
||||||
return "NULL"
|
return "NULL"
|
||||||
return self.alias or self.name
|
return self.alias or self.name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> t.Optional[DataType]:
|
||||||
|
return self._type
|
||||||
|
|
||||||
|
@type.setter
|
||||||
|
def type(self, dtype: t.Optional[DataType | DataType.Type | str]) -> None:
|
||||||
|
if dtype and not isinstance(dtype, DataType):
|
||||||
|
dtype = DataType.build(dtype)
|
||||||
|
self._type = dtype # type: ignore
|
||||||
|
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, memo):
|
||||||
copy = self.__class__(**deepcopy(self.args))
|
copy = self.__class__(**deepcopy(self.args))
|
||||||
copy.comments = self.comments
|
copy.comments = self.comments
|
||||||
|
@ -348,7 +358,7 @@ class Expression(metaclass=_Expression):
|
||||||
indent += "".join([" "] * level)
|
indent += "".join([" "] * level)
|
||||||
left = f"({self.key.upper()} "
|
left = f"({self.key.upper()} "
|
||||||
|
|
||||||
args = {
|
args: t.Dict[str, t.Any] = {
|
||||||
k: ", ".join(
|
k: ", ".join(
|
||||||
v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
|
v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
|
||||||
for v in ensure_collection(vs)
|
for v in ensure_collection(vs)
|
||||||
|
@ -612,6 +622,7 @@ class Create(Expression):
|
||||||
"properties": False,
|
"properties": False,
|
||||||
"temporary": False,
|
"temporary": False,
|
||||||
"transient": False,
|
"transient": False,
|
||||||
|
"external": False,
|
||||||
"replace": False,
|
"replace": False,
|
||||||
"unique": False,
|
"unique": False,
|
||||||
"materialized": False,
|
"materialized": False,
|
||||||
|
@ -744,13 +755,17 @@ class DefaultColumnConstraint(ColumnConstraintKind):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class EncodeColumnConstraint(ColumnConstraintKind):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
|
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
|
||||||
# this: True -> ALWAYS, this: False -> BY DEFAULT
|
# this: True -> ALWAYS, this: False -> BY DEFAULT
|
||||||
arg_types = {"this": True, "expression": False}
|
arg_types = {"this": True, "expression": False}
|
||||||
|
|
||||||
|
|
||||||
class NotNullColumnConstraint(ColumnConstraintKind):
|
class NotNullColumnConstraint(ColumnConstraintKind):
|
||||||
pass
|
arg_types = {"allow_null": False}
|
||||||
|
|
||||||
|
|
||||||
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
|
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
|
||||||
|
@ -766,7 +781,7 @@ class Constraint(Expression):
|
||||||
|
|
||||||
|
|
||||||
class Delete(Expression):
|
class Delete(Expression):
|
||||||
arg_types = {"with": False, "this": True, "using": False, "where": False}
|
arg_types = {"with": False, "this": False, "using": False, "where": False}
|
||||||
|
|
||||||
|
|
||||||
class Drop(Expression):
|
class Drop(Expression):
|
||||||
|
@ -850,7 +865,7 @@ class Insert(Expression):
|
||||||
arg_types = {
|
arg_types = {
|
||||||
"with": False,
|
"with": False,
|
||||||
"this": True,
|
"this": True,
|
||||||
"expression": True,
|
"expression": False,
|
||||||
"overwrite": False,
|
"overwrite": False,
|
||||||
"exists": False,
|
"exists": False,
|
||||||
"partition": False,
|
"partition": False,
|
||||||
|
@ -1125,6 +1140,27 @@ class VolatilityProperty(Property):
|
||||||
arg_types = {"this": True}
|
arg_types = {"this": True}
|
||||||
|
|
||||||
|
|
||||||
|
class RowFormatDelimitedProperty(Property):
|
||||||
|
# https://cwiki.apache.org/confluence/display/hive/languagemanual+dml
|
||||||
|
arg_types = {
|
||||||
|
"fields": False,
|
||||||
|
"escaped": False,
|
||||||
|
"collection_items": False,
|
||||||
|
"map_keys": False,
|
||||||
|
"lines": False,
|
||||||
|
"null": False,
|
||||||
|
"serde": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RowFormatSerdeProperty(Property):
|
||||||
|
arg_types = {"this": True}
|
||||||
|
|
||||||
|
|
||||||
|
class SerdeProperties(Property):
|
||||||
|
arg_types = {"expressions": True}
|
||||||
|
|
||||||
|
|
||||||
class Properties(Expression):
|
class Properties(Expression):
|
||||||
arg_types = {"expressions": True}
|
arg_types = {"expressions": True}
|
||||||
|
|
||||||
|
@ -1169,18 +1205,6 @@ class Reference(Expression):
|
||||||
arg_types = {"this": True, "expressions": True}
|
arg_types = {"this": True, "expressions": True}
|
||||||
|
|
||||||
|
|
||||||
class RowFormat(Expression):
|
|
||||||
# https://cwiki.apache.org/confluence/display/hive/languagemanual+dml
|
|
||||||
arg_types = {
|
|
||||||
"fields": False,
|
|
||||||
"escaped": False,
|
|
||||||
"collection_items": False,
|
|
||||||
"map_keys": False,
|
|
||||||
"lines": False,
|
|
||||||
"null": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class Tuple(Expression):
|
class Tuple(Expression):
|
||||||
arg_types = {"expressions": False}
|
arg_types = {"expressions": False}
|
||||||
|
|
||||||
|
@ -1208,6 +1232,9 @@ class Subqueryable(Unionable):
|
||||||
alias=TableAlias(this=to_identifier(alias)),
|
alias=TableAlias(this=to_identifier(alias)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ctes(self):
|
def ctes(self):
|
||||||
with_ = self.args.get("with")
|
with_ = self.args.get("with")
|
||||||
|
@ -1320,6 +1347,32 @@ class Union(Subqueryable):
|
||||||
**QUERY_MODIFIERS,
|
**QUERY_MODIFIERS,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
|
||||||
|
"""
|
||||||
|
Set the LIMIT expression.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> select("1").union(select("1")).limit(1).sql()
|
||||||
|
'SELECT * FROM (SELECT 1 UNION SELECT 1) AS "_l_0" LIMIT 1'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression (str | int | Expression): the SQL code string to parse.
|
||||||
|
This can also be an integer.
|
||||||
|
If a `Limit` instance is passed, this is used as-is.
|
||||||
|
If another `Expression` instance is passed, it will be wrapped in a `Limit`.
|
||||||
|
dialect (str): the dialect used to parse the input expression.
|
||||||
|
copy (bool): if `False`, modify this expression instance in-place.
|
||||||
|
opts (kwargs): other options to use to parse the input expressions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Select: The limited subqueryable.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
select("*")
|
||||||
|
.from_(self.subquery(alias="_l_0", copy=copy))
|
||||||
|
.limit(expression, dialect=dialect, copy=False, **opts)
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def named_selects(self):
|
def named_selects(self):
|
||||||
return self.this.unnest().named_selects
|
return self.this.unnest().named_selects
|
||||||
|
@ -1356,7 +1409,7 @@ class Unnest(UDTF):
|
||||||
class Update(Expression):
|
class Update(Expression):
|
||||||
arg_types = {
|
arg_types = {
|
||||||
"with": False,
|
"with": False,
|
||||||
"this": True,
|
"this": False,
|
||||||
"expressions": True,
|
"expressions": True,
|
||||||
"from": False,
|
"from": False,
|
||||||
"where": False,
|
"where": False,
|
||||||
|
@ -2057,15 +2110,20 @@ class DataType(Expression):
|
||||||
Type.TEXT,
|
Type.TEXT,
|
||||||
}
|
}
|
||||||
|
|
||||||
NUMERIC_TYPES = {
|
INTEGER_TYPES = {
|
||||||
Type.INT,
|
Type.INT,
|
||||||
Type.TINYINT,
|
Type.TINYINT,
|
||||||
Type.SMALLINT,
|
Type.SMALLINT,
|
||||||
Type.BIGINT,
|
Type.BIGINT,
|
||||||
|
}
|
||||||
|
|
||||||
|
FLOAT_TYPES = {
|
||||||
Type.FLOAT,
|
Type.FLOAT,
|
||||||
Type.DOUBLE,
|
Type.DOUBLE,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES}
|
||||||
|
|
||||||
TEMPORAL_TYPES = {
|
TEMPORAL_TYPES = {
|
||||||
Type.TIMESTAMP,
|
Type.TIMESTAMP,
|
||||||
Type.TIMESTAMPTZ,
|
Type.TIMESTAMPTZ,
|
||||||
|
@ -2968,6 +3026,14 @@ class Use(Expression):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Merge(Expression):
|
||||||
|
arg_types = {"this": True, "using": True, "on": True, "expressions": True}
|
||||||
|
|
||||||
|
|
||||||
|
class When(Func):
|
||||||
|
arg_types = {"this": True, "then": True}
|
||||||
|
|
||||||
|
|
||||||
def _norm_args(expression):
|
def _norm_args(expression):
|
||||||
args = {}
|
args = {}
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -385,3 +385,11 @@ def dict_depth(d: t.Dict) -> int:
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
# d.values() returns an empty sequence
|
# d.values() returns an empty sequence
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def first(it: t.Iterable[T]) -> T:
|
||||||
|
"""Returns the first element from an iterable.
|
||||||
|
|
||||||
|
Useful for sets.
|
||||||
|
"""
|
||||||
|
return next(i for i in it)
|
||||||
|
|
|
@ -14,7 +14,7 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
|
||||||
>>> schema = {"y": {"cola": "SMALLINT"}}
|
>>> schema = {"y": {"cola": "SMALLINT"}}
|
||||||
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
|
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
|
||||||
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
|
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
|
||||||
>>> annotated_expr.expressions[0].type # Get the type of "x.cola + 2.5 AS cola"
|
>>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola"
|
||||||
<Type.DOUBLE: 'DOUBLE'>
|
<Type.DOUBLE: 'DOUBLE'>
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -41,9 +41,12 @@ class TypeAnnotator:
|
||||||
expr_type: lambda self, expr: self._annotate_binary(expr)
|
expr_type: lambda self, expr: self._annotate_binary(expr)
|
||||||
for expr_type in subclasses(exp.__name__, exp.Binary)
|
for expr_type in subclasses(exp.__name__, exp.Binary)
|
||||||
},
|
},
|
||||||
exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"].this),
|
exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
|
||||||
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.this),
|
exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
|
||||||
|
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr),
|
||||||
exp.Alias: lambda self, expr: self._annotate_unary(expr),
|
exp.Alias: lambda self, expr: self._annotate_unary(expr),
|
||||||
|
exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||||
|
exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||||
exp.Literal: lambda self, expr: self._annotate_literal(expr),
|
exp.Literal: lambda self, expr: self._annotate_literal(expr),
|
||||||
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||||
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
|
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
|
||||||
|
@ -52,6 +55,9 @@ class TypeAnnotator:
|
||||||
expr, exp.DataType.Type.BIGINT
|
expr, exp.DataType.Type.BIGINT
|
||||||
),
|
),
|
||||||
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||||
|
exp.Min: lambda self, expr: self._annotate_by_args(expr, "this"),
|
||||||
|
exp.Max: lambda self, expr: self._annotate_by_args(expr, "this"),
|
||||||
|
exp.Sum: lambda self, expr: self._annotate_by_args(expr, "this", promote=True),
|
||||||
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||||
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||||
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||||
|
@ -263,10 +269,10 @@ class TypeAnnotator:
|
||||||
}
|
}
|
||||||
# First annotate the current scope's column references
|
# First annotate the current scope's column references
|
||||||
for col in scope.columns:
|
for col in scope.columns:
|
||||||
source = scope.sources[col.table]
|
source = scope.sources.get(col.table)
|
||||||
if isinstance(source, exp.Table):
|
if isinstance(source, exp.Table):
|
||||||
col.type = self.schema.get_column_type(source, col)
|
col.type = self.schema.get_column_type(source, col)
|
||||||
else:
|
elif source:
|
||||||
col.type = selects[col.table][col.name].type
|
col.type = selects[col.table][col.name].type
|
||||||
# Then (possibly) annotate the remaining expressions in the scope
|
# Then (possibly) annotate the remaining expressions in the scope
|
||||||
self._maybe_annotate(scope.expression)
|
self._maybe_annotate(scope.expression)
|
||||||
|
@ -280,6 +286,7 @@ class TypeAnnotator:
|
||||||
return expression # We've already inferred the expression's type
|
return expression # We've already inferred the expression's type
|
||||||
|
|
||||||
annotator = self.annotators.get(expression.__class__)
|
annotator = self.annotators.get(expression.__class__)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
annotator(self, expression)
|
annotator(self, expression)
|
||||||
if annotator
|
if annotator
|
||||||
|
@ -295,18 +302,23 @@ class TypeAnnotator:
|
||||||
|
|
||||||
def _maybe_coerce(self, type1, type2):
|
def _maybe_coerce(self, type1, type2):
|
||||||
# We propagate the NULL / UNKNOWN types upwards if found
|
# We propagate the NULL / UNKNOWN types upwards if found
|
||||||
|
if isinstance(type1, exp.DataType):
|
||||||
|
type1 = type1.this
|
||||||
|
if isinstance(type2, exp.DataType):
|
||||||
|
type2 = type2.this
|
||||||
|
|
||||||
if exp.DataType.Type.NULL in (type1, type2):
|
if exp.DataType.Type.NULL in (type1, type2):
|
||||||
return exp.DataType.Type.NULL
|
return exp.DataType.Type.NULL
|
||||||
if exp.DataType.Type.UNKNOWN in (type1, type2):
|
if exp.DataType.Type.UNKNOWN in (type1, type2):
|
||||||
return exp.DataType.Type.UNKNOWN
|
return exp.DataType.Type.UNKNOWN
|
||||||
|
|
||||||
return type2 if type2 in self.coerces_to[type1] else type1
|
return type2 if type2 in self.coerces_to.get(type1, {}) else type1
|
||||||
|
|
||||||
def _annotate_binary(self, expression):
|
def _annotate_binary(self, expression):
|
||||||
self._annotate_args(expression)
|
self._annotate_args(expression)
|
||||||
|
|
||||||
left_type = expression.left.type
|
left_type = expression.left.type.this
|
||||||
right_type = expression.right.type
|
right_type = expression.right.type.this
|
||||||
|
|
||||||
if isinstance(expression, (exp.And, exp.Or)):
|
if isinstance(expression, (exp.And, exp.Or)):
|
||||||
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
|
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
|
||||||
|
@ -348,7 +360,7 @@ class TypeAnnotator:
|
||||||
expression.type = target_type
|
expression.type = target_type
|
||||||
return self._annotate_args(expression)
|
return self._annotate_args(expression)
|
||||||
|
|
||||||
def _annotate_by_args(self, expression, *args):
|
def _annotate_by_args(self, expression, *args, promote=False):
|
||||||
self._annotate_args(expression)
|
self._annotate_args(expression)
|
||||||
expressions = []
|
expressions = []
|
||||||
for arg in args:
|
for arg in args:
|
||||||
|
@ -360,4 +372,11 @@ class TypeAnnotator:
|
||||||
last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
|
last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
|
||||||
|
|
||||||
expression.type = last_datatype or exp.DataType.Type.UNKNOWN
|
expression.type = last_datatype or exp.DataType.Type.UNKNOWN
|
||||||
|
|
||||||
|
if promote:
|
||||||
|
if expression.type.this in exp.DataType.INTEGER_TYPES:
|
||||||
|
expression.type = exp.DataType.Type.BIGINT
|
||||||
|
elif expression.type.this in exp.DataType.FLOAT_TYPES:
|
||||||
|
expression.type = exp.DataType.Type.DOUBLE
|
||||||
|
|
||||||
return expression
|
return expression
|
||||||
|
|
|
@ -13,13 +13,16 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
|
||||||
expression: The expression to canonicalize.
|
expression: The expression to canonicalize.
|
||||||
"""
|
"""
|
||||||
exp.replace_children(expression, canonicalize)
|
exp.replace_children(expression, canonicalize)
|
||||||
|
|
||||||
expression = add_text_to_concat(expression)
|
expression = add_text_to_concat(expression)
|
||||||
expression = coerce_type(expression)
|
expression = coerce_type(expression)
|
||||||
|
expression = remove_redundant_casts(expression)
|
||||||
|
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
|
|
||||||
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
|
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
|
||||||
if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES:
|
if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
|
||||||
node = exp.Concat(this=node.this, expression=node.expression)
|
node = exp.Concat(this=node.this, expression=node.expression)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
@ -30,14 +33,30 @@ def coerce_type(node: exp.Expression) -> exp.Expression:
|
||||||
elif isinstance(node, exp.Between):
|
elif isinstance(node, exp.Between):
|
||||||
_coerce_date(node.this, node.args["low"])
|
_coerce_date(node.this, node.args["low"])
|
||||||
elif isinstance(node, exp.Extract):
|
elif isinstance(node, exp.Extract):
|
||||||
if node.expression.type not in exp.DataType.TEMPORAL_TYPES:
|
if node.expression.type.this not in exp.DataType.TEMPORAL_TYPES:
|
||||||
_replace_cast(node.expression, "datetime")
|
_replace_cast(node.expression, "datetime")
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
|
||||||
|
if (
|
||||||
|
isinstance(expression, exp.Cast)
|
||||||
|
and expression.to.type
|
||||||
|
and expression.this.type
|
||||||
|
and expression.to.type.this == expression.this.type.this
|
||||||
|
):
|
||||||
|
return expression.this
|
||||||
|
return expression
|
||||||
|
|
||||||
|
|
||||||
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
|
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
|
||||||
for a, b in itertools.permutations([a, b]):
|
for a, b in itertools.permutations([a, b]):
|
||||||
if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE:
|
if (
|
||||||
|
a.type
|
||||||
|
and a.type.this == exp.DataType.Type.DATE
|
||||||
|
and b.type
|
||||||
|
and b.type.this != exp.DataType.Type.DATE
|
||||||
|
):
|
||||||
_replace_cast(b, "date")
|
_replace_cast(b, "date")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ from decimal import Decimal
|
||||||
from sqlglot import exp
|
from sqlglot import exp
|
||||||
from sqlglot.expressions import FALSE, NULL, TRUE
|
from sqlglot.expressions import FALSE, NULL, TRUE
|
||||||
from sqlglot.generator import Generator
|
from sqlglot.generator import Generator
|
||||||
from sqlglot.helper import while_changing
|
from sqlglot.helper import first, while_changing
|
||||||
|
|
||||||
GENERATOR = Generator(normalize=True, identify=True)
|
GENERATOR = Generator(normalize=True, identify=True)
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ def simplify(expression):
|
||||||
|
|
||||||
def _simplify(expression, root=True):
|
def _simplify(expression, root=True):
|
||||||
node = expression
|
node = expression
|
||||||
|
node = rewrite_between(node)
|
||||||
node = uniq_sort(node)
|
node = uniq_sort(node)
|
||||||
node = absorb_and_eliminate(node)
|
node = absorb_and_eliminate(node)
|
||||||
exp.replace_children(node, lambda e: _simplify(e, False))
|
exp.replace_children(node, lambda e: _simplify(e, False))
|
||||||
|
@ -49,6 +50,19 @@ def simplify(expression):
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
|
|
||||||
|
def rewrite_between(expression: exp.Expression) -> exp.Expression:
|
||||||
|
"""Rewrite x between y and z to x >= y AND x <= z.
|
||||||
|
|
||||||
|
This is done because comparison simplification is only done on lt/lte/gt/gte.
|
||||||
|
"""
|
||||||
|
if isinstance(expression, exp.Between):
|
||||||
|
return exp.and_(
|
||||||
|
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
|
||||||
|
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
|
||||||
|
)
|
||||||
|
return expression
|
||||||
|
|
||||||
|
|
||||||
def simplify_not(expression):
|
def simplify_not(expression):
|
||||||
"""
|
"""
|
||||||
Demorgan's Law
|
Demorgan's Law
|
||||||
|
@ -57,7 +71,7 @@ def simplify_not(expression):
|
||||||
"""
|
"""
|
||||||
if isinstance(expression, exp.Not):
|
if isinstance(expression, exp.Not):
|
||||||
if isinstance(expression.this, exp.Null):
|
if isinstance(expression.this, exp.Null):
|
||||||
return NULL
|
return exp.null()
|
||||||
if isinstance(expression.this, exp.Paren):
|
if isinstance(expression.this, exp.Paren):
|
||||||
condition = expression.this.unnest()
|
condition = expression.this.unnest()
|
||||||
if isinstance(condition, exp.And):
|
if isinstance(condition, exp.And):
|
||||||
|
@ -65,11 +79,11 @@ def simplify_not(expression):
|
||||||
if isinstance(condition, exp.Or):
|
if isinstance(condition, exp.Or):
|
||||||
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
|
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
|
||||||
if isinstance(condition, exp.Null):
|
if isinstance(condition, exp.Null):
|
||||||
return NULL
|
return exp.null()
|
||||||
if always_true(expression.this):
|
if always_true(expression.this):
|
||||||
return FALSE
|
return exp.false()
|
||||||
if expression.this == FALSE:
|
if expression.this == FALSE:
|
||||||
return TRUE
|
return exp.true()
|
||||||
if isinstance(expression.this, exp.Not):
|
if isinstance(expression.this, exp.Not):
|
||||||
# double negation
|
# double negation
|
||||||
# NOT NOT x -> x
|
# NOT NOT x -> x
|
||||||
|
@ -91,41 +105,120 @@ def flatten(expression):
|
||||||
|
|
||||||
|
|
||||||
def simplify_connectors(expression):
|
def simplify_connectors(expression):
|
||||||
|
def _simplify_connectors(expression, left, right):
|
||||||
if isinstance(expression, exp.Connector):
|
if isinstance(expression, exp.Connector):
|
||||||
left = expression.left
|
|
||||||
right = expression.right
|
|
||||||
|
|
||||||
if left == right:
|
if left == right:
|
||||||
return left
|
return left
|
||||||
|
|
||||||
if isinstance(expression, exp.And):
|
if isinstance(expression, exp.And):
|
||||||
if FALSE in (left, right):
|
if FALSE in (left, right):
|
||||||
return FALSE
|
return exp.false()
|
||||||
if NULL in (left, right):
|
if NULL in (left, right):
|
||||||
return NULL
|
return exp.null()
|
||||||
if always_true(left) and always_true(right):
|
if always_true(left) and always_true(right):
|
||||||
return TRUE
|
return exp.true()
|
||||||
if always_true(left):
|
if always_true(left):
|
||||||
return right
|
return right
|
||||||
if always_true(right):
|
if always_true(right):
|
||||||
return left
|
return left
|
||||||
|
return _simplify_comparison(expression, left, right)
|
||||||
elif isinstance(expression, exp.Or):
|
elif isinstance(expression, exp.Or):
|
||||||
if always_true(left) or always_true(right):
|
if always_true(left) or always_true(right):
|
||||||
return TRUE
|
return exp.true()
|
||||||
if left == FALSE and right == FALSE:
|
if left == FALSE and right == FALSE:
|
||||||
return FALSE
|
return exp.false()
|
||||||
if (
|
if (
|
||||||
(left == NULL and right == NULL)
|
(left == NULL and right == NULL)
|
||||||
or (left == NULL and right == FALSE)
|
or (left == NULL and right == FALSE)
|
||||||
or (left == FALSE and right == NULL)
|
or (left == FALSE and right == NULL)
|
||||||
):
|
):
|
||||||
return NULL
|
return exp.null()
|
||||||
if left == FALSE:
|
if left == FALSE:
|
||||||
return right
|
return right
|
||||||
if right == FALSE:
|
if right == FALSE:
|
||||||
return left
|
return left
|
||||||
|
return _simplify_comparison(expression, left, right, or_=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return _flat_simplify(expression, _simplify_connectors)
|
||||||
|
|
||||||
|
|
||||||
|
LT_LTE = (exp.LT, exp.LTE)
|
||||||
|
GT_GTE = (exp.GT, exp.GTE)
|
||||||
|
|
||||||
|
COMPARISONS = (
|
||||||
|
*LT_LTE,
|
||||||
|
*GT_GTE,
|
||||||
|
exp.EQ,
|
||||||
|
exp.NEQ,
|
||||||
|
)
|
||||||
|
|
||||||
|
INVERSE_COMPARISONS = {
|
||||||
|
exp.LT: exp.GT,
|
||||||
|
exp.GT: exp.LT,
|
||||||
|
exp.LTE: exp.GTE,
|
||||||
|
exp.GTE: exp.LTE,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _simplify_comparison(expression, left, right, or_=False):
|
||||||
|
if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
|
||||||
|
ll, lr = left.args.values()
|
||||||
|
rl, rr = right.args.values()
|
||||||
|
|
||||||
|
largs = {ll, lr}
|
||||||
|
rargs = {rl, rr}
|
||||||
|
|
||||||
|
matching = largs & rargs
|
||||||
|
columns = {m for m in matching if isinstance(m, exp.Column)}
|
||||||
|
|
||||||
|
if matching and columns:
|
||||||
|
try:
|
||||||
|
l = first(largs - columns)
|
||||||
|
r = first(rargs - columns)
|
||||||
|
except StopIteration:
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
|
# make sure the comparison is always of the form x > 1 instead of 1 < x
|
||||||
|
if left.__class__ in INVERSE_COMPARISONS and l == ll:
|
||||||
|
left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
|
||||||
|
if right.__class__ in INVERSE_COMPARISONS and r == rl:
|
||||||
|
right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
|
||||||
|
|
||||||
|
if l.is_number and r.is_number:
|
||||||
|
l = float(l.name)
|
||||||
|
r = float(r.name)
|
||||||
|
elif l.is_string and r.is_string:
|
||||||
|
l = l.name
|
||||||
|
r = r.name
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
|
||||||
|
if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
|
||||||
|
return left if (av > bv if or_ else av <= bv) else right
|
||||||
|
if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
|
||||||
|
return left if (av < bv if or_ else av >= bv) else right
|
||||||
|
|
||||||
|
# we can't ever shortcut to true because the column could be null
|
||||||
|
if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
|
||||||
|
if not or_ and av <= bv:
|
||||||
|
return exp.false()
|
||||||
|
elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
|
||||||
|
if not or_ and av >= bv:
|
||||||
|
return exp.false()
|
||||||
|
elif isinstance(a, exp.EQ):
|
||||||
|
if isinstance(b, exp.LT):
|
||||||
|
return exp.false() if av >= bv else a
|
||||||
|
if isinstance(b, exp.LTE):
|
||||||
|
return exp.false() if av > bv else a
|
||||||
|
if isinstance(b, exp.GT):
|
||||||
|
return exp.false() if av <= bv else a
|
||||||
|
if isinstance(b, exp.GTE):
|
||||||
|
return exp.false() if av < bv else a
|
||||||
|
if isinstance(b, exp.NEQ):
|
||||||
|
return exp.false() if av == bv else a
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def remove_compliments(expression):
|
def remove_compliments(expression):
|
||||||
"""
|
"""
|
||||||
|
@ -135,7 +228,7 @@ def remove_compliments(expression):
|
||||||
A OR NOT A -> TRUE
|
A OR NOT A -> TRUE
|
||||||
"""
|
"""
|
||||||
if isinstance(expression, exp.Connector):
|
if isinstance(expression, exp.Connector):
|
||||||
compliment = FALSE if isinstance(expression, exp.And) else TRUE
|
compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
|
||||||
|
|
||||||
for a, b in itertools.permutations(expression.flatten(), 2):
|
for a, b in itertools.permutations(expression.flatten(), 2):
|
||||||
if is_complement(a, b):
|
if is_complement(a, b):
|
||||||
|
@ -211,27 +304,7 @@ def absorb_and_eliminate(expression):
|
||||||
|
|
||||||
def simplify_literals(expression):
|
def simplify_literals(expression):
|
||||||
if isinstance(expression, exp.Binary):
|
if isinstance(expression, exp.Binary):
|
||||||
operands = []
|
return _flat_simplify(expression, _simplify_binary)
|
||||||
queue = deque(expression.flatten(unnest=False))
|
|
||||||
size = len(queue)
|
|
||||||
|
|
||||||
while queue:
|
|
||||||
a = queue.popleft()
|
|
||||||
|
|
||||||
for b in queue:
|
|
||||||
result = _simplify_binary(expression, a, b)
|
|
||||||
|
|
||||||
if result:
|
|
||||||
queue.remove(b)
|
|
||||||
queue.append(result)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
operands.append(a)
|
|
||||||
|
|
||||||
if len(operands) < size:
|
|
||||||
return functools.reduce(
|
|
||||||
lambda a, b: expression.__class__(this=a, expression=b), operands
|
|
||||||
)
|
|
||||||
elif isinstance(expression, exp.Neg):
|
elif isinstance(expression, exp.Neg):
|
||||||
this = expression.this
|
this = expression.this
|
||||||
if this.is_number:
|
if this.is_number:
|
||||||
|
@ -254,20 +327,13 @@ def _simplify_binary(expression, a, b):
|
||||||
|
|
||||||
if c == NULL:
|
if c == NULL:
|
||||||
if isinstance(a, exp.Literal):
|
if isinstance(a, exp.Literal):
|
||||||
return TRUE if not_ else FALSE
|
return exp.true() if not_ else exp.false()
|
||||||
if a == NULL:
|
if a == NULL:
|
||||||
return FALSE if not_ else TRUE
|
return exp.false() if not_ else exp.true()
|
||||||
elif isinstance(expression, exp.NullSafeEQ):
|
elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
|
||||||
if a == b:
|
return None
|
||||||
return TRUE
|
|
||||||
elif isinstance(expression, exp.NullSafeNEQ):
|
|
||||||
if a == b:
|
|
||||||
return FALSE
|
|
||||||
elif NULL in (a, b):
|
elif NULL in (a, b):
|
||||||
return NULL
|
return exp.null()
|
||||||
|
|
||||||
if isinstance(expression, exp.EQ) and a == b:
|
|
||||||
return TRUE
|
|
||||||
|
|
||||||
if a.is_number and b.is_number:
|
if a.is_number and b.is_number:
|
||||||
a = int(a.name) if a.is_int else Decimal(a.name)
|
a = int(a.name) if a.is_int else Decimal(a.name)
|
||||||
|
@ -388,4 +454,27 @@ def date_literal(date):
|
||||||
|
|
||||||
|
|
||||||
def boolean_literal(condition):
|
def boolean_literal(condition):
|
||||||
return TRUE if condition else FALSE
|
return exp.true() if condition else exp.false()
|
||||||
|
|
||||||
|
|
||||||
|
def _flat_simplify(expression, simplifier):
|
||||||
|
operands = []
|
||||||
|
queue = deque(expression.flatten(unnest=False))
|
||||||
|
size = len(queue)
|
||||||
|
|
||||||
|
while queue:
|
||||||
|
a = queue.popleft()
|
||||||
|
|
||||||
|
for b in queue:
|
||||||
|
result = simplifier(expression, a, b)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
queue.remove(b)
|
||||||
|
queue.append(result)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
operands.append(a)
|
||||||
|
|
||||||
|
if len(operands) < size:
|
||||||
|
return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
|
||||||
|
return expression
|
||||||
|
|
|
@ -185,6 +185,7 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.LOCAL,
|
TokenType.LOCAL,
|
||||||
TokenType.LOCATION,
|
TokenType.LOCATION,
|
||||||
TokenType.MATERIALIZED,
|
TokenType.MATERIALIZED,
|
||||||
|
TokenType.MERGE,
|
||||||
TokenType.NATURAL,
|
TokenType.NATURAL,
|
||||||
TokenType.NEXT,
|
TokenType.NEXT,
|
||||||
TokenType.ONLY,
|
TokenType.ONLY,
|
||||||
|
@ -211,7 +212,6 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.TABLE,
|
TokenType.TABLE,
|
||||||
TokenType.TABLE_FORMAT,
|
TokenType.TABLE_FORMAT,
|
||||||
TokenType.TEMPORARY,
|
TokenType.TEMPORARY,
|
||||||
TokenType.TRANSIENT,
|
|
||||||
TokenType.TOP,
|
TokenType.TOP,
|
||||||
TokenType.TRAILING,
|
TokenType.TRAILING,
|
||||||
TokenType.TRUE,
|
TokenType.TRUE,
|
||||||
|
@ -229,6 +229,8 @@ class Parser(metaclass=_Parser):
|
||||||
|
|
||||||
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY}
|
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY}
|
||||||
|
|
||||||
|
UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET}
|
||||||
|
|
||||||
TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
|
TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
|
||||||
|
|
||||||
FUNC_TOKENS = {
|
FUNC_TOKENS = {
|
||||||
|
@ -241,6 +243,7 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.FORMAT,
|
TokenType.FORMAT,
|
||||||
TokenType.IDENTIFIER,
|
TokenType.IDENTIFIER,
|
||||||
TokenType.ISNULL,
|
TokenType.ISNULL,
|
||||||
|
TokenType.MERGE,
|
||||||
TokenType.OFFSET,
|
TokenType.OFFSET,
|
||||||
TokenType.PRIMARY_KEY,
|
TokenType.PRIMARY_KEY,
|
||||||
TokenType.REPLACE,
|
TokenType.REPLACE,
|
||||||
|
@ -407,6 +410,7 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
|
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
|
||||||
TokenType.END: lambda self: self._parse_commit_or_rollback(),
|
TokenType.END: lambda self: self._parse_commit_or_rollback(),
|
||||||
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
|
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
|
||||||
|
TokenType.MERGE: lambda self: self._parse_merge(),
|
||||||
}
|
}
|
||||||
|
|
||||||
UNARY_PARSERS = {
|
UNARY_PARSERS = {
|
||||||
|
@ -474,6 +478,7 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.SORTKEY: lambda self: self._parse_sortkey(),
|
TokenType.SORTKEY: lambda self: self._parse_sortkey(),
|
||||||
TokenType.LIKE: lambda self: self._parse_create_like(),
|
TokenType.LIKE: lambda self: self._parse_create_like(),
|
||||||
TokenType.RETURNS: lambda self: self._parse_returns(),
|
TokenType.RETURNS: lambda self: self._parse_returns(),
|
||||||
|
TokenType.ROW: lambda self: self._parse_row(),
|
||||||
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
|
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
|
||||||
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
|
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
|
||||||
TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
|
TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
|
||||||
|
@ -495,6 +500,8 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.VOLATILE: lambda self: self.expression(
|
TokenType.VOLATILE: lambda self: self.expression(
|
||||||
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
|
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
|
||||||
),
|
),
|
||||||
|
TokenType.WITH: lambda self: self._parse_wrapped_csv(self._parse_property),
|
||||||
|
TokenType.PROPERTIES: lambda self: self._parse_wrapped_csv(self._parse_property),
|
||||||
}
|
}
|
||||||
|
|
||||||
CONSTRAINT_PARSERS = {
|
CONSTRAINT_PARSERS = {
|
||||||
|
@ -802,7 +809,8 @@ class Parser(metaclass=_Parser):
|
||||||
def _parse_create(self):
|
def _parse_create(self):
|
||||||
replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
|
replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
|
||||||
temporary = self._match(TokenType.TEMPORARY)
|
temporary = self._match(TokenType.TEMPORARY)
|
||||||
transient = self._match(TokenType.TRANSIENT)
|
transient = self._match_text_seq("TRANSIENT")
|
||||||
|
external = self._match_text_seq("EXTERNAL")
|
||||||
unique = self._match(TokenType.UNIQUE)
|
unique = self._match(TokenType.UNIQUE)
|
||||||
materialized = self._match(TokenType.MATERIALIZED)
|
materialized = self._match(TokenType.MATERIALIZED)
|
||||||
|
|
||||||
|
@ -846,6 +854,7 @@ class Parser(metaclass=_Parser):
|
||||||
properties=properties,
|
properties=properties,
|
||||||
temporary=temporary,
|
temporary=temporary,
|
||||||
transient=transient,
|
transient=transient,
|
||||||
|
external=external,
|
||||||
replace=replace,
|
replace=replace,
|
||||||
unique=unique,
|
unique=unique,
|
||||||
materialized=materialized,
|
materialized=materialized,
|
||||||
|
@ -861,8 +870,12 @@ class Parser(metaclass=_Parser):
|
||||||
if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
|
if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
|
||||||
return self._parse_sortkey(compound=True)
|
return self._parse_sortkey(compound=True)
|
||||||
|
|
||||||
if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False):
|
assignment = self._match_pair(
|
||||||
key = self._parse_var()
|
TokenType.VAR, TokenType.EQ, advance=False
|
||||||
|
) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False)
|
||||||
|
|
||||||
|
if assignment:
|
||||||
|
key = self._parse_var() or self._parse_string()
|
||||||
self._match(TokenType.EQ)
|
self._match(TokenType.EQ)
|
||||||
return self.expression(exp.Property, this=key, value=self._parse_column())
|
return self.expression(exp.Property, this=key, value=self._parse_column())
|
||||||
|
|
||||||
|
@ -871,7 +884,10 @@ class Parser(metaclass=_Parser):
|
||||||
def _parse_property_assignment(self, exp_class):
|
def _parse_property_assignment(self, exp_class):
|
||||||
self._match(TokenType.EQ)
|
self._match(TokenType.EQ)
|
||||||
self._match(TokenType.ALIAS)
|
self._match(TokenType.ALIAS)
|
||||||
return self.expression(exp_class, this=self._parse_var_or_string() or self._parse_number())
|
return self.expression(
|
||||||
|
exp_class,
|
||||||
|
this=self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
|
||||||
|
)
|
||||||
|
|
||||||
def _parse_partitioned_by(self):
|
def _parse_partitioned_by(self):
|
||||||
self._match(TokenType.EQ)
|
self._match(TokenType.EQ)
|
||||||
|
@ -881,7 +897,7 @@ class Parser(metaclass=_Parser):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_distkey(self):
|
def _parse_distkey(self):
|
||||||
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_var))
|
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var))
|
||||||
|
|
||||||
def _parse_create_like(self):
|
def _parse_create_like(self):
|
||||||
table = self._parse_table(schema=True)
|
table = self._parse_table(schema=True)
|
||||||
|
@ -898,7 +914,7 @@ class Parser(metaclass=_Parser):
|
||||||
|
|
||||||
def _parse_sortkey(self, compound=False):
|
def _parse_sortkey(self, compound=False):
|
||||||
return self.expression(
|
return self.expression(
|
||||||
exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_var), compound=compound
|
exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_id_var), compound=compound
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_character_set(self, default=False):
|
def _parse_character_set(self, default=False):
|
||||||
|
@ -929,23 +945,11 @@ class Parser(metaclass=_Parser):
|
||||||
properties = []
|
properties = []
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if self._match(TokenType.WITH):
|
|
||||||
properties.extend(self._parse_wrapped_csv(self._parse_property))
|
|
||||||
elif self._match(TokenType.PROPERTIES):
|
|
||||||
properties.extend(
|
|
||||||
self._parse_wrapped_csv(
|
|
||||||
lambda: self.expression(
|
|
||||||
exp.Property,
|
|
||||||
this=self._parse_string(),
|
|
||||||
value=self._match(TokenType.EQ) and self._parse_string(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
identified_property = self._parse_property()
|
identified_property = self._parse_property()
|
||||||
if not identified_property:
|
if not identified_property:
|
||||||
break
|
break
|
||||||
properties.append(identified_property)
|
for p in ensure_collection(identified_property):
|
||||||
|
properties.append(p)
|
||||||
|
|
||||||
if properties:
|
if properties:
|
||||||
return self.expression(exp.Properties, expressions=properties)
|
return self.expression(exp.Properties, expressions=properties)
|
||||||
|
@ -963,7 +967,7 @@ class Parser(metaclass=_Parser):
|
||||||
exp.Directory,
|
exp.Directory,
|
||||||
this=self._parse_var_or_string(),
|
this=self._parse_var_or_string(),
|
||||||
local=local,
|
local=local,
|
||||||
row_format=self._parse_row_format(),
|
row_format=self._parse_row_format(match_row=True),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._match(TokenType.INTO)
|
self._match(TokenType.INTO)
|
||||||
|
@ -978,9 +982,17 @@ class Parser(metaclass=_Parser):
|
||||||
overwrite=overwrite,
|
overwrite=overwrite,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_row_format(self):
|
def _parse_row(self):
|
||||||
if not self._match_pair(TokenType.ROW, TokenType.FORMAT):
|
if not self._match(TokenType.FORMAT):
|
||||||
return None
|
return None
|
||||||
|
return self._parse_row_format()
|
||||||
|
|
||||||
|
def _parse_row_format(self, match_row=False):
|
||||||
|
if match_row and not self._match_pair(TokenType.ROW, TokenType.FORMAT):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self._match_text_seq("SERDE"):
|
||||||
|
return self.expression(exp.RowFormatSerdeProperty, this=self._parse_string())
|
||||||
|
|
||||||
self._match_text_seq("DELIMITED")
|
self._match_text_seq("DELIMITED")
|
||||||
|
|
||||||
|
@ -998,7 +1010,7 @@ class Parser(metaclass=_Parser):
|
||||||
kwargs["lines"] = self._parse_string()
|
kwargs["lines"] = self._parse_string()
|
||||||
if self._match_text_seq("NULL", "DEFINED", "AS"):
|
if self._match_text_seq("NULL", "DEFINED", "AS"):
|
||||||
kwargs["null"] = self._parse_string()
|
kwargs["null"] = self._parse_string()
|
||||||
return self.expression(exp.RowFormat, **kwargs)
|
return self.expression(exp.RowFormatDelimitedProperty, **kwargs)
|
||||||
|
|
||||||
def _parse_load_data(self):
|
def _parse_load_data(self):
|
||||||
local = self._match(TokenType.LOCAL)
|
local = self._match(TokenType.LOCAL)
|
||||||
|
@ -1032,7 +1044,7 @@ class Parser(metaclass=_Parser):
|
||||||
return self.expression(
|
return self.expression(
|
||||||
exp.Update,
|
exp.Update,
|
||||||
**{
|
**{
|
||||||
"this": self._parse_table(schema=True),
|
"this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
|
||||||
"expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality),
|
"expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality),
|
||||||
"from": self._parse_from(),
|
"from": self._parse_from(),
|
||||||
"where": self._parse_where(),
|
"where": self._parse_where(),
|
||||||
|
@ -1183,9 +1195,11 @@ class Parser(metaclass=_Parser):
|
||||||
alias=alias,
|
alias=alias,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_table_alias(self):
|
def _parse_table_alias(self, alias_tokens=None):
|
||||||
any_token = self._match(TokenType.ALIAS)
|
any_token = self._match(TokenType.ALIAS)
|
||||||
alias = self._parse_id_var(any_token=any_token, tokens=self.TABLE_ALIAS_TOKENS)
|
alias = self._parse_id_var(
|
||||||
|
any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS
|
||||||
|
)
|
||||||
columns = None
|
columns = None
|
||||||
|
|
||||||
if self._match(TokenType.L_PAREN):
|
if self._match(TokenType.L_PAREN):
|
||||||
|
@ -1337,7 +1351,7 @@ class Parser(metaclass=_Parser):
|
||||||
columns=self._parse_expression(),
|
columns=self._parse_expression(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_table(self, schema=False):
|
def _parse_table(self, schema=False, alias_tokens=None):
|
||||||
lateral = self._parse_lateral()
|
lateral = self._parse_lateral()
|
||||||
|
|
||||||
if lateral:
|
if lateral:
|
||||||
|
@ -1372,7 +1386,7 @@ class Parser(metaclass=_Parser):
|
||||||
table = self._parse_id_var()
|
table = self._parse_id_var()
|
||||||
|
|
||||||
if not table:
|
if not table:
|
||||||
self.raise_error("Expected table name")
|
self.raise_error(f"Expected table name but got {self._curr}")
|
||||||
|
|
||||||
this = self.expression(
|
this = self.expression(
|
||||||
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
|
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
|
||||||
|
@ -1384,7 +1398,7 @@ class Parser(metaclass=_Parser):
|
||||||
if self.alias_post_tablesample:
|
if self.alias_post_tablesample:
|
||||||
table_sample = self._parse_table_sample()
|
table_sample = self._parse_table_sample()
|
||||||
|
|
||||||
alias = self._parse_table_alias()
|
alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS)
|
||||||
|
|
||||||
if alias:
|
if alias:
|
||||||
this.set("alias", alias)
|
this.set("alias", alias)
|
||||||
|
@ -2092,10 +2106,14 @@ class Parser(metaclass=_Parser):
|
||||||
kind = self.expression(exp.CheckColumnConstraint, this=constraint)
|
kind = self.expression(exp.CheckColumnConstraint, this=constraint)
|
||||||
elif self._match(TokenType.COLLATE):
|
elif self._match(TokenType.COLLATE):
|
||||||
kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var())
|
kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var())
|
||||||
|
elif self._match(TokenType.ENCODE):
|
||||||
|
kind = self.expression(exp.EncodeColumnConstraint, this=self._parse_var())
|
||||||
elif self._match(TokenType.DEFAULT):
|
elif self._match(TokenType.DEFAULT):
|
||||||
kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction())
|
kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction())
|
||||||
elif self._match_pair(TokenType.NOT, TokenType.NULL):
|
elif self._match_pair(TokenType.NOT, TokenType.NULL):
|
||||||
kind = exp.NotNullColumnConstraint()
|
kind = exp.NotNullColumnConstraint()
|
||||||
|
elif self._match(TokenType.NULL):
|
||||||
|
kind = exp.NotNullColumnConstraint(allow_null=True)
|
||||||
elif self._match(TokenType.SCHEMA_COMMENT):
|
elif self._match(TokenType.SCHEMA_COMMENT):
|
||||||
kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string())
|
kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string())
|
||||||
elif self._match(TokenType.PRIMARY_KEY):
|
elif self._match(TokenType.PRIMARY_KEY):
|
||||||
|
@ -2234,7 +2252,7 @@ class Parser(metaclass=_Parser):
|
||||||
return self._parse_window(this)
|
return self._parse_window(this)
|
||||||
|
|
||||||
def _parse_extract(self):
|
def _parse_extract(self):
|
||||||
this = self._parse_var() or self._parse_type()
|
this = self._parse_function() or self._parse_var() or self._parse_type()
|
||||||
|
|
||||||
if self._match(TokenType.FROM):
|
if self._match(TokenType.FROM):
|
||||||
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
|
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
|
||||||
|
@ -2635,6 +2653,54 @@ class Parser(metaclass=_Parser):
|
||||||
parser = self._find_parser(self.SET_PARSERS, self._set_trie)
|
parser = self._find_parser(self.SET_PARSERS, self._set_trie)
|
||||||
return parser(self) if parser else self._default_parse_set_item()
|
return parser(self) if parser else self._default_parse_set_item()
|
||||||
|
|
||||||
|
def _parse_merge(self):
|
||||||
|
self._match(TokenType.INTO)
|
||||||
|
target = self._parse_table(schema=True)
|
||||||
|
|
||||||
|
self._match(TokenType.USING)
|
||||||
|
using = self._parse_table()
|
||||||
|
|
||||||
|
self._match(TokenType.ON)
|
||||||
|
on = self._parse_conjunction()
|
||||||
|
|
||||||
|
whens = []
|
||||||
|
while self._match(TokenType.WHEN):
|
||||||
|
this = self._parse_conjunction()
|
||||||
|
self._match(TokenType.THEN)
|
||||||
|
|
||||||
|
if self._match(TokenType.INSERT):
|
||||||
|
_this = self._parse_star()
|
||||||
|
if _this:
|
||||||
|
then = self.expression(exp.Insert, this=_this)
|
||||||
|
else:
|
||||||
|
then = self.expression(
|
||||||
|
exp.Insert,
|
||||||
|
this=self._parse_value(),
|
||||||
|
expression=self._match(TokenType.VALUES) and self._parse_value(),
|
||||||
|
)
|
||||||
|
elif self._match(TokenType.UPDATE):
|
||||||
|
expressions = self._parse_star()
|
||||||
|
if expressions:
|
||||||
|
then = self.expression(exp.Update, expressions=expressions)
|
||||||
|
else:
|
||||||
|
then = self.expression(
|
||||||
|
exp.Update,
|
||||||
|
expressions=self._match(TokenType.SET)
|
||||||
|
and self._parse_csv(self._parse_equality),
|
||||||
|
)
|
||||||
|
elif self._match(TokenType.DELETE):
|
||||||
|
then = self.expression(exp.Var, this=self._prev.text)
|
||||||
|
|
||||||
|
whens.append(self.expression(exp.When, this=this, then=then))
|
||||||
|
|
||||||
|
return self.expression(
|
||||||
|
exp.Merge,
|
||||||
|
this=target,
|
||||||
|
using=using,
|
||||||
|
on=on,
|
||||||
|
expressions=whens,
|
||||||
|
)
|
||||||
|
|
||||||
def _parse_set(self):
|
def _parse_set(self):
|
||||||
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
|
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
|
||||||
|
|
||||||
|
|
|
@ -47,7 +47,7 @@ class Schema(abc.ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType.Type:
|
def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType:
|
||||||
"""
|
"""
|
||||||
Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
|
Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
|
||||||
|
|
||||||
|
@ -160,8 +160,8 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
||||||
super().__init__(schema)
|
super().__init__(schema)
|
||||||
self.visible = visible or {}
|
self.visible = visible or {}
|
||||||
self.dialect = dialect
|
self.dialect = dialect
|
||||||
self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {
|
self._type_mapping_cache: t.Dict[str, exp.DataType] = {
|
||||||
"STR": exp.DataType.Type.TEXT,
|
"STR": exp.DataType.build("text"),
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -231,18 +231,18 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
||||||
visible = self._nested_get(self.table_parts(table_), self.visible)
|
visible = self._nested_get(self.table_parts(table_), self.visible)
|
||||||
return [col for col in schema if col in visible] # type: ignore
|
return [col for col in schema if col in visible] # type: ignore
|
||||||
|
|
||||||
def get_column_type(
|
def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
|
||||||
self, table: exp.Table | str, column: exp.Column | str
|
|
||||||
) -> exp.DataType.Type:
|
|
||||||
column_name = column if isinstance(column, str) else column.name
|
column_name = column if isinstance(column, str) else column.name
|
||||||
table_ = exp.to_table(table)
|
table_ = exp.to_table(table)
|
||||||
if table_:
|
if table_:
|
||||||
table_schema = self.find(table_)
|
table_schema = self.find(table_, raise_on_missing=False)
|
||||||
|
if table_schema:
|
||||||
schema_type = table_schema.get(column_name).upper() # type: ignore
|
schema_type = table_schema.get(column_name).upper() # type: ignore
|
||||||
return self._convert_type(schema_type)
|
return self._convert_type(schema_type)
|
||||||
|
return exp.DataType(this=exp.DataType.Type.UNKNOWN)
|
||||||
raise SchemaError(f"Could not convert table '{table}'")
|
raise SchemaError(f"Could not convert table '{table}'")
|
||||||
|
|
||||||
def _convert_type(self, schema_type: str) -> exp.DataType.Type:
|
def _convert_type(self, schema_type: str) -> exp.DataType:
|
||||||
"""
|
"""
|
||||||
Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
|
Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
|
||||||
|
|
||||||
|
@ -257,7 +257,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
||||||
expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
|
expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
|
||||||
if expression is None:
|
if expression is None:
|
||||||
raise ValueError(f"Could not parse {schema_type}")
|
raise ValueError(f"Could not parse {schema_type}")
|
||||||
self._type_mapping_cache[schema_type] = expression.this
|
self._type_mapping_cache[schema_type] = expression # type: ignore
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise SchemaError(f"Failed to convert type {schema_type}")
|
raise SchemaError(f"Failed to convert type {schema_type}")
|
||||||
|
|
||||||
|
|
|
@ -49,6 +49,9 @@ class TokenType(AutoName):
|
||||||
PARAMETER = auto()
|
PARAMETER = auto()
|
||||||
SESSION_PARAMETER = auto()
|
SESSION_PARAMETER = auto()
|
||||||
|
|
||||||
|
BLOCK_START = auto()
|
||||||
|
BLOCK_END = auto()
|
||||||
|
|
||||||
SPACE = auto()
|
SPACE = auto()
|
||||||
BREAK = auto()
|
BREAK = auto()
|
||||||
|
|
||||||
|
@ -156,6 +159,7 @@ class TokenType(AutoName):
|
||||||
DIV = auto()
|
DIV = auto()
|
||||||
DROP = auto()
|
DROP = auto()
|
||||||
ELSE = auto()
|
ELSE = auto()
|
||||||
|
ENCODE = auto()
|
||||||
END = auto()
|
END = auto()
|
||||||
ENGINE = auto()
|
ENGINE = auto()
|
||||||
ESCAPE = auto()
|
ESCAPE = auto()
|
||||||
|
@ -207,6 +211,7 @@ class TokenType(AutoName):
|
||||||
LOCATION = auto()
|
LOCATION = auto()
|
||||||
MAP = auto()
|
MAP = auto()
|
||||||
MATERIALIZED = auto()
|
MATERIALIZED = auto()
|
||||||
|
MERGE = auto()
|
||||||
MOD = auto()
|
MOD = auto()
|
||||||
NATURAL = auto()
|
NATURAL = auto()
|
||||||
NEXT = auto()
|
NEXT = auto()
|
||||||
|
@ -255,6 +260,7 @@ class TokenType(AutoName):
|
||||||
SELECT = auto()
|
SELECT = auto()
|
||||||
SEMI = auto()
|
SEMI = auto()
|
||||||
SEPARATOR = auto()
|
SEPARATOR = auto()
|
||||||
|
SERDE_PROPERTIES = auto()
|
||||||
SET = auto()
|
SET = auto()
|
||||||
SHOW = auto()
|
SHOW = auto()
|
||||||
SIMILAR_TO = auto()
|
SIMILAR_TO = auto()
|
||||||
|
@ -267,7 +273,6 @@ class TokenType(AutoName):
|
||||||
TABLE_FORMAT = auto()
|
TABLE_FORMAT = auto()
|
||||||
TABLE_SAMPLE = auto()
|
TABLE_SAMPLE = auto()
|
||||||
TEMPORARY = auto()
|
TEMPORARY = auto()
|
||||||
TRANSIENT = auto()
|
|
||||||
TOP = auto()
|
TOP = auto()
|
||||||
THEN = auto()
|
THEN = auto()
|
||||||
TRAILING = auto()
|
TRAILING = auto()
|
||||||
|
@ -420,6 +425,16 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
ESCAPES = ["'"]
|
ESCAPES = ["'"]
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
|
**{
|
||||||
|
f"{key}{postfix}": TokenType.BLOCK_START
|
||||||
|
for key in ("{{", "{%", "{#")
|
||||||
|
for postfix in ("", "+", "-")
|
||||||
|
},
|
||||||
|
**{
|
||||||
|
f"{prefix}{key}": TokenType.BLOCK_END
|
||||||
|
for key in ("}}", "%}", "#}")
|
||||||
|
for prefix in ("", "+", "-")
|
||||||
|
},
|
||||||
"/*+": TokenType.HINT,
|
"/*+": TokenType.HINT,
|
||||||
"==": TokenType.EQ,
|
"==": TokenType.EQ,
|
||||||
"::": TokenType.DCOLON,
|
"::": TokenType.DCOLON,
|
||||||
|
@ -523,6 +538,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"LOCAL": TokenType.LOCAL,
|
"LOCAL": TokenType.LOCAL,
|
||||||
"LOCATION": TokenType.LOCATION,
|
"LOCATION": TokenType.LOCATION,
|
||||||
"MATERIALIZED": TokenType.MATERIALIZED,
|
"MATERIALIZED": TokenType.MATERIALIZED,
|
||||||
|
"MERGE": TokenType.MERGE,
|
||||||
"NATURAL": TokenType.NATURAL,
|
"NATURAL": TokenType.NATURAL,
|
||||||
"NEXT": TokenType.NEXT,
|
"NEXT": TokenType.NEXT,
|
||||||
"NO ACTION": TokenType.NO_ACTION,
|
"NO ACTION": TokenType.NO_ACTION,
|
||||||
|
@ -582,7 +598,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
|
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
|
||||||
"TEMP": TokenType.TEMPORARY,
|
"TEMP": TokenType.TEMPORARY,
|
||||||
"TEMPORARY": TokenType.TEMPORARY,
|
"TEMPORARY": TokenType.TEMPORARY,
|
||||||
"TRANSIENT": TokenType.TRANSIENT,
|
|
||||||
"THEN": TokenType.THEN,
|
"THEN": TokenType.THEN,
|
||||||
"TRUE": TokenType.TRUE,
|
"TRUE": TokenType.TRUE,
|
||||||
"TRAILING": TokenType.TRAILING,
|
"TRAILING": TokenType.TRAILING,
|
||||||
|
|
|
@ -4,6 +4,7 @@ import unittest
|
||||||
from sqlglot.dataframe.sql import types
|
from sqlglot.dataframe.sql import types
|
||||||
from sqlglot.dataframe.sql.dataframe import DataFrame
|
from sqlglot.dataframe.sql.dataframe import DataFrame
|
||||||
from sqlglot.dataframe.sql.session import SparkSession
|
from sqlglot.dataframe.sql.session import SparkSession
|
||||||
|
from sqlglot.helper import ensure_list
|
||||||
|
|
||||||
|
|
||||||
class DataFrameSQLValidator(unittest.TestCase):
|
class DataFrameSQLValidator(unittest.TestCase):
|
||||||
|
@ -33,9 +34,7 @@ class DataFrameSQLValidator(unittest.TestCase):
|
||||||
self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False
|
self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False
|
||||||
):
|
):
|
||||||
actual_sqls = df.sql(pretty=pretty)
|
actual_sqls = df.sql(pretty=pretty)
|
||||||
expected_statements = (
|
expected_statements = ensure_list(expected_statements)
|
||||||
[expected_statements] if isinstance(expected_statements, str) else expected_statements
|
|
||||||
)
|
|
||||||
self.assertEqual(len(expected_statements), len(actual_sqls))
|
self.assertEqual(len(expected_statements), len(actual_sqls))
|
||||||
for expected, actual in zip(expected_statements, actual_sqls):
|
for expected, actual in zip(expected_statements, actual_sqls):
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
|
|
|
@ -10,37 +10,37 @@ class TestDataFrameWriter(DataFrameSQLValidator):
|
||||||
|
|
||||||
def test_insertInto_full_path(self):
|
def test_insertInto_full_path(self):
|
||||||
df = self.df_employee.write.insertInto("catalog.db.table_name")
|
df = self.df_employee.write.insertInto("catalog.db.table_name")
|
||||||
expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "INSERT INTO catalog.db.table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_insertInto_db_table(self):
|
def test_insertInto_db_table(self):
|
||||||
df = self.df_employee.write.insertInto("db.table_name")
|
df = self.df_employee.write.insertInto("db.table_name")
|
||||||
expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "INSERT INTO db.table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_insertInto_table(self):
|
def test_insertInto_table(self):
|
||||||
df = self.df_employee.write.insertInto("table_name")
|
df = self.df_employee.write.insertInto("table_name")
|
||||||
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_insertInto_overwrite(self):
|
def test_insertInto_overwrite(self):
|
||||||
df = self.df_employee.write.insertInto("table_name", overwrite=True)
|
df = self.df_employee.write.insertInto("table_name", overwrite=True)
|
||||||
expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "INSERT OVERWRITE TABLE table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
@mock.patch("sqlglot.schema", MappingSchema())
|
@mock.patch("sqlglot.schema", MappingSchema())
|
||||||
def test_insertInto_byName(self):
|
def test_insertInto_byName(self):
|
||||||
sqlglot.schema.add_table("table_name", {"employee_id": "INT"})
|
sqlglot.schema.add_table("table_name", {"employee_id": "INT"})
|
||||||
df = self.df_employee.write.byName.insertInto("table_name")
|
df = self.df_employee.write.byName.insertInto("table_name")
|
||||||
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_insertInto_cache(self):
|
def test_insertInto_cache(self):
|
||||||
df = self.df_employee.cache().write.insertInto("table_name")
|
df = self.df_employee.cache().write.insertInto("table_name")
|
||||||
expected_statements = [
|
expected_statements = [
|
||||||
"DROP VIEW IF EXISTS t37164",
|
"DROP VIEW IF EXISTS t12441",
|
||||||
"CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
"CACHE LAZY TABLE t12441 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||||
"INSERT INTO table_name SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`",
|
"INSERT INTO table_name SELECT `t12441`.`employee_id` AS `employee_id`, `t12441`.`fname` AS `fname`, `t12441`.`lname` AS `lname`, `t12441`.`age` AS `age`, `t12441`.`store_id` AS `store_id` FROM `t12441` AS `t12441`",
|
||||||
]
|
]
|
||||||
self.compare_sql(df, expected_statements)
|
self.compare_sql(df, expected_statements)
|
||||||
|
|
||||||
|
@ -50,39 +50,39 @@ class TestDataFrameWriter(DataFrameSQLValidator):
|
||||||
|
|
||||||
def test_saveAsTable_append(self):
|
def test_saveAsTable_append(self):
|
||||||
df = self.df_employee.write.saveAsTable("table_name", mode="append")
|
df = self.df_employee.write.saveAsTable("table_name", mode="append")
|
||||||
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_saveAsTable_overwrite(self):
|
def test_saveAsTable_overwrite(self):
|
||||||
df = self.df_employee.write.saveAsTable("table_name", mode="overwrite")
|
df = self.df_employee.write.saveAsTable("table_name", mode="overwrite")
|
||||||
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "CREATE OR REPLACE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_saveAsTable_error(self):
|
def test_saveAsTable_error(self):
|
||||||
df = self.df_employee.write.saveAsTable("table_name", mode="error")
|
df = self.df_employee.write.saveAsTable("table_name", mode="error")
|
||||||
expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "CREATE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_saveAsTable_ignore(self):
|
def test_saveAsTable_ignore(self):
|
||||||
df = self.df_employee.write.saveAsTable("table_name", mode="ignore")
|
df = self.df_employee.write.saveAsTable("table_name", mode="ignore")
|
||||||
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_mode_standalone(self):
|
def test_mode_standalone(self):
|
||||||
df = self.df_employee.write.mode("ignore").saveAsTable("table_name")
|
df = self.df_employee.write.mode("ignore").saveAsTable("table_name")
|
||||||
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_mode_override(self):
|
def test_mode_override(self):
|
||||||
df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite")
|
df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite")
|
||||||
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
expected = "CREATE OR REPLACE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_saveAsTable_cache(self):
|
def test_saveAsTable_cache(self):
|
||||||
df = self.df_employee.cache().write.saveAsTable("table_name")
|
df = self.df_employee.cache().write.saveAsTable("table_name")
|
||||||
expected_statements = [
|
expected_statements = [
|
||||||
"DROP VIEW IF EXISTS t37164",
|
"DROP VIEW IF EXISTS t12441",
|
||||||
"CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
"CACHE LAZY TABLE t12441 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||||
"CREATE TABLE table_name AS SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`",
|
"CREATE TABLE table_name AS SELECT `t12441`.`employee_id` AS `employee_id`, `t12441`.`fname` AS `fname`, `t12441`.`lname` AS `lname`, `t12441`.`age` AS `age`, `t12441`.`store_id` AS `store_id` FROM `t12441` AS `t12441`",
|
||||||
]
|
]
|
||||||
self.compare_sql(df, expected_statements)
|
self.compare_sql(df, expected_statements)
|
||||||
|
|
|
@ -36,7 +36,7 @@ class TestDataframeSession(DataFrameSQLValidator):
|
||||||
|
|
||||||
def test_cdf_str_schema(self):
|
def test_cdf_str_schema(self):
|
||||||
df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING")
|
df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING")
|
||||||
expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
|
expected = "SELECT `a2`.`cola` AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_typed_schema_basic(self):
|
def test_typed_schema_basic(self):
|
||||||
|
@ -47,7 +47,7 @@ class TestDataframeSession(DataFrameSQLValidator):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
df = self.spark.createDataFrame([[1, "test"]], schema)
|
df = self.spark.createDataFrame([[1, "test"]], schema)
|
||||||
expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
|
expected = "SELECT `a2`.`cola` AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_typed_schema_nested(self):
|
def test_typed_schema_nested(self):
|
||||||
|
|
|
@ -6,6 +6,11 @@ class TestBigQuery(Validator):
|
||||||
dialect = "bigquery"
|
dialect = "bigquery"
|
||||||
|
|
||||||
def test_bigquery(self):
|
def test_bigquery(self):
|
||||||
|
self.validate_all(
|
||||||
|
"REGEXP_CONTAINS('foo', '.*')",
|
||||||
|
read={"bigquery": "REGEXP_CONTAINS('foo', '.*')"},
|
||||||
|
write={"mysql": "REGEXP_LIKE('foo', '.*')"},
|
||||||
|
),
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
'"""x"""',
|
'"""x"""',
|
||||||
write={
|
write={
|
||||||
|
@ -94,6 +99,20 @@ class TestBigQuery(Validator):
|
||||||
"spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)",
|
"spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT ARRAY(SELECT x FROM UNNEST([0, 1]) AS x)",
|
||||||
|
write={"bigquery": "SELECT ARRAY(SELECT x FROM UNNEST([0, 1]) AS x)"},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT ARRAY(SELECT DISTINCT x FROM UNNEST(some_numbers) AS x) AS unique_numbers",
|
||||||
|
write={
|
||||||
|
"bigquery": "SELECT ARRAY(SELECT DISTINCT x FROM UNNEST(some_numbers) AS x) AS unique_numbers"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT ARRAY(SELECT * FROM foo JOIN bla ON x = y)",
|
||||||
|
write={"bigquery": "SELECT ARRAY(SELECT * FROM foo JOIN bla ON x = y)"},
|
||||||
|
)
|
||||||
|
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"x IS unknown",
|
"x IS unknown",
|
||||||
|
|
|
@ -1318,3 +1318,39 @@ SELECT
|
||||||
"BEGIN IMMEDIATE TRANSACTION",
|
"BEGIN IMMEDIATE TRANSACTION",
|
||||||
write={"sqlite": "BEGIN IMMEDIATE TRANSACTION"},
|
write={"sqlite": "BEGIN IMMEDIATE TRANSACTION"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_merge(self):
|
||||||
|
self.validate_all(
|
||||||
|
"""
|
||||||
|
MERGE INTO target USING source ON target.id = source.id
|
||||||
|
WHEN NOT MATCHED THEN INSERT (id) values (source.id)
|
||||||
|
""",
|
||||||
|
write={
|
||||||
|
"bigquery": "MERGE INTO target USING source ON target.id = source.id WHEN NOT MATCHED THEN INSERT (id) VALUES (source.id)",
|
||||||
|
"snowflake": "MERGE INTO target USING source ON target.id = source.id WHEN NOT MATCHED THEN INSERT (id) VALUES (source.id)",
|
||||||
|
"spark": "MERGE INTO target USING source ON target.id = source.id WHEN NOT MATCHED THEN INSERT (id) VALUES (source.id)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"""
|
||||||
|
MERGE INTO target USING source ON target.id = source.id
|
||||||
|
WHEN MATCHED AND source.is_deleted = 1 THEN DELETE
|
||||||
|
WHEN MATCHED THEN UPDATE SET val = source.val
|
||||||
|
WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)
|
||||||
|
""",
|
||||||
|
write={
|
||||||
|
"bigquery": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED AND source.is_deleted = 1 THEN DELETE WHEN MATCHED THEN UPDATE SET val = source.val WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)",
|
||||||
|
"snowflake": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED AND source.is_deleted = 1 THEN DELETE WHEN MATCHED THEN UPDATE SET val = source.val WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)",
|
||||||
|
"spark": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED AND source.is_deleted = 1 THEN DELETE WHEN MATCHED THEN UPDATE SET val = source.val WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"""
|
||||||
|
MERGE INTO target USING source ON target.id = source.id
|
||||||
|
WHEN MATCHED THEN UPDATE *
|
||||||
|
WHEN NOT MATCHED THEN INSERT *
|
||||||
|
""",
|
||||||
|
write={
|
||||||
|
"spark": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED THEN UPDATE * WHEN NOT MATCHED THEN INSERT *",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
|
@ -145,6 +145,10 @@ class TestHive(Validator):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.validate_identity(
|
||||||
|
"""CREATE EXTERNAL TABLE x (y INT) ROW FORMAT SERDE 'serde' ROW FORMAT DELIMITED FIELDS TERMINATED BY '1' WITH SERDEPROPERTIES ('input.regex'='')""",
|
||||||
|
)
|
||||||
|
|
||||||
def test_lateral_view(self):
|
def test_lateral_view(self):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b",
|
"SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b",
|
||||||
|
|
|
@ -256,3 +256,7 @@ class TestPostgres(Validator):
|
||||||
"SELECT $$Dianne's horse$$",
|
"SELECT $$Dianne's horse$$",
|
||||||
write={"postgres": "SELECT 'Dianne''s horse'"},
|
write={"postgres": "SELECT 'Dianne''s horse'"},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"UPDATE MYTABLE T1 SET T1.COL = 13",
|
||||||
|
write={"postgres": "UPDATE MYTABLE AS T1 SET T1.COL = 13"},
|
||||||
|
)
|
||||||
|
|
|
@ -56,8 +56,27 @@ class TestRedshift(Validator):
|
||||||
"redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1',
|
"redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1',
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"DECODE(x, a, b, c, d)",
|
||||||
|
write={
|
||||||
|
"": "MATCHES(x, a, b, c, d)",
|
||||||
|
"oracle": "DECODE(x, a, b, c, d)",
|
||||||
|
"snowflake": "DECODE(x, a, b, c, d)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"NVL(a, b, c, d)",
|
||||||
|
write={
|
||||||
|
"redshift": "COALESCE(a, b, c, d)",
|
||||||
|
"mysql": "COALESCE(a, b, c, d)",
|
||||||
|
"postgres": "COALESCE(a, b, c, d)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_identity(self):
|
def test_identity(self):
|
||||||
|
self.validate_identity(
|
||||||
|
"SELECT DECODE(COL1, 'replace_this', 'with_this', 'replace_that', 'with_that')"
|
||||||
|
)
|
||||||
self.validate_identity("CAST('bla' AS SUPER)")
|
self.validate_identity("CAST('bla' AS SUPER)")
|
||||||
self.validate_identity("CREATE TABLE real1 (realcol REAL)")
|
self.validate_identity("CREATE TABLE real1 (realcol REAL)")
|
||||||
self.validate_identity("CAST('foo' AS HLLSKETCH)")
|
self.validate_identity("CAST('foo' AS HLLSKETCH)")
|
||||||
|
@ -70,9 +89,9 @@ class TestRedshift(Validator):
|
||||||
self.validate_identity(
|
self.validate_identity(
|
||||||
"SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'"
|
"SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'"
|
||||||
)
|
)
|
||||||
self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE AUTO")
|
self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL")
|
||||||
self.validate_identity(
|
self.validate_identity(
|
||||||
"CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid)"
|
"CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid) DISTSTYLE AUTO"
|
||||||
)
|
)
|
||||||
self.validate_identity(
|
self.validate_identity(
|
||||||
"COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'"
|
"COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'"
|
||||||
|
@ -80,3 +99,6 @@ class TestRedshift(Validator):
|
||||||
self.validate_identity(
|
self.validate_identity(
|
||||||
"UNLOAD ('select * from venue') TO 's3://mybucket/unload/' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'"
|
"UNLOAD ('select * from venue') TO 's3://mybucket/unload/' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'"
|
||||||
)
|
)
|
||||||
|
self.validate_identity(
|
||||||
|
"CREATE TABLE SOUP (SOUP1 VARCHAR(50) NOT NULL ENCODE ZSTD, SOUP2 VARCHAR(70) NULL ENCODE DELTA)"
|
||||||
|
)
|
||||||
|
|
|
@ -500,3 +500,12 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') f, LATERAL F
|
||||||
},
|
},
|
||||||
pretty=True,
|
pretty=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_minus(self):
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT 1 EXCEPT SELECT 1",
|
||||||
|
read={
|
||||||
|
"oracle": "SELECT 1 MINUS SELECT 1",
|
||||||
|
"snowflake": "SELECT 1 MINUS SELECT 1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
1
tests/fixtures/identity.sql
vendored
1
tests/fixtures/identity.sql
vendored
|
@ -75,6 +75,7 @@ ARRAY(1, 2)
|
||||||
ARRAY_CONTAINS(x, 1)
|
ARRAY_CONTAINS(x, 1)
|
||||||
EXTRACT(x FROM y)
|
EXTRACT(x FROM y)
|
||||||
EXTRACT(DATE FROM y)
|
EXTRACT(DATE FROM y)
|
||||||
|
EXTRACT(WEEK(monday) FROM created_at)
|
||||||
CONCAT_WS('-', 'a', 'b')
|
CONCAT_WS('-', 'a', 'b')
|
||||||
CONCAT_WS('-', 'a', 'b', 'c')
|
CONCAT_WS('-', 'a', 'b', 'c')
|
||||||
POSEXPLODE("x") AS ("a", "b")
|
POSEXPLODE("x") AS ("a", "b")
|
||||||
|
|
6
tests/fixtures/optimizer/canonicalize.sql
vendored
6
tests/fixtures/optimizer/canonicalize.sql
vendored
|
@ -3,3 +3,9 @@ SELECT CONCAT(w.d, w.e) AS c FROM w AS w;
|
||||||
|
|
||||||
SELECT CAST(w.d AS DATE) > w.e AS a FROM w AS w;
|
SELECT CAST(w.d AS DATE) > w.e AS a FROM w AS w;
|
||||||
SELECT CAST(w.d AS DATE) > CAST(w.e AS DATE) AS a FROM w AS w;
|
SELECT CAST(w.d AS DATE) > CAST(w.e AS DATE) AS a FROM w AS w;
|
||||||
|
|
||||||
|
SELECT CAST(1 AS VARCHAR) AS a FROM w AS w;
|
||||||
|
SELECT CAST(1 AS VARCHAR) AS a FROM w AS w;
|
||||||
|
|
||||||
|
SELECT CAST(1 + 3.2 AS DOUBLE) AS a FROM w AS w;
|
||||||
|
SELECT 1 + 3.2 AS a FROM w AS w;
|
||||||
|
|
180
tests/fixtures/optimizer/simplify.sql
vendored
180
tests/fixtures/optimizer/simplify.sql
vendored
|
@ -79,14 +79,16 @@ NULL;
|
||||||
NULL = NULL;
|
NULL = NULL;
|
||||||
NULL;
|
NULL;
|
||||||
|
|
||||||
|
-- Can't optimize this because different engines do different things
|
||||||
|
-- mysql converts to 0 and 1 but tsql does true and false
|
||||||
NULL <=> NULL;
|
NULL <=> NULL;
|
||||||
TRUE;
|
NULL IS NOT DISTINCT FROM NULL;
|
||||||
|
|
||||||
a IS NOT DISTINCT FROM a;
|
a IS NOT DISTINCT FROM a;
|
||||||
TRUE;
|
a IS NOT DISTINCT FROM a;
|
||||||
|
|
||||||
NULL IS DISTINCT FROM NULL;
|
NULL IS DISTINCT FROM NULL;
|
||||||
FALSE;
|
NULL IS DISTINCT FROM NULL;
|
||||||
|
|
||||||
NOT (NOT TRUE);
|
NOT (NOT TRUE);
|
||||||
TRUE;
|
TRUE;
|
||||||
|
@ -239,10 +241,10 @@ TRUE;
|
||||||
FALSE;
|
FALSE;
|
||||||
|
|
||||||
((NOT FALSE) AND (x = x)) AND (TRUE OR 1 <> 3);
|
((NOT FALSE) AND (x = x)) AND (TRUE OR 1 <> 3);
|
||||||
TRUE;
|
x = x;
|
||||||
|
|
||||||
((NOT FALSE) AND (x = x)) AND (FALSE OR 1 <> 2);
|
((NOT FALSE) AND (x = x)) AND (FALSE OR 1 <> 2);
|
||||||
TRUE;
|
x = x;
|
||||||
|
|
||||||
(('a' = 'a') AND TRUE and NOT FALSE);
|
(('a' = 'a') AND TRUE and NOT FALSE);
|
||||||
TRUE;
|
TRUE;
|
||||||
|
@ -372,3 +374,171 @@ CAST('1998-12-01' AS DATE) - INTERVAL '90' foo;
|
||||||
|
|
||||||
date '1998-12-01' + interval '90' foo;
|
date '1998-12-01' + interval '90' foo;
|
||||||
CAST('1998-12-01' AS DATE) + INTERVAL '90' foo;
|
CAST('1998-12-01' AS DATE) + INTERVAL '90' foo;
|
||||||
|
|
||||||
|
--------------------------------------
|
||||||
|
-- Comparisons
|
||||||
|
--------------------------------------
|
||||||
|
x < 0 OR x > 1;
|
||||||
|
x < 0 OR x > 1;
|
||||||
|
|
||||||
|
x < 0 OR x > 0;
|
||||||
|
x < 0 OR x > 0;
|
||||||
|
|
||||||
|
x < 1 OR x > 0;
|
||||||
|
x < 1 OR x > 0;
|
||||||
|
|
||||||
|
x < 1 OR x >= 0;
|
||||||
|
x < 1 OR x >= 0;
|
||||||
|
|
||||||
|
x <= 1 OR x > 0;
|
||||||
|
x <= 1 OR x > 0;
|
||||||
|
|
||||||
|
x <= 1 OR x >= 0;
|
||||||
|
x <= 1 OR x >= 0;
|
||||||
|
|
||||||
|
x <= 1 AND x <= 0;
|
||||||
|
x <= 0;
|
||||||
|
|
||||||
|
x <= 1 AND x > 0;
|
||||||
|
x <= 1 AND x > 0;
|
||||||
|
|
||||||
|
x <= 1 OR x > 0;
|
||||||
|
x <= 1 OR x > 0;
|
||||||
|
|
||||||
|
x <= 0 OR x < 0;
|
||||||
|
x <= 0;
|
||||||
|
|
||||||
|
x >= 0 OR x > 0;
|
||||||
|
x >= 0;
|
||||||
|
|
||||||
|
x >= 0 OR x > 1;
|
||||||
|
x >= 0;
|
||||||
|
|
||||||
|
x <= 0 OR x >= 0;
|
||||||
|
x <= 0 OR x >= 0;
|
||||||
|
|
||||||
|
x <= 0 AND x >= 0;
|
||||||
|
x <= 0 AND x >= 0;
|
||||||
|
|
||||||
|
x < 1 AND x < 2;
|
||||||
|
x < 1;
|
||||||
|
|
||||||
|
x < 1 OR x < 2;
|
||||||
|
x < 2;
|
||||||
|
|
||||||
|
x < 2 AND x < 1;
|
||||||
|
x < 1;
|
||||||
|
|
||||||
|
x < 2 OR x < 1;
|
||||||
|
x < 2;
|
||||||
|
|
||||||
|
x < 1 AND x < 1;
|
||||||
|
x < 1;
|
||||||
|
|
||||||
|
x < 1 OR x < 1;
|
||||||
|
x < 1;
|
||||||
|
|
||||||
|
x <= 1 AND x < 1;
|
||||||
|
x < 1;
|
||||||
|
|
||||||
|
x <= 1 OR x < 1;
|
||||||
|
x <= 1;
|
||||||
|
|
||||||
|
x < 1 AND x <= 1;
|
||||||
|
x < 1;
|
||||||
|
|
||||||
|
x < 1 OR x <= 1;
|
||||||
|
x <= 1;
|
||||||
|
|
||||||
|
x > 1 AND x > 2;
|
||||||
|
x > 2;
|
||||||
|
|
||||||
|
x > 1 OR x > 2;
|
||||||
|
x > 1;
|
||||||
|
|
||||||
|
x > 2 AND x > 1;
|
||||||
|
x > 2;
|
||||||
|
|
||||||
|
x > 2 OR x > 1;
|
||||||
|
x > 1;
|
||||||
|
|
||||||
|
x > 1 AND x > 1;
|
||||||
|
x > 1;
|
||||||
|
|
||||||
|
x > 1 OR x > 1;
|
||||||
|
x > 1;
|
||||||
|
|
||||||
|
x >= 1 AND x > 1;
|
||||||
|
x > 1;
|
||||||
|
|
||||||
|
x >= 1 OR x > 1;
|
||||||
|
x >= 1;
|
||||||
|
|
||||||
|
x > 1 AND x >= 1;
|
||||||
|
x > 1;
|
||||||
|
|
||||||
|
x > 1 OR x >= 1;
|
||||||
|
x >= 1;
|
||||||
|
|
||||||
|
x > 1 AND x >= 2;
|
||||||
|
x >= 2;
|
||||||
|
|
||||||
|
x > 1 OR x >= 2;
|
||||||
|
x > 1;
|
||||||
|
|
||||||
|
x > 1 AND x >= 2 AND x > 3 AND x > 0;
|
||||||
|
x > 3;
|
||||||
|
|
||||||
|
(x > 1 AND x >= 2 AND x > 3 AND x > 0) OR x > 0;
|
||||||
|
x > 0;
|
||||||
|
|
||||||
|
x > 1 AND x < 2 AND x > 3;
|
||||||
|
FALSE;
|
||||||
|
|
||||||
|
x > 1 AND x < 1;
|
||||||
|
FALSE;
|
||||||
|
|
||||||
|
x < 2 AND x > 1;
|
||||||
|
x < 2 AND x > 1;
|
||||||
|
|
||||||
|
x = 1 AND x < 1;
|
||||||
|
FALSE;
|
||||||
|
|
||||||
|
x = 1 AND x < 1.1;
|
||||||
|
x = 1;
|
||||||
|
|
||||||
|
x = 1 AND x <= 1;
|
||||||
|
x = 1;
|
||||||
|
|
||||||
|
x = 1 AND x <= 0.9;
|
||||||
|
FALSE;
|
||||||
|
|
||||||
|
x = 1 AND x > 0.9;
|
||||||
|
x = 1;
|
||||||
|
|
||||||
|
x = 1 AND x > 1;
|
||||||
|
FALSE;
|
||||||
|
|
||||||
|
x = 1 AND x >= 1;
|
||||||
|
x = 1;
|
||||||
|
|
||||||
|
x = 1 AND x >= 2;
|
||||||
|
FALSE;
|
||||||
|
|
||||||
|
x = 1 AND x <> 2;
|
||||||
|
x = 1;
|
||||||
|
|
||||||
|
x <> 1 AND x = 1;
|
||||||
|
FALSE;
|
||||||
|
|
||||||
|
x BETWEEN 0 AND 5 AND x > 3;
|
||||||
|
x <= 5 AND x > 3;
|
||||||
|
|
||||||
|
x > 3 AND 5 > x AND x BETWEEN 0 AND 10;
|
||||||
|
x < 5 AND x > 3;
|
||||||
|
|
||||||
|
x > 3 AND 5 < x AND x BETWEEN 9 AND 10;
|
||||||
|
x <= 10 AND x >= 9;
|
||||||
|
|
||||||
|
1 < x AND 3 < x;
|
||||||
|
x > 3;
|
||||||
|
|
51
tests/fixtures/optimizer/tpc-h/tpc-h.sql
vendored
51
tests/fixtures/optimizer/tpc-h/tpc-h.sql
vendored
|
@ -190,7 +190,7 @@ SELECT
|
||||||
SUM("lineitem"."l_extendedprice" * (
|
SUM("lineitem"."l_extendedprice" * (
|
||||||
1 - "lineitem"."l_discount"
|
1 - "lineitem"."l_discount"
|
||||||
)) AS "revenue",
|
)) AS "revenue",
|
||||||
CAST("orders"."o_orderdate" AS TEXT) AS "o_orderdate",
|
"orders"."o_orderdate" AS "o_orderdate",
|
||||||
"orders"."o_shippriority" AS "o_shippriority"
|
"orders"."o_shippriority" AS "o_shippriority"
|
||||||
FROM "customer" AS "customer"
|
FROM "customer" AS "customer"
|
||||||
JOIN "orders" AS "orders"
|
JOIN "orders" AS "orders"
|
||||||
|
@ -326,7 +326,8 @@ SELECT
|
||||||
SUM("lineitem"."l_extendedprice" * "lineitem"."l_discount") AS "revenue"
|
SUM("lineitem"."l_extendedprice" * "lineitem"."l_discount") AS "revenue"
|
||||||
FROM "lineitem" AS "lineitem"
|
FROM "lineitem" AS "lineitem"
|
||||||
WHERE
|
WHERE
|
||||||
"lineitem"."l_discount" BETWEEN 0.05 AND 0.07
|
"lineitem"."l_discount" <= 0.07
|
||||||
|
AND "lineitem"."l_discount" >= 0.05
|
||||||
AND "lineitem"."l_quantity" < 24
|
AND "lineitem"."l_quantity" < 24
|
||||||
AND CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE)
|
AND CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE)
|
||||||
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE);
|
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE);
|
||||||
|
@ -344,7 +345,7 @@ from
|
||||||
select
|
select
|
||||||
n1.n_name as supp_nation,
|
n1.n_name as supp_nation,
|
||||||
n2.n_name as cust_nation,
|
n2.n_name as cust_nation,
|
||||||
extract(year from l_shipdate) as l_year,
|
extract(year from cast(l_shipdate as date)) as l_year,
|
||||||
l_extendedprice * (1 - l_discount) as volume
|
l_extendedprice * (1 - l_discount) as volume
|
||||||
from
|
from
|
||||||
supplier,
|
supplier,
|
||||||
|
@ -384,13 +385,14 @@ WITH "n1" AS (
|
||||||
SELECT
|
SELECT
|
||||||
"n1"."n_name" AS "supp_nation",
|
"n1"."n_name" AS "supp_nation",
|
||||||
"n2"."n_name" AS "cust_nation",
|
"n2"."n_name" AS "cust_nation",
|
||||||
EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME)) AS "l_year",
|
EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATE)) AS "l_year",
|
||||||
SUM("lineitem"."l_extendedprice" * (
|
SUM("lineitem"."l_extendedprice" * (
|
||||||
1 - "lineitem"."l_discount"
|
1 - "lineitem"."l_discount"
|
||||||
)) AS "revenue"
|
)) AS "revenue"
|
||||||
FROM "supplier" AS "supplier"
|
FROM "supplier" AS "supplier"
|
||||||
JOIN "lineitem" AS "lineitem"
|
JOIN "lineitem" AS "lineitem"
|
||||||
ON CAST("lineitem"."l_shipdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
|
ON CAST("lineitem"."l_shipdate" AS DATE) <= CAST('1996-12-31' AS DATE)
|
||||||
|
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1995-01-01' AS DATE)
|
||||||
AND "supplier"."s_suppkey" = "lineitem"."l_suppkey"
|
AND "supplier"."s_suppkey" = "lineitem"."l_suppkey"
|
||||||
JOIN "orders" AS "orders"
|
JOIN "orders" AS "orders"
|
||||||
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
|
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
|
||||||
|
@ -409,7 +411,7 @@ JOIN "n1" AS "n2"
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"n1"."n_name",
|
"n1"."n_name",
|
||||||
"n2"."n_name",
|
"n2"."n_name",
|
||||||
EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME))
|
EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATE))
|
||||||
ORDER BY
|
ORDER BY
|
||||||
"supp_nation",
|
"supp_nation",
|
||||||
"cust_nation",
|
"cust_nation",
|
||||||
|
@ -427,7 +429,7 @@ select
|
||||||
from
|
from
|
||||||
(
|
(
|
||||||
select
|
select
|
||||||
extract(year from o_orderdate) as o_year,
|
extract(year from cast(o_orderdate as date)) as o_year,
|
||||||
l_extendedprice * (1 - l_discount) as volume,
|
l_extendedprice * (1 - l_discount) as volume,
|
||||||
n2.n_name as nation
|
n2.n_name as nation
|
||||||
from
|
from
|
||||||
|
@ -456,7 +458,7 @@ group by
|
||||||
order by
|
order by
|
||||||
o_year;
|
o_year;
|
||||||
SELECT
|
SELECT
|
||||||
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year",
|
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE)) AS "o_year",
|
||||||
SUM(
|
SUM(
|
||||||
CASE
|
CASE
|
||||||
WHEN "nation_2"."n_name" = 'BRAZIL'
|
WHEN "nation_2"."n_name" = 'BRAZIL'
|
||||||
|
@ -477,7 +479,8 @@ JOIN "customer" AS "customer"
|
||||||
ON "customer"."c_nationkey" = "nation"."n_nationkey"
|
ON "customer"."c_nationkey" = "nation"."n_nationkey"
|
||||||
JOIN "orders" AS "orders"
|
JOIN "orders" AS "orders"
|
||||||
ON "orders"."o_custkey" = "customer"."c_custkey"
|
ON "orders"."o_custkey" = "customer"."c_custkey"
|
||||||
AND CAST("orders"."o_orderdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
|
AND CAST("orders"."o_orderdate" AS DATE) <= CAST('1996-12-31' AS DATE)
|
||||||
|
AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1995-01-01' AS DATE)
|
||||||
JOIN "lineitem" AS "lineitem"
|
JOIN "lineitem" AS "lineitem"
|
||||||
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
|
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
|
||||||
AND "part"."p_partkey" = "lineitem"."l_partkey"
|
AND "part"."p_partkey" = "lineitem"."l_partkey"
|
||||||
|
@ -488,7 +491,7 @@ JOIN "nation" AS "nation_2"
|
||||||
WHERE
|
WHERE
|
||||||
"part"."p_type" = 'ECONOMY ANODIZED STEEL'
|
"part"."p_type" = 'ECONOMY ANODIZED STEEL'
|
||||||
GROUP BY
|
GROUP BY
|
||||||
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME))
|
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE))
|
||||||
ORDER BY
|
ORDER BY
|
||||||
"o_year";
|
"o_year";
|
||||||
|
|
||||||
|
@ -503,7 +506,7 @@ from
|
||||||
(
|
(
|
||||||
select
|
select
|
||||||
n_name as nation,
|
n_name as nation,
|
||||||
extract(year from o_orderdate) as o_year,
|
extract(year from cast(o_orderdate as date)) as o_year,
|
||||||
l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount
|
l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount
|
||||||
from
|
from
|
||||||
part,
|
part,
|
||||||
|
@ -529,7 +532,7 @@ order by
|
||||||
o_year desc;
|
o_year desc;
|
||||||
SELECT
|
SELECT
|
||||||
"nation"."n_name" AS "nation",
|
"nation"."n_name" AS "nation",
|
||||||
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year",
|
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE)) AS "o_year",
|
||||||
SUM(
|
SUM(
|
||||||
"lineitem"."l_extendedprice" * (
|
"lineitem"."l_extendedprice" * (
|
||||||
1 - "lineitem"."l_discount"
|
1 - "lineitem"."l_discount"
|
||||||
|
@ -551,7 +554,7 @@ WHERE
|
||||||
"part"."p_name" LIKE '%green%'
|
"part"."p_name" LIKE '%green%'
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"nation"."n_name",
|
"nation"."n_name",
|
||||||
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME))
|
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE))
|
||||||
ORDER BY
|
ORDER BY
|
||||||
"nation",
|
"nation",
|
||||||
"o_year" DESC;
|
"o_year" DESC;
|
||||||
|
@ -1016,7 +1019,7 @@ select
|
||||||
o_orderkey,
|
o_orderkey,
|
||||||
o_orderdate,
|
o_orderdate,
|
||||||
o_totalprice,
|
o_totalprice,
|
||||||
sum(l_quantity)
|
sum(l_quantity) total_quantity
|
||||||
from
|
from
|
||||||
customer,
|
customer,
|
||||||
orders,
|
orders,
|
||||||
|
@ -1060,7 +1063,7 @@ SELECT
|
||||||
"orders"."o_orderkey" AS "o_orderkey",
|
"orders"."o_orderkey" AS "o_orderkey",
|
||||||
"orders"."o_orderdate" AS "o_orderdate",
|
"orders"."o_orderdate" AS "o_orderdate",
|
||||||
"orders"."o_totalprice" AS "o_totalprice",
|
"orders"."o_totalprice" AS "o_totalprice",
|
||||||
SUM("lineitem"."l_quantity") AS "_col_5"
|
SUM("lineitem"."l_quantity") AS "total_quantity"
|
||||||
FROM "customer" AS "customer"
|
FROM "customer" AS "customer"
|
||||||
JOIN "orders" AS "orders"
|
JOIN "orders" AS "orders"
|
||||||
ON "customer"."c_custkey" = "orders"."o_custkey"
|
ON "customer"."c_custkey" = "orders"."o_custkey"
|
||||||
|
@ -1129,19 +1132,22 @@ JOIN "part" AS "part"
|
||||||
"part"."p_brand" = 'Brand#12'
|
"part"."p_brand" = 'Brand#12'
|
||||||
AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG')
|
AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG')
|
||||||
AND "part"."p_partkey" = "lineitem"."l_partkey"
|
AND "part"."p_partkey" = "lineitem"."l_partkey"
|
||||||
AND "part"."p_size" BETWEEN 1 AND 5
|
AND "part"."p_size" <= 5
|
||||||
|
AND "part"."p_size" >= 1
|
||||||
)
|
)
|
||||||
OR (
|
OR (
|
||||||
"part"."p_brand" = 'Brand#23'
|
"part"."p_brand" = 'Brand#23'
|
||||||
AND "part"."p_container" IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK')
|
AND "part"."p_container" IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK')
|
||||||
AND "part"."p_partkey" = "lineitem"."l_partkey"
|
AND "part"."p_partkey" = "lineitem"."l_partkey"
|
||||||
AND "part"."p_size" BETWEEN 1 AND 10
|
AND "part"."p_size" <= 10
|
||||||
|
AND "part"."p_size" >= 1
|
||||||
)
|
)
|
||||||
OR (
|
OR (
|
||||||
"part"."p_brand" = 'Brand#34'
|
"part"."p_brand" = 'Brand#34'
|
||||||
AND "part"."p_container" IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG')
|
AND "part"."p_container" IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG')
|
||||||
AND "part"."p_partkey" = "lineitem"."l_partkey"
|
AND "part"."p_partkey" = "lineitem"."l_partkey"
|
||||||
AND "part"."p_size" BETWEEN 1 AND 15
|
AND "part"."p_size" <= 15
|
||||||
|
AND "part"."p_size" >= 1
|
||||||
)
|
)
|
||||||
WHERE
|
WHERE
|
||||||
(
|
(
|
||||||
|
@ -1152,7 +1158,8 @@ WHERE
|
||||||
AND "part"."p_brand" = 'Brand#12'
|
AND "part"."p_brand" = 'Brand#12'
|
||||||
AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG')
|
AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG')
|
||||||
AND "part"."p_partkey" = "lineitem"."l_partkey"
|
AND "part"."p_partkey" = "lineitem"."l_partkey"
|
||||||
AND "part"."p_size" BETWEEN 1 AND 5
|
AND "part"."p_size" <= 5
|
||||||
|
AND "part"."p_size" >= 1
|
||||||
)
|
)
|
||||||
OR (
|
OR (
|
||||||
"lineitem"."l_quantity" <= 20
|
"lineitem"."l_quantity" <= 20
|
||||||
|
@ -1162,7 +1169,8 @@ WHERE
|
||||||
AND "part"."p_brand" = 'Brand#23'
|
AND "part"."p_brand" = 'Brand#23'
|
||||||
AND "part"."p_container" IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK')
|
AND "part"."p_container" IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK')
|
||||||
AND "part"."p_partkey" = "lineitem"."l_partkey"
|
AND "part"."p_partkey" = "lineitem"."l_partkey"
|
||||||
AND "part"."p_size" BETWEEN 1 AND 10
|
AND "part"."p_size" <= 10
|
||||||
|
AND "part"."p_size" >= 1
|
||||||
)
|
)
|
||||||
OR (
|
OR (
|
||||||
"lineitem"."l_quantity" <= 30
|
"lineitem"."l_quantity" <= 30
|
||||||
|
@ -1172,7 +1180,8 @@ WHERE
|
||||||
AND "part"."p_brand" = 'Brand#34'
|
AND "part"."p_brand" = 'Brand#34'
|
||||||
AND "part"."p_container" IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG')
|
AND "part"."p_container" IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG')
|
||||||
AND "part"."p_partkey" = "lineitem"."l_partkey"
|
AND "part"."p_partkey" = "lineitem"."l_partkey"
|
||||||
AND "part"."p_size" BETWEEN 1 AND 15
|
AND "part"."p_size" <= 15
|
||||||
|
AND "part"."p_size" >= 1
|
||||||
);
|
);
|
||||||
|
|
||||||
--------------------------------------
|
--------------------------------------
|
||||||
|
|
|
@ -26,12 +26,12 @@ class TestExecutor(unittest.TestCase):
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.conn = duckdb.connect()
|
cls.conn = duckdb.connect()
|
||||||
|
|
||||||
for table in TPCH_SCHEMA:
|
for table, columns in TPCH_SCHEMA.items():
|
||||||
cls.conn.execute(
|
cls.conn.execute(
|
||||||
f"""
|
f"""
|
||||||
CREATE VIEW {table} AS
|
CREATE VIEW {table} AS
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM READ_CSV_AUTO('{DIR}{table}.csv.gz')
|
FROM READ_CSV('{DIR}{table}.csv.gz', delim='|', header=True, columns={columns})
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -74,13 +74,13 @@ class TestExecutor(unittest.TestCase):
|
||||||
)
|
)
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
for i, (sql, _) in enumerate(self.sqls[0:16]):
|
for i, (sql, _) in enumerate(self.sqls[0:18]):
|
||||||
with self.subTest(f"tpch-h {i + 1}"):
|
with self.subTest(f"tpch-h {i + 1}"):
|
||||||
a = self.cached_execute(sql)
|
a = self.cached_execute(sql)
|
||||||
sql = parse_one(sql).transform(to_csv).sql(pretty=True)
|
sql = parse_one(sql).transform(to_csv).sql(pretty=True)
|
||||||
table = execute(sql, TPCH_SCHEMA)
|
table = execute(sql, TPCH_SCHEMA)
|
||||||
b = pd.DataFrame(table.rows, columns=table.columns)
|
b = pd.DataFrame(table.rows, columns=table.columns)
|
||||||
assert_frame_equal(a, b, check_dtype=False)
|
assert_frame_equal(a, b, check_dtype=False, check_index_type=False)
|
||||||
|
|
||||||
def test_execute_callable(self):
|
def test_execute_callable(self):
|
||||||
tables = {
|
tables = {
|
||||||
|
@ -456,8 +456,13 @@ class TestExecutor(unittest.TestCase):
|
||||||
("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
|
("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
|
||||||
("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
|
("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
|
||||||
("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]),
|
("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]),
|
||||||
("SELECT SUM(x) FROM (SELECT 1 AS x WHERE FALSE)", ["_col_0"], [(0,)]),
|
(
|
||||||
|
"SELECT SUM(x), COUNT(x) FROM (SELECT 1 AS x WHERE FALSE)",
|
||||||
|
["_col_0", "_col_1"],
|
||||||
|
[(None, 0)],
|
||||||
|
),
|
||||||
]:
|
]:
|
||||||
|
with self.subTest(sql):
|
||||||
result = execute(sql)
|
result = execute(sql)
|
||||||
self.assertEqual(result.columns, tuple(cols))
|
self.assertEqual(result.columns, tuple(cols))
|
||||||
self.assertEqual(result.rows, rows)
|
self.assertEqual(result.rows, rows)
|
||||||
|
|
|
@ -333,7 +333,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
||||||
|
|
||||||
for sql, target_type in tests.items():
|
for sql, target_type in tests.items():
|
||||||
expression = annotate_types(parse_one(sql))
|
expression = annotate_types(parse_one(sql))
|
||||||
self.assertEqual(expression.find(exp.Literal).type, target_type)
|
self.assertEqual(expression.find(exp.Literal).type.this, target_type)
|
||||||
|
|
||||||
def test_boolean_type_annotation(self):
|
def test_boolean_type_annotation(self):
|
||||||
tests = {
|
tests = {
|
||||||
|
@ -343,31 +343,33 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
||||||
|
|
||||||
for sql, target_type in tests.items():
|
for sql, target_type in tests.items():
|
||||||
expression = annotate_types(parse_one(sql))
|
expression = annotate_types(parse_one(sql))
|
||||||
self.assertEqual(expression.find(exp.Boolean).type, target_type)
|
self.assertEqual(expression.find(exp.Boolean).type.this, target_type)
|
||||||
|
|
||||||
def test_cast_type_annotation(self):
|
def test_cast_type_annotation(self):
|
||||||
expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))"))
|
expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))"))
|
||||||
|
self.assertEqual(expression.type.this, exp.DataType.Type.TIMESTAMPTZ)
|
||||||
|
self.assertEqual(expression.this.type.this, exp.DataType.Type.VARCHAR)
|
||||||
|
self.assertEqual(expression.args["to"].type.this, exp.DataType.Type.TIMESTAMPTZ)
|
||||||
|
self.assertEqual(expression.args["to"].expressions[0].type.this, exp.DataType.Type.INT)
|
||||||
|
|
||||||
self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ)
|
expression = annotate_types(parse_one("ARRAY(1)::ARRAY<INT>"))
|
||||||
self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR)
|
self.assertEqual(expression.type, parse_one("ARRAY<INT>", into=exp.DataType))
|
||||||
self.assertEqual(expression.args["to"].type, exp.DataType.Type.TIMESTAMPTZ)
|
|
||||||
self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT)
|
|
||||||
|
|
||||||
def test_cache_annotation(self):
|
def test_cache_annotation(self):
|
||||||
expression = annotate_types(
|
expression = annotate_types(
|
||||||
parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")
|
parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")
|
||||||
)
|
)
|
||||||
self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT)
|
self.assertEqual(expression.expression.expressions[0].type.this, exp.DataType.Type.INT)
|
||||||
|
|
||||||
def test_binary_annotation(self):
|
def test_binary_annotation(self):
|
||||||
expression = annotate_types(parse_one("SELECT 0.0 + (2 + 3)")).expressions[0]
|
expression = annotate_types(parse_one("SELECT 0.0 + (2 + 3)")).expressions[0]
|
||||||
|
|
||||||
self.assertEqual(expression.type, exp.DataType.Type.DOUBLE)
|
self.assertEqual(expression.type.this, exp.DataType.Type.DOUBLE)
|
||||||
self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE)
|
self.assertEqual(expression.left.type.this, exp.DataType.Type.DOUBLE)
|
||||||
self.assertEqual(expression.right.type, exp.DataType.Type.INT)
|
self.assertEqual(expression.right.type.this, exp.DataType.Type.INT)
|
||||||
self.assertEqual(expression.right.this.type, exp.DataType.Type.INT)
|
self.assertEqual(expression.right.this.type.this, exp.DataType.Type.INT)
|
||||||
self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT)
|
self.assertEqual(expression.right.this.left.type.this, exp.DataType.Type.INT)
|
||||||
self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT)
|
self.assertEqual(expression.right.this.right.type.this, exp.DataType.Type.INT)
|
||||||
|
|
||||||
def test_derived_tables_column_annotation(self):
|
def test_derived_tables_column_annotation(self):
|
||||||
schema = {"x": {"cola": "INT"}, "y": {"cola": "FLOAT"}}
|
schema = {"x": {"cola": "INT"}, "y": {"cola": "FLOAT"}}
|
||||||
|
@ -387,128 +389,169 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
expression = annotate_types(parse_one(sql), schema=schema)
|
expression = annotate_types(parse_one(sql), schema=schema)
|
||||||
self.assertEqual(expression.expressions[0].type, exp.DataType.Type.FLOAT) # a.cola AS cola
|
self.assertEqual(
|
||||||
|
expression.expressions[0].type.this, exp.DataType.Type.FLOAT
|
||||||
|
) # a.cola AS cola
|
||||||
|
|
||||||
addition_alias = expression.args["from"].expressions[0].this.expressions[0]
|
addition_alias = expression.args["from"].expressions[0].this.expressions[0]
|
||||||
self.assertEqual(addition_alias.type, exp.DataType.Type.FLOAT) # x.cola + y.cola AS cola
|
self.assertEqual(
|
||||||
|
addition_alias.type.this, exp.DataType.Type.FLOAT
|
||||||
|
) # x.cola + y.cola AS cola
|
||||||
|
|
||||||
addition = addition_alias.this
|
addition = addition_alias.this
|
||||||
self.assertEqual(addition.type, exp.DataType.Type.FLOAT)
|
self.assertEqual(addition.type.this, exp.DataType.Type.FLOAT)
|
||||||
self.assertEqual(addition.this.type, exp.DataType.Type.INT)
|
self.assertEqual(addition.this.type.this, exp.DataType.Type.INT)
|
||||||
self.assertEqual(addition.expression.type, exp.DataType.Type.FLOAT)
|
self.assertEqual(addition.expression.type.this, exp.DataType.Type.FLOAT)
|
||||||
|
|
||||||
def test_cte_column_annotation(self):
|
def test_cte_column_annotation(self):
|
||||||
schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT"}}
|
schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT", "colc": "BOOLEAN"}}
|
||||||
sql = """
|
sql = """
|
||||||
WITH tbl AS (
|
WITH tbl AS (
|
||||||
SELECT x.cola + 'bla' AS cola, y.colb AS colb
|
SELECT x.cola + 'bla' AS cola, y.colb AS colb, y.colc AS colc
|
||||||
FROM (
|
FROM (
|
||||||
SELECT x.cola AS cola
|
SELECT x.cola AS cola
|
||||||
FROM x AS x
|
FROM x AS x
|
||||||
) AS x
|
) AS x
|
||||||
JOIN (
|
JOIN (
|
||||||
SELECT y.colb AS colb
|
SELECT y.colb AS colb, y.colc AS colc
|
||||||
FROM y AS y
|
FROM y AS y
|
||||||
) AS y
|
) AS y
|
||||||
)
|
)
|
||||||
SELECT tbl.cola + tbl.colb + 'foo' AS col
|
SELECT tbl.cola + tbl.colb + 'foo' AS col
|
||||||
FROM tbl AS tbl
|
FROM tbl AS tbl
|
||||||
|
WHERE tbl.colc = True
|
||||||
"""
|
"""
|
||||||
|
|
||||||
expression = annotate_types(parse_one(sql), schema=schema)
|
expression = annotate_types(parse_one(sql), schema=schema)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
expression.expressions[0].type, exp.DataType.Type.TEXT
|
expression.expressions[0].type.this, exp.DataType.Type.TEXT
|
||||||
) # tbl.cola + tbl.colb + 'foo' AS col
|
) # tbl.cola + tbl.colb + 'foo' AS col
|
||||||
|
|
||||||
outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo'
|
outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo'
|
||||||
self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT)
|
self.assertEqual(outer_addition.type.this, exp.DataType.Type.TEXT)
|
||||||
self.assertEqual(outer_addition.left.type, exp.DataType.Type.TEXT)
|
self.assertEqual(outer_addition.left.type.this, exp.DataType.Type.TEXT)
|
||||||
self.assertEqual(outer_addition.right.type, exp.DataType.Type.VARCHAR)
|
self.assertEqual(outer_addition.right.type.this, exp.DataType.Type.VARCHAR)
|
||||||
|
|
||||||
inner_addition = expression.expressions[0].this.left # tbl.cola + tbl.colb
|
inner_addition = expression.expressions[0].this.left # tbl.cola + tbl.colb
|
||||||
self.assertEqual(inner_addition.left.type, exp.DataType.Type.VARCHAR)
|
self.assertEqual(inner_addition.left.type.this, exp.DataType.Type.VARCHAR)
|
||||||
self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT)
|
self.assertEqual(inner_addition.right.type.this, exp.DataType.Type.TEXT)
|
||||||
|
|
||||||
|
# WHERE tbl.colc = True
|
||||||
|
self.assertEqual(expression.args["where"].this.type.this, exp.DataType.Type.BOOLEAN)
|
||||||
|
|
||||||
cte_select = expression.args["with"].expressions[0].this
|
cte_select = expression.args["with"].expressions[0].this
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
cte_select.expressions[0].type, exp.DataType.Type.VARCHAR
|
cte_select.expressions[0].type.this, exp.DataType.Type.VARCHAR
|
||||||
) # x.cola + 'bla' AS cola
|
) # x.cola + 'bla' AS cola
|
||||||
self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb
|
self.assertEqual(
|
||||||
|
cte_select.expressions[1].type.this, exp.DataType.Type.TEXT
|
||||||
|
) # y.colb AS colb
|
||||||
|
self.assertEqual(
|
||||||
|
cte_select.expressions[2].type.this, exp.DataType.Type.BOOLEAN
|
||||||
|
) # y.colc AS colc
|
||||||
|
|
||||||
cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla'
|
cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla'
|
||||||
self.assertEqual(cte_select_addition.type, exp.DataType.Type.VARCHAR)
|
self.assertEqual(cte_select_addition.type.this, exp.DataType.Type.VARCHAR)
|
||||||
self.assertEqual(cte_select_addition.left.type, exp.DataType.Type.CHAR)
|
self.assertEqual(cte_select_addition.left.type.this, exp.DataType.Type.CHAR)
|
||||||
self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR)
|
self.assertEqual(cte_select_addition.right.type.this, exp.DataType.Type.VARCHAR)
|
||||||
|
|
||||||
# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
|
# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
|
||||||
for d, t in zip(
|
for d, t in zip(
|
||||||
cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]
|
cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]
|
||||||
):
|
):
|
||||||
self.assertEqual(d.this.expressions[0].this.type, t)
|
self.assertEqual(d.this.expressions[0].this.type.this, t)
|
||||||
|
|
||||||
def test_function_annotation(self):
|
def test_function_annotation(self):
|
||||||
schema = {"x": {"cola": "VARCHAR", "colb": "CHAR"}}
|
schema = {"x": {"cola": "VARCHAR", "colb": "CHAR"}}
|
||||||
sql = "SELECT x.cola || TRIM(x.colb) AS col FROM x AS x"
|
sql = "SELECT x.cola || TRIM(x.colb) AS col FROM x AS x"
|
||||||
|
|
||||||
concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
|
concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
|
||||||
self.assertEqual(concat_expr_alias.type, exp.DataType.Type.VARCHAR)
|
self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.VARCHAR)
|
||||||
|
|
||||||
concat_expr = concat_expr_alias.this
|
concat_expr = concat_expr_alias.this
|
||||||
self.assertEqual(concat_expr.type, exp.DataType.Type.VARCHAR)
|
self.assertEqual(concat_expr.type.this, exp.DataType.Type.VARCHAR)
|
||||||
self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
|
self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.VARCHAR) # x.cola
|
||||||
self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
|
self.assertEqual(concat_expr.right.type.this, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
|
||||||
self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb
|
self.assertEqual(concat_expr.right.this.type.this, exp.DataType.Type.CHAR) # x.colb
|
||||||
|
|
||||||
sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x"
|
sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x"
|
||||||
|
|
||||||
case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
|
case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
|
||||||
self.assertEqual(case_expr_alias.type, exp.DataType.Type.VARCHAR)
|
self.assertEqual(case_expr_alias.type.this, exp.DataType.Type.VARCHAR)
|
||||||
|
|
||||||
case_expr = case_expr_alias.this
|
case_expr = case_expr_alias.this
|
||||||
self.assertEqual(case_expr.type, exp.DataType.Type.VARCHAR)
|
self.assertEqual(case_expr.type.this, exp.DataType.Type.VARCHAR)
|
||||||
self.assertEqual(case_expr.args["default"].type, exp.DataType.Type.CHAR)
|
self.assertEqual(case_expr.args["default"].type.this, exp.DataType.Type.CHAR)
|
||||||
|
|
||||||
case_ifs_expr = case_expr.args["ifs"][0]
|
case_ifs_expr = case_expr.args["ifs"][0]
|
||||||
self.assertEqual(case_ifs_expr.type, exp.DataType.Type.VARCHAR)
|
self.assertEqual(case_ifs_expr.type.this, exp.DataType.Type.VARCHAR)
|
||||||
self.assertEqual(case_ifs_expr.args["true"].type, exp.DataType.Type.VARCHAR)
|
self.assertEqual(case_ifs_expr.args["true"].type.this, exp.DataType.Type.VARCHAR)
|
||||||
|
|
||||||
def test_unknown_annotation(self):
|
def test_unknown_annotation(self):
|
||||||
schema = {"x": {"cola": "VARCHAR"}}
|
schema = {"x": {"cola": "VARCHAR"}}
|
||||||
sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x"
|
sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x"
|
||||||
|
|
||||||
concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
|
concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
|
||||||
self.assertEqual(concat_expr_alias.type, exp.DataType.Type.UNKNOWN)
|
self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.UNKNOWN)
|
||||||
|
|
||||||
concat_expr = concat_expr_alias.this
|
concat_expr = concat_expr_alias.this
|
||||||
self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN)
|
self.assertEqual(concat_expr.type.this, exp.DataType.Type.UNKNOWN)
|
||||||
self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
|
self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.VARCHAR) # x.cola
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
concat_expr.right.type, exp.DataType.Type.UNKNOWN
|
concat_expr.right.type.this, exp.DataType.Type.UNKNOWN
|
||||||
) # SOME_ANONYMOUS_FUNC(x.cola)
|
) # SOME_ANONYMOUS_FUNC(x.cola)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR
|
concat_expr.right.expressions[0].type.this, exp.DataType.Type.VARCHAR
|
||||||
) # x.cola (arg)
|
) # x.cola (arg)
|
||||||
|
|
||||||
def test_null_annotation(self):
|
def test_null_annotation(self):
|
||||||
expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this
|
expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this
|
||||||
self.assertEqual(expression.left.type, exp.DataType.Type.NULL)
|
self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL)
|
||||||
self.assertEqual(expression.right.type, exp.DataType.Type.INT)
|
self.assertEqual(expression.right.type.this, exp.DataType.Type.INT)
|
||||||
|
|
||||||
# NULL <op> UNKNOWN should yield NULL
|
# NULL <op> UNKNOWN should yield NULL
|
||||||
sql = "SELECT NULL || SOME_ANONYMOUS_FUNC() AS result"
|
sql = "SELECT NULL || SOME_ANONYMOUS_FUNC() AS result"
|
||||||
|
|
||||||
concat_expr_alias = annotate_types(parse_one(sql)).expressions[0]
|
concat_expr_alias = annotate_types(parse_one(sql)).expressions[0]
|
||||||
self.assertEqual(concat_expr_alias.type, exp.DataType.Type.NULL)
|
self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.NULL)
|
||||||
|
|
||||||
concat_expr = concat_expr_alias.this
|
concat_expr = concat_expr_alias.this
|
||||||
self.assertEqual(concat_expr.type, exp.DataType.Type.NULL)
|
self.assertEqual(concat_expr.type.this, exp.DataType.Type.NULL)
|
||||||
self.assertEqual(concat_expr.left.type, exp.DataType.Type.NULL)
|
self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.NULL)
|
||||||
self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN)
|
self.assertEqual(concat_expr.right.type.this, exp.DataType.Type.UNKNOWN)
|
||||||
|
|
||||||
def test_nullable_annotation(self):
|
def test_nullable_annotation(self):
|
||||||
nullable = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
|
nullable = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
|
||||||
expression = annotate_types(parse_one("NULL AND FALSE"))
|
expression = annotate_types(parse_one("NULL AND FALSE"))
|
||||||
|
|
||||||
self.assertEqual(expression.type, nullable)
|
self.assertEqual(expression.type, nullable)
|
||||||
self.assertEqual(expression.left.type, exp.DataType.Type.NULL)
|
self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL)
|
||||||
self.assertEqual(expression.right.type, exp.DataType.Type.BOOLEAN)
|
self.assertEqual(expression.right.type.this, exp.DataType.Type.BOOLEAN)
|
||||||
|
|
||||||
|
def test_predicate_annotation(self):
|
||||||
|
expression = annotate_types(parse_one("x BETWEEN a AND b"))
|
||||||
|
self.assertEqual(expression.type.this, exp.DataType.Type.BOOLEAN)
|
||||||
|
|
||||||
|
expression = annotate_types(parse_one("x IN (a, b, c, d)"))
|
||||||
|
self.assertEqual(expression.type.this, exp.DataType.Type.BOOLEAN)
|
||||||
|
|
||||||
|
def test_aggfunc_annotation(self):
|
||||||
|
schema = {"x": {"cola": "SMALLINT", "colb": "FLOAT", "colc": "TEXT", "cold": "DATE"}}
|
||||||
|
|
||||||
|
tests = {
|
||||||
|
("AVG", "cola"): exp.DataType.Type.DOUBLE,
|
||||||
|
("SUM", "cola"): exp.DataType.Type.BIGINT,
|
||||||
|
("SUM", "colb"): exp.DataType.Type.DOUBLE,
|
||||||
|
("MIN", "cola"): exp.DataType.Type.SMALLINT,
|
||||||
|
("MIN", "colb"): exp.DataType.Type.FLOAT,
|
||||||
|
("MAX", "colc"): exp.DataType.Type.TEXT,
|
||||||
|
("MAX", "cold"): exp.DataType.Type.DATE,
|
||||||
|
("COUNT", "colb"): exp.DataType.Type.BIGINT,
|
||||||
|
("STDDEV", "cola"): exp.DataType.Type.DOUBLE,
|
||||||
|
}
|
||||||
|
|
||||||
|
for (func, col), target_type in tests.items():
|
||||||
|
expression = annotate_types(
|
||||||
|
parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), schema=schema
|
||||||
|
)
|
||||||
|
self.assertEqual(expression.expressions[0].type.this, target_type)
|
||||||
|
|
|
@ -151,31 +151,33 @@ class TestSchema(unittest.TestCase):
|
||||||
|
|
||||||
def test_schema_get_column_type(self):
|
def test_schema_get_column_type(self):
|
||||||
schema = MappingSchema({"a": {"b": "varchar"}})
|
schema = MappingSchema({"a": {"b": "varchar"}})
|
||||||
self.assertEqual(schema.get_column_type("a", "b"), exp.DataType.Type.VARCHAR)
|
self.assertEqual(schema.get_column_type("a", "b").this, exp.DataType.Type.VARCHAR)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")),
|
schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")).this,
|
||||||
exp.DataType.Type.VARCHAR,
|
exp.DataType.Type.VARCHAR,
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
schema.get_column_type("a", exp.Column(this="b")), exp.DataType.Type.VARCHAR
|
schema.get_column_type("a", exp.Column(this="b")).this, exp.DataType.Type.VARCHAR
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
schema.get_column_type(exp.Table(this="a"), "b"), exp.DataType.Type.VARCHAR
|
schema.get_column_type(exp.Table(this="a"), "b").this, exp.DataType.Type.VARCHAR
|
||||||
)
|
)
|
||||||
schema = MappingSchema({"a": {"b": {"c": "varchar"}}})
|
schema = MappingSchema({"a": {"b": {"c": "varchar"}}})
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")),
|
schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")).this,
|
||||||
exp.DataType.Type.VARCHAR,
|
exp.DataType.Type.VARCHAR,
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
schema.get_column_type(exp.Table(this="b", db="a"), "c"), exp.DataType.Type.VARCHAR
|
schema.get_column_type(exp.Table(this="b", db="a"), "c").this, exp.DataType.Type.VARCHAR
|
||||||
)
|
)
|
||||||
schema = MappingSchema({"a": {"b": {"c": {"d": "varchar"}}}})
|
schema = MappingSchema({"a": {"b": {"c": {"d": "varchar"}}}})
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d")),
|
schema.get_column_type(
|
||||||
|
exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d")
|
||||||
|
).this,
|
||||||
exp.DataType.Type.VARCHAR,
|
exp.DataType.Type.VARCHAR,
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d"),
|
schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d").this,
|
||||||
exp.DataType.Type.VARCHAR,
|
exp.DataType.Type.VARCHAR,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from sqlglot.tokens import Tokenizer
|
from sqlglot.tokens import Tokenizer, TokenType
|
||||||
|
|
||||||
|
|
||||||
class TestTokens(unittest.TestCase):
|
class TestTokens(unittest.TestCase):
|
||||||
|
@ -17,3 +17,48 @@ class TestTokens(unittest.TestCase):
|
||||||
|
|
||||||
for sql, comment in sql_comment:
|
for sql, comment in sql_comment:
|
||||||
self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment)
|
self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment)
|
||||||
|
|
||||||
|
def test_jinja(self):
|
||||||
|
tokenizer = Tokenizer()
|
||||||
|
|
||||||
|
tokens = tokenizer.tokenize(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
{{ x }},
|
||||||
|
{{- x -}},
|
||||||
|
{% for x in y -%}
|
||||||
|
a {{+ b }}
|
||||||
|
{% endfor %};
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = [(token.token_type, token.text) for token in tokens]
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
tokens,
|
||||||
|
[
|
||||||
|
(TokenType.SELECT, "SELECT"),
|
||||||
|
(TokenType.BLOCK_START, "{{"),
|
||||||
|
(TokenType.VAR, "x"),
|
||||||
|
(TokenType.BLOCK_END, "}}"),
|
||||||
|
(TokenType.COMMA, ","),
|
||||||
|
(TokenType.BLOCK_START, "{{-"),
|
||||||
|
(TokenType.VAR, "x"),
|
||||||
|
(TokenType.BLOCK_END, "-}}"),
|
||||||
|
(TokenType.COMMA, ","),
|
||||||
|
(TokenType.BLOCK_START, "{%"),
|
||||||
|
(TokenType.FOR, "for"),
|
||||||
|
(TokenType.VAR, "x"),
|
||||||
|
(TokenType.IN, "in"),
|
||||||
|
(TokenType.VAR, "y"),
|
||||||
|
(TokenType.BLOCK_END, "-%}"),
|
||||||
|
(TokenType.VAR, "a"),
|
||||||
|
(TokenType.BLOCK_START, "{{+"),
|
||||||
|
(TokenType.VAR, "b"),
|
||||||
|
(TokenType.BLOCK_END, "}}"),
|
||||||
|
(TokenType.BLOCK_START, "{%"),
|
||||||
|
(TokenType.VAR, "endfor"),
|
||||||
|
(TokenType.BLOCK_END, "%}"),
|
||||||
|
(TokenType.SEMICOLON, ";"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue