Adding upstream version 6.3.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
24cf9d8984
commit
291e0c125c
41 changed files with 1558 additions and 267 deletions
22
CHANGELOG.md
22
CHANGELOG.md
|
@ -1,6 +1,28 @@
|
|||
Changelog
|
||||
=========
|
||||
|
||||
v6.3.0
|
||||
------
|
||||
|
||||
Changes:
|
||||
|
||||
- New: Snowflake [table literals](https://docs.snowflake.com/en/sql-reference/literals-table.html)
|
||||
|
||||
- New: Anti and semi joins
|
||||
|
||||
- New: Vacuum as a command
|
||||
|
||||
- New: Stored procedures
|
||||
|
||||
- New: Reweriting derived tables as CTES
|
||||
|
||||
- Improvement: Various clickhouse improvements
|
||||
|
||||
- Improvement: Optimizer predicate pushdown
|
||||
|
||||
- Breaking: DATE\_DIFF default renamed to DATEDIFF
|
||||
|
||||
|
||||
v6.2.0
|
||||
------
|
||||
|
||||
|
|
|
@ -8,7 +8,9 @@ from sqlglot.expressions import (
|
|||
and_,
|
||||
column,
|
||||
condition,
|
||||
except_,
|
||||
from_,
|
||||
intersect,
|
||||
maybe_parse,
|
||||
not_,
|
||||
or_,
|
||||
|
@ -16,11 +18,12 @@ from sqlglot.expressions import (
|
|||
subquery,
|
||||
)
|
||||
from sqlglot.expressions import table_ as table
|
||||
from sqlglot.expressions import union
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
|
||||
__version__ = "6.2.8"
|
||||
__version__ = "6.3.1"
|
||||
|
||||
pretty = False
|
||||
|
||||
|
|
|
@ -135,6 +135,7 @@ class BigQuery(Dialect):
|
|||
exp.DateSub: _date_add_sql("DATE", "SUB"),
|
||||
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
|
||||
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
|
||||
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
|
||||
exp.TimeSub: _date_add_sql("TIME", "SUB"),
|
||||
|
@ -172,12 +173,11 @@ class BigQuery(Dialect):
|
|||
exp.AnonymousProperty,
|
||||
}
|
||||
|
||||
EXPLICIT_UNION = True
|
||||
|
||||
def in_unnest_op(self, unnest):
|
||||
return self.sql(unnest)
|
||||
|
||||
def union_op(self, expression):
|
||||
return f"UNION{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
|
||||
|
||||
def except_op(self, expression):
|
||||
if not expression.args.get("distinct", False):
|
||||
self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery")
|
||||
|
|
|
@ -1,10 +1,16 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import Dialect, inline_array_sql
|
||||
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.helper import csv
|
||||
from sqlglot.parser import Parser, parse_var_map
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
|
||||
|
||||
def _lower_func(sql):
|
||||
index = sql.index("(")
|
||||
return sql[:index].lower() + sql[index:]
|
||||
|
||||
|
||||
class ClickHouse(Dialect):
|
||||
normalize_functions = None
|
||||
null_ordering = "nulls_are_last"
|
||||
|
@ -14,17 +20,23 @@ class ClickHouse(Dialect):
|
|||
|
||||
KEYWORDS = {
|
||||
**Tokenizer.KEYWORDS,
|
||||
"NULLABLE": TokenType.NULLABLE,
|
||||
"FINAL": TokenType.FINAL,
|
||||
"DATETIME64": TokenType.DATETIME,
|
||||
"INT8": TokenType.TINYINT,
|
||||
"INT16": TokenType.SMALLINT,
|
||||
"INT32": TokenType.INT,
|
||||
"INT64": TokenType.BIGINT,
|
||||
"FLOAT32": TokenType.FLOAT,
|
||||
"FLOAT64": TokenType.DOUBLE,
|
||||
"TUPLE": TokenType.STRUCT,
|
||||
}
|
||||
|
||||
class Parser(Parser):
|
||||
FUNCTIONS = {
|
||||
**Parser.FUNCTIONS,
|
||||
"MAP": parse_var_map,
|
||||
}
|
||||
|
||||
def _parse_table(self, schema=False):
|
||||
this = super()._parse_table(schema)
|
||||
|
||||
|
@ -39,10 +51,25 @@ class ClickHouse(Dialect):
|
|||
TYPE_MAPPING = {
|
||||
**Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.NULLABLE: "Nullable",
|
||||
exp.DataType.Type.DATETIME: "DateTime64",
|
||||
exp.DataType.Type.MAP: "Map",
|
||||
exp.DataType.Type.ARRAY: "Array",
|
||||
exp.DataType.Type.STRUCT: "Tuple",
|
||||
exp.DataType.Type.TINYINT: "Int8",
|
||||
exp.DataType.Type.SMALLINT: "Int16",
|
||||
exp.DataType.Type.INT: "Int32",
|
||||
exp.DataType.Type.BIGINT: "Int64",
|
||||
exp.DataType.Type.FLOAT: "Float32",
|
||||
exp.DataType.Type.DOUBLE: "Float64",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Generator.TRANSFORMS,
|
||||
exp.Array: inline_array_sql,
|
||||
exp.StrPosition: lambda self, e: f"position({csv(self.sql(e, 'this'), self.sql(e, 'substr'), self.sql(e, 'position'))})",
|
||||
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
|
||||
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
}
|
||||
|
||||
EXPLICIT_UNION = True
|
||||
|
|
|
@ -77,7 +77,6 @@ class Dialect(metaclass=_Dialect):
|
|||
alias_post_tablesample = False
|
||||
normalize_functions = "upper"
|
||||
null_ordering = "nulls_are_small"
|
||||
wrap_derived_values = True
|
||||
|
||||
date_format = "'%Y-%m-%d'"
|
||||
dateint_format = "'%Y%m%d'"
|
||||
|
@ -170,7 +169,6 @@ class Dialect(metaclass=_Dialect):
|
|||
"alias_post_tablesample": self.alias_post_tablesample,
|
||||
"normalize_functions": self.normalize_functions,
|
||||
"null_ordering": self.null_ordering,
|
||||
"wrap_derived_values": self.wrap_derived_values,
|
||||
**opts,
|
||||
}
|
||||
)
|
||||
|
@ -271,6 +269,21 @@ def struct_extract_sql(self, expression):
|
|||
return f"{this}.{struct_key}"
|
||||
|
||||
|
||||
def var_map_sql(self, expression):
|
||||
keys = expression.args["keys"]
|
||||
values = expression.args["values"]
|
||||
|
||||
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
|
||||
self.unsupported("Cannot convert array columns into map.")
|
||||
return f"MAP({self.sql(keys)}, {self.sql(values)})"
|
||||
|
||||
args = []
|
||||
for key, value in zip(keys.expressions, values.expressions):
|
||||
args.append(self.sql(key))
|
||||
args.append(self.sql(value))
|
||||
return f"MAP({csv(*args)})"
|
||||
|
||||
|
||||
def format_time_lambda(exp_class, dialect, default=None):
|
||||
"""Helper used for time expressions.
|
||||
|
||||
|
|
|
@ -11,40 +11,14 @@ from sqlglot.dialects.dialect import (
|
|||
no_trycast_sql,
|
||||
rename_func,
|
||||
struct_extract_sql,
|
||||
var_map_sql,
|
||||
)
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import csv, list_get
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.parser import Parser, parse_var_map
|
||||
from sqlglot.tokens import Tokenizer
|
||||
|
||||
|
||||
def _parse_map(args):
|
||||
keys = []
|
||||
values = []
|
||||
for i in range(0, len(args), 2):
|
||||
keys.append(args[i])
|
||||
values.append(args[i + 1])
|
||||
return HiveMap(
|
||||
keys=exp.Array(expressions=keys),
|
||||
values=exp.Array(expressions=values),
|
||||
)
|
||||
|
||||
|
||||
def _map_sql(self, expression):
|
||||
keys = expression.args["keys"]
|
||||
values = expression.args["values"]
|
||||
|
||||
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
|
||||
self.unsupported("Cannot convert array columns into map use SparkSQL instead.")
|
||||
return f"MAP({self.sql(keys)}, {self.sql(values)})"
|
||||
|
||||
args = []
|
||||
for key, value in zip(keys.expressions, values.expressions):
|
||||
args.append(self.sql(key))
|
||||
args.append(self.sql(value))
|
||||
return f"MAP({csv(*args)})"
|
||||
|
||||
|
||||
def _array_sort(self, expression):
|
||||
if expression.expression:
|
||||
self.unsupported("Hive SORT_ARRAY does not support a comparator")
|
||||
|
@ -122,10 +96,6 @@ def _index_sql(self, expression):
|
|||
return f"{this} ON TABLE {table} {columns}"
|
||||
|
||||
|
||||
class HiveMap(exp.Map):
|
||||
is_var_len_args = True
|
||||
|
||||
|
||||
class Hive(Dialect):
|
||||
alias_post_tablesample = True
|
||||
|
||||
|
@ -206,7 +176,7 @@ class Hive(Dialect):
|
|||
position=list_get(args, 2),
|
||||
),
|
||||
"LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)),
|
||||
"MAP": _parse_map,
|
||||
"MAP": parse_var_map,
|
||||
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
|
||||
"PERCENTILE": exp.Quantile.from_arg_list,
|
||||
"PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list,
|
||||
|
@ -245,8 +215,8 @@ class Hive(Dialect):
|
|||
exp.Join: _unnest_to_explode_sql,
|
||||
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
|
||||
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
|
||||
exp.Map: _map_sql,
|
||||
HiveMap: _map_sql,
|
||||
exp.Map: var_map_sql,
|
||||
exp.VarMap: var_map_sql,
|
||||
exp.Create: create_with_partitions_sql,
|
||||
exp.Quantile: rename_func("PERCENTILE"),
|
||||
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
|
||||
|
|
|
@ -10,6 +10,32 @@ def _limit_sql(self, expression):
|
|||
|
||||
|
||||
class Oracle(Dialect):
|
||||
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
|
||||
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
|
||||
time_mapping = {
|
||||
"AM": "%p", # Meridian indicator with or without periods
|
||||
"A.M.": "%p", # Meridian indicator with or without periods
|
||||
"PM": "%p", # Meridian indicator with or without periods
|
||||
"P.M.": "%p", # Meridian indicator with or without periods
|
||||
"D": "%u", # Day of week (1-7)
|
||||
"DAY": "%A", # name of day
|
||||
"DD": "%d", # day of month (1-31)
|
||||
"DDD": "%j", # day of year (1-366)
|
||||
"DY": "%a", # abbreviated name of day
|
||||
"HH": "%I", # Hour of day (1-12)
|
||||
"HH12": "%I", # alias for HH
|
||||
"HH24": "%H", # Hour of day (0-23)
|
||||
"IW": "%V", # Calendar week of year (1-52 or 1-53), as defined by the ISO 8601 standard
|
||||
"MI": "%M", # Minute (0-59)
|
||||
"MM": "%m", # Month (01-12; January = 01)
|
||||
"MON": "%b", # Abbreviated name of month
|
||||
"MONTH": "%B", # Name of month
|
||||
"SS": "%S", # Second (0-59)
|
||||
"WW": "%W", # Week of year (1-53)
|
||||
"YY": "%y", # 15
|
||||
"YYYY": "%Y", # 2015
|
||||
}
|
||||
|
||||
class Generator(Generator):
|
||||
TYPE_MAPPING = {
|
||||
**Generator.TYPE_MAPPING,
|
||||
|
@ -30,6 +56,9 @@ class Oracle(Dialect):
|
|||
**transforms.UNALIAS_GROUP,
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.Limit: _limit_sql,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
|
||||
}
|
||||
|
||||
def query_modifiers(self, expression, *sqls):
|
||||
|
|
|
@ -118,13 +118,22 @@ def _serial_to_generated(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def _to_timestamp(args):
|
||||
# TO_TIMESTAMP accepts either a single double argument or (text, text)
|
||||
if len(args) == 1 and args[0].is_number:
|
||||
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE
|
||||
return exp.UnixToTime.from_arg_list(args)
|
||||
# https://www.postgresql.org/docs/current/functions-formatting.html
|
||||
return format_time_lambda(exp.StrToTime, "postgres")(args)
|
||||
|
||||
|
||||
class Postgres(Dialect):
|
||||
null_ordering = "nulls_are_large"
|
||||
time_format = "'YYYY-MM-DD HH24:MI:SS'"
|
||||
time_mapping = {
|
||||
"AM": "%p",
|
||||
"PM": "%p",
|
||||
"D": "%w", # 1-based day of week
|
||||
"D": "%u", # 1-based day of week
|
||||
"DD": "%d", # day of month
|
||||
"DDD": "%j", # zero padded day of year
|
||||
"FMDD": "%-d", # - is no leading zero for Python; same for FM in postgres
|
||||
|
@ -172,7 +181,7 @@ class Postgres(Dialect):
|
|||
|
||||
FUNCTIONS = {
|
||||
**Parser.FUNCTIONS,
|
||||
"TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"),
|
||||
"TO_TIMESTAMP": _to_timestamp,
|
||||
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
|
||||
}
|
||||
|
||||
|
@ -211,4 +220,5 @@ class Postgres(Dialect):
|
|||
exp.TableSample: no_tablesample_sql,
|
||||
exp.Trim: _trim_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
|
||||
}
|
||||
|
|
|
@ -121,6 +121,7 @@ class Snowflake(Dialect):
|
|||
FUNC_TOKENS = {
|
||||
*Parser.FUNC_TOKENS,
|
||||
TokenType.RLIKE,
|
||||
TokenType.TABLE,
|
||||
}
|
||||
|
||||
COLUMN_OPERATORS = {
|
||||
|
@ -143,7 +144,7 @@ class Snowflake(Dialect):
|
|||
|
||||
SINGLE_TOKENS = {
|
||||
**Tokenizer.SINGLE_TOKENS,
|
||||
"$": TokenType.DOLLAR, # needed to break for quotes
|
||||
"$": TokenType.PARAMETER,
|
||||
}
|
||||
|
||||
KEYWORDS = {
|
||||
|
@ -164,6 +165,8 @@ class Snowflake(Dialect):
|
|||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.UnixToTime: _unix_to_time,
|
||||
exp.Array: inline_array_sql,
|
||||
exp.StrPosition: rename_func("POSITION"),
|
||||
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
|
||||
}
|
||||
|
||||
|
|
|
@ -4,8 +4,9 @@ from sqlglot.dialects.dialect import (
|
|||
no_ilike_sql,
|
||||
rename_func,
|
||||
)
|
||||
from sqlglot.dialects.hive import Hive, HiveMap
|
||||
from sqlglot.dialects.hive import Hive
|
||||
from sqlglot.helper import list_get
|
||||
from sqlglot.parser import Parser
|
||||
|
||||
|
||||
def _create_sql(self, e):
|
||||
|
@ -47,8 +48,6 @@ def _unix_to_time(self, expression):
|
|||
|
||||
|
||||
class Spark(Hive):
|
||||
wrap_derived_values = False
|
||||
|
||||
class Parser(Hive.Parser):
|
||||
FUNCTIONS = {
|
||||
**Hive.Parser.FUNCTIONS,
|
||||
|
@ -78,8 +77,19 @@ class Spark(Hive):
|
|||
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||
}
|
||||
|
||||
class Generator(Hive.Generator):
|
||||
FUNCTION_PARSERS = {
|
||||
**Parser.FUNCTION_PARSERS,
|
||||
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
|
||||
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
|
||||
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
|
||||
"MERGE": lambda self: self._parse_join_hint("MERGE"),
|
||||
"SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
|
||||
"MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
|
||||
"SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
|
||||
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
|
||||
}
|
||||
|
||||
class Generator(Hive.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**Hive.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.TINYINT: "BYTE",
|
||||
|
@ -102,8 +112,9 @@ class Spark(Hive):
|
|||
exp.Map: _map_sql,
|
||||
exp.Reduce: rename_func("AGGREGATE"),
|
||||
exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
|
||||
HiveMap: _map_sql,
|
||||
}
|
||||
|
||||
WRAP_DERIVED_VALUES = False
|
||||
|
||||
class Tokenizer(Hive.Tokenizer):
|
||||
HEX_STRINGS = [("X'", "'")]
|
||||
|
|
|
@ -326,6 +326,7 @@ class Python(Dialect):
|
|||
exp.Alias: lambda self, e: self.sql(e.this),
|
||||
exp.Array: inline_array_sql,
|
||||
exp.And: lambda self, e: self.binary(e, "and"),
|
||||
exp.Boolean: lambda self, e: "True" if e.this else "False",
|
||||
exp.Cast: _cast_py,
|
||||
exp.Column: _column_py,
|
||||
exp.EQ: lambda self, e: self.binary(e, "=="),
|
||||
|
|
|
@ -508,7 +508,69 @@ class DerivedTable(Expression):
|
|||
return [select.alias_or_name for select in self.selects]
|
||||
|
||||
|
||||
class UDTF(DerivedTable):
|
||||
class Unionable:
|
||||
def union(self, expression, distinct=True, dialect=None, **opts):
|
||||
"""
|
||||
Builds a UNION expression.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> sqlglot.parse_one("SELECT * FROM foo").union("SELECT * FROM bla").sql()
|
||||
'SELECT * FROM foo UNION SELECT * FROM bla'
|
||||
|
||||
Args:
|
||||
expression (str or Expression): the SQL code string.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
distinct (bool): set the DISTINCT flag if and only if this is true.
|
||||
dialect (str): the dialect used to parse the input expression.
|
||||
opts (kwargs): other options to use to parse the input expressions.
|
||||
Returns:
|
||||
Union: the Union expression.
|
||||
"""
|
||||
return union(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
|
||||
|
||||
def intersect(self, expression, distinct=True, dialect=None, **opts):
|
||||
"""
|
||||
Builds an INTERSECT expression.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> sqlglot.parse_one("SELECT * FROM foo").intersect("SELECT * FROM bla").sql()
|
||||
'SELECT * FROM foo INTERSECT SELECT * FROM bla'
|
||||
|
||||
Args:
|
||||
expression (str or Expression): the SQL code string.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
distinct (bool): set the DISTINCT flag if and only if this is true.
|
||||
dialect (str): the dialect used to parse the input expression.
|
||||
opts (kwargs): other options to use to parse the input expressions.
|
||||
Returns:
|
||||
Intersect: the Intersect expression
|
||||
"""
|
||||
return intersect(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
|
||||
|
||||
def except_(self, expression, distinct=True, dialect=None, **opts):
|
||||
"""
|
||||
Builds an EXCEPT expression.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> sqlglot.parse_one("SELECT * FROM foo").except_("SELECT * FROM bla").sql()
|
||||
'SELECT * FROM foo EXCEPT SELECT * FROM bla'
|
||||
|
||||
Args:
|
||||
expression (str or Expression): the SQL code string.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
distinct (bool): set the DISTINCT flag if and only if this is true.
|
||||
dialect (str): the dialect used to parse the input expression.
|
||||
opts (kwargs): other options to use to parse the input expressions.
|
||||
Returns:
|
||||
Except: the Except expression
|
||||
"""
|
||||
return except_(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
|
||||
|
||||
|
||||
class UDTF(DerivedTable, Unionable):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -518,6 +580,10 @@ class Annotation(Expression):
|
|||
"expression": True,
|
||||
}
|
||||
|
||||
@property
|
||||
def alias(self):
|
||||
return self.expression.alias_or_name
|
||||
|
||||
|
||||
class Cache(Expression):
|
||||
arg_types = {
|
||||
|
@ -700,6 +766,10 @@ class Hint(Expression):
|
|||
arg_types = {"expressions": True}
|
||||
|
||||
|
||||
class JoinHint(Expression):
|
||||
arg_types = {"this": True, "expressions": True}
|
||||
|
||||
|
||||
class Identifier(Expression):
|
||||
arg_types = {"this": True, "quoted": False}
|
||||
|
||||
|
@ -971,7 +1041,7 @@ class Tuple(Expression):
|
|||
arg_types = {"expressions": False}
|
||||
|
||||
|
||||
class Subqueryable:
|
||||
class Subqueryable(Unionable):
|
||||
def subquery(self, alias=None, copy=True):
|
||||
"""
|
||||
Convert this expression to an aliased expression that can be used as a Subquery.
|
||||
|
@ -1654,7 +1724,7 @@ class Select(Subqueryable, Expression):
|
|||
return self.expressions
|
||||
|
||||
|
||||
class Subquery(DerivedTable):
|
||||
class Subquery(DerivedTable, Unionable):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"alias": False,
|
||||
|
@ -1731,7 +1801,7 @@ class Parameter(Expression):
|
|||
|
||||
|
||||
class Placeholder(Expression):
|
||||
arg_types = {}
|
||||
arg_types = {"this": False}
|
||||
|
||||
|
||||
class Null(Condition):
|
||||
|
@ -1791,6 +1861,8 @@ class DataType(Expression):
|
|||
IMAGE = auto()
|
||||
VARIANT = auto()
|
||||
OBJECT = auto()
|
||||
NULL = auto()
|
||||
UNKNOWN = auto() # Sentinel value, useful for type annotation
|
||||
|
||||
@classmethod
|
||||
def build(cls, dtype, **kwargs):
|
||||
|
@ -2007,7 +2079,7 @@ class Distinct(Expression):
|
|||
|
||||
|
||||
class In(Predicate):
|
||||
arg_types = {"this": True, "expressions": False, "query": False, "unnest": False}
|
||||
arg_types = {"this": True, "expressions": False, "query": False, "unnest": False, "field": False}
|
||||
|
||||
|
||||
class TimeUnit(Expression):
|
||||
|
@ -2377,6 +2449,11 @@ class Map(Func):
|
|||
arg_types = {"keys": True, "values": True}
|
||||
|
||||
|
||||
class VarMap(Func):
|
||||
arg_types = {"keys": True, "values": True}
|
||||
is_var_len_args = True
|
||||
|
||||
|
||||
class Max(AggFunc):
|
||||
pass
|
||||
|
||||
|
@ -2449,7 +2526,7 @@ class Substring(Func):
|
|||
|
||||
|
||||
class StrPosition(Func):
|
||||
arg_types = {"this": True, "substr": True, "position": False}
|
||||
arg_types = {"substr": True, "this": True, "position": False}
|
||||
|
||||
|
||||
class StrToDate(Func):
|
||||
|
@ -2785,6 +2862,81 @@ def _wrap_operator(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def union(left, right, distinct=True, dialect=None, **opts):
|
||||
"""
|
||||
Initializes a syntax tree from one UNION expression.
|
||||
|
||||
Example:
|
||||
>>> union("SELECT * FROM foo", "SELECT * FROM bla").sql()
|
||||
'SELECT * FROM foo UNION SELECT * FROM bla'
|
||||
|
||||
Args:
|
||||
left (str or Expression): the SQL code string corresponding to the left-hand side.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
right (str or Expression): the SQL code string corresponding to the right-hand side.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
distinct (bool): set the DISTINCT flag if and only if this is true.
|
||||
dialect (str): the dialect used to parse the input expression.
|
||||
opts (kwargs): other options to use to parse the input expressions.
|
||||
Returns:
|
||||
Union: the syntax tree for the UNION expression.
|
||||
"""
|
||||
left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
|
||||
right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
|
||||
|
||||
return Union(this=left, expression=right, distinct=distinct)
|
||||
|
||||
|
||||
def intersect(left, right, distinct=True, dialect=None, **opts):
|
||||
"""
|
||||
Initializes a syntax tree from one INTERSECT expression.
|
||||
|
||||
Example:
|
||||
>>> intersect("SELECT * FROM foo", "SELECT * FROM bla").sql()
|
||||
'SELECT * FROM foo INTERSECT SELECT * FROM bla'
|
||||
|
||||
Args:
|
||||
left (str or Expression): the SQL code string corresponding to the left-hand side.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
right (str or Expression): the SQL code string corresponding to the right-hand side.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
distinct (bool): set the DISTINCT flag if and only if this is true.
|
||||
dialect (str): the dialect used to parse the input expression.
|
||||
opts (kwargs): other options to use to parse the input expressions.
|
||||
Returns:
|
||||
Intersect: the syntax tree for the INTERSECT expression.
|
||||
"""
|
||||
left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
|
||||
right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
|
||||
|
||||
return Intersect(this=left, expression=right, distinct=distinct)
|
||||
|
||||
|
||||
def except_(left, right, distinct=True, dialect=None, **opts):
|
||||
"""
|
||||
Initializes a syntax tree from one EXCEPT expression.
|
||||
|
||||
Example:
|
||||
>>> except_("SELECT * FROM foo", "SELECT * FROM bla").sql()
|
||||
'SELECT * FROM foo EXCEPT SELECT * FROM bla'
|
||||
|
||||
Args:
|
||||
left (str or Expression): the SQL code string corresponding to the left-hand side.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
right (str or Expression): the SQL code string corresponding to the right-hand side.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
distinct (bool): set the DISTINCT flag if and only if this is true.
|
||||
dialect (str): the dialect used to parse the input expression.
|
||||
opts (kwargs): other options to use to parse the input expressions.
|
||||
Returns:
|
||||
Except: the syntax tree for the EXCEPT statement.
|
||||
"""
|
||||
left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
|
||||
right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
|
||||
|
||||
return Except(this=left, expression=right, distinct=distinct)
|
||||
|
||||
|
||||
def select(*expressions, dialect=None, **opts):
|
||||
"""
|
||||
Initializes a syntax tree from one or multiple SELECT expressions.
|
||||
|
@ -2991,7 +3143,7 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
|
|||
If an Expression instance is passed, this is used as-is.
|
||||
alias (str or Identifier): the alias name to use. If the name has
|
||||
special characters it is quoted.
|
||||
table (boolean): create a table alias, default false
|
||||
table (bool): create a table alias, default false
|
||||
dialect (str): the dialect used to parse the input expression.
|
||||
**opts: other options to use to parse the input expressions.
|
||||
|
||||
|
@ -3002,7 +3154,7 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
|
|||
alias = to_identifier(alias, quoted=quoted)
|
||||
alias = TableAlias(this=alias) if table else alias
|
||||
|
||||
if "alias" in exp.arg_types:
|
||||
if "alias" in exp.arg_types and not isinstance(exp, Window):
|
||||
exp = exp.copy()
|
||||
exp.set("alias", alias)
|
||||
return exp
|
||||
|
@ -3138,6 +3290,60 @@ def column_table_names(expression):
|
|||
return list(dict.fromkeys(column.table for column in expression.find_all(Column)))
|
||||
|
||||
|
||||
def table_name(table):
|
||||
"""Get the full name of a table as a string.
|
||||
|
||||
Args:
|
||||
table (exp.Table | str): Table expression node or string.
|
||||
|
||||
Examples:
|
||||
>>> from sqlglot import exp, parse_one
|
||||
>>> table_name(parse_one("select * from a.b.c").find(exp.Table))
|
||||
'a.b.c'
|
||||
|
||||
Returns:
|
||||
str: the table name
|
||||
"""
|
||||
|
||||
table = maybe_parse(table, into=Table)
|
||||
|
||||
return ".".join(
|
||||
part
|
||||
for part in (
|
||||
table.text("catalog"),
|
||||
table.text("db"),
|
||||
table.name,
|
||||
)
|
||||
if part
|
||||
)
|
||||
|
||||
|
||||
def replace_tables(expression, mapping):
|
||||
"""Replace all tables in expression according to the mapping.
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): Expression node to be transformed and replaced
|
||||
mapping (Dict[str, str]): Mapping of table names
|
||||
|
||||
Examples:
|
||||
>>> from sqlglot import exp, parse_one
|
||||
>>> replace_tables(parse_one("select * from a.b"), {"a.b": "c"}).sql()
|
||||
'SELECT * FROM "c"'
|
||||
|
||||
Returns:
|
||||
The mapped expression
|
||||
"""
|
||||
|
||||
def _replace_tables(node):
|
||||
if isinstance(node, Table):
|
||||
new_name = mapping.get(table_name(node))
|
||||
if new_name:
|
||||
return table_(*reversed(new_name.split(".")), quoted=True)
|
||||
return node
|
||||
|
||||
return expression.transform(_replace_tables)
|
||||
|
||||
|
||||
TRUE = Boolean(this=True)
|
||||
FALSE = Boolean(this=False)
|
||||
NULL = Null()
|
||||
|
|
|
@ -48,8 +48,9 @@ class Generator:
|
|||
TRANSFORMS = {
|
||||
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
|
||||
exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
|
||||
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.DateDiff: lambda self, e: f"DATEDIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
|
||||
exp.VarMap: lambda self, e: f"MAP({self.sql(e.args['keys'])}, {self.sql(e.args['values'])})",
|
||||
exp.LanguageProperty: lambda self, e: self.naked_property(e),
|
||||
exp.LocationProperty: lambda self, e: self.naked_property(e),
|
||||
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
|
||||
|
@ -57,7 +58,12 @@ class Generator:
|
|||
exp.VolatilityProperty: lambda self, e: self.sql(e.name),
|
||||
}
|
||||
|
||||
# whether or not null ordering is supported in order by
|
||||
NULL_ORDERING_SUPPORTED = True
|
||||
# always do union distinct or union all
|
||||
EXPLICIT_UNION = False
|
||||
# wrap derived values in parens, usually standard but spark doesn't support it
|
||||
WRAP_DERIVED_VALUES = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
exp.DataType.Type.NCHAR: "CHAR",
|
||||
|
@ -101,7 +107,6 @@ class Generator:
|
|||
"unsupported_messages",
|
||||
"null_ordering",
|
||||
"max_unsupported",
|
||||
"wrap_derived_values",
|
||||
"_indent",
|
||||
"_replace_backslash",
|
||||
"_escaped_quote_end",
|
||||
|
@ -130,7 +135,6 @@ class Generator:
|
|||
null_ordering=None,
|
||||
max_unsupported=3,
|
||||
leading_comma=False,
|
||||
wrap_derived_values=True,
|
||||
):
|
||||
import sqlglot
|
||||
|
||||
|
@ -154,7 +158,6 @@ class Generator:
|
|||
self.unsupported_messages = []
|
||||
self.max_unsupported = max_unsupported
|
||||
self.null_ordering = null_ordering
|
||||
self.wrap_derived_values = wrap_derived_values
|
||||
self._indent = indent
|
||||
self._replace_backslash = self.escape == "\\"
|
||||
self._escaped_quote_end = self.escape + self.quote_end
|
||||
|
@ -595,7 +598,7 @@ class Generator:
|
|||
if not alias:
|
||||
return f"VALUES{self.seg('')}{args}"
|
||||
alias = f" AS {alias}" if alias else alias
|
||||
if self.wrap_derived_values:
|
||||
if self.WRAP_DERIVED_VALUES:
|
||||
return f"(VALUES{self.seg('')}{args}){alias}"
|
||||
return f"VALUES{self.seg('')}{args}{alias}"
|
||||
|
||||
|
@ -779,8 +782,8 @@ class Generator:
|
|||
def parameter_sql(self, expression):
|
||||
return f"@{self.sql(expression, 'this')}"
|
||||
|
||||
def placeholder_sql(self, *_):
|
||||
return "?"
|
||||
def placeholder_sql(self, expression):
|
||||
return f":{expression.name}" if expression.name else "?"
|
||||
|
||||
def subquery_sql(self, expression):
|
||||
alias = self.sql(expression, "alias")
|
||||
|
@ -803,7 +806,9 @@ class Generator:
|
|||
)
|
||||
|
||||
def union_op(self, expression):
|
||||
return f"UNION{'' if expression.args.get('distinct') else ' ALL'}"
|
||||
kind = " DISTINCT" if self.EXPLICIT_UNION else ""
|
||||
kind = kind if expression.args.get("distinct") else " ALL"
|
||||
return f"UNION{kind}"
|
||||
|
||||
def unnest_sql(self, expression):
|
||||
args = self.expressions(expression, flat=True)
|
||||
|
@ -940,10 +945,13 @@ class Generator:
|
|||
def in_sql(self, expression):
|
||||
query = expression.args.get("query")
|
||||
unnest = expression.args.get("unnest")
|
||||
field = expression.args.get("field")
|
||||
if query:
|
||||
in_sql = self.wrap(query)
|
||||
elif unnest:
|
||||
in_sql = self.in_unnest_op(unnest)
|
||||
elif field:
|
||||
in_sql = self.sql(field)
|
||||
else:
|
||||
in_sql = f"({self.expressions(expression, flat=True)})"
|
||||
return f"{self.sql(expression, 'this')} IN {in_sql}"
|
||||
|
@ -1178,3 +1186,8 @@ class Generator:
|
|||
this = self.sql(expression, "this")
|
||||
kind = self.sql(expression, "kind")
|
||||
return f"{this} {kind}"
|
||||
|
||||
def joinhint_sql(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
expressions = self.expressions(expression, flat=True)
|
||||
return f"{this}({expressions})"
|
||||
|
|
|
@ -1,16 +1,20 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.helper import ensure_list, subclasses
|
||||
from sqlglot.optimizer.schema import ensure_schema
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
|
||||
|
||||
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
|
||||
"""
|
||||
Recursively infer & annotate types in an expression syntax tree against a schema.
|
||||
Assumes that we've already executed the optimizer's qualify_columns step.
|
||||
|
||||
(TODO -- replace this with a better example after adding some functionality)
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> annotated_expression = annotate_types(sqlglot.parse_one('5 + 5.3'))
|
||||
>>> annotated_expression.type
|
||||
>>> schema = {"y": {"cola": "SMALLINT"}}
|
||||
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
|
||||
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
|
||||
>>> annotated_expr.expressions[0].type # Get the type of "x.cola + 2.5 AS cola"
|
||||
<Type.DOUBLE: 'DOUBLE'>
|
||||
|
||||
Args:
|
||||
|
@ -22,6 +26,8 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
|
|||
sqlglot.Expression: expression annotated with types
|
||||
"""
|
||||
|
||||
schema = ensure_schema(schema)
|
||||
|
||||
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
|
||||
|
||||
|
||||
|
@ -35,10 +41,81 @@ class TypeAnnotator:
|
|||
expr_type: lambda self, expr: self._annotate_binary(expr)
|
||||
for expr_type in subclasses(exp.__name__, exp.Binary)
|
||||
},
|
||||
exp.Cast: lambda self, expr: self._annotate_cast(expr),
|
||||
exp.DataType: lambda self, expr: self._annotate_data_type(expr),
|
||||
exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"].this),
|
||||
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.this),
|
||||
exp.Alias: lambda self, expr: self._annotate_unary(expr),
|
||||
exp.Literal: lambda self, expr: self._annotate_literal(expr),
|
||||
exp.Boolean: lambda self, expr: self._annotate_boolean(expr),
|
||||
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
|
||||
exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
|
||||
exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
|
||||
exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
|
||||
exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
|
||||
exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
}
|
||||
|
||||
# Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
|
||||
|
@ -97,43 +174,82 @@ class TypeAnnotator:
|
|||
},
|
||||
}
|
||||
|
||||
TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
|
||||
|
||||
def __init__(self, schema=None, annotators=None, coerces_to=None):
|
||||
self.schema = schema
|
||||
self.annotators = annotators or self.ANNOTATORS
|
||||
self.coerces_to = coerces_to or self.COERCES_TO
|
||||
|
||||
def annotate(self, expression):
|
||||
if isinstance(expression, self.TRAVERSABLES):
|
||||
for scope in traverse_scope(expression):
|
||||
subscope_selects = {
|
||||
name: {select.alias_or_name: select for select in source.selects}
|
||||
for name, source in scope.sources.items()
|
||||
if isinstance(source, Scope)
|
||||
}
|
||||
|
||||
# First annotate the current scope's column references
|
||||
for col in scope.columns:
|
||||
source = scope.sources[col.table]
|
||||
if isinstance(source, exp.Table):
|
||||
col.type = self.schema.get_column_type(source, col)
|
||||
else:
|
||||
col.type = subscope_selects[col.table][col.name].type
|
||||
|
||||
# Then (possibly) annotate the remaining expressions in the scope
|
||||
self._maybe_annotate(scope.expression)
|
||||
|
||||
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
|
||||
|
||||
def _maybe_annotate(self, expression):
|
||||
if not isinstance(expression, exp.Expression):
|
||||
return None
|
||||
|
||||
if expression.type:
|
||||
return expression # We've already inferred the expression's type
|
||||
|
||||
annotator = self.annotators.get(expression.__class__)
|
||||
return annotator(self, expression) if annotator else self._annotate_args(expression)
|
||||
return (
|
||||
annotator(self, expression)
|
||||
if annotator
|
||||
else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
|
||||
)
|
||||
|
||||
def _annotate_args(self, expression):
|
||||
for value in expression.args.values():
|
||||
for v in ensure_list(value):
|
||||
self.annotate(v)
|
||||
self._maybe_annotate(v)
|
||||
|
||||
return expression
|
||||
|
||||
def _annotate_cast(self, expression):
|
||||
expression.type = expression.args["to"].this
|
||||
return self._annotate_args(expression)
|
||||
|
||||
def _annotate_data_type(self, expression):
|
||||
expression.type = expression.this
|
||||
return self._annotate_args(expression)
|
||||
|
||||
def _maybe_coerce(self, type1, type2):
|
||||
# We propagate the NULL / UNKNOWN types upwards if found
|
||||
if exp.DataType.Type.NULL in (type1, type2):
|
||||
return exp.DataType.Type.NULL
|
||||
if exp.DataType.Type.UNKNOWN in (type1, type2):
|
||||
return exp.DataType.Type.UNKNOWN
|
||||
|
||||
return type2 if type2 in self.coerces_to[type1] else type1
|
||||
|
||||
def _annotate_binary(self, expression):
|
||||
self._annotate_args(expression)
|
||||
|
||||
if isinstance(expression, (exp.Condition, exp.Predicate)):
|
||||
left_type = expression.left.type
|
||||
right_type = expression.right.type
|
||||
|
||||
if isinstance(expression, (exp.And, exp.Or)):
|
||||
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
|
||||
expression.type = exp.DataType.Type.NULL
|
||||
elif exp.DataType.Type.NULL in (left_type, right_type):
|
||||
expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
|
||||
else:
|
||||
expression.type = exp.DataType.Type.BOOLEAN
|
||||
elif isinstance(expression, (exp.Condition, exp.Predicate)):
|
||||
expression.type = exp.DataType.Type.BOOLEAN
|
||||
else:
|
||||
expression.type = self._maybe_coerce(expression.left.type, expression.right.type)
|
||||
expression.type = self._maybe_coerce(left_type, right_type)
|
||||
|
||||
return expression
|
||||
|
||||
|
@ -157,6 +273,6 @@ class TypeAnnotator:
|
|||
|
||||
return expression
|
||||
|
||||
def _annotate_boolean(self, expression):
|
||||
expression.type = exp.DataType.Type.BOOLEAN
|
||||
return expression
|
||||
def _annotate_with_type(self, expression, target_type):
|
||||
expression.type = target_type
|
||||
return self._annotate_args(expression)
|
||||
|
|
|
@ -44,6 +44,7 @@ UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
|
|||
"joins",
|
||||
"where",
|
||||
"order",
|
||||
"hint",
|
||||
}
|
||||
|
||||
|
||||
|
@ -67,21 +68,22 @@ def merge_ctes(expression, leave_tables_isolated=False):
|
|||
singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
|
||||
for outer_scope, inner_scope, table in singular_cte_selections:
|
||||
inner_select = inner_scope.expression.unnest()
|
||||
if _mergeable(outer_scope, inner_select, leave_tables_isolated):
|
||||
from_or_join = table.find_ancestor(exp.From, exp.Join)
|
||||
|
||||
from_or_join = table.find_ancestor(exp.From, exp.Join)
|
||||
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
||||
node_to_replace = table
|
||||
if isinstance(node_to_replace.parent, exp.Alias):
|
||||
node_to_replace = node_to_replace.parent
|
||||
alias = node_to_replace.alias
|
||||
else:
|
||||
alias = table.name
|
||||
|
||||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||
_merge_from(outer_scope, inner_scope, node_to_replace, alias)
|
||||
_merge_expressions(outer_scope, inner_scope, alias)
|
||||
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||
_merge_where(outer_scope, inner_scope, from_or_join)
|
||||
_merge_order(outer_scope, inner_scope)
|
||||
_merge_hints(outer_scope, inner_scope)
|
||||
_pop_cte(inner_scope)
|
||||
return expression
|
||||
|
||||
|
@ -90,9 +92,9 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
|
|||
for outer_scope in traverse_scope(expression):
|
||||
for subquery in outer_scope.derived_tables:
|
||||
inner_select = subquery.unnest()
|
||||
if _mergeable(outer_scope, inner_select, leave_tables_isolated):
|
||||
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
|
||||
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
||||
alias = subquery.alias_or_name
|
||||
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
|
||||
inner_scope = outer_scope.sources[alias]
|
||||
|
||||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||
|
@ -101,10 +103,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
|
|||
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||
_merge_where(outer_scope, inner_scope, from_or_join)
|
||||
_merge_order(outer_scope, inner_scope)
|
||||
_merge_hints(outer_scope, inner_scope)
|
||||
return expression
|
||||
|
||||
|
||||
def _mergeable(outer_scope, inner_select, leave_tables_isolated):
|
||||
def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
||||
"""
|
||||
Return True if `inner_select` can be merged into outer query.
|
||||
|
||||
|
@ -112,6 +115,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated):
|
|||
outer_scope (Scope)
|
||||
inner_select (exp.Select)
|
||||
leave_tables_isolated (bool)
|
||||
from_or_join (exp.From|exp.Join)
|
||||
Returns:
|
||||
bool: True if can be merged
|
||||
"""
|
||||
|
@ -123,6 +127,16 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated):
|
|||
and inner_select.args.get("from")
|
||||
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
|
||||
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
|
||||
and not (
|
||||
isinstance(from_or_join, exp.Join)
|
||||
and inner_select.args.get("where")
|
||||
and from_or_join.side in {"FULL", "LEFT", "RIGHT"}
|
||||
)
|
||||
and not (
|
||||
isinstance(from_or_join, exp.From)
|
||||
and inner_select.args.get("where")
|
||||
and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
@ -170,6 +184,12 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
|
|||
"""
|
||||
new_subquery = inner_scope.expression.args.get("from").expressions[0]
|
||||
node_to_replace.replace(new_subquery)
|
||||
for join_hint in outer_scope.join_hints:
|
||||
tables = join_hint.find_all(exp.Table)
|
||||
for table in tables:
|
||||
if table.alias_or_name == node_to_replace.alias_or_name:
|
||||
new_table = new_subquery.this if isinstance(new_subquery, exp.Alias) else new_subquery
|
||||
table.set("this", exp.to_identifier(new_table.alias_or_name))
|
||||
outer_scope.remove_source(alias)
|
||||
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
|
||||
|
||||
|
@ -273,6 +293,18 @@ def _merge_order(outer_scope, inner_scope):
|
|||
outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
|
||||
|
||||
|
||||
def _merge_hints(outer_scope, inner_scope):
|
||||
inner_scope_hint = inner_scope.expression.args.get("hint")
|
||||
if not inner_scope_hint:
|
||||
return
|
||||
outer_scope_hint = outer_scope.expression.args.get("hint")
|
||||
if outer_scope_hint:
|
||||
for hint_expression in inner_scope_hint.expressions:
|
||||
outer_scope_hint.append("expressions", hint_expression)
|
||||
else:
|
||||
outer_scope.expression.set("hint", inner_scope_hint)
|
||||
|
||||
|
||||
def _pop_cte(inner_scope):
|
||||
"""
|
||||
Remove CTE from the AST.
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from collections import defaultdict
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.optimizer.normalize import normalized
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
|
@ -20,22 +22,30 @@ def pushdown_predicates(expression):
|
|||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
for scope in reversed(traverse_scope(expression)):
|
||||
scope_ref_count = defaultdict(lambda: 0)
|
||||
scopes = traverse_scope(expression)
|
||||
scopes.reverse()
|
||||
|
||||
for scope in scopes:
|
||||
for _, source in scope.selected_sources.values():
|
||||
scope_ref_count[id(source)] += 1
|
||||
|
||||
for scope in scopes:
|
||||
select = scope.expression
|
||||
where = select.args.get("where")
|
||||
if where:
|
||||
pushdown(where.this, scope.selected_sources)
|
||||
pushdown(where.this, scope.selected_sources, scope_ref_count)
|
||||
|
||||
# joins should only pushdown into itself, not to other joins
|
||||
# so we limit the selected sources to only itself
|
||||
for join in select.args.get("joins") or []:
|
||||
name = join.this.alias_or_name
|
||||
pushdown(join.args.get("on"), {name: scope.selected_sources[name]})
|
||||
pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def pushdown(condition, sources):
|
||||
def pushdown(condition, sources, scope_ref_count):
|
||||
if not condition:
|
||||
return
|
||||
|
||||
|
@ -45,17 +55,17 @@ def pushdown(condition, sources):
|
|||
predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition])
|
||||
|
||||
if cnf_like:
|
||||
pushdown_cnf(predicates, sources)
|
||||
pushdown_cnf(predicates, sources, scope_ref_count)
|
||||
else:
|
||||
pushdown_dnf(predicates, sources)
|
||||
pushdown_dnf(predicates, sources, scope_ref_count)
|
||||
|
||||
|
||||
def pushdown_cnf(predicates, scope):
|
||||
def pushdown_cnf(predicates, scope, scope_ref_count):
|
||||
"""
|
||||
If the predicates are in CNF like form, we can simply replace each block in the parent.
|
||||
"""
|
||||
for predicate in predicates:
|
||||
for node in nodes_for_predicate(predicate, scope).values():
|
||||
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
|
||||
if isinstance(node, exp.Join):
|
||||
predicate.replace(exp.TRUE)
|
||||
node.on(predicate, copy=False)
|
||||
|
@ -65,7 +75,7 @@ def pushdown_cnf(predicates, scope):
|
|||
node.where(replace_aliases(node, predicate), copy=False)
|
||||
|
||||
|
||||
def pushdown_dnf(predicates, scope):
|
||||
def pushdown_dnf(predicates, scope, scope_ref_count):
|
||||
"""
|
||||
If the predicates are in DNF form, we can only push down conditions that are in all blocks.
|
||||
Additionally, we can't remove predicates from their original form.
|
||||
|
@ -91,7 +101,7 @@ def pushdown_dnf(predicates, scope):
|
|||
# (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
|
||||
for table in sorted(pushdown_tables):
|
||||
for predicate in predicates:
|
||||
nodes = nodes_for_predicate(predicate, scope)
|
||||
nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
|
||||
|
||||
if table not in nodes:
|
||||
continue
|
||||
|
@ -120,7 +130,7 @@ def pushdown_dnf(predicates, scope):
|
|||
node.where(replace_aliases(node, predicate), copy=False)
|
||||
|
||||
|
||||
def nodes_for_predicate(predicate, sources):
|
||||
def nodes_for_predicate(predicate, sources, scope_ref_count):
|
||||
nodes = {}
|
||||
tables = exp.column_table_names(predicate)
|
||||
where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
|
||||
|
@ -133,7 +143,7 @@ def nodes_for_predicate(predicate, sources):
|
|||
if node and where_condition:
|
||||
node = node.find_ancestor(exp.Join, exp.From)
|
||||
|
||||
# a node can reference a CTE which should be push down
|
||||
# a node can reference a CTE which should be pushed down
|
||||
if isinstance(node, exp.From) and not isinstance(source, exp.Table):
|
||||
node = source.expression
|
||||
|
||||
|
@ -142,7 +152,9 @@ def nodes_for_predicate(predicate, sources):
|
|||
return {}
|
||||
nodes[table] = node
|
||||
elif isinstance(node, exp.Select) and len(tables) == 1:
|
||||
if not node.args.get("group"):
|
||||
# we can't push down predicates to select statements if they are referenced in
|
||||
# multiple places.
|
||||
if not node.args.get("group") and scope_ref_count[id(source)] < 2:
|
||||
nodes[table] = node
|
||||
return nodes
|
||||
|
||||
|
|
|
@ -31,8 +31,8 @@ def qualify_columns(expression, schema):
|
|||
_pop_table_column_aliases(scope.derived_tables)
|
||||
_expand_using(scope, resolver)
|
||||
_expand_group_by(scope, resolver)
|
||||
_expand_order_by(scope)
|
||||
_qualify_columns(scope, resolver)
|
||||
_expand_order_by(scope)
|
||||
if not isinstance(scope.expression, exp.UDTF):
|
||||
_expand_stars(scope, resolver)
|
||||
_qualify_outputs(scope)
|
||||
|
@ -235,7 +235,7 @@ def _expand_stars(scope, resolver):
|
|||
for table in tables:
|
||||
if table not in scope.sources:
|
||||
raise OptimizeError(f"Unknown table: {table}")
|
||||
columns = resolver.get_source_columns(table)
|
||||
columns = resolver.get_source_columns(table, only_visible=True)
|
||||
table_id = id(table)
|
||||
for name in columns:
|
||||
if name not in except_columns.get(table_id, set()):
|
||||
|
@ -332,7 +332,7 @@ class _Resolver:
|
|||
self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns)
|
||||
return self._all_columns
|
||||
|
||||
def get_source_columns(self, name):
|
||||
def get_source_columns(self, name, only_visible=False):
|
||||
"""Resolve the source columns for a given source `name`"""
|
||||
if name not in self.scope.sources:
|
||||
raise OptimizeError(f"Unknown table: {name}")
|
||||
|
@ -342,7 +342,7 @@ class _Resolver:
|
|||
# If referencing a table, return the columns from the schema
|
||||
if isinstance(source, exp.Table):
|
||||
try:
|
||||
return self.schema.column_names(source)
|
||||
return self.schema.column_names(source, only_visible)
|
||||
except Exception as e:
|
||||
raise OptimizeError(str(e)) from e
|
||||
|
||||
|
|
|
@ -9,16 +9,28 @@ class Schema(abc.ABC):
|
|||
"""Abstract base class for database schemas"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def column_names(self, table):
|
||||
def column_names(self, table, only_visible=False):
|
||||
"""
|
||||
Get the column names for a table.
|
||||
|
||||
Args:
|
||||
table (sqlglot.expressions.Table): Table expression instance
|
||||
only_visible (bool): Whether to include invisible columns
|
||||
Returns:
|
||||
list[str]: list of column names
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_column_type(self, table, column):
|
||||
"""
|
||||
Get the exp.DataType type of a column in the schema.
|
||||
|
||||
Args:
|
||||
table (sqlglot.expressions.Table): The source table.
|
||||
column (sqlglot.expressions.Column): The target column.
|
||||
Returns:
|
||||
sqlglot.expressions.DataType.Type: The resulting column type.
|
||||
"""
|
||||
|
||||
|
||||
class MappingSchema(Schema):
|
||||
"""
|
||||
|
@ -29,10 +41,19 @@ class MappingSchema(Schema):
|
|||
1. {table: {col: type}}
|
||||
2. {db: {table: {col: type}}}
|
||||
3. {catalog: {db: {table: {col: type}}}}
|
||||
visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
|
||||
are assumed to be visible. The nesting should mirror that of the schema:
|
||||
1. {table: set(*cols)}}
|
||||
2. {db: {table: set(*cols)}}}
|
||||
3. {catalog: {db: {table: set(*cols)}}}}
|
||||
dialect (str): The dialect to be used for custom type mappings.
|
||||
"""
|
||||
|
||||
def __init__(self, schema):
|
||||
def __init__(self, schema, visible=None, dialect=None):
|
||||
self.schema = schema
|
||||
self.visible = visible
|
||||
self.dialect = dialect
|
||||
self._type_mapping_cache = {}
|
||||
|
||||
depth = _dict_depth(schema)
|
||||
|
||||
|
@ -49,7 +70,7 @@ class MappingSchema(Schema):
|
|||
|
||||
self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args)
|
||||
|
||||
def column_names(self, table):
|
||||
def column_names(self, table, only_visible=False):
|
||||
if not isinstance(table.this, exp.Identifier):
|
||||
return fs_get(table)
|
||||
|
||||
|
@ -58,7 +79,39 @@ class MappingSchema(Schema):
|
|||
for forbidden in self.forbidden_args:
|
||||
if table.text(forbidden):
|
||||
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
|
||||
return list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
|
||||
|
||||
columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
|
||||
if not only_visible or not self.visible:
|
||||
return columns
|
||||
|
||||
visible = _nested_get(self.visible, *zip(self.supported_table_args, args))
|
||||
return [col for col in columns if col in visible]
|
||||
|
||||
def get_column_type(self, table, column):
|
||||
try:
|
||||
schema_type = self.schema.get(table.name, {}).get(column.name).upper()
|
||||
return self._convert_type(schema_type)
|
||||
except:
|
||||
raise OptimizeError(f"Failed to get type for column {column.sql()}")
|
||||
|
||||
def _convert_type(self, schema_type):
|
||||
"""
|
||||
Convert a type represented as a string to the corresponding exp.DataType.Type object.
|
||||
|
||||
Args:
|
||||
schema_type (str): The type we want to convert.
|
||||
Returns:
|
||||
sqlglot.expressions.DataType.Type: The resulting expression type.
|
||||
"""
|
||||
if schema_type not in self._type_mapping_cache:
|
||||
try:
|
||||
self._type_mapping_cache[schema_type] = exp.maybe_parse(
|
||||
schema_type, into=exp.DataType, dialect=self.dialect
|
||||
).this
|
||||
except AttributeError:
|
||||
raise OptimizeError(f"Failed to convert type {schema_type}")
|
||||
|
||||
return self._type_mapping_cache[schema_type]
|
||||
|
||||
|
||||
def ensure_schema(schema):
|
||||
|
|
|
@ -68,6 +68,7 @@ class Scope:
|
|||
self._selected_sources = None
|
||||
self._columns = None
|
||||
self._external_columns = None
|
||||
self._join_hints = None
|
||||
|
||||
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
|
||||
"""Branch from the current scope to a new, inner scope"""
|
||||
|
@ -85,14 +86,17 @@ class Scope:
|
|||
self._subqueries = []
|
||||
self._derived_tables = []
|
||||
self._raw_columns = []
|
||||
self._join_hints = []
|
||||
|
||||
for node, parent, _ in self.walk(bfs=False):
|
||||
if node is self.expression:
|
||||
continue
|
||||
elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
|
||||
self._raw_columns.append(node)
|
||||
elif isinstance(node, exp.Table):
|
||||
elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
|
||||
self._tables.append(node)
|
||||
elif isinstance(node, exp.JoinHint):
|
||||
self._join_hints.append(node)
|
||||
elif isinstance(node, exp.UDTF):
|
||||
self._derived_tables.append(node)
|
||||
elif isinstance(node, exp.CTE):
|
||||
|
@ -246,7 +250,7 @@ class Scope:
|
|||
table only becomes a selected source if it's included in a FROM or JOIN clause.
|
||||
|
||||
Returns:
|
||||
dict[str, (exp.Table|exp.Subquery, exp.Table|Scope)]: selected sources and nodes
|
||||
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
|
||||
"""
|
||||
if self._selected_sources is None:
|
||||
referenced_names = []
|
||||
|
@ -310,6 +314,18 @@ class Scope:
|
|||
self._external_columns = [c for c in self.columns if c.table not in self.selected_sources]
|
||||
return self._external_columns
|
||||
|
||||
@property
|
||||
def join_hints(self):
|
||||
"""
|
||||
Hints that exist in the scope that reference tables
|
||||
|
||||
Returns:
|
||||
list[exp.JoinHint]: Join hints that are referenced within the scope
|
||||
"""
|
||||
if self._join_hints is None:
|
||||
return []
|
||||
return self._join_hints
|
||||
|
||||
def source_columns(self, source_name):
|
||||
"""
|
||||
Get all columns in the current scope for a particular source.
|
||||
|
|
|
@ -56,12 +56,16 @@ def simplify_not(expression):
|
|||
NOT (x AND y) -> NOT x OR NOT y
|
||||
"""
|
||||
if isinstance(expression, exp.Not):
|
||||
if isinstance(expression.this, exp.Null):
|
||||
return NULL
|
||||
if isinstance(expression.this, exp.Paren):
|
||||
condition = expression.this.unnest()
|
||||
if isinstance(condition, exp.And):
|
||||
return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
|
||||
if isinstance(condition, exp.Or):
|
||||
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
|
||||
if isinstance(condition, exp.Null):
|
||||
return NULL
|
||||
if always_true(expression.this):
|
||||
return FALSE
|
||||
if expression.this == FALSE:
|
||||
|
@ -95,10 +99,10 @@ def simplify_connectors(expression):
|
|||
return left
|
||||
|
||||
if isinstance(expression, exp.And):
|
||||
if NULL in (left, right):
|
||||
return NULL
|
||||
if FALSE in (left, right):
|
||||
return FALSE
|
||||
if NULL in (left, right):
|
||||
return NULL
|
||||
if always_true(left) and always_true(right):
|
||||
return TRUE
|
||||
if always_true(left):
|
||||
|
|
|
@ -8,6 +8,18 @@ from sqlglot.tokens import Token, Tokenizer, TokenType
|
|||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
||||
def parse_var_map(args):
|
||||
keys = []
|
||||
values = []
|
||||
for i in range(0, len(args), 2):
|
||||
keys.append(args[i])
|
||||
values.append(args[i + 1])
|
||||
return exp.VarMap(
|
||||
keys=exp.Array(expressions=keys),
|
||||
values=exp.Array(expressions=values),
|
||||
)
|
||||
|
||||
|
||||
class Parser:
|
||||
"""
|
||||
Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer`
|
||||
|
@ -48,6 +60,7 @@ class Parser:
|
|||
start=exp.Literal.number(1),
|
||||
length=exp.Literal.number(10),
|
||||
),
|
||||
"VAR_MAP": parse_var_map,
|
||||
}
|
||||
|
||||
NO_PAREN_FUNCTIONS = {
|
||||
|
@ -117,6 +130,7 @@ class Parser:
|
|||
TokenType.VAR,
|
||||
TokenType.ALTER,
|
||||
TokenType.ALWAYS,
|
||||
TokenType.ANTI,
|
||||
TokenType.BEGIN,
|
||||
TokenType.BOTH,
|
||||
TokenType.BUCKET,
|
||||
|
@ -164,6 +178,7 @@ class Parser:
|
|||
TokenType.ROWS,
|
||||
TokenType.SCHEMA_COMMENT,
|
||||
TokenType.SEED,
|
||||
TokenType.SEMI,
|
||||
TokenType.SET,
|
||||
TokenType.SHOW,
|
||||
TokenType.STABLE,
|
||||
|
@ -273,6 +288,8 @@ class Parser:
|
|||
TokenType.INNER,
|
||||
TokenType.OUTER,
|
||||
TokenType.CROSS,
|
||||
TokenType.SEMI,
|
||||
TokenType.ANTI,
|
||||
}
|
||||
|
||||
COLUMN_OPERATORS = {
|
||||
|
@ -318,6 +335,8 @@ class Parser:
|
|||
exp.Properties: lambda self: self._parse_properties(),
|
||||
exp.Where: lambda self: self._parse_where(),
|
||||
exp.Ordered: lambda self: self._parse_ordered(),
|
||||
exp.Having: lambda self: self._parse_having(),
|
||||
exp.With: lambda self: self._parse_with(),
|
||||
"JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
|
||||
}
|
||||
|
||||
|
@ -338,7 +357,6 @@ class Parser:
|
|||
TokenType.NULL: lambda *_: exp.Null(),
|
||||
TokenType.TRUE: lambda *_: exp.Boolean(this=True),
|
||||
TokenType.FALSE: lambda *_: exp.Boolean(this=False),
|
||||
TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(),
|
||||
TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()),
|
||||
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
|
||||
TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
|
||||
|
@ -910,7 +928,20 @@ class Parser:
|
|||
return self.expression(exp.Tuple, expressions=expressions)
|
||||
|
||||
def _parse_select(self, nested=False, table=False):
|
||||
if self._match(TokenType.SELECT):
|
||||
cte = self._parse_with()
|
||||
if cte:
|
||||
this = self._parse_statement()
|
||||
|
||||
if not this:
|
||||
self.raise_error("Failed to parse any statement following CTE")
|
||||
return cte
|
||||
|
||||
if "with" in this.arg_types:
|
||||
this.set("with", cte)
|
||||
else:
|
||||
self.raise_error(f"{this.key} does not support CTE")
|
||||
this = cte
|
||||
elif self._match(TokenType.SELECT):
|
||||
hint = self._parse_hint()
|
||||
all_ = self._match(TokenType.ALL)
|
||||
distinct = self._match(TokenType.DISTINCT)
|
||||
|
@ -938,39 +969,6 @@ class Parser:
|
|||
if from_:
|
||||
this.set("from", from_)
|
||||
self._parse_query_modifiers(this)
|
||||
elif self._match(TokenType.WITH):
|
||||
recursive = self._match(TokenType.RECURSIVE)
|
||||
|
||||
expressions = []
|
||||
|
||||
while True:
|
||||
expressions.append(self._parse_cte())
|
||||
|
||||
if not self._match(TokenType.COMMA):
|
||||
break
|
||||
|
||||
cte = self.expression(
|
||||
exp.With,
|
||||
expressions=expressions,
|
||||
recursive=recursive,
|
||||
)
|
||||
this = self._parse_statement()
|
||||
|
||||
if not this:
|
||||
self.raise_error("Failed to parse any statement following CTE")
|
||||
return cte
|
||||
|
||||
if "with" in this.arg_types:
|
||||
this.set(
|
||||
"with",
|
||||
self.expression(
|
||||
exp.With,
|
||||
expressions=expressions,
|
||||
recursive=recursive,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.raise_error(f"{this.key} does not support CTE")
|
||||
elif (table or nested) and self._match(TokenType.L_PAREN):
|
||||
this = self._parse_table() if table else self._parse_select(nested=True)
|
||||
self._parse_query_modifiers(this)
|
||||
|
@ -986,6 +984,26 @@ class Parser:
|
|||
|
||||
return self._parse_set_operations(this) if this else None
|
||||
|
||||
def _parse_with(self):
|
||||
if not self._match(TokenType.WITH):
|
||||
return None
|
||||
|
||||
recursive = self._match(TokenType.RECURSIVE)
|
||||
|
||||
expressions = []
|
||||
|
||||
while True:
|
||||
expressions.append(self._parse_cte())
|
||||
|
||||
if not self._match(TokenType.COMMA):
|
||||
break
|
||||
|
||||
return self.expression(
|
||||
exp.With,
|
||||
expressions=expressions,
|
||||
recursive=recursive,
|
||||
)
|
||||
|
||||
def _parse_cte(self):
|
||||
alias = self._parse_table_alias()
|
||||
if not alias or not alias.this:
|
||||
|
@ -1485,8 +1503,7 @@ class Parser:
|
|||
unnest = self._parse_unnest()
|
||||
if unnest:
|
||||
this = self.expression(exp.In, this=this, unnest=unnest)
|
||||
else:
|
||||
self._match_l_paren()
|
||||
elif self._match(TokenType.L_PAREN):
|
||||
expressions = self._parse_csv(self._parse_select_or_expression)
|
||||
|
||||
if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable):
|
||||
|
@ -1495,6 +1512,9 @@ class Parser:
|
|||
this = self.expression(exp.In, this=this, expressions=expressions)
|
||||
|
||||
self._match_r_paren()
|
||||
else:
|
||||
this = self.expression(exp.In, this=this, field=self._parse_field())
|
||||
|
||||
return this
|
||||
|
||||
def _parse_between(self, this):
|
||||
|
@ -1591,7 +1611,7 @@ class Parser:
|
|||
elif nested:
|
||||
expressions = self._parse_csv(self._parse_types)
|
||||
else:
|
||||
expressions = self._parse_csv(self._parse_number)
|
||||
expressions = self._parse_csv(self._parse_type)
|
||||
|
||||
if not expressions:
|
||||
self._retreat(index)
|
||||
|
@ -1706,7 +1726,7 @@ class Parser:
|
|||
def _parse_field(self, any_token=False):
|
||||
return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token)
|
||||
|
||||
def _parse_function(self):
|
||||
def _parse_function(self, functions=None):
|
||||
if not self._curr:
|
||||
return None
|
||||
|
||||
|
@ -1742,7 +1762,9 @@ class Parser:
|
|||
self._match_r_paren()
|
||||
return this
|
||||
|
||||
function = self.FUNCTIONS.get(upper)
|
||||
if functions is None:
|
||||
functions = self.FUNCTIONS
|
||||
function = functions.get(upper)
|
||||
args = self._parse_csv(self._parse_lambda)
|
||||
|
||||
if function:
|
||||
|
@ -2025,10 +2047,20 @@ class Parser:
|
|||
return self.expression(exp.Cast, this=this, to=to)
|
||||
|
||||
def _parse_position(self):
|
||||
substr = self._parse_bitwise()
|
||||
args = self._parse_csv(self._parse_bitwise)
|
||||
|
||||
if self._match(TokenType.IN):
|
||||
string = self._parse_bitwise()
|
||||
return self.expression(exp.StrPosition, this=string, substr=substr)
|
||||
args.append(self._parse_bitwise())
|
||||
|
||||
# Note: we're parsing in order needle, haystack, position
|
||||
this = exp.StrPosition.from_arg_list(args)
|
||||
self.validate_expression(this, args)
|
||||
|
||||
return this
|
||||
|
||||
def _parse_join_hint(self, func_name):
|
||||
args = self._parse_csv(self._parse_table)
|
||||
return exp.JoinHint(this=func_name.upper(), expressions=args)
|
||||
|
||||
def _parse_substring(self):
|
||||
# Postgres supports the form: substring(string [from int] [for int])
|
||||
|
@ -2247,6 +2279,9 @@ class Parser:
|
|||
def _parse_placeholder(self):
|
||||
if self._match(TokenType.PLACEHOLDER):
|
||||
return exp.Placeholder()
|
||||
elif self._match(TokenType.COLON):
|
||||
self._advance()
|
||||
return exp.Placeholder(this=self._prev.text)
|
||||
return None
|
||||
|
||||
def _parse_except(self):
|
||||
|
|
|
@ -104,6 +104,7 @@ class TokenType(AutoName):
|
|||
ALL = auto()
|
||||
ALTER = auto()
|
||||
ANALYZE = auto()
|
||||
ANTI = auto()
|
||||
ANY = auto()
|
||||
ARRAY = auto()
|
||||
ASC = auto()
|
||||
|
@ -236,6 +237,7 @@ class TokenType(AutoName):
|
|||
SCHEMA_COMMENT = auto()
|
||||
SEED = auto()
|
||||
SELECT = auto()
|
||||
SEMI = auto()
|
||||
SEPARATOR = auto()
|
||||
SET = auto()
|
||||
SHOW = auto()
|
||||
|
@ -262,6 +264,7 @@ class TokenType(AutoName):
|
|||
USE = auto()
|
||||
USING = auto()
|
||||
VALUES = auto()
|
||||
VACUUM = auto()
|
||||
VIEW = auto()
|
||||
VOLATILE = auto()
|
||||
WHEN = auto()
|
||||
|
@ -406,6 +409,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"ALTER": TokenType.ALTER,
|
||||
"ANALYZE": TokenType.ANALYZE,
|
||||
"AND": TokenType.AND,
|
||||
"ANTI": TokenType.ANTI,
|
||||
"ANY": TokenType.ANY,
|
||||
"ASC": TokenType.ASC,
|
||||
"AS": TokenType.ALIAS,
|
||||
|
@ -528,6 +532,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"ROWS": TokenType.ROWS,
|
||||
"SEED": TokenType.SEED,
|
||||
"SELECT": TokenType.SELECT,
|
||||
"SEMI": TokenType.SEMI,
|
||||
"SET": TokenType.SET,
|
||||
"SHOW": TokenType.SHOW,
|
||||
"SOME": TokenType.SOME,
|
||||
|
@ -551,6 +556,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"UPDATE": TokenType.UPDATE,
|
||||
"USE": TokenType.USE,
|
||||
"USING": TokenType.USING,
|
||||
"VACUUM": TokenType.VACUUM,
|
||||
"VALUES": TokenType.VALUES,
|
||||
"VIEW": TokenType.VIEW,
|
||||
"VOLATILE": TokenType.VOLATILE,
|
||||
|
@ -577,6 +583,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"INT8": TokenType.BIGINT,
|
||||
"DECIMAL": TokenType.DECIMAL,
|
||||
"MAP": TokenType.MAP,
|
||||
"NULLABLE": TokenType.NULLABLE,
|
||||
"NUMBER": TokenType.DECIMAL,
|
||||
"NUMERIC": TokenType.DECIMAL,
|
||||
"FIXED": TokenType.DECIMAL,
|
||||
|
@ -629,6 +636,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
TokenType.SHOW,
|
||||
TokenType.TRUNCATE,
|
||||
TokenType.USE,
|
||||
TokenType.VACUUM,
|
||||
}
|
||||
|
||||
# handle numeric literals like in hive (3L = BIGINT)
|
||||
|
|
|
@ -152,6 +152,10 @@ class TestBigQuery(Validator):
|
|||
"SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)"
|
||||
)
|
||||
|
||||
self.validate_identity(
|
||||
"SELECT item, purchases, LAST_VALUE(item) OVER (item_window ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING) AS most_popular FROM Produce WINDOW item_window AS (ORDER BY purchases)"
|
||||
)
|
||||
|
||||
self.validate_identity(
|
||||
"SELECT LAST_VALUE(a IGNORE NULLS) OVER y FROM x WINDOW y AS (PARTITION BY CATEGORY)",
|
||||
)
|
||||
|
@ -222,6 +226,20 @@ class TestBigQuery(Validator):
|
|||
"spark": "DATE_ADD(CURRENT_DATE, 1)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"DATE_DIFF(DATE '2010-07-07', DATE '2008-12-25', DAY)",
|
||||
write={
|
||||
"bigquery": "DATE_DIFF(CAST('2010-07-07' AS DATE), CAST('2008-12-25' AS DATE), DAY)",
|
||||
"mysql": "DATEDIFF(CAST('2010-07-07' AS DATE), CAST('2008-12-25' AS DATE))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"DATE_DIFF(DATE '2010-07-07', DATE '2008-12-25', MINUTE)",
|
||||
write={
|
||||
"bigquery": "DATE_DIFF(CAST('2010-07-07' AS DATE), CAST('2008-12-25' AS DATE), MINUTE)",
|
||||
"mysql": "DATEDIFF(CAST('2010-07-07' AS DATE), CAST('2008-12-25' AS DATE))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CURRENT_DATE('UTC')",
|
||||
write={
|
||||
|
|
|
@ -8,6 +8,8 @@ class TestClickhouse(Validator):
|
|||
self.validate_identity("dictGet(x, 'y')")
|
||||
self.validate_identity("SELECT * FROM x FINAL")
|
||||
self.validate_identity("SELECT * FROM x AS y FINAL")
|
||||
self.validate_identity("'a' IN mapKeys(map('a', 1, 'b', 2))")
|
||||
self.validate_identity("CAST((1, 2) AS Tuple(a Int8, b Int16))")
|
||||
|
||||
self.validate_all(
|
||||
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
|
||||
|
@ -20,6 +22,12 @@ class TestClickhouse(Validator):
|
|||
self.validate_all(
|
||||
"CAST(1 AS NULLABLE(Int64))",
|
||||
write={
|
||||
"clickhouse": "CAST(1 AS Nullable(BIGINT))",
|
||||
"clickhouse": "CAST(1 AS Nullable(Int64))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CAST(1 AS Nullable(DateTime64(6, 'UTC')))",
|
||||
write={
|
||||
"clickhouse": "CAST(1 AS Nullable(DateTime64(6, 'UTC')))",
|
||||
},
|
||||
)
|
||||
|
|
|
@ -81,6 +81,24 @@ class TestDialect(Validator):
|
|||
"starrocks": "CAST(a AS STRING)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CAST(MAP('a', '1') AS MAP(TEXT, TEXT))",
|
||||
write={
|
||||
"clickhouse": "CAST(map('a', '1') AS Map(TEXT, TEXT))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CAST(ARRAY(1, 2) AS ARRAY<TINYINT>)",
|
||||
write={
|
||||
"clickhouse": "CAST([1, 2] AS Array(Int8))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CAST((1, 2) AS STRUCT<a: TINYINT, b: SMALLINT, c: INT, d: BIGINT>)",
|
||||
write={
|
||||
"clickhouse": "CAST((1, 2) AS Tuple(a Int8, b Int16, c Int32, d Int64))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CAST(a AS DATETIME)",
|
||||
write={
|
||||
|
@ -170,7 +188,7 @@ class TestDialect(Validator):
|
|||
"CAST(a AS DOUBLE)",
|
||||
write={
|
||||
"bigquery": "CAST(a AS FLOAT64)",
|
||||
"clickhouse": "CAST(a AS DOUBLE)",
|
||||
"clickhouse": "CAST(a AS Float64)",
|
||||
"duckdb": "CAST(a AS DOUBLE)",
|
||||
"mysql": "CAST(a AS DOUBLE)",
|
||||
"hive": "CAST(a AS DOUBLE)",
|
||||
|
@ -234,6 +252,8 @@ class TestDialect(Validator):
|
|||
write={
|
||||
"duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')",
|
||||
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
|
||||
"oracle": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
|
||||
"postgres": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
|
||||
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d')",
|
||||
"redshift": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
|
||||
"spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')",
|
||||
|
@ -245,6 +265,8 @@ class TestDialect(Validator):
|
|||
"duckdb": "STRPTIME(x, '%y')",
|
||||
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)",
|
||||
"presto": "DATE_PARSE(x, '%y')",
|
||||
"oracle": "TO_TIMESTAMP(x, 'YY')",
|
||||
"postgres": "TO_TIMESTAMP(x, 'YY')",
|
||||
"redshift": "TO_TIMESTAMP(x, 'YY')",
|
||||
"spark": "TO_TIMESTAMP(x, 'yy')",
|
||||
},
|
||||
|
@ -288,6 +310,8 @@ class TestDialect(Validator):
|
|||
write={
|
||||
"duckdb": "STRFTIME(x, '%Y-%m-%d')",
|
||||
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd')",
|
||||
"oracle": "TO_CHAR(x, 'YYYY-MM-DD')",
|
||||
"postgres": "TO_CHAR(x, 'YYYY-MM-DD')",
|
||||
"presto": "DATE_FORMAT(x, '%Y-%m-%d')",
|
||||
"redshift": "TO_CHAR(x, 'YYYY-MM-DD')",
|
||||
},
|
||||
|
@ -348,6 +372,8 @@ class TestDialect(Validator):
|
|||
write={
|
||||
"duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))",
|
||||
"hive": "FROM_UNIXTIME(x)",
|
||||
"oracle": "TO_DATE('1970-01-01','YYYY-MM-DD') + (x / 86400)",
|
||||
"postgres": "TO_TIMESTAMP(x)",
|
||||
"presto": "FROM_UNIXTIME(x)",
|
||||
"starrocks": "FROM_UNIXTIME(x)",
|
||||
},
|
||||
|
@ -704,6 +730,7 @@ class TestDialect(Validator):
|
|||
"SELECT * FROM a UNION SELECT * FROM b",
|
||||
read={
|
||||
"bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b",
|
||||
"clickhouse": "SELECT * FROM a UNION DISTINCT SELECT * FROM b",
|
||||
"duckdb": "SELECT * FROM a UNION SELECT * FROM b",
|
||||
"presto": "SELECT * FROM a UNION SELECT * FROM b",
|
||||
"spark": "SELECT * FROM a UNION SELECT * FROM b",
|
||||
|
@ -719,6 +746,7 @@ class TestDialect(Validator):
|
|||
"SELECT * FROM a UNION ALL SELECT * FROM b",
|
||||
read={
|
||||
"bigquery": "SELECT * FROM a UNION ALL SELECT * FROM b",
|
||||
"clickhouse": "SELECT * FROM a UNION ALL SELECT * FROM b",
|
||||
"duckdb": "SELECT * FROM a UNION ALL SELECT * FROM b",
|
||||
"presto": "SELECT * FROM a UNION ALL SELECT * FROM b",
|
||||
"spark": "SELECT * FROM a UNION ALL SELECT * FROM b",
|
||||
|
@ -848,15 +876,28 @@ class TestDialect(Validator):
|
|||
"postgres": "STRPOS(x, ' ')",
|
||||
"presto": "STRPOS(x, ' ')",
|
||||
"spark": "LOCATE(' ', x)",
|
||||
"clickhouse": "position(x, ' ')",
|
||||
"snowflake": "POSITION(' ', x)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"STR_POSITION(x, 'a')",
|
||||
"STR_POSITION('a', x)",
|
||||
write={
|
||||
"duckdb": "STRPOS(x, 'a')",
|
||||
"postgres": "STRPOS(x, 'a')",
|
||||
"presto": "STRPOS(x, 'a')",
|
||||
"spark": "LOCATE('a', x)",
|
||||
"clickhouse": "position(x, 'a')",
|
||||
"snowflake": "POSITION('a', x)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"POSITION('a', x, 3)",
|
||||
write={
|
||||
"presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
|
||||
"spark": "LOCATE('a', x, 3)",
|
||||
"clickhouse": "position(x, 'a', 3)",
|
||||
"snowflake": "POSITION('a', x, 3)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
|
|
|
@ -247,7 +247,7 @@ class TestHive(Validator):
|
|||
"presto": "DATE_DIFF('day', CAST(SUBSTR(CAST(b AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST(a AS VARCHAR), 1, 10) AS DATE))",
|
||||
"hive": "DATEDIFF(TO_DATE(a), TO_DATE(b))",
|
||||
"spark": "DATEDIFF(TO_DATE(a), TO_DATE(b))",
|
||||
"": "DATE_DIFF(TS_OR_DS_TO_DATE(a), TS_OR_DS_TO_DATE(b))",
|
||||
"": "DATEDIFF(TS_OR_DS_TO_DATE(a), TS_OR_DS_TO_DATE(b))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
|
@ -295,7 +295,7 @@ class TestHive(Validator):
|
|||
"presto": "DATE_DIFF('day', CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST(CAST(SUBSTR(CAST(y AS VARCHAR), 1, 10) AS DATE) AS VARCHAR), 1, 10) AS DATE))",
|
||||
"hive": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))",
|
||||
"spark": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))",
|
||||
"": "DATE_DIFF(TS_OR_DS_TO_DATE(TS_OR_DS_TO_DATE(y)), TS_OR_DS_TO_DATE(x))",
|
||||
"": "DATEDIFF(TS_OR_DS_TO_DATE(TS_OR_DS_TO_DATE(y)), TS_OR_DS_TO_DATE(x))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
|
@ -450,11 +450,21 @@ class TestHive(Validator):
|
|||
)
|
||||
self.validate_all(
|
||||
"MAP(a, b, c, d)",
|
||||
read={
|
||||
"": "VAR_MAP(a, b, c, d)",
|
||||
"clickhouse": "map(a, b, c, d)",
|
||||
"duckdb": "MAP(LIST_VALUE(a, c), LIST_VALUE(b, d))",
|
||||
"hive": "MAP(a, b, c, d)",
|
||||
"presto": "MAP(ARRAY[a, c], ARRAY[b, d])",
|
||||
"spark": "MAP(a, b, c, d)",
|
||||
},
|
||||
write={
|
||||
"": "MAP(ARRAY(a, c), ARRAY(b, d))",
|
||||
"clickhouse": "map(a, b, c, d)",
|
||||
"duckdb": "MAP(LIST_VALUE(a, c), LIST_VALUE(b, d))",
|
||||
"presto": "MAP(ARRAY[a, c], ARRAY[b, d])",
|
||||
"hive": "MAP(a, b, c, d)",
|
||||
"spark": "MAP_FROM_ARRAYS(ARRAY(a, c), ARRAY(b, d))",
|
||||
"spark": "MAP(a, b, c, d)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
|
@ -463,7 +473,7 @@ class TestHive(Validator):
|
|||
"duckdb": "MAP(LIST_VALUE(a), LIST_VALUE(b))",
|
||||
"presto": "MAP(ARRAY[a], ARRAY[b])",
|
||||
"hive": "MAP(a, b)",
|
||||
"spark": "MAP_FROM_ARRAYS(ARRAY(a), ARRAY(b))",
|
||||
"spark": "MAP(a, b)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
|
|
|
@ -67,6 +67,7 @@ class TestPostgres(Validator):
|
|||
self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))")
|
||||
self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')")
|
||||
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
|
||||
self.validate_identity("SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')")
|
||||
|
||||
self.validate_all(
|
||||
"CREATE TABLE x (a UUID, b BYTEA)",
|
||||
|
|
|
@ -305,3 +305,35 @@ class TestSnowflake(Validator):
|
|||
self.validate_identity(
|
||||
"CREATE PROCEDURE a.b.c(x INT, y VARIANT) RETURNS OBJECT EXECUTE AS CALLER AS 'BEGIN SELECT 1; END;'"
|
||||
)
|
||||
|
||||
def test_table_literal(self):
|
||||
# All examples from https://docs.snowflake.com/en/sql-reference/literals-table.html
|
||||
self.validate_all(
|
||||
r"""SELECT * FROM TABLE('MYTABLE')""", write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""}
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
r"""SELECT * FROM TABLE('MYDB."MYSCHEMA"."MYTABLE"')""",
|
||||
write={"snowflake": r"""SELECT * FROM TABLE('MYDB."MYSCHEMA"."MYTABLE"')"""},
|
||||
)
|
||||
|
||||
# Per Snowflake documentation at https://docs.snowflake.com/en/sql-reference/literals-table.html
|
||||
# one can use either a " ' " or " $$ " to enclose the object identifier.
|
||||
# Capturing the single tokens seems like lot of work. Hence adjusting tests to use these interchangeably,
|
||||
self.validate_all(
|
||||
r"""SELECT * FROM TABLE($$MYDB. "MYSCHEMA"."MYTABLE"$$)""",
|
||||
write={"snowflake": r"""SELECT * FROM TABLE('MYDB. "MYSCHEMA"."MYTABLE"')"""},
|
||||
)
|
||||
|
||||
self.validate_all(r"""SELECT * FROM TABLE($MYVAR)""", write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""})
|
||||
|
||||
self.validate_all(r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""})
|
||||
|
||||
self.validate_all(
|
||||
r"""SELECT * FROM TABLE(:BINDING)""", write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""}
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10""",
|
||||
write={"snowflake": r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10"""},
|
||||
)
|
||||
|
|
|
@ -111,12 +111,70 @@ TBLPROPERTIES (
|
|||
"SELECT /*+ COALESCE(3) */ * FROM x",
|
||||
write={
|
||||
"spark": "SELECT /*+ COALESCE(3) */ * FROM x",
|
||||
"bigquery": "SELECT * FROM x",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT /*+ COALESCE(3), REPARTITION(1) */ * FROM x",
|
||||
write={
|
||||
"spark": "SELECT /*+ COALESCE(3), REPARTITION(1) */ * FROM x",
|
||||
"bigquery": "SELECT * FROM x",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT /*+ BROADCAST(table) */ cola FROM table",
|
||||
write={
|
||||
"spark": "SELECT /*+ BROADCAST(table) */ cola FROM table",
|
||||
"bigquery": "SELECT cola FROM table",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT /*+ BROADCASTJOIN(table) */ cola FROM table",
|
||||
write={
|
||||
"spark": "SELECT /*+ BROADCASTJOIN(table) */ cola FROM table",
|
||||
"bigquery": "SELECT cola FROM table",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT /*+ MAPJOIN(table) */ cola FROM table",
|
||||
write={
|
||||
"spark": "SELECT /*+ MAPJOIN(table) */ cola FROM table",
|
||||
"bigquery": "SELECT cola FROM table",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT /*+ MERGE(table) */ cola FROM table",
|
||||
write={
|
||||
"spark": "SELECT /*+ MERGE(table) */ cola FROM table",
|
||||
"bigquery": "SELECT cola FROM table",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT /*+ SHUFFLEMERGE(table) */ cola FROM table",
|
||||
write={
|
||||
"spark": "SELECT /*+ SHUFFLEMERGE(table) */ cola FROM table",
|
||||
"bigquery": "SELECT cola FROM table",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT /*+ MERGEJOIN(table) */ cola FROM table",
|
||||
write={
|
||||
"spark": "SELECT /*+ MERGEJOIN(table) */ cola FROM table",
|
||||
"bigquery": "SELECT cola FROM table",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT /*+ SHUFFLE_HASH(table) */ cola FROM table",
|
||||
write={
|
||||
"spark": "SELECT /*+ SHUFFLE_HASH(table) */ cola FROM table",
|
||||
"bigquery": "SELECT cola FROM table",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT /*+ SHUFFLE_REPLICATE_NL(table) */ cola FROM table",
|
||||
write={
|
||||
"spark": "SELECT /*+ SHUFFLE_REPLICATE_NL(table) */ cola FROM table",
|
||||
"bigquery": "SELECT cola FROM table",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
6
tests/fixtures/identity.sql
vendored
6
tests/fixtures/identity.sql
vendored
|
@ -321,6 +321,10 @@ SELECT 1 FROM a INNER JOIN b ON a.x = b.x
|
|||
SELECT 1 FROM a LEFT JOIN b ON a.x = b.x
|
||||
SELECT 1 FROM a RIGHT JOIN b ON a.x = b.x
|
||||
SELECT 1 FROM a CROSS JOIN b ON a.x = b.x
|
||||
SELECT 1 FROM a LEFT SEMI JOIN b ON a.x = b.x
|
||||
SELECT 1 FROM a LEFT ANTI JOIN b ON a.x = b.x
|
||||
SELECT 1 FROM a RIGHT SEMI JOIN b ON a.x = b.x
|
||||
SELECT 1 FROM a RIGHT ANTI JOIN b ON a.x = b.x
|
||||
SELECT 1 FROM a JOIN b USING (x)
|
||||
SELECT 1 FROM a JOIN b USING (x, y, z)
|
||||
SELECT 1 FROM a JOIN (SELECT a FROM c) AS b ON a.x = b.x AND a.x < 2
|
||||
|
@ -529,12 +533,14 @@ UPDATE db.tbl_name SET foo = 123 WHERE tbl_name.bar = 234
|
|||
UPDATE db.tbl_name SET foo = 123, foo_1 = 234 WHERE tbl_name.bar = 234
|
||||
TRUNCATE TABLE x
|
||||
OPTIMIZE TABLE y
|
||||
VACUUM FREEZE my_table
|
||||
WITH a AS (SELECT 1) INSERT INTO b SELECT * FROM a
|
||||
WITH a AS (SELECT * FROM b) UPDATE a SET col = 1
|
||||
WITH a AS (SELECT * FROM b) CREATE TABLE b AS SELECT * FROM a
|
||||
WITH a AS (SELECT * FROM b) DELETE FROM a
|
||||
WITH a AS (SELECT * FROM b) CACHE TABLE a
|
||||
SELECT ? AS ? FROM x WHERE b BETWEEN ? AND ? GROUP BY ?, 1 LIMIT ?
|
||||
SELECT :hello, ? FROM x LIMIT :my_limit
|
||||
WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS foo, c.bar AS bar FROM c)) SELECT * FROM a
|
||||
WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a
|
||||
SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z
|
||||
|
|
168
tests/fixtures/optimizer/merge_subqueries.sql
vendored
168
tests/fixtures/optimizer/merge_subqueries.sql
vendored
|
@ -1,107 +1,189 @@
|
|||
-- Simple
|
||||
# title: Simple
|
||||
SELECT a, b FROM (SELECT a, b FROM x);
|
||||
SELECT x.a AS a, x.b AS b FROM x AS x;
|
||||
|
||||
-- Inner table alias is merged
|
||||
# title: Inner table alias is merged
|
||||
SELECT a, b FROM (SELECT a, b FROM x AS q) AS r;
|
||||
SELECT q.a AS a, q.b AS b FROM x AS q;
|
||||
|
||||
-- Double nesting
|
||||
# title: Double nesting
|
||||
SELECT a, b FROM (SELECT a, b FROM (SELECT a, b FROM x));
|
||||
SELECT x.a AS a, x.b AS b FROM x AS x;
|
||||
|
||||
-- WHERE clause is merged
|
||||
SELECT a, SUM(b) FROM (SELECT a, b FROM x WHERE a > 1) GROUP BY a;
|
||||
SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 GROUP BY x.a;
|
||||
# title: WHERE clause is merged
|
||||
SELECT a, SUM(b) AS b FROM (SELECT a, b FROM x WHERE a > 1) GROUP BY a;
|
||||
SELECT x.a AS a, SUM(x.b) AS b FROM x AS x WHERE x.a > 1 GROUP BY x.a;
|
||||
|
||||
-- Outer query has join
|
||||
SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b;
|
||||
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1;
|
||||
|
||||
-- Outer query has join
|
||||
# title: Outer query has join
|
||||
SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b;
|
||||
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1;
|
||||
|
||||
# title: Leave tables isolated
|
||||
# leave_tables_isolated: true
|
||||
SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b;
|
||||
SELECT x.a AS a, y.c AS c FROM (SELECT x.a AS a, x.b AS b FROM x AS x WHERE x.a > 1) AS x JOIN y AS y ON x.b = y.b;
|
||||
|
||||
-- Join on derived table
|
||||
# title: Join on derived table
|
||||
SELECT a, c FROM x JOIN (SELECT b, c FROM y) AS y ON x.b = y.b;
|
||||
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
|
||||
|
||||
-- Inner query has a join
|
||||
# title: Inner query has a join
|
||||
SELECT a, c FROM (SELECT a, c FROM x JOIN y ON x.b = y.b);
|
||||
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
|
||||
|
||||
-- Inner query has conflicting name in outer query
|
||||
# title: Inner query has conflicting name in outer query
|
||||
SELECT a, c FROM (SELECT q.a, q.b FROM x AS q) AS x JOIN y AS q ON x.b = q.b;
|
||||
SELECT q_2.a AS a, q.c AS c FROM x AS q_2 JOIN y AS q ON q_2.b = q.b;
|
||||
|
||||
-- Inner query has conflicting name in joined source
|
||||
# title: Inner query has conflicting name in joined source
|
||||
SELECT x.a, q.c FROM (SELECT a, x.b FROM x JOIN y AS q ON x.b = q.b) AS x JOIN y AS q ON x.b = q.b;
|
||||
SELECT x.a AS a, q.c AS c FROM x AS x JOIN y AS q_2 ON x.b = q_2.b JOIN y AS q ON x.b = q.b;
|
||||
|
||||
-- Inner query has multiple conflicting names
|
||||
SELECT x.a, q.c, r.c FROM (SELECT q.a, r.b FROM x AS q JOIN y AS r ON q.b = r.b) AS x JOIN y AS q ON x.b = q.b JOIN y AS r ON x.b = r.b;
|
||||
SELECT q_2.a AS a, q.c AS c, r.c AS c FROM x AS q_2 JOIN y AS r_2 ON q_2.b = r_2.b JOIN y AS q ON r_2.b = q.b JOIN y AS r ON r_2.b = r.b;
|
||||
# title: Inner query has multiple conflicting names
|
||||
SELECT x.a, q.c, r.c FROM (SELECT q.a, r.b FROM x AS q JOIN y AS r ON q.b = r.b) AS x JOIN y AS q ON x.b = q.b JOIN y AS r ON x.b = r.b ORDER BY x.a, q.c, r.c;
|
||||
SELECT q_2.a AS a, q.c AS c, r.c AS c FROM x AS q_2 JOIN y AS r_2 ON q_2.b = r_2.b JOIN y AS q ON r_2.b = q.b JOIN y AS r ON r_2.b = r.b ORDER BY q_2.a, q.c, r.c;
|
||||
|
||||
-- Inner queries have conflicting names with each other
|
||||
# title: Inner queries have conflicting names with each other
|
||||
SELECT r.b FROM (SELECT b FROM x AS x) AS q JOIN (SELECT b FROM x) AS r ON q.b = r.b;
|
||||
SELECT x_2.b AS b FROM x AS x JOIN x AS x_2 ON x.b = x_2.b;
|
||||
|
||||
-- WHERE clause in joined derived table is merged to ON clause
|
||||
SELECT x.a, y.c FROM x JOIN (SELECT b, c FROM y WHERE c > 1) AS y;
|
||||
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON y.c > 1;
|
||||
# title: WHERE clause in joined derived table is merged to ON clause
|
||||
SELECT x.a, y.c FROM x JOIN (SELECT b, c FROM y WHERE c > 1) AS y ON x.b = y.b;
|
||||
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b AND y.c > 1;
|
||||
|
||||
-- Comma JOIN in outer query
|
||||
# title: Comma JOIN in outer query
|
||||
SELECT x.a, y.c FROM (SELECT a FROM x) AS x, (SELECT c FROM y) AS y;
|
||||
SELECT x.a AS a, y.c AS c FROM x AS x, y AS y;
|
||||
|
||||
-- Comma JOIN in inner query
|
||||
# title: Comma JOIN in inner query
|
||||
SELECT x.a, x.c FROM (SELECT x.a, z.c FROM x, y AS z) AS x;
|
||||
SELECT x.a AS a, z.c AS c FROM x AS x CROSS JOIN y AS z;
|
||||
|
||||
-- (Regression) Column in ORDER BY
|
||||
# title: (Regression) Column in ORDER BY
|
||||
SELECT * FROM (SELECT * FROM (SELECT * FROM x)) ORDER BY a LIMIT 1;
|
||||
SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a LIMIT 1;
|
||||
|
||||
-- CTE
|
||||
# title: CTE
|
||||
WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x;
|
||||
SELECT x.a AS a, x.b AS b FROM x AS x;
|
||||
|
||||
-- CTE with outer table alias
|
||||
# title: CTE with outer table alias
|
||||
WITH y AS (SELECT a, b FROM x) SELECT a, b FROM y AS z;
|
||||
SELECT x.a AS a, x.b AS b FROM x AS x;
|
||||
|
||||
-- Nested CTE
|
||||
WITH x AS (SELECT a FROM x), x2 AS (SELECT a FROM x) SELECT a FROM x2;
|
||||
# title: Nested CTE
|
||||
WITH x2 AS (SELECT a FROM x), x3 AS (SELECT a FROM x2) SELECT a FROM x3;
|
||||
SELECT x.a AS a FROM x AS x;
|
||||
|
||||
-- CTE WHERE clause is merged
|
||||
WITH x AS (SELECT a, b FROM x WHERE a > 1) SELECT a, SUM(b) FROM x GROUP BY a;
|
||||
SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 GROUP BY x.a;
|
||||
# title: CTE WHERE clause is merged
|
||||
WITH x AS (SELECT a, b FROM x WHERE a > 1) SELECT a, SUM(b) AS b FROM x GROUP BY a;
|
||||
SELECT x.a AS a, SUM(x.b) AS b FROM x AS x WHERE x.a > 1 GROUP BY x.a;
|
||||
|
||||
-- CTE Outer query has join
|
||||
WITH x AS (SELECT a, b FROM x WHERE a > 1) SELECT a, c FROM x AS x JOIN y ON x.b = y.b;
|
||||
# title: CTE Outer query has join
|
||||
WITH x2 AS (SELECT a, b FROM x WHERE a > 1) SELECT a, c FROM x2 AS x JOIN y ON x.b = y.b;
|
||||
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1;
|
||||
|
||||
-- CTE with inner table alias
|
||||
# title: CTE with inner table alias
|
||||
WITH y AS (SELECT a, b FROM x AS q) SELECT a, b FROM y AS z;
|
||||
SELECT q.a AS a, q.b AS b FROM x AS q;
|
||||
|
||||
-- Duplicate queries to CTE
|
||||
WITH x AS (SELECT a, b FROM x) SELECT x.a, y.b FROM x JOIN x AS y;
|
||||
WITH x AS (SELECT x.a AS a, x.b AS b FROM x AS x) SELECT x.a AS a, y.b AS b FROM x JOIN x AS y;
|
||||
|
||||
-- Nested CTE
|
||||
# title: Nested CTE
|
||||
SELECT * FROM (WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x);
|
||||
SELECT x.a AS a, x.b AS b FROM x AS x;
|
||||
|
||||
-- Inner select is an expression
|
||||
# title: Inner select is an expression
|
||||
SELECT a FROM (SELECT a FROM (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) AS x) AS x;
|
||||
SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b;
|
||||
|
||||
-- CTE select is an expression
|
||||
WITH x AS (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) SELECT a FROM (SELECT a FROM x AS x) AS x;
|
||||
# title: CTE select is an expression
|
||||
WITH x2 AS (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) SELECT a FROM (SELECT a FROM x2 AS x) AS x;
|
||||
SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b;
|
||||
|
||||
# title: Full outer join
|
||||
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x FULL OUTER JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b;
|
||||
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x FULL OUTER JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b;
|
||||
|
||||
# title: Full outer join, no predicates
|
||||
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x) AS x FULL OUTER JOIN (SELECT y.b AS b FROM y AS y) AS y ON x.b = y.b;
|
||||
SELECT x.b AS b, y.b AS b2 FROM x AS x FULL OUTER JOIN y AS y ON x.b = y.b;
|
||||
|
||||
# title: Left join
|
||||
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x LEFT JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b;
|
||||
SELECT x.b AS b, y.b AS b2 FROM x AS x LEFT JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b WHERE x.b = 1;
|
||||
|
||||
# title: Left join, no predicates
|
||||
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x) AS x LEFT JOIN (SELECT y.b AS b FROM y AS y) AS y ON x.b = y.b;
|
||||
SELECT x.b AS b, y.b AS b2 FROM x AS x LEFT JOIN y AS y ON x.b = y.b;
|
||||
|
||||
# title: Right join
|
||||
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x RIGHT JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b;
|
||||
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x RIGHT JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b;
|
||||
|
||||
# title: Right join, no predicates
|
||||
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x) AS x RIGHT JOIN (SELECT y.b AS b FROM y AS y) AS y ON x.b = y.b;
|
||||
SELECT x.b AS b, y.b AS b2 FROM x AS x RIGHT JOIN y AS y ON x.b = y.b;
|
||||
|
||||
# title: Inner join
|
||||
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x INNER JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b;
|
||||
SELECT x.b AS b, y.b AS b2 FROM x AS x INNER JOIN y AS y ON x.b = y.b AND y.b = 2 WHERE x.b = 1;
|
||||
|
||||
# title: Inner join, no predicates
|
||||
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x) AS x INNER JOIN (SELECT y.b AS b FROM y AS y) AS y ON x.b = y.b;
|
||||
SELECT x.b AS b, y.b AS b2 FROM x AS x INNER JOIN y AS y ON x.b = y.b;
|
||||
|
||||
# title: Cross join
|
||||
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x CROSS JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y;
|
||||
SELECT x.b AS b, y.b AS b2 FROM x AS x JOIN y AS y ON y.b = 2 WHERE x.b = 1;
|
||||
|
||||
# title: Cross join, no predicates
|
||||
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x) AS x CROSS JOIN (SELECT y.b AS b FROM y AS y) AS y;
|
||||
SELECT x.b AS b, y.b AS b2 FROM x AS x CROSS JOIN y AS y;
|
||||
|
||||
# title: Broadcast hint
|
||||
# dialect: spark
|
||||
WITH m AS (SELECT x.a, x.b FROM x), n AS (SELECT y.b, y.c FROM y), joined as (SELECT /*+ BROADCAST(k) */ m.a, k.c FROM m JOIN n AS k ON m.b = k.b) SELECT joined.a, joined.c FROM joined;
|
||||
SELECT /*+ BROADCAST(y) */ x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
|
||||
|
||||
# title: Broadcast hint multiple tables
|
||||
# dialect: spark
|
||||
WITH m AS (SELECT x.a, x.b FROM x), n AS (SELECT y.b, y.c FROM y), joined as (SELECT /*+ BROADCAST(m, n) */ m.a, n.c FROM m JOIN n ON m.b = n.b) SELECT joined.a, joined.c FROM joined;
|
||||
SELECT /*+ BROADCAST(x, y) */ x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
|
||||
|
||||
# title: Multiple Table Hints
|
||||
# dialect: spark
|
||||
WITH m AS (SELECT x.a, x.b FROM x), n AS (SELECT y.b, y.c FROM y), joined as (SELECT /*+ BROADCAST(m), MERGE(m, n) */ m.a, n.c FROM m JOIN n ON m.b = n.b) SELECT joined.a, joined.c FROM joined;
|
||||
SELECT /*+ BROADCAST(x), MERGE(x, y) */ x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
|
||||
|
||||
# title: Mix Table and Column Hints
|
||||
# dialect: spark
|
||||
WITH m AS (SELECT x.a, x.b FROM x), n AS (SELECT y.b, y.c FROM y), joined as (SELECT /*+ BROADCAST(m), MERGE(m, n) */ m.a, n.c FROM m JOIN n ON m.b = n.b) SELECT /*+ COALESCE(3) */ joined.a, joined.c FROM joined;
|
||||
SELECT /*+ COALESCE(3), BROADCAST(x), MERGE(x, y) */ x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
|
||||
|
||||
# title: Hint Subquery
|
||||
# dialect: spark
|
||||
SELECT
|
||||
subquery.a,
|
||||
subquery.c
|
||||
FROM (
|
||||
SELECT /*+ BROADCAST(m), MERGE(m, n) */ m.a, n.c FROM (SELECT x.a, x.b FROM x) AS m JOIN (SELECT y.b, y.c FROM y) AS n ON m.b = n.b
|
||||
) AS subquery;
|
||||
SELECT /*+ BROADCAST(x), MERGE(x, y) */ x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
|
||||
|
||||
# title: Subquery Test
|
||||
# dialect: spark
|
||||
SELECT /*+ BROADCAST(x) */
|
||||
x.a,
|
||||
x.c
|
||||
FROM (
|
||||
SELECT
|
||||
x.a,
|
||||
x.c
|
||||
FROM (
|
||||
SELECT
|
||||
x.a,
|
||||
COUNT(1) AS c
|
||||
FROM x
|
||||
GROUP BY x.a
|
||||
) AS x
|
||||
) AS x;
|
||||
SELECT /*+ BROADCAST(x) */ x.a AS a, x.c AS c FROM (SELECT x.a AS a, COUNT(1) AS c FROM x AS x GROUP BY x.a) AS x;
|
||||
|
|
140
tests/fixtures/optimizer/optimizer.sql
vendored
140
tests/fixtures/optimizer/optimizer.sql
vendored
|
@ -1,3 +1,5 @@
|
|||
# title: lateral
|
||||
# execute: false
|
||||
SELECT a, m FROM z LATERAL VIEW EXPLODE([1, 2]) q AS m;
|
||||
SELECT
|
||||
"z"."a" AS "a",
|
||||
|
@ -6,11 +8,13 @@ FROM "z" AS "z"
|
|||
LATERAL VIEW
|
||||
EXPLODE(ARRAY(1, 2)) q AS "m";
|
||||
|
||||
# title: unnest
|
||||
SELECT x FROM UNNEST([1, 2]) AS q(x, y);
|
||||
SELECT
|
||||
"q"."x" AS "x"
|
||||
FROM UNNEST(ARRAY(1, 2)) AS "q"("x", "y");
|
||||
|
||||
# title: Union in CTE
|
||||
WITH cte AS (
|
||||
(
|
||||
SELECT
|
||||
|
@ -21,7 +25,7 @@ WITH cte AS (
|
|||
UNION ALL
|
||||
(
|
||||
SELECT
|
||||
a
|
||||
b AS a
|
||||
FROM
|
||||
y
|
||||
)
|
||||
|
@ -39,7 +43,7 @@ WITH "cte" AS (
|
|||
UNION ALL
|
||||
(
|
||||
SELECT
|
||||
"y"."a" AS "a"
|
||||
"y"."b" AS "a"
|
||||
FROM "y" AS "y"
|
||||
)
|
||||
)
|
||||
|
@ -47,6 +51,7 @@ SELECT
|
|||
"cte"."a" AS "a"
|
||||
FROM "cte";
|
||||
|
||||
# title: Chained CTEs
|
||||
WITH cte1 AS (
|
||||
SELECT a
|
||||
FROM x
|
||||
|
@ -74,30 +79,31 @@ SELECT
|
|||
"cte1"."a" + 1 AS "a"
|
||||
FROM "cte1";
|
||||
|
||||
SELECT a, SUM(b)
|
||||
# title: Correlated subquery
|
||||
SELECT a, SUM(b) AS sum_b
|
||||
FROM (
|
||||
SELECT x.a, y.b
|
||||
FROM x, y
|
||||
WHERE (SELECT max(b) FROM y WHERE x.a = y.a) >= 0 AND x.a = y.a
|
||||
WHERE (SELECT max(b) FROM y WHERE x.b = y.b) >= 0 AND x.b = y.b
|
||||
) d
|
||||
WHERE (TRUE AND TRUE OR 'a' = 'b') AND a > 1
|
||||
GROUP BY a;
|
||||
WITH "_u_0" AS (
|
||||
SELECT
|
||||
MAX("y"."b") AS "_col_0",
|
||||
"y"."a" AS "_u_1"
|
||||
"y"."b" AS "_u_1"
|
||||
FROM "y" AS "y"
|
||||
GROUP BY
|
||||
"y"."a"
|
||||
"y"."b"
|
||||
)
|
||||
SELECT
|
||||
"x"."a" AS "a",
|
||||
SUM("y"."b") AS "_col_1"
|
||||
SUM("y"."b") AS "sum_b"
|
||||
FROM "x" AS "x"
|
||||
LEFT JOIN "_u_0" AS "_u_0"
|
||||
ON "x"."a" = "_u_0"."_u_1"
|
||||
ON "x"."b" = "_u_0"."_u_1"
|
||||
JOIN "y" AS "y"
|
||||
ON "x"."a" = "y"."a"
|
||||
ON "x"."b" = "y"."b"
|
||||
WHERE
|
||||
"_u_0"."_col_0" >= 0
|
||||
AND "x"."a" > 1
|
||||
|
@ -105,6 +111,7 @@ WHERE
|
|||
GROUP BY
|
||||
"x"."a";
|
||||
|
||||
# title: Root subquery
|
||||
(SELECT a FROM x) LIMIT 1;
|
||||
(
|
||||
SELECT
|
||||
|
@ -113,6 +120,7 @@ GROUP BY
|
|||
)
|
||||
LIMIT 1;
|
||||
|
||||
# title: Root subquery is union
|
||||
(SELECT b FROM x UNION SELECT b FROM y) LIMIT 1;
|
||||
(
|
||||
SELECT
|
||||
|
@ -125,6 +133,7 @@ LIMIT 1;
|
|||
)
|
||||
LIMIT 1;
|
||||
|
||||
# title: broadcast
|
||||
# dialect: spark
|
||||
SELECT /*+ BROADCAST(y) */ x.b FROM x JOIN y ON x.b = y.b;
|
||||
SELECT /*+ BROADCAST(`y`) */
|
||||
|
@ -133,11 +142,14 @@ FROM `x` AS `x`
|
|||
JOIN `y` AS `y`
|
||||
ON `x`.`b` = `y`.`b`;
|
||||
|
||||
# title: aggregate
|
||||
# execute: false
|
||||
SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
|
||||
SELECT
|
||||
AGGREGATE(ARRAY("x"."a", "x"."b"), 0, ("x", "acc") -> "x" + "acc" + "x"."a") AS "sum_agg"
|
||||
FROM "x" AS "x";
|
||||
|
||||
# title: values
|
||||
SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb);
|
||||
SELECT
|
||||
"tab"."cola" AS "cola",
|
||||
|
@ -146,6 +158,7 @@ FROM (VALUES
|
|||
(1, 'test'),
|
||||
(2, 'test2')) AS "tab"("cola", "colb");
|
||||
|
||||
# title: spark values
|
||||
# dialect: spark
|
||||
SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb);
|
||||
SELECT
|
||||
|
@ -154,3 +167,112 @@ SELECT
|
|||
FROM VALUES
|
||||
(1, 'test'),
|
||||
(2, 'test2') AS `tab`(`cola`, `colb`);
|
||||
|
||||
# title: complex CTE dependencies
|
||||
WITH m AS (
|
||||
SELECT a, b FROM (VALUES (1, 2)) AS a1(a, b)
|
||||
), n AS (
|
||||
SELECT a, b FROM m WHERE m.a = 1
|
||||
), o AS (
|
||||
SELECT a, b FROM m WHERE m.a = 2
|
||||
) SELECT
|
||||
n.a,
|
||||
n.b,
|
||||
o.b
|
||||
FROM n
|
||||
FULL OUTER JOIN o ON n.a = o.a
|
||||
CROSS JOIN n AS n2
|
||||
WHERE o.b > 0 AND n.a = n2.a;
|
||||
WITH "m" AS (
|
||||
SELECT
|
||||
"a1"."a" AS "a",
|
||||
"a1"."b" AS "b"
|
||||
FROM (VALUES
|
||||
(1, 2)) AS "a1"("a", "b")
|
||||
), "n" AS (
|
||||
SELECT
|
||||
"m"."a" AS "a",
|
||||
"m"."b" AS "b"
|
||||
FROM "m"
|
||||
WHERE
|
||||
"m"."a" = 1
|
||||
), "o" AS (
|
||||
SELECT
|
||||
"m"."a" AS "a",
|
||||
"m"."b" AS "b"
|
||||
FROM "m"
|
||||
WHERE
|
||||
"m"."a" = 2
|
||||
)
|
||||
SELECT
|
||||
"n"."a" AS "a",
|
||||
"n"."b" AS "b",
|
||||
"o"."b" AS "b"
|
||||
FROM "n"
|
||||
FULL JOIN "o"
|
||||
ON "n"."a" = "o"."a"
|
||||
JOIN "n" AS "n2"
|
||||
ON "n"."a" = "n2"."a"
|
||||
WHERE
|
||||
"o"."b" > 0;
|
||||
|
||||
# title: Broadcast hint
|
||||
# dialect: spark
|
||||
WITH m AS (
|
||||
SELECT
|
||||
x.a,
|
||||
x.b
|
||||
FROM x
|
||||
), n AS (
|
||||
SELECT
|
||||
y.b,
|
||||
y.c
|
||||
FROM y
|
||||
), joined as (
|
||||
SELECT /*+ BROADCAST(n) */
|
||||
m.a,
|
||||
n.c
|
||||
FROM m JOIN n ON m.b = n.b
|
||||
)
|
||||
SELECT
|
||||
joined.a,
|
||||
joined.c
|
||||
FROM joined;
|
||||
SELECT /*+ BROADCAST(`y`) */
|
||||
`x`.`a` AS `a`,
|
||||
`y`.`c` AS `c`
|
||||
FROM `x` AS `x`
|
||||
JOIN `y` AS `y`
|
||||
ON `x`.`b` = `y`.`b`;
|
||||
|
||||
# title: Mix Table and Column Hints
|
||||
# dialect: spark
|
||||
WITH m AS (
|
||||
SELECT
|
||||
x.a,
|
||||
x.b
|
||||
FROM x
|
||||
), n AS (
|
||||
SELECT
|
||||
y.b,
|
||||
y.c
|
||||
FROM y
|
||||
), joined as (
|
||||
SELECT /*+ BROADCAST(m), MERGE(m, n) */
|
||||
m.a,
|
||||
n.c
|
||||
FROM m JOIN n ON m.b = n.b
|
||||
)
|
||||
SELECT
|
||||
/*+ COALESCE(3) */
|
||||
joined.a,
|
||||
joined.c
|
||||
FROM joined;
|
||||
SELECT /*+ COALESCE(3),
|
||||
BROADCAST(`x`),
|
||||
MERGE(`x`, `y`) */
|
||||
`x`.`a` AS `a`,
|
||||
`y`.`c` AS `c`
|
||||
FROM `x` AS `x`
|
||||
JOIN `y` AS `y`
|
||||
ON `x`.`b` = `y`.`b`;
|
||||
|
|
43
tests/fixtures/optimizer/qualify_columns.sql
vendored
43
tests/fixtures/optimizer/qualify_columns.sql
vendored
|
@ -19,38 +19,49 @@ SELECT x.a AS a FROM x AS x;
|
|||
SELECT a AS b FROM x;
|
||||
SELECT x.a AS b FROM x AS x;
|
||||
|
||||
# execute: false
|
||||
SELECT 1, 2 FROM x;
|
||||
SELECT 1 AS "_col_0", 2 AS "_col_1" FROM x AS x;
|
||||
|
||||
# execute: false
|
||||
SELECT a + b FROM x;
|
||||
SELECT x.a + x.b AS "_col_0" FROM x AS x;
|
||||
|
||||
SELECT a + b FROM x;
|
||||
SELECT x.a + x.b AS "_col_0" FROM x AS x;
|
||||
|
||||
# execute: false
|
||||
SELECT a, SUM(b) FROM x WHERE a > 1 AND b > 1 GROUP BY a;
|
||||
SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 AND x.b > 1 GROUP BY x.a;
|
||||
|
||||
SELECT a AS j, b FROM x ORDER BY j;
|
||||
SELECT x.a AS j, x.b AS b FROM x AS x ORDER BY j;
|
||||
|
||||
SELECT a AS j, b FROM x GROUP BY j;
|
||||
SELECT x.a AS j, x.b AS b FROM x AS x GROUP BY x.a;
|
||||
SELECT a AS j, b AS a FROM x ORDER BY 1;
|
||||
SELECT x.a AS j, x.b AS a FROM x AS x ORDER BY x.a;
|
||||
|
||||
SELECT SUM(a) AS c, SUM(b) AS d FROM x ORDER BY 1, 2;
|
||||
SELECT SUM(x.a) AS c, SUM(x.b) AS d FROM x AS x ORDER BY SUM(x.a), SUM(x.b);
|
||||
|
||||
# execute: false
|
||||
SELECT SUM(a), SUM(b) AS c FROM x ORDER BY 1, 2;
|
||||
SELECT SUM(x.a) AS "_col_0", SUM(x.b) AS c FROM x AS x ORDER BY SUM(x.a), SUM(x.b);
|
||||
|
||||
SELECT a AS j, b FROM x GROUP BY j, b;
|
||||
SELECT x.a AS j, x.b AS b FROM x AS x GROUP BY x.a, x.b;
|
||||
|
||||
SELECT a, b FROM x GROUP BY 1, 2;
|
||||
SELECT x.a AS a, x.b AS b FROM x AS x GROUP BY x.a, x.b;
|
||||
|
||||
SELECT a, b FROM x ORDER BY 1, 2;
|
||||
SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a, b;
|
||||
SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a, x.b;
|
||||
|
||||
# execute: false
|
||||
SELECT DATE(a), DATE(b) AS c FROM x GROUP BY 1, 2;
|
||||
SELECT DATE(x.a) AS "_col_0", DATE(x.b) AS c FROM x AS x GROUP BY DATE(x.a), DATE(x.b);
|
||||
|
||||
SELECT x.a AS c FROM x JOIN y ON x.b = y.b GROUP BY c;
|
||||
SELECT x.a AS c FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY y.c;
|
||||
SELECT SUM(x.a) AS c FROM x JOIN y ON x.b = y.b GROUP BY c;
|
||||
SELECT SUM(x.a) AS c FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY y.c;
|
||||
|
||||
SELECT DATE(x.a) AS d FROM x JOIN y ON x.b = y.b GROUP BY d;
|
||||
SELECT DATE(x.a) AS d FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY DATE(x.a);
|
||||
SELECT COALESCE(x.a) AS d FROM x JOIN y ON x.b = y.b GROUP BY d;
|
||||
SELECT COALESCE(x.a) AS d FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY COALESCE(x.a);
|
||||
|
||||
SELECT a AS a, b FROM x ORDER BY a;
|
||||
SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a;
|
||||
|
@ -69,6 +80,7 @@ SELECT ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.b) AS row_num FROM x AS x
|
|||
SELECT x.b, x.a FROM x LEFT JOIN y ON x.b = y.b QUALIFY ROW_NUMBER() OVER(PARTITION BY x.b ORDER BY x.a DESC) = 1;
|
||||
SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_NUMBER() OVER (PARTITION BY x.b ORDER BY x.a DESC) = 1;
|
||||
|
||||
# execute: false
|
||||
SELECT AGGREGATE(ARRAY(a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
|
||||
SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + x.a) AS sum_agg FROM x AS x;
|
||||
|
||||
|
@ -93,8 +105,8 @@ SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0";
|
|||
SELECT a FROM (SELECT a FROM (SELECT a FROM x));
|
||||
SELECT "_q_1".a AS a FROM (SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0") AS "_q_1";
|
||||
|
||||
SELECT x.a FROM x AS x JOIN (SELECT * FROM x);
|
||||
SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0";
|
||||
SELECT x.a FROM x AS x JOIN (SELECT * FROM x) AS y ON x.a = y.a;
|
||||
SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a, x.b AS b FROM x AS x) AS y ON x.a = y.a;
|
||||
|
||||
--------------------------------------
|
||||
-- Joins
|
||||
|
@ -123,6 +135,7 @@ SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FRO
|
|||
SELECT a FROM x WHERE b IN (SELECT c FROM y);
|
||||
SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT y.c AS c FROM y AS y);
|
||||
|
||||
# execute: false
|
||||
SELECT (SELECT c FROM y) FROM x;
|
||||
SELECT (SELECT y.c AS c FROM y AS y) AS "_col_0" FROM x AS x;
|
||||
|
||||
|
@ -144,10 +157,12 @@ SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT x.b AS b FROM y AS x);
|
|||
SELECT a FROM x AS i WHERE b IN (SELECT b FROM y AS j WHERE j.b IN (SELECT c FROM y AS k WHERE k.b = j.b));
|
||||
SELECT i.a AS a FROM x AS i WHERE i.b IN (SELECT j.b AS b FROM y AS j WHERE j.b IN (SELECT k.c AS c FROM y AS k WHERE k.b = j.b));
|
||||
|
||||
# execute: false
|
||||
# dialect: bigquery
|
||||
SELECT aa FROM x, UNNEST(a) AS aa;
|
||||
SELECT aa AS aa FROM x AS x, UNNEST(x.a) AS aa;
|
||||
|
||||
# execute: false
|
||||
SELECT aa FROM x, UNNEST(a) AS t(aa);
|
||||
SELECT t.aa AS aa FROM x AS x, UNNEST(x.a) AS t(aa);
|
||||
|
||||
|
@ -205,15 +220,19 @@ WITH z AS ((SELECT x.b AS b FROM x AS x UNION ALL SELECT y.b AS b FROM y AS y) O
|
|||
--------------------------------------
|
||||
-- Except and Replace
|
||||
--------------------------------------
|
||||
# execute: false
|
||||
SELECT * REPLACE(a AS d) FROM x;
|
||||
SELECT x.a AS d, x.b AS b FROM x AS x;
|
||||
|
||||
# execute: false
|
||||
SELECT * EXCEPT(b) REPLACE(a AS d) FROM x;
|
||||
SELECT x.a AS d FROM x AS x;
|
||||
|
||||
# execute: false
|
||||
SELECT x.* EXCEPT(a), y.* FROM x, y;
|
||||
SELECT x.b AS b, y.b AS b, y.c AS c FROM x AS x, y AS y;
|
||||
|
||||
# execute: false
|
||||
SELECT * EXCEPT(a) FROM x;
|
||||
SELECT x.b AS b FROM x AS x;
|
||||
|
||||
|
|
35
tests/fixtures/optimizer/qualify_columns__with_invisible.sql
vendored
Normal file
35
tests/fixtures/optimizer/qualify_columns__with_invisible.sql
vendored
Normal file
|
@ -0,0 +1,35 @@
|
|||
--------------------------------------
|
||||
-- Qualify columns
|
||||
--------------------------------------
|
||||
SELECT a FROM x;
|
||||
SELECT x.a AS a FROM x AS x;
|
||||
|
||||
SELECT b FROM x;
|
||||
SELECT x.b AS b FROM x AS x;
|
||||
|
||||
--------------------------------------
|
||||
-- Derived tables
|
||||
--------------------------------------
|
||||
SELECT x.a FROM x AS x JOIN (SELECT * FROM x);
|
||||
SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a FROM x AS x) AS "_q_0";
|
||||
|
||||
SELECT x.b FROM x AS x JOIN (SELECT b FROM x);
|
||||
SELECT x.b AS b FROM x AS x JOIN (SELECT x.b AS b FROM x AS x) AS "_q_0";
|
||||
|
||||
--------------------------------------
|
||||
-- Expand *
|
||||
--------------------------------------
|
||||
SELECT * FROM x;
|
||||
SELECT x.a AS a FROM x AS x;
|
||||
|
||||
SELECT * FROM y JOIN z ON y.b = z.b;
|
||||
SELECT y.b AS b, z.b AS b FROM y AS y JOIN z AS z ON y.b = z.b;
|
||||
|
||||
SELECT * FROM y JOIN z ON y.c = z.c;
|
||||
SELECT y.b AS b, z.b AS b FROM y AS y JOIN z AS z ON y.c = z.c;
|
||||
|
||||
SELECT a FROM (SELECT * FROM x);
|
||||
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0";
|
||||
|
||||
SELECT * FROM (SELECT a FROM x);
|
||||
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0";
|
6
tests/fixtures/optimizer/simplify.sql
vendored
6
tests/fixtures/optimizer/simplify.sql
vendored
|
@ -52,6 +52,9 @@ TRUE;
|
|||
NULL AND TRUE;
|
||||
NULL;
|
||||
|
||||
NULL AND FALSE;
|
||||
FALSE;
|
||||
|
||||
NULL AND NULL;
|
||||
NULL;
|
||||
|
||||
|
@ -70,6 +73,9 @@ FALSE;
|
|||
NOT FALSE;
|
||||
TRUE;
|
||||
|
||||
NOT NULL;
|
||||
NULL;
|
||||
|
||||
NULL = NULL;
|
||||
NULL;
|
||||
|
||||
|
|
13
tests/fixtures/optimizer/tpc-h/tpc-h.sql
vendored
13
tests/fixtures/optimizer/tpc-h/tpc-h.sql
vendored
|
@ -769,13 +769,20 @@ group by
|
|||
order by
|
||||
custdist desc,
|
||||
c_count desc;
|
||||
WITH "c_orders" AS (
|
||||
WITH "orders_2" AS (
|
||||
SELECT
|
||||
"orders"."o_orderkey" AS "o_orderkey",
|
||||
"orders"."o_custkey" AS "o_custkey",
|
||||
"orders"."o_comment" AS "o_comment"
|
||||
FROM "orders" AS "orders"
|
||||
WHERE
|
||||
NOT "orders"."o_comment" LIKE '%special%requests%'
|
||||
), "c_orders" AS (
|
||||
SELECT
|
||||
COUNT("orders"."o_orderkey") AS "c_count"
|
||||
FROM "customer" AS "customer"
|
||||
LEFT JOIN "orders" AS "orders"
|
||||
LEFT JOIN "orders_2" AS "orders"
|
||||
ON "customer"."c_custkey" = "orders"."o_custkey"
|
||||
AND NOT "orders"."o_comment" LIKE '%special%requests%'
|
||||
GROUP BY
|
||||
"customer"."c_custkey"
|
||||
)
|
||||
|
|
|
@ -45,6 +45,14 @@ def load_sql_fixture_pairs(filename):
|
|||
yield meta, sql, expected
|
||||
|
||||
|
||||
def string_to_bool(string):
|
||||
if string is None:
|
||||
return False
|
||||
if string in (True, False):
|
||||
return string
|
||||
return string and string.lower() in ("true", "1")
|
||||
|
||||
|
||||
TPCH_SCHEMA = {
|
||||
"lineitem": {
|
||||
"l_orderkey": "uint64",
|
||||
|
|
|
@ -1,6 +1,19 @@
|
|||
import unittest
|
||||
|
||||
from sqlglot import and_, condition, exp, from_, not_, or_, parse_one, select
|
||||
from sqlglot import (
|
||||
alias,
|
||||
and_,
|
||||
condition,
|
||||
except_,
|
||||
exp,
|
||||
from_,
|
||||
intersect,
|
||||
not_,
|
||||
or_,
|
||||
parse_one,
|
||||
select,
|
||||
union,
|
||||
)
|
||||
|
||||
|
||||
class TestBuild(unittest.TestCase):
|
||||
|
@ -320,6 +333,54 @@ class TestBuild(unittest.TestCase):
|
|||
lambda: exp.update("tbl", {"x": 1}, from_="tbl2"),
|
||||
"UPDATE tbl SET x = 1 FROM tbl2",
|
||||
),
|
||||
(
|
||||
lambda: union("SELECT * FROM foo", "SELECT * FROM bla"),
|
||||
"SELECT * FROM foo UNION SELECT * FROM bla",
|
||||
),
|
||||
(
|
||||
lambda: parse_one("SELECT * FROM foo").union("SELECT * FROM bla"),
|
||||
"SELECT * FROM foo UNION SELECT * FROM bla",
|
||||
),
|
||||
(
|
||||
lambda: intersect("SELECT * FROM foo", "SELECT * FROM bla"),
|
||||
"SELECT * FROM foo INTERSECT SELECT * FROM bla",
|
||||
),
|
||||
(
|
||||
lambda: parse_one("SELECT * FROM foo").intersect("SELECT * FROM bla"),
|
||||
"SELECT * FROM foo INTERSECT SELECT * FROM bla",
|
||||
),
|
||||
(
|
||||
lambda: except_("SELECT * FROM foo", "SELECT * FROM bla"),
|
||||
"SELECT * FROM foo EXCEPT SELECT * FROM bla",
|
||||
),
|
||||
(
|
||||
lambda: parse_one("SELECT * FROM foo").except_("SELECT * FROM bla"),
|
||||
"SELECT * FROM foo EXCEPT SELECT * FROM bla",
|
||||
),
|
||||
(
|
||||
lambda: parse_one("(SELECT * FROM foo)").union("SELECT * FROM bla"),
|
||||
"(SELECT * FROM foo) UNION SELECT * FROM bla",
|
||||
),
|
||||
(
|
||||
lambda: parse_one("(SELECT * FROM foo)").union("SELECT * FROM bla", distinct=False),
|
||||
"(SELECT * FROM foo) UNION ALL SELECT * FROM bla",
|
||||
),
|
||||
(
|
||||
lambda: alias(parse_one("LAG(x) OVER (PARTITION BY y)"), "a"),
|
||||
"LAG(x) OVER (PARTITION BY y) AS a",
|
||||
),
|
||||
(
|
||||
lambda: alias(parse_one("LAG(x) OVER (ORDER BY z)"), "a"),
|
||||
"LAG(x) OVER (ORDER BY z) AS a",
|
||||
),
|
||||
(
|
||||
lambda: alias(parse_one("LAG(x) OVER (PARTITION BY y ORDER BY z)"), "a"),
|
||||
"LAG(x) OVER (PARTITION BY y ORDER BY z) AS a",
|
||||
),
|
||||
(
|
||||
lambda: alias(parse_one("LAG(x) OVER ()"), "a"),
|
||||
"LAG(x) OVER () AS a",
|
||||
),
|
||||
]:
|
||||
with self.subTest(sql):
|
||||
self.assertEqual(expression().sql(dialect[0] if dialect else None), sql)
|
||||
|
|
|
@ -115,6 +115,21 @@ class TestExpressions(unittest.TestCase):
|
|||
["first", "second", "third"],
|
||||
)
|
||||
|
||||
def test_table_name(self):
|
||||
self.assertEqual(exp.table_name(parse_one("a", into=exp.Table)), "a")
|
||||
self.assertEqual(exp.table_name(parse_one("a.b", into=exp.Table)), "a.b")
|
||||
self.assertEqual(exp.table_name(parse_one("a.b.c", into=exp.Table)), "a.b.c")
|
||||
self.assertEqual(exp.table_name("a.b.c"), "a.b.c")
|
||||
|
||||
def test_replace_tables(self):
|
||||
self.assertEqual(
|
||||
exp.replace_tables(
|
||||
parse_one("select * from a join b join c.a join d.a join e.a"),
|
||||
{"a": "a1", "b": "b.a", "c.a": "c.a2", "d.a": "d2"},
|
||||
).sql(),
|
||||
'SELECT * FROM "a1" JOIN "b"."a" JOIN "c"."a2" JOIN "d2" JOIN e.a',
|
||||
)
|
||||
|
||||
def test_named_selects(self):
|
||||
expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
|
||||
self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"])
|
||||
|
@ -474,3 +489,10 @@ class TestExpressions(unittest.TestCase):
|
|||
]:
|
||||
with self.subTest(value):
|
||||
self.assertEqual(exp.convert(value).sql(), expected)
|
||||
|
||||
def test_annotation_alias(self):
|
||||
expression = parse_one("SELECT a, b AS B, c #comment, d AS D #another_comment FROM foo")
|
||||
self.assertEqual(
|
||||
[e.alias_or_name for e in expression.expressions],
|
||||
["a", "B", "c", "D"],
|
||||
)
|
||||
|
|
|
@ -1,17 +1,55 @@
|
|||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import duckdb
|
||||
from pandas.testing import assert_frame_equal
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import exp, optimizer, parse_one, table
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
from sqlglot.optimizer.schema import MappingSchema, ensure_schema
|
||||
from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope
|
||||
from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures
|
||||
from tests.helpers import (
|
||||
TPCH_SCHEMA,
|
||||
load_sql_fixture_pairs,
|
||||
load_sql_fixtures,
|
||||
string_to_bool,
|
||||
)
|
||||
|
||||
|
||||
class TestOptimizer(unittest.TestCase):
|
||||
maxDiff = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.conn = duckdb.connect()
|
||||
cls.conn.execute(
|
||||
"""
|
||||
CREATE TABLE x (a INT, b INT);
|
||||
CREATE TABLE y (b INT, c INT);
|
||||
CREATE TABLE z (b INT, c INT);
|
||||
|
||||
INSERT INTO x VALUES (1, 1);
|
||||
INSERT INTO x VALUES (2, 2);
|
||||
INSERT INTO x VALUES (2, 2);
|
||||
INSERT INTO x VALUES (3, 3);
|
||||
INSERT INTO x VALUES (null, null);
|
||||
|
||||
INSERT INTO y VALUES (2, 2);
|
||||
INSERT INTO y VALUES (2, 2);
|
||||
INSERT INTO y VALUES (3, 3);
|
||||
INSERT INTO y VALUES (4, 4);
|
||||
INSERT INTO y VALUES (null, null);
|
||||
|
||||
INSERT INTO y VALUES (3, 3);
|
||||
INSERT INTO y VALUES (3, 3);
|
||||
INSERT INTO y VALUES (4, 4);
|
||||
INSERT INTO y VALUES (5, 5);
|
||||
INSERT INTO y VALUES (null, null);
|
||||
"""
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.schema = {
|
||||
"x": {
|
||||
|
@ -28,29 +66,42 @@ class TestOptimizer(unittest.TestCase):
|
|||
},
|
||||
}
|
||||
|
||||
def check_file(self, file, func, pretty=False, **kwargs):
|
||||
def check_file(self, file, func, pretty=False, execute=False, **kwargs):
|
||||
for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1):
|
||||
title = meta.get("title") or f"{i}, {sql}"
|
||||
dialect = meta.get("dialect")
|
||||
leave_tables_isolated = meta.get("leave_tables_isolated")
|
||||
|
||||
func_kwargs = {**kwargs}
|
||||
if leave_tables_isolated is not None:
|
||||
func_kwargs["leave_tables_isolated"] = leave_tables_isolated.lower() in ("true", "1")
|
||||
func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated)
|
||||
|
||||
with self.subTest(f"{i}, {sql}"):
|
||||
optimized = func(parse_one(sql, read=dialect), **func_kwargs)
|
||||
|
||||
with self.subTest(title):
|
||||
self.assertEqual(
|
||||
func(parse_one(sql, read=dialect), **func_kwargs).sql(pretty=pretty, dialect=dialect),
|
||||
optimized.sql(pretty=pretty, dialect=dialect),
|
||||
expected,
|
||||
)
|
||||
|
||||
should_execute = meta.get("execute")
|
||||
if should_execute is None:
|
||||
should_execute = execute
|
||||
|
||||
if string_to_bool(should_execute):
|
||||
with self.subTest(f"(execute) {title}"):
|
||||
df1 = self.conn.execute(sqlglot.transpile(sql, read=dialect, write="duckdb")[0]).df()
|
||||
df2 = self.conn.execute(optimized.sql(pretty=pretty, dialect="duckdb")).df()
|
||||
assert_frame_equal(df1, df2)
|
||||
|
||||
def test_optimize(self):
|
||||
schema = {
|
||||
"x": {"a": "INT", "b": "INT"},
|
||||
"y": {"a": "INT", "b": "INT"},
|
||||
"y": {"b": "INT", "c": "INT"},
|
||||
"z": {"a": "INT", "c": "INT"},
|
||||
}
|
||||
|
||||
self.check_file("optimizer", optimizer.optimize, pretty=True, schema=schema)
|
||||
self.check_file("optimizer", optimizer.optimize, pretty=True, execute=True, schema=schema)
|
||||
|
||||
def test_isolate_table_selects(self):
|
||||
self.check_file(
|
||||
|
@ -86,7 +137,16 @@ class TestOptimizer(unittest.TestCase):
|
|||
expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
|
||||
return expression
|
||||
|
||||
self.check_file("qualify_columns", qualify_columns, schema=self.schema)
|
||||
self.check_file("qualify_columns", qualify_columns, execute=True, schema=self.schema)
|
||||
|
||||
def test_qualify_columns__with_invisible(self):
|
||||
def qualify_columns(expression, **kwargs):
|
||||
expression = optimizer.qualify_tables.qualify_tables(expression)
|
||||
expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
|
||||
return expression
|
||||
|
||||
schema = MappingSchema(self.schema, {"x": {"a"}, "y": {"b"}, "z": {"b"}})
|
||||
self.check_file("qualify_columns__with_invisible", qualify_columns, schema=schema)
|
||||
|
||||
def test_qualify_columns__invalid(self):
|
||||
for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"):
|
||||
|
@ -141,7 +201,7 @@ class TestOptimizer(unittest.TestCase):
|
|||
],
|
||||
)
|
||||
|
||||
self.check_file("merge_subqueries", optimize, schema=self.schema)
|
||||
self.check_file("merge_subqueries", optimize, execute=True, schema=self.schema)
|
||||
|
||||
def test_eliminate_subqueries(self):
|
||||
self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries)
|
||||
|
@ -301,10 +361,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
}
|
||||
|
||||
for sql, target_type in tests.items():
|
||||
expression = parse_one(sql)
|
||||
annotated_expression = annotate_types(expression)
|
||||
|
||||
self.assertEqual(annotated_expression.find(exp.Literal).type, target_type)
|
||||
expression = annotate_types(parse_one(sql))
|
||||
self.assertEqual(expression.find(exp.Literal).type, target_type)
|
||||
|
||||
def test_boolean_type_annotation(self):
|
||||
tests = {
|
||||
|
@ -313,14 +371,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
}
|
||||
|
||||
for sql, target_type in tests.items():
|
||||
expression = parse_one(sql)
|
||||
annotated_expression = annotate_types(expression)
|
||||
|
||||
self.assertEqual(annotated_expression.find(exp.Boolean).type, target_type)
|
||||
expression = annotate_types(parse_one(sql))
|
||||
self.assertEqual(expression.find(exp.Boolean).type, target_type)
|
||||
|
||||
def test_cast_type_annotation(self):
|
||||
expression = parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))")
|
||||
annotate_types(expression)
|
||||
expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))"))
|
||||
|
||||
self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ)
|
||||
self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR)
|
||||
|
@ -328,16 +383,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT)
|
||||
|
||||
def test_cache_annotation(self):
|
||||
expression = parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")
|
||||
annotated_expression = annotate_types(expression)
|
||||
|
||||
self.assertEqual(annotated_expression.expression.expressions[0].type, exp.DataType.Type.INT)
|
||||
expression = annotate_types(parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1"))
|
||||
self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT)
|
||||
|
||||
def test_binary_annotation(self):
|
||||
expression = parse_one("SELECT 0.0 + (2 + 3)")
|
||||
annotate_types(expression)
|
||||
|
||||
expression = expression.expressions[0]
|
||||
expression = annotate_types(parse_one("SELECT 0.0 + (2 + 3)")).expressions[0]
|
||||
|
||||
self.assertEqual(expression.type, exp.DataType.Type.DOUBLE)
|
||||
self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE)
|
||||
|
@ -345,3 +395,124 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
self.assertEqual(expression.right.this.type, exp.DataType.Type.INT)
|
||||
self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT)
|
||||
self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT)
|
||||
|
||||
def test_derived_tables_column_annotation(self):
|
||||
schema = {"x": {"cola": "INT"}, "y": {"cola": "FLOAT"}}
|
||||
sql = """
|
||||
SELECT a.cola AS cola
|
||||
FROM (
|
||||
SELECT x.cola + y.cola AS cola
|
||||
FROM (
|
||||
SELECT x.cola AS cola
|
||||
FROM x AS x
|
||||
) AS x
|
||||
JOIN (
|
||||
SELECT y.cola AS cola
|
||||
FROM y AS y
|
||||
) AS y
|
||||
) AS a
|
||||
"""
|
||||
|
||||
expression = annotate_types(parse_one(sql), schema=schema)
|
||||
self.assertEqual(expression.expressions[0].type, exp.DataType.Type.FLOAT) # a.cola AS cola
|
||||
|
||||
addition_alias = expression.args["from"].expressions[0].this.expressions[0]
|
||||
self.assertEqual(addition_alias.type, exp.DataType.Type.FLOAT) # x.cola + y.cola AS cola
|
||||
|
||||
addition = addition_alias.this
|
||||
self.assertEqual(addition.type, exp.DataType.Type.FLOAT)
|
||||
self.assertEqual(addition.this.type, exp.DataType.Type.INT)
|
||||
self.assertEqual(addition.expression.type, exp.DataType.Type.FLOAT)
|
||||
|
||||
def test_cte_column_annotation(self):
|
||||
schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT"}}
|
||||
sql = """
|
||||
WITH tbl AS (
|
||||
SELECT x.cola + 'bla' AS cola, y.colb AS colb
|
||||
FROM (
|
||||
SELECT x.cola AS cola
|
||||
FROM x AS x
|
||||
) AS x
|
||||
JOIN (
|
||||
SELECT y.colb AS colb
|
||||
FROM y AS y
|
||||
) AS y
|
||||
)
|
||||
SELECT tbl.cola + tbl.colb + 'foo' AS col
|
||||
FROM tbl AS tbl
|
||||
"""
|
||||
|
||||
expression = annotate_types(parse_one(sql), schema=schema)
|
||||
self.assertEqual(expression.expressions[0].type, exp.DataType.Type.TEXT) # tbl.cola + tbl.colb + 'foo' AS col
|
||||
|
||||
outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo'
|
||||
self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT)
|
||||
self.assertEqual(outer_addition.left.type, exp.DataType.Type.TEXT)
|
||||
self.assertEqual(outer_addition.right.type, exp.DataType.Type.VARCHAR)
|
||||
|
||||
inner_addition = expression.expressions[0].this.left # tbl.cola + tbl.colb
|
||||
self.assertEqual(inner_addition.left.type, exp.DataType.Type.VARCHAR)
|
||||
self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT)
|
||||
|
||||
cte_select = expression.args["with"].expressions[0].this
|
||||
self.assertEqual(cte_select.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola + 'bla' AS cola
|
||||
self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb
|
||||
|
||||
cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla'
|
||||
self.assertEqual(cte_select_addition.type, exp.DataType.Type.VARCHAR)
|
||||
self.assertEqual(cte_select_addition.left.type, exp.DataType.Type.CHAR)
|
||||
self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR)
|
||||
|
||||
# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
|
||||
for d, t in zip(cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]):
|
||||
self.assertEqual(d.this.expressions[0].this.type, t)
|
||||
|
||||
def test_function_annotation(self):
|
||||
schema = {"x": {"cola": "VARCHAR", "colb": "CHAR"}}
|
||||
sql = "SELECT x.cola || TRIM(x.colb) AS col FROM x AS x"
|
||||
|
||||
concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
|
||||
self.assertEqual(concat_expr_alias.type, exp.DataType.Type.VARCHAR)
|
||||
|
||||
concat_expr = concat_expr_alias.this
|
||||
self.assertEqual(concat_expr.type, exp.DataType.Type.VARCHAR)
|
||||
self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
|
||||
self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
|
||||
self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb
|
||||
|
||||
def test_unknown_annotation(self):
|
||||
schema = {"x": {"cola": "VARCHAR"}}
|
||||
sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x"
|
||||
|
||||
concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
|
||||
self.assertEqual(concat_expr_alias.type, exp.DataType.Type.UNKNOWN)
|
||||
|
||||
concat_expr = concat_expr_alias.this
|
||||
self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN)
|
||||
self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
|
||||
self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) # SOME_ANONYMOUS_FUNC(x.cola)
|
||||
self.assertEqual(concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola (arg)
|
||||
|
||||
def test_null_annotation(self):
|
||||
expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this
|
||||
self.assertEqual(expression.left.type, exp.DataType.Type.NULL)
|
||||
self.assertEqual(expression.right.type, exp.DataType.Type.INT)
|
||||
|
||||
# NULL <op> UNKNOWN should yield NULL
|
||||
sql = "SELECT NULL || SOME_ANONYMOUS_FUNC() AS result"
|
||||
|
||||
concat_expr_alias = annotate_types(parse_one(sql)).expressions[0]
|
||||
self.assertEqual(concat_expr_alias.type, exp.DataType.Type.NULL)
|
||||
|
||||
concat_expr = concat_expr_alias.this
|
||||
self.assertEqual(concat_expr.type, exp.DataType.Type.NULL)
|
||||
self.assertEqual(concat_expr.left.type, exp.DataType.Type.NULL)
|
||||
self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN)
|
||||
|
||||
def test_nullable_annotation(self):
|
||||
nullable = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
|
||||
expression = annotate_types(parse_one("NULL AND FALSE"))
|
||||
|
||||
self.assertEqual(expression.type, nullable)
|
||||
self.assertEqual(expression.left.type, exp.DataType.Type.NULL)
|
||||
self.assertEqual(expression.right.type, exp.DataType.Type.BOOLEAN)
|
||||
|
|
|
@ -338,7 +338,7 @@ class TestTranspile(unittest.TestCase):
|
|||
unsupported_level=level,
|
||||
)
|
||||
|
||||
error = "Cannot convert array columns into map use SparkSQL instead."
|
||||
error = "Cannot convert array columns into map."
|
||||
|
||||
unsupported(ErrorLevel.WARN)
|
||||
assert_logger_contains("\n".join([error] * 4), logger, level="warning")
|
||||
|
|
Loading…
Add table
Reference in a new issue