1
0
Fork 0

Merging upstream version 6.3.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:45:11 +01:00
parent 81e6900b0a
commit 393757f998
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
41 changed files with 1558 additions and 267 deletions

View file

@ -1,6 +1,28 @@
Changelog Changelog
========= =========
v6.3.0
------
Changes:
- New: Snowflake [table literals](https://docs.snowflake.com/en/sql-reference/literals-table.html)
- New: Anti and semi joins
- New: Vacuum as a command
- New: Stored procedures
- New: Reweriting derived tables as CTES
- Improvement: Various clickhouse improvements
- Improvement: Optimizer predicate pushdown
- Breaking: DATE\_DIFF default renamed to DATEDIFF
v6.2.0 v6.2.0
------ ------

View file

@ -8,7 +8,9 @@ from sqlglot.expressions import (
and_, and_,
column, column,
condition, condition,
except_,
from_, from_,
intersect,
maybe_parse, maybe_parse,
not_, not_,
or_, or_,
@ -16,11 +18,12 @@ from sqlglot.expressions import (
subquery, subquery,
) )
from sqlglot.expressions import table_ as table from sqlglot.expressions import table_ as table
from sqlglot.expressions import union
from sqlglot.generator import Generator from sqlglot.generator import Generator
from sqlglot.parser import Parser from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType from sqlglot.tokens import Tokenizer, TokenType
__version__ = "6.2.8" __version__ = "6.3.1"
pretty = False pretty = False

View file

@ -135,6 +135,7 @@ class BigQuery(Dialect):
exp.DateSub: _date_add_sql("DATE", "SUB"), exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"), exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"), exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.ILike: no_ilike_sql, exp.ILike: no_ilike_sql,
exp.TimeAdd: _date_add_sql("TIME", "ADD"), exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"), exp.TimeSub: _date_add_sql("TIME", "SUB"),
@ -172,12 +173,11 @@ class BigQuery(Dialect):
exp.AnonymousProperty, exp.AnonymousProperty,
} }
EXPLICIT_UNION = True
def in_unnest_op(self, unnest): def in_unnest_op(self, unnest):
return self.sql(unnest) return self.sql(unnest)
def union_op(self, expression):
return f"UNION{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
def except_op(self, expression): def except_op(self, expression):
if not expression.args.get("distinct", False): if not expression.args.get("distinct", False):
self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery") self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery")

View file

@ -1,10 +1,16 @@
from sqlglot import exp from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, inline_array_sql from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
from sqlglot.generator import Generator from sqlglot.generator import Generator
from sqlglot.parser import Parser from sqlglot.helper import csv
from sqlglot.parser import Parser, parse_var_map
from sqlglot.tokens import Tokenizer, TokenType from sqlglot.tokens import Tokenizer, TokenType
def _lower_func(sql):
index = sql.index("(")
return sql[:index].lower() + sql[index:]
class ClickHouse(Dialect): class ClickHouse(Dialect):
normalize_functions = None normalize_functions = None
null_ordering = "nulls_are_last" null_ordering = "nulls_are_last"
@ -14,17 +20,23 @@ class ClickHouse(Dialect):
KEYWORDS = { KEYWORDS = {
**Tokenizer.KEYWORDS, **Tokenizer.KEYWORDS,
"NULLABLE": TokenType.NULLABLE,
"FINAL": TokenType.FINAL, "FINAL": TokenType.FINAL,
"DATETIME64": TokenType.DATETIME,
"INT8": TokenType.TINYINT, "INT8": TokenType.TINYINT,
"INT16": TokenType.SMALLINT, "INT16": TokenType.SMALLINT,
"INT32": TokenType.INT, "INT32": TokenType.INT,
"INT64": TokenType.BIGINT, "INT64": TokenType.BIGINT,
"FLOAT32": TokenType.FLOAT, "FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE, "FLOAT64": TokenType.DOUBLE,
"TUPLE": TokenType.STRUCT,
} }
class Parser(Parser): class Parser(Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
"MAP": parse_var_map,
}
def _parse_table(self, schema=False): def _parse_table(self, schema=False):
this = super()._parse_table(schema) this = super()._parse_table(schema)
@ -39,10 +51,25 @@ class ClickHouse(Dialect):
TYPE_MAPPING = { TYPE_MAPPING = {
**Generator.TYPE_MAPPING, **Generator.TYPE_MAPPING,
exp.DataType.Type.NULLABLE: "Nullable", exp.DataType.Type.NULLABLE: "Nullable",
exp.DataType.Type.DATETIME: "DateTime64",
exp.DataType.Type.MAP: "Map",
exp.DataType.Type.ARRAY: "Array",
exp.DataType.Type.STRUCT: "Tuple",
exp.DataType.Type.TINYINT: "Int8",
exp.DataType.Type.SMALLINT: "Int16",
exp.DataType.Type.INT: "Int32",
exp.DataType.Type.BIGINT: "Int64",
exp.DataType.Type.FLOAT: "Float32",
exp.DataType.Type.DOUBLE: "Float64",
} }
TRANSFORMS = { TRANSFORMS = {
**Generator.TRANSFORMS, **Generator.TRANSFORMS,
exp.Array: inline_array_sql, exp.Array: inline_array_sql,
exp.StrPosition: lambda self, e: f"position({csv(self.sql(e, 'this'), self.sql(e, 'substr'), self.sql(e, 'position'))})",
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
} }
EXPLICIT_UNION = True

View file

@ -77,7 +77,6 @@ class Dialect(metaclass=_Dialect):
alias_post_tablesample = False alias_post_tablesample = False
normalize_functions = "upper" normalize_functions = "upper"
null_ordering = "nulls_are_small" null_ordering = "nulls_are_small"
wrap_derived_values = True
date_format = "'%Y-%m-%d'" date_format = "'%Y-%m-%d'"
dateint_format = "'%Y%m%d'" dateint_format = "'%Y%m%d'"
@ -170,7 +169,6 @@ class Dialect(metaclass=_Dialect):
"alias_post_tablesample": self.alias_post_tablesample, "alias_post_tablesample": self.alias_post_tablesample,
"normalize_functions": self.normalize_functions, "normalize_functions": self.normalize_functions,
"null_ordering": self.null_ordering, "null_ordering": self.null_ordering,
"wrap_derived_values": self.wrap_derived_values,
**opts, **opts,
} }
) )
@ -271,6 +269,21 @@ def struct_extract_sql(self, expression):
return f"{this}.{struct_key}" return f"{this}.{struct_key}"
def var_map_sql(self, expression):
keys = expression.args["keys"]
values = expression.args["values"]
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
self.unsupported("Cannot convert array columns into map.")
return f"MAP({self.sql(keys)}, {self.sql(values)})"
args = []
for key, value in zip(keys.expressions, values.expressions):
args.append(self.sql(key))
args.append(self.sql(value))
return f"MAP({csv(*args)})"
def format_time_lambda(exp_class, dialect, default=None): def format_time_lambda(exp_class, dialect, default=None):
"""Helper used for time expressions. """Helper used for time expressions.

View file

@ -11,40 +11,14 @@ from sqlglot.dialects.dialect import (
no_trycast_sql, no_trycast_sql,
rename_func, rename_func,
struct_extract_sql, struct_extract_sql,
var_map_sql,
) )
from sqlglot.generator import Generator from sqlglot.generator import Generator
from sqlglot.helper import csv, list_get from sqlglot.helper import csv, list_get
from sqlglot.parser import Parser from sqlglot.parser import Parser, parse_var_map
from sqlglot.tokens import Tokenizer from sqlglot.tokens import Tokenizer
def _parse_map(args):
keys = []
values = []
for i in range(0, len(args), 2):
keys.append(args[i])
values.append(args[i + 1])
return HiveMap(
keys=exp.Array(expressions=keys),
values=exp.Array(expressions=values),
)
def _map_sql(self, expression):
keys = expression.args["keys"]
values = expression.args["values"]
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
self.unsupported("Cannot convert array columns into map use SparkSQL instead.")
return f"MAP({self.sql(keys)}, {self.sql(values)})"
args = []
for key, value in zip(keys.expressions, values.expressions):
args.append(self.sql(key))
args.append(self.sql(value))
return f"MAP({csv(*args)})"
def _array_sort(self, expression): def _array_sort(self, expression):
if expression.expression: if expression.expression:
self.unsupported("Hive SORT_ARRAY does not support a comparator") self.unsupported("Hive SORT_ARRAY does not support a comparator")
@ -122,10 +96,6 @@ def _index_sql(self, expression):
return f"{this} ON TABLE {table} {columns}" return f"{this} ON TABLE {table} {columns}"
class HiveMap(exp.Map):
is_var_len_args = True
class Hive(Dialect): class Hive(Dialect):
alias_post_tablesample = True alias_post_tablesample = True
@ -206,7 +176,7 @@ class Hive(Dialect):
position=list_get(args, 2), position=list_get(args, 2),
), ),
"LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)), "LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)),
"MAP": _parse_map, "MAP": parse_var_map,
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
"PERCENTILE": exp.Quantile.from_arg_list, "PERCENTILE": exp.Quantile.from_arg_list,
"PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list, "PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list,
@ -245,8 +215,8 @@ class Hive(Dialect):
exp.Join: _unnest_to_explode_sql, exp.Join: _unnest_to_explode_sql,
exp.JSONExtract: rename_func("GET_JSON_OBJECT"), exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.Map: _map_sql, exp.Map: var_map_sql,
HiveMap: _map_sql, exp.VarMap: var_map_sql,
exp.Create: create_with_partitions_sql, exp.Create: create_with_partitions_sql,
exp.Quantile: rename_func("PERCENTILE"), exp.Quantile: rename_func("PERCENTILE"),
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"), exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),

View file

@ -10,6 +10,32 @@ def _limit_sql(self, expression):
class Oracle(Dialect): class Oracle(Dialect):
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
time_mapping = {
"AM": "%p", # Meridian indicator with or without periods
"A.M.": "%p", # Meridian indicator with or without periods
"PM": "%p", # Meridian indicator with or without periods
"P.M.": "%p", # Meridian indicator with or without periods
"D": "%u", # Day of week (1-7)
"DAY": "%A", # name of day
"DD": "%d", # day of month (1-31)
"DDD": "%j", # day of year (1-366)
"DY": "%a", # abbreviated name of day
"HH": "%I", # Hour of day (1-12)
"HH12": "%I", # alias for HH
"HH24": "%H", # Hour of day (0-23)
"IW": "%V", # Calendar week of year (1-52 or 1-53), as defined by the ISO 8601 standard
"MI": "%M", # Minute (0-59)
"MM": "%m", # Month (01-12; January = 01)
"MON": "%b", # Abbreviated name of month
"MONTH": "%B", # Name of month
"SS": "%S", # Second (0-59)
"WW": "%W", # Week of year (1-53)
"YY": "%y", # 15
"YYYY": "%Y", # 2015
}
class Generator(Generator): class Generator(Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**Generator.TYPE_MAPPING, **Generator.TYPE_MAPPING,
@ -30,6 +56,9 @@ class Oracle(Dialect):
**transforms.UNALIAS_GROUP, **transforms.UNALIAS_GROUP,
exp.ILike: no_ilike_sql, exp.ILike: no_ilike_sql,
exp.Limit: _limit_sql, exp.Limit: _limit_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
} }
def query_modifiers(self, expression, *sqls): def query_modifiers(self, expression, *sqls):

View file

@ -118,13 +118,22 @@ def _serial_to_generated(expression):
return expression return expression
def _to_timestamp(args):
# TO_TIMESTAMP accepts either a single double argument or (text, text)
if len(args) == 1 and args[0].is_number:
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE
return exp.UnixToTime.from_arg_list(args)
# https://www.postgresql.org/docs/current/functions-formatting.html
return format_time_lambda(exp.StrToTime, "postgres")(args)
class Postgres(Dialect): class Postgres(Dialect):
null_ordering = "nulls_are_large" null_ordering = "nulls_are_large"
time_format = "'YYYY-MM-DD HH24:MI:SS'" time_format = "'YYYY-MM-DD HH24:MI:SS'"
time_mapping = { time_mapping = {
"AM": "%p", "AM": "%p",
"PM": "%p", "PM": "%p",
"D": "%w", # 1-based day of week "D": "%u", # 1-based day of week
"DD": "%d", # day of month "DD": "%d", # day of month
"DDD": "%j", # zero padded day of year "DDD": "%j", # zero padded day of year
"FMDD": "%-d", # - is no leading zero for Python; same for FM in postgres "FMDD": "%-d", # - is no leading zero for Python; same for FM in postgres
@ -172,7 +181,7 @@ class Postgres(Dialect):
FUNCTIONS = { FUNCTIONS = {
**Parser.FUNCTIONS, **Parser.FUNCTIONS,
"TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"), "TO_TIMESTAMP": _to_timestamp,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
} }
@ -211,4 +220,5 @@ class Postgres(Dialect):
exp.TableSample: no_tablesample_sql, exp.TableSample: no_tablesample_sql,
exp.Trim: _trim_sql, exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql, exp.TryCast: no_trycast_sql,
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
} }

View file

@ -121,6 +121,7 @@ class Snowflake(Dialect):
FUNC_TOKENS = { FUNC_TOKENS = {
*Parser.FUNC_TOKENS, *Parser.FUNC_TOKENS,
TokenType.RLIKE, TokenType.RLIKE,
TokenType.TABLE,
} }
COLUMN_OPERATORS = { COLUMN_OPERATORS = {
@ -143,7 +144,7 @@ class Snowflake(Dialect):
SINGLE_TOKENS = { SINGLE_TOKENS = {
**Tokenizer.SINGLE_TOKENS, **Tokenizer.SINGLE_TOKENS,
"$": TokenType.DOLLAR, # needed to break for quotes "$": TokenType.PARAMETER,
} }
KEYWORDS = { KEYWORDS = {
@ -164,6 +165,8 @@ class Snowflake(Dialect):
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time, exp.UnixToTime: _unix_to_time,
exp.Array: inline_array_sql, exp.Array: inline_array_sql,
exp.StrPosition: rename_func("POSITION"),
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}", exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
} }

View file

@ -4,8 +4,9 @@ from sqlglot.dialects.dialect import (
no_ilike_sql, no_ilike_sql,
rename_func, rename_func,
) )
from sqlglot.dialects.hive import Hive, HiveMap from sqlglot.dialects.hive import Hive
from sqlglot.helper import list_get from sqlglot.helper import list_get
from sqlglot.parser import Parser
def _create_sql(self, e): def _create_sql(self, e):
@ -47,8 +48,6 @@ def _unix_to_time(self, expression):
class Spark(Hive): class Spark(Hive):
wrap_derived_values = False
class Parser(Hive.Parser): class Parser(Hive.Parser):
FUNCTIONS = { FUNCTIONS = {
**Hive.Parser.FUNCTIONS, **Hive.Parser.FUNCTIONS,
@ -78,8 +77,19 @@ class Spark(Hive):
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
} }
class Generator(Hive.Generator): FUNCTION_PARSERS = {
**Parser.FUNCTION_PARSERS,
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
"MERGE": lambda self: self._parse_join_hint("MERGE"),
"SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
"MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
"SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
}
class Generator(Hive.Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING, **Hive.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "BYTE", exp.DataType.Type.TINYINT: "BYTE",
@ -102,8 +112,9 @@ class Spark(Hive):
exp.Map: _map_sql, exp.Map: _map_sql,
exp.Reduce: rename_func("AGGREGATE"), exp.Reduce: rename_func("AGGREGATE"),
exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}", exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
HiveMap: _map_sql,
} }
WRAP_DERIVED_VALUES = False
class Tokenizer(Hive.Tokenizer): class Tokenizer(Hive.Tokenizer):
HEX_STRINGS = [("X'", "'")] HEX_STRINGS = [("X'", "'")]

View file

@ -326,6 +326,7 @@ class Python(Dialect):
exp.Alias: lambda self, e: self.sql(e.this), exp.Alias: lambda self, e: self.sql(e.this),
exp.Array: inline_array_sql, exp.Array: inline_array_sql,
exp.And: lambda self, e: self.binary(e, "and"), exp.And: lambda self, e: self.binary(e, "and"),
exp.Boolean: lambda self, e: "True" if e.this else "False",
exp.Cast: _cast_py, exp.Cast: _cast_py,
exp.Column: _column_py, exp.Column: _column_py,
exp.EQ: lambda self, e: self.binary(e, "=="), exp.EQ: lambda self, e: self.binary(e, "=="),

View file

@ -508,7 +508,69 @@ class DerivedTable(Expression):
return [select.alias_or_name for select in self.selects] return [select.alias_or_name for select in self.selects]
class UDTF(DerivedTable): class Unionable:
def union(self, expression, distinct=True, dialect=None, **opts):
"""
Builds a UNION expression.
Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT * FROM foo").union("SELECT * FROM bla").sql()
'SELECT * FROM foo UNION SELECT * FROM bla'
Args:
expression (str or Expression): the SQL code string.
If an `Expression` instance is passed, it will be used as-is.
distinct (bool): set the DISTINCT flag if and only if this is true.
dialect (str): the dialect used to parse the input expression.
opts (kwargs): other options to use to parse the input expressions.
Returns:
Union: the Union expression.
"""
return union(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
def intersect(self, expression, distinct=True, dialect=None, **opts):
"""
Builds an INTERSECT expression.
Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT * FROM foo").intersect("SELECT * FROM bla").sql()
'SELECT * FROM foo INTERSECT SELECT * FROM bla'
Args:
expression (str or Expression): the SQL code string.
If an `Expression` instance is passed, it will be used as-is.
distinct (bool): set the DISTINCT flag if and only if this is true.
dialect (str): the dialect used to parse the input expression.
opts (kwargs): other options to use to parse the input expressions.
Returns:
Intersect: the Intersect expression
"""
return intersect(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
def except_(self, expression, distinct=True, dialect=None, **opts):
"""
Builds an EXCEPT expression.
Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT * FROM foo").except_("SELECT * FROM bla").sql()
'SELECT * FROM foo EXCEPT SELECT * FROM bla'
Args:
expression (str or Expression): the SQL code string.
If an `Expression` instance is passed, it will be used as-is.
distinct (bool): set the DISTINCT flag if and only if this is true.
dialect (str): the dialect used to parse the input expression.
opts (kwargs): other options to use to parse the input expressions.
Returns:
Except: the Except expression
"""
return except_(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
class UDTF(DerivedTable, Unionable):
pass pass
@ -518,6 +580,10 @@ class Annotation(Expression):
"expression": True, "expression": True,
} }
@property
def alias(self):
return self.expression.alias_or_name
class Cache(Expression): class Cache(Expression):
arg_types = { arg_types = {
@ -700,6 +766,10 @@ class Hint(Expression):
arg_types = {"expressions": True} arg_types = {"expressions": True}
class JoinHint(Expression):
arg_types = {"this": True, "expressions": True}
class Identifier(Expression): class Identifier(Expression):
arg_types = {"this": True, "quoted": False} arg_types = {"this": True, "quoted": False}
@ -971,7 +1041,7 @@ class Tuple(Expression):
arg_types = {"expressions": False} arg_types = {"expressions": False}
class Subqueryable: class Subqueryable(Unionable):
def subquery(self, alias=None, copy=True): def subquery(self, alias=None, copy=True):
""" """
Convert this expression to an aliased expression that can be used as a Subquery. Convert this expression to an aliased expression that can be used as a Subquery.
@ -1654,7 +1724,7 @@ class Select(Subqueryable, Expression):
return self.expressions return self.expressions
class Subquery(DerivedTable): class Subquery(DerivedTable, Unionable):
arg_types = { arg_types = {
"this": True, "this": True,
"alias": False, "alias": False,
@ -1731,7 +1801,7 @@ class Parameter(Expression):
class Placeholder(Expression): class Placeholder(Expression):
arg_types = {} arg_types = {"this": False}
class Null(Condition): class Null(Condition):
@ -1791,6 +1861,8 @@ class DataType(Expression):
IMAGE = auto() IMAGE = auto()
VARIANT = auto() VARIANT = auto()
OBJECT = auto() OBJECT = auto()
NULL = auto()
UNKNOWN = auto() # Sentinel value, useful for type annotation
@classmethod @classmethod
def build(cls, dtype, **kwargs): def build(cls, dtype, **kwargs):
@ -2007,7 +2079,7 @@ class Distinct(Expression):
class In(Predicate): class In(Predicate):
arg_types = {"this": True, "expressions": False, "query": False, "unnest": False} arg_types = {"this": True, "expressions": False, "query": False, "unnest": False, "field": False}
class TimeUnit(Expression): class TimeUnit(Expression):
@ -2377,6 +2449,11 @@ class Map(Func):
arg_types = {"keys": True, "values": True} arg_types = {"keys": True, "values": True}
class VarMap(Func):
arg_types = {"keys": True, "values": True}
is_var_len_args = True
class Max(AggFunc): class Max(AggFunc):
pass pass
@ -2449,7 +2526,7 @@ class Substring(Func):
class StrPosition(Func): class StrPosition(Func):
arg_types = {"this": True, "substr": True, "position": False} arg_types = {"substr": True, "this": True, "position": False}
class StrToDate(Func): class StrToDate(Func):
@ -2785,6 +2862,81 @@ def _wrap_operator(expression):
return expression return expression
def union(left, right, distinct=True, dialect=None, **opts):
"""
Initializes a syntax tree from one UNION expression.
Example:
>>> union("SELECT * FROM foo", "SELECT * FROM bla").sql()
'SELECT * FROM foo UNION SELECT * FROM bla'
Args:
left (str or Expression): the SQL code string corresponding to the left-hand side.
If an `Expression` instance is passed, it will be used as-is.
right (str or Expression): the SQL code string corresponding to the right-hand side.
If an `Expression` instance is passed, it will be used as-is.
distinct (bool): set the DISTINCT flag if and only if this is true.
dialect (str): the dialect used to parse the input expression.
opts (kwargs): other options to use to parse the input expressions.
Returns:
Union: the syntax tree for the UNION expression.
"""
left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
return Union(this=left, expression=right, distinct=distinct)
def intersect(left, right, distinct=True, dialect=None, **opts):
"""
Initializes a syntax tree from one INTERSECT expression.
Example:
>>> intersect("SELECT * FROM foo", "SELECT * FROM bla").sql()
'SELECT * FROM foo INTERSECT SELECT * FROM bla'
Args:
left (str or Expression): the SQL code string corresponding to the left-hand side.
If an `Expression` instance is passed, it will be used as-is.
right (str or Expression): the SQL code string corresponding to the right-hand side.
If an `Expression` instance is passed, it will be used as-is.
distinct (bool): set the DISTINCT flag if and only if this is true.
dialect (str): the dialect used to parse the input expression.
opts (kwargs): other options to use to parse the input expressions.
Returns:
Intersect: the syntax tree for the INTERSECT expression.
"""
left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
return Intersect(this=left, expression=right, distinct=distinct)
def except_(left, right, distinct=True, dialect=None, **opts):
"""
Initializes a syntax tree from one EXCEPT expression.
Example:
>>> except_("SELECT * FROM foo", "SELECT * FROM bla").sql()
'SELECT * FROM foo EXCEPT SELECT * FROM bla'
Args:
left (str or Expression): the SQL code string corresponding to the left-hand side.
If an `Expression` instance is passed, it will be used as-is.
right (str or Expression): the SQL code string corresponding to the right-hand side.
If an `Expression` instance is passed, it will be used as-is.
distinct (bool): set the DISTINCT flag if and only if this is true.
dialect (str): the dialect used to parse the input expression.
opts (kwargs): other options to use to parse the input expressions.
Returns:
Except: the syntax tree for the EXCEPT statement.
"""
left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
return Except(this=left, expression=right, distinct=distinct)
def select(*expressions, dialect=None, **opts): def select(*expressions, dialect=None, **opts):
""" """
Initializes a syntax tree from one or multiple SELECT expressions. Initializes a syntax tree from one or multiple SELECT expressions.
@ -2991,7 +3143,7 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
If an Expression instance is passed, this is used as-is. If an Expression instance is passed, this is used as-is.
alias (str or Identifier): the alias name to use. If the name has alias (str or Identifier): the alias name to use. If the name has
special characters it is quoted. special characters it is quoted.
table (boolean): create a table alias, default false table (bool): create a table alias, default false
dialect (str): the dialect used to parse the input expression. dialect (str): the dialect used to parse the input expression.
**opts: other options to use to parse the input expressions. **opts: other options to use to parse the input expressions.
@ -3002,7 +3154,7 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
alias = to_identifier(alias, quoted=quoted) alias = to_identifier(alias, quoted=quoted)
alias = TableAlias(this=alias) if table else alias alias = TableAlias(this=alias) if table else alias
if "alias" in exp.arg_types: if "alias" in exp.arg_types and not isinstance(exp, Window):
exp = exp.copy() exp = exp.copy()
exp.set("alias", alias) exp.set("alias", alias)
return exp return exp
@ -3138,6 +3290,60 @@ def column_table_names(expression):
return list(dict.fromkeys(column.table for column in expression.find_all(Column))) return list(dict.fromkeys(column.table for column in expression.find_all(Column)))
def table_name(table):
"""Get the full name of a table as a string.
Args:
table (exp.Table | str): Table expression node or string.
Examples:
>>> from sqlglot import exp, parse_one
>>> table_name(parse_one("select * from a.b.c").find(exp.Table))
'a.b.c'
Returns:
str: the table name
"""
table = maybe_parse(table, into=Table)
return ".".join(
part
for part in (
table.text("catalog"),
table.text("db"),
table.name,
)
if part
)
def replace_tables(expression, mapping):
"""Replace all tables in expression according to the mapping.
Args:
expression (sqlglot.Expression): Expression node to be transformed and replaced
mapping (Dict[str, str]): Mapping of table names
Examples:
>>> from sqlglot import exp, parse_one
>>> replace_tables(parse_one("select * from a.b"), {"a.b": "c"}).sql()
'SELECT * FROM "c"'
Returns:
The mapped expression
"""
def _replace_tables(node):
if isinstance(node, Table):
new_name = mapping.get(table_name(node))
if new_name:
return table_(*reversed(new_name.split(".")), quoted=True)
return node
return expression.transform(_replace_tables)
TRUE = Boolean(this=True) TRUE = Boolean(this=True)
FALSE = Boolean(this=False) FALSE = Boolean(this=False)
NULL = Null() NULL = Null()

View file

@ -48,8 +48,9 @@ class Generator:
TRANSFORMS = { TRANSFORMS = {
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}", exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.DateDiff: lambda self, e: f"DATEDIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
exp.VarMap: lambda self, e: f"MAP({self.sql(e.args['keys'])}, {self.sql(e.args['values'])})",
exp.LanguageProperty: lambda self, e: self.naked_property(e), exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: self.naked_property(e), exp.ReturnsProperty: lambda self, e: self.naked_property(e),
@ -57,7 +58,12 @@ class Generator:
exp.VolatilityProperty: lambda self, e: self.sql(e.name), exp.VolatilityProperty: lambda self, e: self.sql(e.name),
} }
# whether or not null ordering is supported in order by
NULL_ORDERING_SUPPORTED = True NULL_ORDERING_SUPPORTED = True
# always do union distinct or union all
EXPLICIT_UNION = False
# wrap derived values in parens, usually standard but spark doesn't support it
WRAP_DERIVED_VALUES = True
TYPE_MAPPING = { TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NCHAR: "CHAR",
@ -101,7 +107,6 @@ class Generator:
"unsupported_messages", "unsupported_messages",
"null_ordering", "null_ordering",
"max_unsupported", "max_unsupported",
"wrap_derived_values",
"_indent", "_indent",
"_replace_backslash", "_replace_backslash",
"_escaped_quote_end", "_escaped_quote_end",
@ -130,7 +135,6 @@ class Generator:
null_ordering=None, null_ordering=None,
max_unsupported=3, max_unsupported=3,
leading_comma=False, leading_comma=False,
wrap_derived_values=True,
): ):
import sqlglot import sqlglot
@ -154,7 +158,6 @@ class Generator:
self.unsupported_messages = [] self.unsupported_messages = []
self.max_unsupported = max_unsupported self.max_unsupported = max_unsupported
self.null_ordering = null_ordering self.null_ordering = null_ordering
self.wrap_derived_values = wrap_derived_values
self._indent = indent self._indent = indent
self._replace_backslash = self.escape == "\\" self._replace_backslash = self.escape == "\\"
self._escaped_quote_end = self.escape + self.quote_end self._escaped_quote_end = self.escape + self.quote_end
@ -595,7 +598,7 @@ class Generator:
if not alias: if not alias:
return f"VALUES{self.seg('')}{args}" return f"VALUES{self.seg('')}{args}"
alias = f" AS {alias}" if alias else alias alias = f" AS {alias}" if alias else alias
if self.wrap_derived_values: if self.WRAP_DERIVED_VALUES:
return f"(VALUES{self.seg('')}{args}){alias}" return f"(VALUES{self.seg('')}{args}){alias}"
return f"VALUES{self.seg('')}{args}{alias}" return f"VALUES{self.seg('')}{args}{alias}"
@ -779,8 +782,8 @@ class Generator:
def parameter_sql(self, expression): def parameter_sql(self, expression):
return f"@{self.sql(expression, 'this')}" return f"@{self.sql(expression, 'this')}"
def placeholder_sql(self, *_): def placeholder_sql(self, expression):
return "?" return f":{expression.name}" if expression.name else "?"
def subquery_sql(self, expression): def subquery_sql(self, expression):
alias = self.sql(expression, "alias") alias = self.sql(expression, "alias")
@ -803,7 +806,9 @@ class Generator:
) )
def union_op(self, expression): def union_op(self, expression):
return f"UNION{'' if expression.args.get('distinct') else ' ALL'}" kind = " DISTINCT" if self.EXPLICIT_UNION else ""
kind = kind if expression.args.get("distinct") else " ALL"
return f"UNION{kind}"
def unnest_sql(self, expression): def unnest_sql(self, expression):
args = self.expressions(expression, flat=True) args = self.expressions(expression, flat=True)
@ -940,10 +945,13 @@ class Generator:
def in_sql(self, expression): def in_sql(self, expression):
query = expression.args.get("query") query = expression.args.get("query")
unnest = expression.args.get("unnest") unnest = expression.args.get("unnest")
field = expression.args.get("field")
if query: if query:
in_sql = self.wrap(query) in_sql = self.wrap(query)
elif unnest: elif unnest:
in_sql = self.in_unnest_op(unnest) in_sql = self.in_unnest_op(unnest)
elif field:
in_sql = self.sql(field)
else: else:
in_sql = f"({self.expressions(expression, flat=True)})" in_sql = f"({self.expressions(expression, flat=True)})"
return f"{self.sql(expression, 'this')} IN {in_sql}" return f"{self.sql(expression, 'this')} IN {in_sql}"
@ -1178,3 +1186,8 @@ class Generator:
this = self.sql(expression, "this") this = self.sql(expression, "this")
kind = self.sql(expression, "kind") kind = self.sql(expression, "kind")
return f"{this} {kind}" return f"{this} {kind}"
def joinhint_sql(self, expression):
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
return f"{this}({expressions})"

View file

@ -1,16 +1,20 @@
from sqlglot import exp from sqlglot import exp
from sqlglot.helper import ensure_list, subclasses from sqlglot.helper import ensure_list, subclasses
from sqlglot.optimizer.schema import ensure_schema
from sqlglot.optimizer.scope import Scope, traverse_scope
def annotate_types(expression, schema=None, annotators=None, coerces_to=None): def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
""" """
Recursively infer & annotate types in an expression syntax tree against a schema. Recursively infer & annotate types in an expression syntax tree against a schema.
Assumes that we've already executed the optimizer's qualify_columns step.
(TODO -- replace this with a better example after adding some functionality)
Example: Example:
>>> import sqlglot >>> import sqlglot
>>> annotated_expression = annotate_types(sqlglot.parse_one('5 + 5.3')) >>> schema = {"y": {"cola": "SMALLINT"}}
>>> annotated_expression.type >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
>>> annotated_expr.expressions[0].type # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'> <Type.DOUBLE: 'DOUBLE'>
Args: Args:
@ -22,6 +26,8 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
sqlglot.Expression: expression annotated with types sqlglot.Expression: expression annotated with types
""" """
schema = ensure_schema(schema)
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
@ -35,10 +41,81 @@ class TypeAnnotator:
expr_type: lambda self, expr: self._annotate_binary(expr) expr_type: lambda self, expr: self._annotate_binary(expr)
for expr_type in subclasses(exp.__name__, exp.Binary) for expr_type in subclasses(exp.__name__, exp.Binary)
}, },
exp.Cast: lambda self, expr: self._annotate_cast(expr), exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"].this),
exp.DataType: lambda self, expr: self._annotate_data_type(expr), exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.this),
exp.Alias: lambda self, expr: self._annotate_unary(expr),
exp.Literal: lambda self, expr: self._annotate_literal(expr), exp.Literal: lambda self, expr: self._annotate_literal(expr),
exp.Boolean: lambda self, expr: self._annotate_boolean(expr), exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
} }
# Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
@ -97,43 +174,82 @@ class TypeAnnotator:
}, },
} }
TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
def __init__(self, schema=None, annotators=None, coerces_to=None): def __init__(self, schema=None, annotators=None, coerces_to=None):
self.schema = schema self.schema = schema
self.annotators = annotators or self.ANNOTATORS self.annotators = annotators or self.ANNOTATORS
self.coerces_to = coerces_to or self.COERCES_TO self.coerces_to = coerces_to or self.COERCES_TO
def annotate(self, expression): def annotate(self, expression):
if isinstance(expression, self.TRAVERSABLES):
for scope in traverse_scope(expression):
subscope_selects = {
name: {select.alias_or_name: select for select in source.selects}
for name, source in scope.sources.items()
if isinstance(source, Scope)
}
# First annotate the current scope's column references
for col in scope.columns:
source = scope.sources[col.table]
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
else:
col.type = subscope_selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
def _maybe_annotate(self, expression):
if not isinstance(expression, exp.Expression): if not isinstance(expression, exp.Expression):
return None return None
if expression.type:
return expression # We've already inferred the expression's type
annotator = self.annotators.get(expression.__class__) annotator = self.annotators.get(expression.__class__)
return annotator(self, expression) if annotator else self._annotate_args(expression) return (
annotator(self, expression)
if annotator
else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
)
def _annotate_args(self, expression): def _annotate_args(self, expression):
for value in expression.args.values(): for value in expression.args.values():
for v in ensure_list(value): for v in ensure_list(value):
self.annotate(v) self._maybe_annotate(v)
return expression return expression
def _annotate_cast(self, expression):
expression.type = expression.args["to"].this
return self._annotate_args(expression)
def _annotate_data_type(self, expression):
expression.type = expression.this
return self._annotate_args(expression)
def _maybe_coerce(self, type1, type2): def _maybe_coerce(self, type1, type2):
# We propagate the NULL / UNKNOWN types upwards if found
if exp.DataType.Type.NULL in (type1, type2):
return exp.DataType.Type.NULL
if exp.DataType.Type.UNKNOWN in (type1, type2):
return exp.DataType.Type.UNKNOWN
return type2 if type2 in self.coerces_to[type1] else type1 return type2 if type2 in self.coerces_to[type1] else type1
def _annotate_binary(self, expression): def _annotate_binary(self, expression):
self._annotate_args(expression) self._annotate_args(expression)
if isinstance(expression, (exp.Condition, exp.Predicate)): left_type = expression.left.type
right_type = expression.right.type
if isinstance(expression, (exp.And, exp.Or)):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
expression.type = exp.DataType.Type.NULL
elif exp.DataType.Type.NULL in (left_type, right_type):
expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
else:
expression.type = exp.DataType.Type.BOOLEAN
elif isinstance(expression, (exp.Condition, exp.Predicate)):
expression.type = exp.DataType.Type.BOOLEAN expression.type = exp.DataType.Type.BOOLEAN
else: else:
expression.type = self._maybe_coerce(expression.left.type, expression.right.type) expression.type = self._maybe_coerce(left_type, right_type)
return expression return expression
@ -157,6 +273,6 @@ class TypeAnnotator:
return expression return expression
def _annotate_boolean(self, expression): def _annotate_with_type(self, expression, target_type):
expression.type = exp.DataType.Type.BOOLEAN expression.type = target_type
return expression return self._annotate_args(expression)

View file

@ -44,6 +44,7 @@ UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
"joins", "joins",
"where", "where",
"order", "order",
"hint",
} }
@ -67,21 +68,22 @@ def merge_ctes(expression, leave_tables_isolated=False):
singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
for outer_scope, inner_scope, table in singular_cte_selections: for outer_scope, inner_scope, table in singular_cte_selections:
inner_select = inner_scope.expression.unnest() inner_select = inner_scope.expression.unnest()
if _mergeable(outer_scope, inner_select, leave_tables_isolated):
from_or_join = table.find_ancestor(exp.From, exp.Join) from_or_join = table.find_ancestor(exp.From, exp.Join)
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
node_to_replace = table node_to_replace = table
if isinstance(node_to_replace.parent, exp.Alias): if isinstance(node_to_replace.parent, exp.Alias):
node_to_replace = node_to_replace.parent node_to_replace = node_to_replace.parent
alias = node_to_replace.alias alias = node_to_replace.alias
else: else:
alias = table.name alias = table.name
_rename_inner_sources(outer_scope, inner_scope, alias) _rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, node_to_replace, alias) _merge_from(outer_scope, inner_scope, node_to_replace, alias)
_merge_expressions(outer_scope, inner_scope, alias) _merge_expressions(outer_scope, inner_scope, alias)
_merge_joins(outer_scope, inner_scope, from_or_join) _merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join) _merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope) _merge_order(outer_scope, inner_scope)
_merge_hints(outer_scope, inner_scope)
_pop_cte(inner_scope) _pop_cte(inner_scope)
return expression return expression
@ -90,9 +92,9 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
for outer_scope in traverse_scope(expression): for outer_scope in traverse_scope(expression):
for subquery in outer_scope.derived_tables: for subquery in outer_scope.derived_tables:
inner_select = subquery.unnest() inner_select = subquery.unnest()
if _mergeable(outer_scope, inner_select, leave_tables_isolated):
alias = subquery.alias_or_name
from_or_join = subquery.find_ancestor(exp.From, exp.Join) from_or_join = subquery.find_ancestor(exp.From, exp.Join)
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
alias = subquery.alias_or_name
inner_scope = outer_scope.sources[alias] inner_scope = outer_scope.sources[alias]
_rename_inner_sources(outer_scope, inner_scope, alias) _rename_inner_sources(outer_scope, inner_scope, alias)
@ -101,10 +103,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
_merge_joins(outer_scope, inner_scope, from_or_join) _merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join) _merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope) _merge_order(outer_scope, inner_scope)
_merge_hints(outer_scope, inner_scope)
return expression return expression
def _mergeable(outer_scope, inner_select, leave_tables_isolated): def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
""" """
Return True if `inner_select` can be merged into outer query. Return True if `inner_select` can be merged into outer query.
@ -112,6 +115,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated):
outer_scope (Scope) outer_scope (Scope)
inner_select (exp.Select) inner_select (exp.Select)
leave_tables_isolated (bool) leave_tables_isolated (bool)
from_or_join (exp.From|exp.Join)
Returns: Returns:
bool: True if can be merged bool: True if can be merged
""" """
@ -123,6 +127,16 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated):
and inner_select.args.get("from") and inner_select.args.get("from")
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
and not (
isinstance(from_or_join, exp.Join)
and inner_select.args.get("where")
and from_or_join.side in {"FULL", "LEFT", "RIGHT"}
)
and not (
isinstance(from_or_join, exp.From)
and inner_select.args.get("where")
and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []))
)
) )
@ -170,6 +184,12 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
""" """
new_subquery = inner_scope.expression.args.get("from").expressions[0] new_subquery = inner_scope.expression.args.get("from").expressions[0]
node_to_replace.replace(new_subquery) node_to_replace.replace(new_subquery)
for join_hint in outer_scope.join_hints:
tables = join_hint.find_all(exp.Table)
for table in tables:
if table.alias_or_name == node_to_replace.alias_or_name:
new_table = new_subquery.this if isinstance(new_subquery, exp.Alias) else new_subquery
table.set("this", exp.to_identifier(new_table.alias_or_name))
outer_scope.remove_source(alias) outer_scope.remove_source(alias)
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]) outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
@ -273,6 +293,18 @@ def _merge_order(outer_scope, inner_scope):
outer_scope.expression.set("order", inner_scope.expression.args.get("order")) outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
def _merge_hints(outer_scope, inner_scope):
inner_scope_hint = inner_scope.expression.args.get("hint")
if not inner_scope_hint:
return
outer_scope_hint = outer_scope.expression.args.get("hint")
if outer_scope_hint:
for hint_expression in inner_scope_hint.expressions:
outer_scope_hint.append("expressions", hint_expression)
else:
outer_scope.expression.set("hint", inner_scope_hint)
def _pop_cte(inner_scope): def _pop_cte(inner_scope):
""" """
Remove CTE from the AST. Remove CTE from the AST.

View file

@ -1,3 +1,5 @@
from collections import defaultdict
from sqlglot import exp from sqlglot import exp
from sqlglot.optimizer.normalize import normalized from sqlglot.optimizer.normalize import normalized
from sqlglot.optimizer.scope import traverse_scope from sqlglot.optimizer.scope import traverse_scope
@ -20,22 +22,30 @@ def pushdown_predicates(expression):
Returns: Returns:
sqlglot.Expression: optimized expression sqlglot.Expression: optimized expression
""" """
for scope in reversed(traverse_scope(expression)): scope_ref_count = defaultdict(lambda: 0)
scopes = traverse_scope(expression)
scopes.reverse()
for scope in scopes:
for _, source in scope.selected_sources.values():
scope_ref_count[id(source)] += 1
for scope in scopes:
select = scope.expression select = scope.expression
where = select.args.get("where") where = select.args.get("where")
if where: if where:
pushdown(where.this, scope.selected_sources) pushdown(where.this, scope.selected_sources, scope_ref_count)
# joins should only pushdown into itself, not to other joins # joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself # so we limit the selected sources to only itself
for join in select.args.get("joins") or []: for join in select.args.get("joins") or []:
name = join.this.alias_or_name name = join.this.alias_or_name
pushdown(join.args.get("on"), {name: scope.selected_sources[name]}) pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
return expression return expression
def pushdown(condition, sources): def pushdown(condition, sources, scope_ref_count):
if not condition: if not condition:
return return
@ -45,17 +55,17 @@ def pushdown(condition, sources):
predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition]) predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition])
if cnf_like: if cnf_like:
pushdown_cnf(predicates, sources) pushdown_cnf(predicates, sources, scope_ref_count)
else: else:
pushdown_dnf(predicates, sources) pushdown_dnf(predicates, sources, scope_ref_count)
def pushdown_cnf(predicates, scope): def pushdown_cnf(predicates, scope, scope_ref_count):
""" """
If the predicates are in CNF like form, we can simply replace each block in the parent. If the predicates are in CNF like form, we can simply replace each block in the parent.
""" """
for predicate in predicates: for predicate in predicates:
for node in nodes_for_predicate(predicate, scope).values(): for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
if isinstance(node, exp.Join): if isinstance(node, exp.Join):
predicate.replace(exp.TRUE) predicate.replace(exp.TRUE)
node.on(predicate, copy=False) node.on(predicate, copy=False)
@ -65,7 +75,7 @@ def pushdown_cnf(predicates, scope):
node.where(replace_aliases(node, predicate), copy=False) node.where(replace_aliases(node, predicate), copy=False)
def pushdown_dnf(predicates, scope): def pushdown_dnf(predicates, scope, scope_ref_count):
""" """
If the predicates are in DNF form, we can only push down conditions that are in all blocks. If the predicates are in DNF form, we can only push down conditions that are in all blocks.
Additionally, we can't remove predicates from their original form. Additionally, we can't remove predicates from their original form.
@ -91,7 +101,7 @@ def pushdown_dnf(predicates, scope):
# (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z) # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
for table in sorted(pushdown_tables): for table in sorted(pushdown_tables):
for predicate in predicates: for predicate in predicates:
nodes = nodes_for_predicate(predicate, scope) nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
if table not in nodes: if table not in nodes:
continue continue
@ -120,7 +130,7 @@ def pushdown_dnf(predicates, scope):
node.where(replace_aliases(node, predicate), copy=False) node.where(replace_aliases(node, predicate), copy=False)
def nodes_for_predicate(predicate, sources): def nodes_for_predicate(predicate, sources, scope_ref_count):
nodes = {} nodes = {}
tables = exp.column_table_names(predicate) tables = exp.column_table_names(predicate)
where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
@ -133,7 +143,7 @@ def nodes_for_predicate(predicate, sources):
if node and where_condition: if node and where_condition:
node = node.find_ancestor(exp.Join, exp.From) node = node.find_ancestor(exp.Join, exp.From)
# a node can reference a CTE which should be push down # a node can reference a CTE which should be pushed down
if isinstance(node, exp.From) and not isinstance(source, exp.Table): if isinstance(node, exp.From) and not isinstance(source, exp.Table):
node = source.expression node = source.expression
@ -142,7 +152,9 @@ def nodes_for_predicate(predicate, sources):
return {} return {}
nodes[table] = node nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1: elif isinstance(node, exp.Select) and len(tables) == 1:
if not node.args.get("group"): # we can't push down predicates to select statements if they are referenced in
# multiple places.
if not node.args.get("group") and scope_ref_count[id(source)] < 2:
nodes[table] = node nodes[table] = node
return nodes return nodes

View file

@ -31,8 +31,8 @@ def qualify_columns(expression, schema):
_pop_table_column_aliases(scope.derived_tables) _pop_table_column_aliases(scope.derived_tables)
_expand_using(scope, resolver) _expand_using(scope, resolver)
_expand_group_by(scope, resolver) _expand_group_by(scope, resolver)
_expand_order_by(scope)
_qualify_columns(scope, resolver) _qualify_columns(scope, resolver)
_expand_order_by(scope)
if not isinstance(scope.expression, exp.UDTF): if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver) _expand_stars(scope, resolver)
_qualify_outputs(scope) _qualify_outputs(scope)
@ -235,7 +235,7 @@ def _expand_stars(scope, resolver):
for table in tables: for table in tables:
if table not in scope.sources: if table not in scope.sources:
raise OptimizeError(f"Unknown table: {table}") raise OptimizeError(f"Unknown table: {table}")
columns = resolver.get_source_columns(table) columns = resolver.get_source_columns(table, only_visible=True)
table_id = id(table) table_id = id(table)
for name in columns: for name in columns:
if name not in except_columns.get(table_id, set()): if name not in except_columns.get(table_id, set()):
@ -332,7 +332,7 @@ class _Resolver:
self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns) self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns)
return self._all_columns return self._all_columns
def get_source_columns(self, name): def get_source_columns(self, name, only_visible=False):
"""Resolve the source columns for a given source `name`""" """Resolve the source columns for a given source `name`"""
if name not in self.scope.sources: if name not in self.scope.sources:
raise OptimizeError(f"Unknown table: {name}") raise OptimizeError(f"Unknown table: {name}")
@ -342,7 +342,7 @@ class _Resolver:
# If referencing a table, return the columns from the schema # If referencing a table, return the columns from the schema
if isinstance(source, exp.Table): if isinstance(source, exp.Table):
try: try:
return self.schema.column_names(source) return self.schema.column_names(source, only_visible)
except Exception as e: except Exception as e:
raise OptimizeError(str(e)) from e raise OptimizeError(str(e)) from e

View file

@ -9,16 +9,28 @@ class Schema(abc.ABC):
"""Abstract base class for database schemas""" """Abstract base class for database schemas"""
@abc.abstractmethod @abc.abstractmethod
def column_names(self, table): def column_names(self, table, only_visible=False):
""" """
Get the column names for a table. Get the column names for a table.
Args: Args:
table (sqlglot.expressions.Table): Table expression instance table (sqlglot.expressions.Table): Table expression instance
only_visible (bool): Whether to include invisible columns
Returns: Returns:
list[str]: list of column names list[str]: list of column names
""" """
@abc.abstractmethod
def get_column_type(self, table, column):
"""
Get the exp.DataType type of a column in the schema.
Args:
table (sqlglot.expressions.Table): The source table.
column (sqlglot.expressions.Column): The target column.
Returns:
sqlglot.expressions.DataType.Type: The resulting column type.
"""
class MappingSchema(Schema): class MappingSchema(Schema):
""" """
@ -29,10 +41,19 @@ class MappingSchema(Schema):
1. {table: {col: type}} 1. {table: {col: type}}
2. {db: {table: {col: type}}} 2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}} 3. {catalog: {db: {table: {col: type}}}}
visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
1. {table: set(*cols)}}
2. {db: {table: set(*cols)}}}
3. {catalog: {db: {table: set(*cols)}}}}
dialect (str): The dialect to be used for custom type mappings.
""" """
def __init__(self, schema): def __init__(self, schema, visible=None, dialect=None):
self.schema = schema self.schema = schema
self.visible = visible
self.dialect = dialect
self._type_mapping_cache = {}
depth = _dict_depth(schema) depth = _dict_depth(schema)
@ -49,7 +70,7 @@ class MappingSchema(Schema):
self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args) self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args)
def column_names(self, table): def column_names(self, table, only_visible=False):
if not isinstance(table.this, exp.Identifier): if not isinstance(table.this, exp.Identifier):
return fs_get(table) return fs_get(table)
@ -58,7 +79,39 @@ class MappingSchema(Schema):
for forbidden in self.forbidden_args: for forbidden in self.forbidden_args:
if table.text(forbidden): if table.text(forbidden):
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
return list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
if not only_visible or not self.visible:
return columns
visible = _nested_get(self.visible, *zip(self.supported_table_args, args))
return [col for col in columns if col in visible]
def get_column_type(self, table, column):
try:
schema_type = self.schema.get(table.name, {}).get(column.name).upper()
return self._convert_type(schema_type)
except:
raise OptimizeError(f"Failed to get type for column {column.sql()}")
def _convert_type(self, schema_type):
"""
Convert a type represented as a string to the corresponding exp.DataType.Type object.
Args:
schema_type (str): The type we want to convert.
Returns:
sqlglot.expressions.DataType.Type: The resulting expression type.
"""
if schema_type not in self._type_mapping_cache:
try:
self._type_mapping_cache[schema_type] = exp.maybe_parse(
schema_type, into=exp.DataType, dialect=self.dialect
).this
except AttributeError:
raise OptimizeError(f"Failed to convert type {schema_type}")
return self._type_mapping_cache[schema_type]
def ensure_schema(schema): def ensure_schema(schema):

View file

@ -68,6 +68,7 @@ class Scope:
self._selected_sources = None self._selected_sources = None
self._columns = None self._columns = None
self._external_columns = None self._external_columns = None
self._join_hints = None
def branch(self, expression, scope_type, chain_sources=None, **kwargs): def branch(self, expression, scope_type, chain_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope""" """Branch from the current scope to a new, inner scope"""
@ -85,14 +86,17 @@ class Scope:
self._subqueries = [] self._subqueries = []
self._derived_tables = [] self._derived_tables = []
self._raw_columns = [] self._raw_columns = []
self._join_hints = []
for node, parent, _ in self.walk(bfs=False): for node, parent, _ in self.walk(bfs=False):
if node is self.expression: if node is self.expression:
continue continue
elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
self._raw_columns.append(node) self._raw_columns.append(node)
elif isinstance(node, exp.Table): elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
self._tables.append(node) self._tables.append(node)
elif isinstance(node, exp.JoinHint):
self._join_hints.append(node)
elif isinstance(node, exp.UDTF): elif isinstance(node, exp.UDTF):
self._derived_tables.append(node) self._derived_tables.append(node)
elif isinstance(node, exp.CTE): elif isinstance(node, exp.CTE):
@ -246,7 +250,7 @@ class Scope:
table only becomes a selected source if it's included in a FROM or JOIN clause. table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns: Returns:
dict[str, (exp.Table|exp.Subquery, exp.Table|Scope)]: selected sources and nodes dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
""" """
if self._selected_sources is None: if self._selected_sources is None:
referenced_names = [] referenced_names = []
@ -310,6 +314,18 @@ class Scope:
self._external_columns = [c for c in self.columns if c.table not in self.selected_sources] self._external_columns = [c for c in self.columns if c.table not in self.selected_sources]
return self._external_columns return self._external_columns
@property
def join_hints(self):
"""
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
"""
if self._join_hints is None:
return []
return self._join_hints
def source_columns(self, source_name): def source_columns(self, source_name):
""" """
Get all columns in the current scope for a particular source. Get all columns in the current scope for a particular source.

View file

@ -56,12 +56,16 @@ def simplify_not(expression):
NOT (x AND y) -> NOT x OR NOT y NOT (x AND y) -> NOT x OR NOT y
""" """
if isinstance(expression, exp.Not): if isinstance(expression, exp.Not):
if isinstance(expression.this, exp.Null):
return NULL
if isinstance(expression.this, exp.Paren): if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest() condition = expression.this.unnest()
if isinstance(condition, exp.And): if isinstance(condition, exp.And):
return exp.or_(exp.not_(condition.left), exp.not_(condition.right)) return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Or): if isinstance(condition, exp.Or):
return exp.and_(exp.not_(condition.left), exp.not_(condition.right)) return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Null):
return NULL
if always_true(expression.this): if always_true(expression.this):
return FALSE return FALSE
if expression.this == FALSE: if expression.this == FALSE:
@ -95,10 +99,10 @@ def simplify_connectors(expression):
return left return left
if isinstance(expression, exp.And): if isinstance(expression, exp.And):
if NULL in (left, right):
return NULL
if FALSE in (left, right): if FALSE in (left, right):
return FALSE return FALSE
if NULL in (left, right):
return NULL
if always_true(left) and always_true(right): if always_true(left) and always_true(right):
return TRUE return TRUE
if always_true(left): if always_true(left):

View file

@ -8,6 +8,18 @@ from sqlglot.tokens import Token, Tokenizer, TokenType
logger = logging.getLogger("sqlglot") logger = logging.getLogger("sqlglot")
def parse_var_map(args):
keys = []
values = []
for i in range(0, len(args), 2):
keys.append(args[i])
values.append(args[i + 1])
return exp.VarMap(
keys=exp.Array(expressions=keys),
values=exp.Array(expressions=values),
)
class Parser: class Parser:
""" """
Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer` Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer`
@ -48,6 +60,7 @@ class Parser:
start=exp.Literal.number(1), start=exp.Literal.number(1),
length=exp.Literal.number(10), length=exp.Literal.number(10),
), ),
"VAR_MAP": parse_var_map,
} }
NO_PAREN_FUNCTIONS = { NO_PAREN_FUNCTIONS = {
@ -117,6 +130,7 @@ class Parser:
TokenType.VAR, TokenType.VAR,
TokenType.ALTER, TokenType.ALTER,
TokenType.ALWAYS, TokenType.ALWAYS,
TokenType.ANTI,
TokenType.BEGIN, TokenType.BEGIN,
TokenType.BOTH, TokenType.BOTH,
TokenType.BUCKET, TokenType.BUCKET,
@ -164,6 +178,7 @@ class Parser:
TokenType.ROWS, TokenType.ROWS,
TokenType.SCHEMA_COMMENT, TokenType.SCHEMA_COMMENT,
TokenType.SEED, TokenType.SEED,
TokenType.SEMI,
TokenType.SET, TokenType.SET,
TokenType.SHOW, TokenType.SHOW,
TokenType.STABLE, TokenType.STABLE,
@ -273,6 +288,8 @@ class Parser:
TokenType.INNER, TokenType.INNER,
TokenType.OUTER, TokenType.OUTER,
TokenType.CROSS, TokenType.CROSS,
TokenType.SEMI,
TokenType.ANTI,
} }
COLUMN_OPERATORS = { COLUMN_OPERATORS = {
@ -318,6 +335,8 @@ class Parser:
exp.Properties: lambda self: self._parse_properties(), exp.Properties: lambda self: self._parse_properties(),
exp.Where: lambda self: self._parse_where(), exp.Where: lambda self: self._parse_where(),
exp.Ordered: lambda self: self._parse_ordered(), exp.Ordered: lambda self: self._parse_ordered(),
exp.Having: lambda self: self._parse_having(),
exp.With: lambda self: self._parse_with(),
"JOIN_TYPE": lambda self: self._parse_join_side_and_kind(), "JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
} }
@ -338,7 +357,6 @@ class Parser:
TokenType.NULL: lambda *_: exp.Null(), TokenType.NULL: lambda *_: exp.Null(),
TokenType.TRUE: lambda *_: exp.Boolean(this=True), TokenType.TRUE: lambda *_: exp.Boolean(this=True),
TokenType.FALSE: lambda *_: exp.Boolean(this=False), TokenType.FALSE: lambda *_: exp.Boolean(this=False),
TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(),
TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()), TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()),
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text), TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text), TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
@ -910,7 +928,20 @@ class Parser:
return self.expression(exp.Tuple, expressions=expressions) return self.expression(exp.Tuple, expressions=expressions)
def _parse_select(self, nested=False, table=False): def _parse_select(self, nested=False, table=False):
if self._match(TokenType.SELECT): cte = self._parse_with()
if cte:
this = self._parse_statement()
if not this:
self.raise_error("Failed to parse any statement following CTE")
return cte
if "with" in this.arg_types:
this.set("with", cte)
else:
self.raise_error(f"{this.key} does not support CTE")
this = cte
elif self._match(TokenType.SELECT):
hint = self._parse_hint() hint = self._parse_hint()
all_ = self._match(TokenType.ALL) all_ = self._match(TokenType.ALL)
distinct = self._match(TokenType.DISTINCT) distinct = self._match(TokenType.DISTINCT)
@ -938,39 +969,6 @@ class Parser:
if from_: if from_:
this.set("from", from_) this.set("from", from_)
self._parse_query_modifiers(this) self._parse_query_modifiers(this)
elif self._match(TokenType.WITH):
recursive = self._match(TokenType.RECURSIVE)
expressions = []
while True:
expressions.append(self._parse_cte())
if not self._match(TokenType.COMMA):
break
cte = self.expression(
exp.With,
expressions=expressions,
recursive=recursive,
)
this = self._parse_statement()
if not this:
self.raise_error("Failed to parse any statement following CTE")
return cte
if "with" in this.arg_types:
this.set(
"with",
self.expression(
exp.With,
expressions=expressions,
recursive=recursive,
),
)
else:
self.raise_error(f"{this.key} does not support CTE")
elif (table or nested) and self._match(TokenType.L_PAREN): elif (table or nested) and self._match(TokenType.L_PAREN):
this = self._parse_table() if table else self._parse_select(nested=True) this = self._parse_table() if table else self._parse_select(nested=True)
self._parse_query_modifiers(this) self._parse_query_modifiers(this)
@ -986,6 +984,26 @@ class Parser:
return self._parse_set_operations(this) if this else None return self._parse_set_operations(this) if this else None
def _parse_with(self):
if not self._match(TokenType.WITH):
return None
recursive = self._match(TokenType.RECURSIVE)
expressions = []
while True:
expressions.append(self._parse_cte())
if not self._match(TokenType.COMMA):
break
return self.expression(
exp.With,
expressions=expressions,
recursive=recursive,
)
def _parse_cte(self): def _parse_cte(self):
alias = self._parse_table_alias() alias = self._parse_table_alias()
if not alias or not alias.this: if not alias or not alias.this:
@ -1485,8 +1503,7 @@ class Parser:
unnest = self._parse_unnest() unnest = self._parse_unnest()
if unnest: if unnest:
this = self.expression(exp.In, this=this, unnest=unnest) this = self.expression(exp.In, this=this, unnest=unnest)
else: elif self._match(TokenType.L_PAREN):
self._match_l_paren()
expressions = self._parse_csv(self._parse_select_or_expression) expressions = self._parse_csv(self._parse_select_or_expression)
if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable): if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable):
@ -1495,6 +1512,9 @@ class Parser:
this = self.expression(exp.In, this=this, expressions=expressions) this = self.expression(exp.In, this=this, expressions=expressions)
self._match_r_paren() self._match_r_paren()
else:
this = self.expression(exp.In, this=this, field=self._parse_field())
return this return this
def _parse_between(self, this): def _parse_between(self, this):
@ -1591,7 +1611,7 @@ class Parser:
elif nested: elif nested:
expressions = self._parse_csv(self._parse_types) expressions = self._parse_csv(self._parse_types)
else: else:
expressions = self._parse_csv(self._parse_number) expressions = self._parse_csv(self._parse_type)
if not expressions: if not expressions:
self._retreat(index) self._retreat(index)
@ -1706,7 +1726,7 @@ class Parser:
def _parse_field(self, any_token=False): def _parse_field(self, any_token=False):
return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token) return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token)
def _parse_function(self): def _parse_function(self, functions=None):
if not self._curr: if not self._curr:
return None return None
@ -1742,7 +1762,9 @@ class Parser:
self._match_r_paren() self._match_r_paren()
return this return this
function = self.FUNCTIONS.get(upper) if functions is None:
functions = self.FUNCTIONS
function = functions.get(upper)
args = self._parse_csv(self._parse_lambda) args = self._parse_csv(self._parse_lambda)
if function: if function:
@ -2025,10 +2047,20 @@ class Parser:
return self.expression(exp.Cast, this=this, to=to) return self.expression(exp.Cast, this=this, to=to)
def _parse_position(self): def _parse_position(self):
substr = self._parse_bitwise() args = self._parse_csv(self._parse_bitwise)
if self._match(TokenType.IN): if self._match(TokenType.IN):
string = self._parse_bitwise() args.append(self._parse_bitwise())
return self.expression(exp.StrPosition, this=string, substr=substr)
# Note: we're parsing in order needle, haystack, position
this = exp.StrPosition.from_arg_list(args)
self.validate_expression(this, args)
return this
def _parse_join_hint(self, func_name):
args = self._parse_csv(self._parse_table)
return exp.JoinHint(this=func_name.upper(), expressions=args)
def _parse_substring(self): def _parse_substring(self):
# Postgres supports the form: substring(string [from int] [for int]) # Postgres supports the form: substring(string [from int] [for int])
@ -2247,6 +2279,9 @@ class Parser:
def _parse_placeholder(self): def _parse_placeholder(self):
if self._match(TokenType.PLACEHOLDER): if self._match(TokenType.PLACEHOLDER):
return exp.Placeholder() return exp.Placeholder()
elif self._match(TokenType.COLON):
self._advance()
return exp.Placeholder(this=self._prev.text)
return None return None
def _parse_except(self): def _parse_except(self):

View file

@ -104,6 +104,7 @@ class TokenType(AutoName):
ALL = auto() ALL = auto()
ALTER = auto() ALTER = auto()
ANALYZE = auto() ANALYZE = auto()
ANTI = auto()
ANY = auto() ANY = auto()
ARRAY = auto() ARRAY = auto()
ASC = auto() ASC = auto()
@ -236,6 +237,7 @@ class TokenType(AutoName):
SCHEMA_COMMENT = auto() SCHEMA_COMMENT = auto()
SEED = auto() SEED = auto()
SELECT = auto() SELECT = auto()
SEMI = auto()
SEPARATOR = auto() SEPARATOR = auto()
SET = auto() SET = auto()
SHOW = auto() SHOW = auto()
@ -262,6 +264,7 @@ class TokenType(AutoName):
USE = auto() USE = auto()
USING = auto() USING = auto()
VALUES = auto() VALUES = auto()
VACUUM = auto()
VIEW = auto() VIEW = auto()
VOLATILE = auto() VOLATILE = auto()
WHEN = auto() WHEN = auto()
@ -406,6 +409,7 @@ class Tokenizer(metaclass=_Tokenizer):
"ALTER": TokenType.ALTER, "ALTER": TokenType.ALTER,
"ANALYZE": TokenType.ANALYZE, "ANALYZE": TokenType.ANALYZE,
"AND": TokenType.AND, "AND": TokenType.AND,
"ANTI": TokenType.ANTI,
"ANY": TokenType.ANY, "ANY": TokenType.ANY,
"ASC": TokenType.ASC, "ASC": TokenType.ASC,
"AS": TokenType.ALIAS, "AS": TokenType.ALIAS,
@ -528,6 +532,7 @@ class Tokenizer(metaclass=_Tokenizer):
"ROWS": TokenType.ROWS, "ROWS": TokenType.ROWS,
"SEED": TokenType.SEED, "SEED": TokenType.SEED,
"SELECT": TokenType.SELECT, "SELECT": TokenType.SELECT,
"SEMI": TokenType.SEMI,
"SET": TokenType.SET, "SET": TokenType.SET,
"SHOW": TokenType.SHOW, "SHOW": TokenType.SHOW,
"SOME": TokenType.SOME, "SOME": TokenType.SOME,
@ -551,6 +556,7 @@ class Tokenizer(metaclass=_Tokenizer):
"UPDATE": TokenType.UPDATE, "UPDATE": TokenType.UPDATE,
"USE": TokenType.USE, "USE": TokenType.USE,
"USING": TokenType.USING, "USING": TokenType.USING,
"VACUUM": TokenType.VACUUM,
"VALUES": TokenType.VALUES, "VALUES": TokenType.VALUES,
"VIEW": TokenType.VIEW, "VIEW": TokenType.VIEW,
"VOLATILE": TokenType.VOLATILE, "VOLATILE": TokenType.VOLATILE,
@ -577,6 +583,7 @@ class Tokenizer(metaclass=_Tokenizer):
"INT8": TokenType.BIGINT, "INT8": TokenType.BIGINT,
"DECIMAL": TokenType.DECIMAL, "DECIMAL": TokenType.DECIMAL,
"MAP": TokenType.MAP, "MAP": TokenType.MAP,
"NULLABLE": TokenType.NULLABLE,
"NUMBER": TokenType.DECIMAL, "NUMBER": TokenType.DECIMAL,
"NUMERIC": TokenType.DECIMAL, "NUMERIC": TokenType.DECIMAL,
"FIXED": TokenType.DECIMAL, "FIXED": TokenType.DECIMAL,
@ -629,6 +636,7 @@ class Tokenizer(metaclass=_Tokenizer):
TokenType.SHOW, TokenType.SHOW,
TokenType.TRUNCATE, TokenType.TRUNCATE,
TokenType.USE, TokenType.USE,
TokenType.VACUUM,
} }
# handle numeric literals like in hive (3L = BIGINT) # handle numeric literals like in hive (3L = BIGINT)

View file

@ -152,6 +152,10 @@ class TestBigQuery(Validator):
"SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)" "SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)"
) )
self.validate_identity(
"SELECT item, purchases, LAST_VALUE(item) OVER (item_window ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING) AS most_popular FROM Produce WINDOW item_window AS (ORDER BY purchases)"
)
self.validate_identity( self.validate_identity(
"SELECT LAST_VALUE(a IGNORE NULLS) OVER y FROM x WINDOW y AS (PARTITION BY CATEGORY)", "SELECT LAST_VALUE(a IGNORE NULLS) OVER y FROM x WINDOW y AS (PARTITION BY CATEGORY)",
) )
@ -222,6 +226,20 @@ class TestBigQuery(Validator):
"spark": "DATE_ADD(CURRENT_DATE, 1)", "spark": "DATE_ADD(CURRENT_DATE, 1)",
}, },
) )
self.validate_all(
"DATE_DIFF(DATE '2010-07-07', DATE '2008-12-25', DAY)",
write={
"bigquery": "DATE_DIFF(CAST('2010-07-07' AS DATE), CAST('2008-12-25' AS DATE), DAY)",
"mysql": "DATEDIFF(CAST('2010-07-07' AS DATE), CAST('2008-12-25' AS DATE))",
},
)
self.validate_all(
"DATE_DIFF(DATE '2010-07-07', DATE '2008-12-25', MINUTE)",
write={
"bigquery": "DATE_DIFF(CAST('2010-07-07' AS DATE), CAST('2008-12-25' AS DATE), MINUTE)",
"mysql": "DATEDIFF(CAST('2010-07-07' AS DATE), CAST('2008-12-25' AS DATE))",
},
)
self.validate_all( self.validate_all(
"CURRENT_DATE('UTC')", "CURRENT_DATE('UTC')",
write={ write={

View file

@ -8,6 +8,8 @@ class TestClickhouse(Validator):
self.validate_identity("dictGet(x, 'y')") self.validate_identity("dictGet(x, 'y')")
self.validate_identity("SELECT * FROM x FINAL") self.validate_identity("SELECT * FROM x FINAL")
self.validate_identity("SELECT * FROM x AS y FINAL") self.validate_identity("SELECT * FROM x AS y FINAL")
self.validate_identity("'a' IN mapKeys(map('a', 1, 'b', 2))")
self.validate_identity("CAST((1, 2) AS Tuple(a Int8, b Int16))")
self.validate_all( self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
@ -20,6 +22,12 @@ class TestClickhouse(Validator):
self.validate_all( self.validate_all(
"CAST(1 AS NULLABLE(Int64))", "CAST(1 AS NULLABLE(Int64))",
write={ write={
"clickhouse": "CAST(1 AS Nullable(BIGINT))", "clickhouse": "CAST(1 AS Nullable(Int64))",
},
)
self.validate_all(
"CAST(1 AS Nullable(DateTime64(6, 'UTC')))",
write={
"clickhouse": "CAST(1 AS Nullable(DateTime64(6, 'UTC')))",
}, },
) )

View file

@ -81,6 +81,24 @@ class TestDialect(Validator):
"starrocks": "CAST(a AS STRING)", "starrocks": "CAST(a AS STRING)",
}, },
) )
self.validate_all(
"CAST(MAP('a', '1') AS MAP(TEXT, TEXT))",
write={
"clickhouse": "CAST(map('a', '1') AS Map(TEXT, TEXT))",
},
)
self.validate_all(
"CAST(ARRAY(1, 2) AS ARRAY<TINYINT>)",
write={
"clickhouse": "CAST([1, 2] AS Array(Int8))",
},
)
self.validate_all(
"CAST((1, 2) AS STRUCT<a: TINYINT, b: SMALLINT, c: INT, d: BIGINT>)",
write={
"clickhouse": "CAST((1, 2) AS Tuple(a Int8, b Int16, c Int32, d Int64))",
},
)
self.validate_all( self.validate_all(
"CAST(a AS DATETIME)", "CAST(a AS DATETIME)",
write={ write={
@ -170,7 +188,7 @@ class TestDialect(Validator):
"CAST(a AS DOUBLE)", "CAST(a AS DOUBLE)",
write={ write={
"bigquery": "CAST(a AS FLOAT64)", "bigquery": "CAST(a AS FLOAT64)",
"clickhouse": "CAST(a AS DOUBLE)", "clickhouse": "CAST(a AS Float64)",
"duckdb": "CAST(a AS DOUBLE)", "duckdb": "CAST(a AS DOUBLE)",
"mysql": "CAST(a AS DOUBLE)", "mysql": "CAST(a AS DOUBLE)",
"hive": "CAST(a AS DOUBLE)", "hive": "CAST(a AS DOUBLE)",
@ -234,6 +252,8 @@ class TestDialect(Validator):
write={ write={
"duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')", "duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')",
"hive": "CAST('2020-01-01' AS TIMESTAMP)", "hive": "CAST('2020-01-01' AS TIMESTAMP)",
"oracle": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
"postgres": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d')", "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d')",
"redshift": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')", "redshift": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
"spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')", "spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')",
@ -245,6 +265,8 @@ class TestDialect(Validator):
"duckdb": "STRPTIME(x, '%y')", "duckdb": "STRPTIME(x, '%y')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)",
"presto": "DATE_PARSE(x, '%y')", "presto": "DATE_PARSE(x, '%y')",
"oracle": "TO_TIMESTAMP(x, 'YY')",
"postgres": "TO_TIMESTAMP(x, 'YY')",
"redshift": "TO_TIMESTAMP(x, 'YY')", "redshift": "TO_TIMESTAMP(x, 'YY')",
"spark": "TO_TIMESTAMP(x, 'yy')", "spark": "TO_TIMESTAMP(x, 'yy')",
}, },
@ -288,6 +310,8 @@ class TestDialect(Validator):
write={ write={
"duckdb": "STRFTIME(x, '%Y-%m-%d')", "duckdb": "STRFTIME(x, '%Y-%m-%d')",
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd')", "hive": "DATE_FORMAT(x, 'yyyy-MM-dd')",
"oracle": "TO_CHAR(x, 'YYYY-MM-DD')",
"postgres": "TO_CHAR(x, 'YYYY-MM-DD')",
"presto": "DATE_FORMAT(x, '%Y-%m-%d')", "presto": "DATE_FORMAT(x, '%Y-%m-%d')",
"redshift": "TO_CHAR(x, 'YYYY-MM-DD')", "redshift": "TO_CHAR(x, 'YYYY-MM-DD')",
}, },
@ -348,6 +372,8 @@ class TestDialect(Validator):
write={ write={
"duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))", "duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))",
"hive": "FROM_UNIXTIME(x)", "hive": "FROM_UNIXTIME(x)",
"oracle": "TO_DATE('1970-01-01','YYYY-MM-DD') + (x / 86400)",
"postgres": "TO_TIMESTAMP(x)",
"presto": "FROM_UNIXTIME(x)", "presto": "FROM_UNIXTIME(x)",
"starrocks": "FROM_UNIXTIME(x)", "starrocks": "FROM_UNIXTIME(x)",
}, },
@ -704,6 +730,7 @@ class TestDialect(Validator):
"SELECT * FROM a UNION SELECT * FROM b", "SELECT * FROM a UNION SELECT * FROM b",
read={ read={
"bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b", "bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b",
"clickhouse": "SELECT * FROM a UNION DISTINCT SELECT * FROM b",
"duckdb": "SELECT * FROM a UNION SELECT * FROM b", "duckdb": "SELECT * FROM a UNION SELECT * FROM b",
"presto": "SELECT * FROM a UNION SELECT * FROM b", "presto": "SELECT * FROM a UNION SELECT * FROM b",
"spark": "SELECT * FROM a UNION SELECT * FROM b", "spark": "SELECT * FROM a UNION SELECT * FROM b",
@ -719,6 +746,7 @@ class TestDialect(Validator):
"SELECT * FROM a UNION ALL SELECT * FROM b", "SELECT * FROM a UNION ALL SELECT * FROM b",
read={ read={
"bigquery": "SELECT * FROM a UNION ALL SELECT * FROM b", "bigquery": "SELECT * FROM a UNION ALL SELECT * FROM b",
"clickhouse": "SELECT * FROM a UNION ALL SELECT * FROM b",
"duckdb": "SELECT * FROM a UNION ALL SELECT * FROM b", "duckdb": "SELECT * FROM a UNION ALL SELECT * FROM b",
"presto": "SELECT * FROM a UNION ALL SELECT * FROM b", "presto": "SELECT * FROM a UNION ALL SELECT * FROM b",
"spark": "SELECT * FROM a UNION ALL SELECT * FROM b", "spark": "SELECT * FROM a UNION ALL SELECT * FROM b",
@ -848,15 +876,28 @@ class TestDialect(Validator):
"postgres": "STRPOS(x, ' ')", "postgres": "STRPOS(x, ' ')",
"presto": "STRPOS(x, ' ')", "presto": "STRPOS(x, ' ')",
"spark": "LOCATE(' ', x)", "spark": "LOCATE(' ', x)",
"clickhouse": "position(x, ' ')",
"snowflake": "POSITION(' ', x)",
}, },
) )
self.validate_all( self.validate_all(
"STR_POSITION(x, 'a')", "STR_POSITION('a', x)",
write={ write={
"duckdb": "STRPOS(x, 'a')", "duckdb": "STRPOS(x, 'a')",
"postgres": "STRPOS(x, 'a')", "postgres": "STRPOS(x, 'a')",
"presto": "STRPOS(x, 'a')", "presto": "STRPOS(x, 'a')",
"spark": "LOCATE('a', x)", "spark": "LOCATE('a', x)",
"clickhouse": "position(x, 'a')",
"snowflake": "POSITION('a', x)",
},
)
self.validate_all(
"POSITION('a', x, 3)",
write={
"presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
"spark": "LOCATE('a', x, 3)",
"clickhouse": "position(x, 'a', 3)",
"snowflake": "POSITION('a', x, 3)",
}, },
) )
self.validate_all( self.validate_all(

View file

@ -247,7 +247,7 @@ class TestHive(Validator):
"presto": "DATE_DIFF('day', CAST(SUBSTR(CAST(b AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST(a AS VARCHAR), 1, 10) AS DATE))", "presto": "DATE_DIFF('day', CAST(SUBSTR(CAST(b AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST(a AS VARCHAR), 1, 10) AS DATE))",
"hive": "DATEDIFF(TO_DATE(a), TO_DATE(b))", "hive": "DATEDIFF(TO_DATE(a), TO_DATE(b))",
"spark": "DATEDIFF(TO_DATE(a), TO_DATE(b))", "spark": "DATEDIFF(TO_DATE(a), TO_DATE(b))",
"": "DATE_DIFF(TS_OR_DS_TO_DATE(a), TS_OR_DS_TO_DATE(b))", "": "DATEDIFF(TS_OR_DS_TO_DATE(a), TS_OR_DS_TO_DATE(b))",
}, },
) )
self.validate_all( self.validate_all(
@ -295,7 +295,7 @@ class TestHive(Validator):
"presto": "DATE_DIFF('day', CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST(CAST(SUBSTR(CAST(y AS VARCHAR), 1, 10) AS DATE) AS VARCHAR), 1, 10) AS DATE))", "presto": "DATE_DIFF('day', CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST(CAST(SUBSTR(CAST(y AS VARCHAR), 1, 10) AS DATE) AS VARCHAR), 1, 10) AS DATE))",
"hive": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))", "hive": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))",
"spark": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))", "spark": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))",
"": "DATE_DIFF(TS_OR_DS_TO_DATE(TS_OR_DS_TO_DATE(y)), TS_OR_DS_TO_DATE(x))", "": "DATEDIFF(TS_OR_DS_TO_DATE(TS_OR_DS_TO_DATE(y)), TS_OR_DS_TO_DATE(x))",
}, },
) )
self.validate_all( self.validate_all(
@ -450,11 +450,21 @@ class TestHive(Validator):
) )
self.validate_all( self.validate_all(
"MAP(a, b, c, d)", "MAP(a, b, c, d)",
read={
"": "VAR_MAP(a, b, c, d)",
"clickhouse": "map(a, b, c, d)",
"duckdb": "MAP(LIST_VALUE(a, c), LIST_VALUE(b, d))",
"hive": "MAP(a, b, c, d)",
"presto": "MAP(ARRAY[a, c], ARRAY[b, d])",
"spark": "MAP(a, b, c, d)",
},
write={ write={
"": "MAP(ARRAY(a, c), ARRAY(b, d))",
"clickhouse": "map(a, b, c, d)",
"duckdb": "MAP(LIST_VALUE(a, c), LIST_VALUE(b, d))", "duckdb": "MAP(LIST_VALUE(a, c), LIST_VALUE(b, d))",
"presto": "MAP(ARRAY[a, c], ARRAY[b, d])", "presto": "MAP(ARRAY[a, c], ARRAY[b, d])",
"hive": "MAP(a, b, c, d)", "hive": "MAP(a, b, c, d)",
"spark": "MAP_FROM_ARRAYS(ARRAY(a, c), ARRAY(b, d))", "spark": "MAP(a, b, c, d)",
}, },
) )
self.validate_all( self.validate_all(
@ -463,7 +473,7 @@ class TestHive(Validator):
"duckdb": "MAP(LIST_VALUE(a), LIST_VALUE(b))", "duckdb": "MAP(LIST_VALUE(a), LIST_VALUE(b))",
"presto": "MAP(ARRAY[a], ARRAY[b])", "presto": "MAP(ARRAY[a], ARRAY[b])",
"hive": "MAP(a, b)", "hive": "MAP(a, b)",
"spark": "MAP_FROM_ARRAYS(ARRAY(a), ARRAY(b))", "spark": "MAP(a, b)",
}, },
) )
self.validate_all( self.validate_all(

View file

@ -67,6 +67,7 @@ class TestPostgres(Validator):
self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))") self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))")
self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')") self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')")
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)") self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
self.validate_identity("SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')")
self.validate_all( self.validate_all(
"CREATE TABLE x (a UUID, b BYTEA)", "CREATE TABLE x (a UUID, b BYTEA)",

View file

@ -305,3 +305,35 @@ class TestSnowflake(Validator):
self.validate_identity( self.validate_identity(
"CREATE PROCEDURE a.b.c(x INT, y VARIANT) RETURNS OBJECT EXECUTE AS CALLER AS 'BEGIN SELECT 1; END;'" "CREATE PROCEDURE a.b.c(x INT, y VARIANT) RETURNS OBJECT EXECUTE AS CALLER AS 'BEGIN SELECT 1; END;'"
) )
def test_table_literal(self):
# All examples from https://docs.snowflake.com/en/sql-reference/literals-table.html
self.validate_all(
r"""SELECT * FROM TABLE('MYTABLE')""", write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""}
)
self.validate_all(
r"""SELECT * FROM TABLE('MYDB."MYSCHEMA"."MYTABLE"')""",
write={"snowflake": r"""SELECT * FROM TABLE('MYDB."MYSCHEMA"."MYTABLE"')"""},
)
# Per Snowflake documentation at https://docs.snowflake.com/en/sql-reference/literals-table.html
# one can use either a " ' " or " $$ " to enclose the object identifier.
# Capturing the single tokens seems like lot of work. Hence adjusting tests to use these interchangeably,
self.validate_all(
r"""SELECT * FROM TABLE($$MYDB. "MYSCHEMA"."MYTABLE"$$)""",
write={"snowflake": r"""SELECT * FROM TABLE('MYDB. "MYSCHEMA"."MYTABLE"')"""},
)
self.validate_all(r"""SELECT * FROM TABLE($MYVAR)""", write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""})
self.validate_all(r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""})
self.validate_all(
r"""SELECT * FROM TABLE(:BINDING)""", write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""}
)
self.validate_all(
r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10""",
write={"snowflake": r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10"""},
)

View file

@ -111,12 +111,70 @@ TBLPROPERTIES (
"SELECT /*+ COALESCE(3) */ * FROM x", "SELECT /*+ COALESCE(3) */ * FROM x",
write={ write={
"spark": "SELECT /*+ COALESCE(3) */ * FROM x", "spark": "SELECT /*+ COALESCE(3) */ * FROM x",
"bigquery": "SELECT * FROM x",
}, },
) )
self.validate_all( self.validate_all(
"SELECT /*+ COALESCE(3), REPARTITION(1) */ * FROM x", "SELECT /*+ COALESCE(3), REPARTITION(1) */ * FROM x",
write={ write={
"spark": "SELECT /*+ COALESCE(3), REPARTITION(1) */ * FROM x", "spark": "SELECT /*+ COALESCE(3), REPARTITION(1) */ * FROM x",
"bigquery": "SELECT * FROM x",
},
)
self.validate_all(
"SELECT /*+ BROADCAST(table) */ cola FROM table",
write={
"spark": "SELECT /*+ BROADCAST(table) */ cola FROM table",
"bigquery": "SELECT cola FROM table",
},
)
self.validate_all(
"SELECT /*+ BROADCASTJOIN(table) */ cola FROM table",
write={
"spark": "SELECT /*+ BROADCASTJOIN(table) */ cola FROM table",
"bigquery": "SELECT cola FROM table",
},
)
self.validate_all(
"SELECT /*+ MAPJOIN(table) */ cola FROM table",
write={
"spark": "SELECT /*+ MAPJOIN(table) */ cola FROM table",
"bigquery": "SELECT cola FROM table",
},
)
self.validate_all(
"SELECT /*+ MERGE(table) */ cola FROM table",
write={
"spark": "SELECT /*+ MERGE(table) */ cola FROM table",
"bigquery": "SELECT cola FROM table",
},
)
self.validate_all(
"SELECT /*+ SHUFFLEMERGE(table) */ cola FROM table",
write={
"spark": "SELECT /*+ SHUFFLEMERGE(table) */ cola FROM table",
"bigquery": "SELECT cola FROM table",
},
)
self.validate_all(
"SELECT /*+ MERGEJOIN(table) */ cola FROM table",
write={
"spark": "SELECT /*+ MERGEJOIN(table) */ cola FROM table",
"bigquery": "SELECT cola FROM table",
},
)
self.validate_all(
"SELECT /*+ SHUFFLE_HASH(table) */ cola FROM table",
write={
"spark": "SELECT /*+ SHUFFLE_HASH(table) */ cola FROM table",
"bigquery": "SELECT cola FROM table",
},
)
self.validate_all(
"SELECT /*+ SHUFFLE_REPLICATE_NL(table) */ cola FROM table",
write={
"spark": "SELECT /*+ SHUFFLE_REPLICATE_NL(table) */ cola FROM table",
"bigquery": "SELECT cola FROM table",
}, },
) )

View file

@ -321,6 +321,10 @@ SELECT 1 FROM a INNER JOIN b ON a.x = b.x
SELECT 1 FROM a LEFT JOIN b ON a.x = b.x SELECT 1 FROM a LEFT JOIN b ON a.x = b.x
SELECT 1 FROM a RIGHT JOIN b ON a.x = b.x SELECT 1 FROM a RIGHT JOIN b ON a.x = b.x
SELECT 1 FROM a CROSS JOIN b ON a.x = b.x SELECT 1 FROM a CROSS JOIN b ON a.x = b.x
SELECT 1 FROM a LEFT SEMI JOIN b ON a.x = b.x
SELECT 1 FROM a LEFT ANTI JOIN b ON a.x = b.x
SELECT 1 FROM a RIGHT SEMI JOIN b ON a.x = b.x
SELECT 1 FROM a RIGHT ANTI JOIN b ON a.x = b.x
SELECT 1 FROM a JOIN b USING (x) SELECT 1 FROM a JOIN b USING (x)
SELECT 1 FROM a JOIN b USING (x, y, z) SELECT 1 FROM a JOIN b USING (x, y, z)
SELECT 1 FROM a JOIN (SELECT a FROM c) AS b ON a.x = b.x AND a.x < 2 SELECT 1 FROM a JOIN (SELECT a FROM c) AS b ON a.x = b.x AND a.x < 2
@ -529,12 +533,14 @@ UPDATE db.tbl_name SET foo = 123 WHERE tbl_name.bar = 234
UPDATE db.tbl_name SET foo = 123, foo_1 = 234 WHERE tbl_name.bar = 234 UPDATE db.tbl_name SET foo = 123, foo_1 = 234 WHERE tbl_name.bar = 234
TRUNCATE TABLE x TRUNCATE TABLE x
OPTIMIZE TABLE y OPTIMIZE TABLE y
VACUUM FREEZE my_table
WITH a AS (SELECT 1) INSERT INTO b SELECT * FROM a WITH a AS (SELECT 1) INSERT INTO b SELECT * FROM a
WITH a AS (SELECT * FROM b) UPDATE a SET col = 1 WITH a AS (SELECT * FROM b) UPDATE a SET col = 1
WITH a AS (SELECT * FROM b) CREATE TABLE b AS SELECT * FROM a WITH a AS (SELECT * FROM b) CREATE TABLE b AS SELECT * FROM a
WITH a AS (SELECT * FROM b) DELETE FROM a WITH a AS (SELECT * FROM b) DELETE FROM a
WITH a AS (SELECT * FROM b) CACHE TABLE a WITH a AS (SELECT * FROM b) CACHE TABLE a
SELECT ? AS ? FROM x WHERE b BETWEEN ? AND ? GROUP BY ?, 1 LIMIT ? SELECT ? AS ? FROM x WHERE b BETWEEN ? AND ? GROUP BY ?, 1 LIMIT ?
SELECT :hello, ? FROM x LIMIT :my_limit
WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS foo, c.bar AS bar FROM c)) SELECT * FROM a WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS foo, c.bar AS bar FROM c)) SELECT * FROM a
WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a
SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z

View file

@ -1,107 +1,189 @@
-- Simple # title: Simple
SELECT a, b FROM (SELECT a, b FROM x); SELECT a, b FROM (SELECT a, b FROM x);
SELECT x.a AS a, x.b AS b FROM x AS x; SELECT x.a AS a, x.b AS b FROM x AS x;
-- Inner table alias is merged # title: Inner table alias is merged
SELECT a, b FROM (SELECT a, b FROM x AS q) AS r; SELECT a, b FROM (SELECT a, b FROM x AS q) AS r;
SELECT q.a AS a, q.b AS b FROM x AS q; SELECT q.a AS a, q.b AS b FROM x AS q;
-- Double nesting # title: Double nesting
SELECT a, b FROM (SELECT a, b FROM (SELECT a, b FROM x)); SELECT a, b FROM (SELECT a, b FROM (SELECT a, b FROM x));
SELECT x.a AS a, x.b AS b FROM x AS x; SELECT x.a AS a, x.b AS b FROM x AS x;
-- WHERE clause is merged # title: WHERE clause is merged
SELECT a, SUM(b) FROM (SELECT a, b FROM x WHERE a > 1) GROUP BY a; SELECT a, SUM(b) AS b FROM (SELECT a, b FROM x WHERE a > 1) GROUP BY a;
SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 GROUP BY x.a; SELECT x.a AS a, SUM(x.b) AS b FROM x AS x WHERE x.a > 1 GROUP BY x.a;
-- Outer query has join # title: Outer query has join
SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1;
-- Outer query has join
SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b; SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1; SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1;
# title: Leave tables isolated
# leave_tables_isolated: true # leave_tables_isolated: true
SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b; SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM (SELECT x.a AS a, x.b AS b FROM x AS x WHERE x.a > 1) AS x JOIN y AS y ON x.b = y.b; SELECT x.a AS a, y.c AS c FROM (SELECT x.a AS a, x.b AS b FROM x AS x WHERE x.a > 1) AS x JOIN y AS y ON x.b = y.b;
-- Join on derived table # title: Join on derived table
SELECT a, c FROM x JOIN (SELECT b, c FROM y) AS y ON x.b = y.b; SELECT a, c FROM x JOIN (SELECT b, c FROM y) AS y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
-- Inner query has a join # title: Inner query has a join
SELECT a, c FROM (SELECT a, c FROM x JOIN y ON x.b = y.b); SELECT a, c FROM (SELECT a, c FROM x JOIN y ON x.b = y.b);
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
-- Inner query has conflicting name in outer query # title: Inner query has conflicting name in outer query
SELECT a, c FROM (SELECT q.a, q.b FROM x AS q) AS x JOIN y AS q ON x.b = q.b; SELECT a, c FROM (SELECT q.a, q.b FROM x AS q) AS x JOIN y AS q ON x.b = q.b;
SELECT q_2.a AS a, q.c AS c FROM x AS q_2 JOIN y AS q ON q_2.b = q.b; SELECT q_2.a AS a, q.c AS c FROM x AS q_2 JOIN y AS q ON q_2.b = q.b;
-- Inner query has conflicting name in joined source # title: Inner query has conflicting name in joined source
SELECT x.a, q.c FROM (SELECT a, x.b FROM x JOIN y AS q ON x.b = q.b) AS x JOIN y AS q ON x.b = q.b; SELECT x.a, q.c FROM (SELECT a, x.b FROM x JOIN y AS q ON x.b = q.b) AS x JOIN y AS q ON x.b = q.b;
SELECT x.a AS a, q.c AS c FROM x AS x JOIN y AS q_2 ON x.b = q_2.b JOIN y AS q ON x.b = q.b; SELECT x.a AS a, q.c AS c FROM x AS x JOIN y AS q_2 ON x.b = q_2.b JOIN y AS q ON x.b = q.b;
-- Inner query has multiple conflicting names # title: Inner query has multiple conflicting names
SELECT x.a, q.c, r.c FROM (SELECT q.a, r.b FROM x AS q JOIN y AS r ON q.b = r.b) AS x JOIN y AS q ON x.b = q.b JOIN y AS r ON x.b = r.b; SELECT x.a, q.c, r.c FROM (SELECT q.a, r.b FROM x AS q JOIN y AS r ON q.b = r.b) AS x JOIN y AS q ON x.b = q.b JOIN y AS r ON x.b = r.b ORDER BY x.a, q.c, r.c;
SELECT q_2.a AS a, q.c AS c, r.c AS c FROM x AS q_2 JOIN y AS r_2 ON q_2.b = r_2.b JOIN y AS q ON r_2.b = q.b JOIN y AS r ON r_2.b = r.b; SELECT q_2.a AS a, q.c AS c, r.c AS c FROM x AS q_2 JOIN y AS r_2 ON q_2.b = r_2.b JOIN y AS q ON r_2.b = q.b JOIN y AS r ON r_2.b = r.b ORDER BY q_2.a, q.c, r.c;
-- Inner queries have conflicting names with each other # title: Inner queries have conflicting names with each other
SELECT r.b FROM (SELECT b FROM x AS x) AS q JOIN (SELECT b FROM x) AS r ON q.b = r.b; SELECT r.b FROM (SELECT b FROM x AS x) AS q JOIN (SELECT b FROM x) AS r ON q.b = r.b;
SELECT x_2.b AS b FROM x AS x JOIN x AS x_2 ON x.b = x_2.b; SELECT x_2.b AS b FROM x AS x JOIN x AS x_2 ON x.b = x_2.b;
-- WHERE clause in joined derived table is merged to ON clause # title: WHERE clause in joined derived table is merged to ON clause
SELECT x.a, y.c FROM x JOIN (SELECT b, c FROM y WHERE c > 1) AS y; SELECT x.a, y.c FROM x JOIN (SELECT b, c FROM y WHERE c > 1) AS y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON y.c > 1; SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b AND y.c > 1;
-- Comma JOIN in outer query # title: Comma JOIN in outer query
SELECT x.a, y.c FROM (SELECT a FROM x) AS x, (SELECT c FROM y) AS y; SELECT x.a, y.c FROM (SELECT a FROM x) AS x, (SELECT c FROM y) AS y;
SELECT x.a AS a, y.c AS c FROM x AS x, y AS y; SELECT x.a AS a, y.c AS c FROM x AS x, y AS y;
-- Comma JOIN in inner query # title: Comma JOIN in inner query
SELECT x.a, x.c FROM (SELECT x.a, z.c FROM x, y AS z) AS x; SELECT x.a, x.c FROM (SELECT x.a, z.c FROM x, y AS z) AS x;
SELECT x.a AS a, z.c AS c FROM x AS x CROSS JOIN y AS z; SELECT x.a AS a, z.c AS c FROM x AS x CROSS JOIN y AS z;
-- (Regression) Column in ORDER BY # title: (Regression) Column in ORDER BY
SELECT * FROM (SELECT * FROM (SELECT * FROM x)) ORDER BY a LIMIT 1; SELECT * FROM (SELECT * FROM (SELECT * FROM x)) ORDER BY a LIMIT 1;
SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a LIMIT 1; SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a LIMIT 1;
-- CTE # title: CTE
WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x; WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x;
SELECT x.a AS a, x.b AS b FROM x AS x; SELECT x.a AS a, x.b AS b FROM x AS x;
-- CTE with outer table alias # title: CTE with outer table alias
WITH y AS (SELECT a, b FROM x) SELECT a, b FROM y AS z; WITH y AS (SELECT a, b FROM x) SELECT a, b FROM y AS z;
SELECT x.a AS a, x.b AS b FROM x AS x; SELECT x.a AS a, x.b AS b FROM x AS x;
-- Nested CTE # title: Nested CTE
WITH x AS (SELECT a FROM x), x2 AS (SELECT a FROM x) SELECT a FROM x2; WITH x2 AS (SELECT a FROM x), x3 AS (SELECT a FROM x2) SELECT a FROM x3;
SELECT x.a AS a FROM x AS x; SELECT x.a AS a FROM x AS x;
-- CTE WHERE clause is merged # title: CTE WHERE clause is merged
WITH x AS (SELECT a, b FROM x WHERE a > 1) SELECT a, SUM(b) FROM x GROUP BY a; WITH x AS (SELECT a, b FROM x WHERE a > 1) SELECT a, SUM(b) AS b FROM x GROUP BY a;
SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 GROUP BY x.a; SELECT x.a AS a, SUM(x.b) AS b FROM x AS x WHERE x.a > 1 GROUP BY x.a;
-- CTE Outer query has join # title: CTE Outer query has join
WITH x AS (SELECT a, b FROM x WHERE a > 1) SELECT a, c FROM x AS x JOIN y ON x.b = y.b; WITH x2 AS (SELECT a, b FROM x WHERE a > 1) SELECT a, c FROM x2 AS x JOIN y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1; SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1;
-- CTE with inner table alias # title: CTE with inner table alias
WITH y AS (SELECT a, b FROM x AS q) SELECT a, b FROM y AS z; WITH y AS (SELECT a, b FROM x AS q) SELECT a, b FROM y AS z;
SELECT q.a AS a, q.b AS b FROM x AS q; SELECT q.a AS a, q.b AS b FROM x AS q;
-- Duplicate queries to CTE # title: Nested CTE
WITH x AS (SELECT a, b FROM x) SELECT x.a, y.b FROM x JOIN x AS y;
WITH x AS (SELECT x.a AS a, x.b AS b FROM x AS x) SELECT x.a AS a, y.b AS b FROM x JOIN x AS y;
-- Nested CTE
SELECT * FROM (WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x); SELECT * FROM (WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x);
SELECT x.a AS a, x.b AS b FROM x AS x; SELECT x.a AS a, x.b AS b FROM x AS x;
-- Inner select is an expression # title: Inner select is an expression
SELECT a FROM (SELECT a FROM (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) AS x) AS x; SELECT a FROM (SELECT a FROM (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) AS x) AS x;
SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b; SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b;
-- CTE select is an expression # title: CTE select is an expression
WITH x AS (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) SELECT a FROM (SELECT a FROM x AS x) AS x; WITH x2 AS (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) SELECT a FROM (SELECT a FROM x2 AS x) AS x;
SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b; SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b;
# title: Full outer join
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x FULL OUTER JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b;
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x FULL OUTER JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b;
# title: Full outer join, no predicates
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x) AS x FULL OUTER JOIN (SELECT y.b AS b FROM y AS y) AS y ON x.b = y.b;
SELECT x.b AS b, y.b AS b2 FROM x AS x FULL OUTER JOIN y AS y ON x.b = y.b;
# title: Left join
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x LEFT JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b;
SELECT x.b AS b, y.b AS b2 FROM x AS x LEFT JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b WHERE x.b = 1;
# title: Left join, no predicates
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x) AS x LEFT JOIN (SELECT y.b AS b FROM y AS y) AS y ON x.b = y.b;
SELECT x.b AS b, y.b AS b2 FROM x AS x LEFT JOIN y AS y ON x.b = y.b;
# title: Right join
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x RIGHT JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b;
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x RIGHT JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b;
# title: Right join, no predicates
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x) AS x RIGHT JOIN (SELECT y.b AS b FROM y AS y) AS y ON x.b = y.b;
SELECT x.b AS b, y.b AS b2 FROM x AS x RIGHT JOIN y AS y ON x.b = y.b;
# title: Inner join
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x INNER JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b;
SELECT x.b AS b, y.b AS b2 FROM x AS x INNER JOIN y AS y ON x.b = y.b AND y.b = 2 WHERE x.b = 1;
# title: Inner join, no predicates
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x) AS x INNER JOIN (SELECT y.b AS b FROM y AS y) AS y ON x.b = y.b;
SELECT x.b AS b, y.b AS b2 FROM x AS x INNER JOIN y AS y ON x.b = y.b;
# title: Cross join
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x CROSS JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y;
SELECT x.b AS b, y.b AS b2 FROM x AS x JOIN y AS y ON y.b = 2 WHERE x.b = 1;
# title: Cross join, no predicates
SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x) AS x CROSS JOIN (SELECT y.b AS b FROM y AS y) AS y;
SELECT x.b AS b, y.b AS b2 FROM x AS x CROSS JOIN y AS y;
# title: Broadcast hint
# dialect: spark
WITH m AS (SELECT x.a, x.b FROM x), n AS (SELECT y.b, y.c FROM y), joined as (SELECT /*+ BROADCAST(k) */ m.a, k.c FROM m JOIN n AS k ON m.b = k.b) SELECT joined.a, joined.c FROM joined;
SELECT /*+ BROADCAST(y) */ x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
# title: Broadcast hint multiple tables
# dialect: spark
WITH m AS (SELECT x.a, x.b FROM x), n AS (SELECT y.b, y.c FROM y), joined as (SELECT /*+ BROADCAST(m, n) */ m.a, n.c FROM m JOIN n ON m.b = n.b) SELECT joined.a, joined.c FROM joined;
SELECT /*+ BROADCAST(x, y) */ x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
# title: Multiple Table Hints
# dialect: spark
WITH m AS (SELECT x.a, x.b FROM x), n AS (SELECT y.b, y.c FROM y), joined as (SELECT /*+ BROADCAST(m), MERGE(m, n) */ m.a, n.c FROM m JOIN n ON m.b = n.b) SELECT joined.a, joined.c FROM joined;
SELECT /*+ BROADCAST(x), MERGE(x, y) */ x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
# title: Mix Table and Column Hints
# dialect: spark
WITH m AS (SELECT x.a, x.b FROM x), n AS (SELECT y.b, y.c FROM y), joined as (SELECT /*+ BROADCAST(m), MERGE(m, n) */ m.a, n.c FROM m JOIN n ON m.b = n.b) SELECT /*+ COALESCE(3) */ joined.a, joined.c FROM joined;
SELECT /*+ COALESCE(3), BROADCAST(x), MERGE(x, y) */ x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
# title: Hint Subquery
# dialect: spark
SELECT
subquery.a,
subquery.c
FROM (
SELECT /*+ BROADCAST(m), MERGE(m, n) */ m.a, n.c FROM (SELECT x.a, x.b FROM x) AS m JOIN (SELECT y.b, y.c FROM y) AS n ON m.b = n.b
) AS subquery;
SELECT /*+ BROADCAST(x), MERGE(x, y) */ x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
# title: Subquery Test
# dialect: spark
SELECT /*+ BROADCAST(x) */
x.a,
x.c
FROM (
SELECT
x.a,
x.c
FROM (
SELECT
x.a,
COUNT(1) AS c
FROM x
GROUP BY x.a
) AS x
) AS x;
SELECT /*+ BROADCAST(x) */ x.a AS a, x.c AS c FROM (SELECT x.a AS a, COUNT(1) AS c FROM x AS x GROUP BY x.a) AS x;

View file

@ -1,3 +1,5 @@
# title: lateral
# execute: false
SELECT a, m FROM z LATERAL VIEW EXPLODE([1, 2]) q AS m; SELECT a, m FROM z LATERAL VIEW EXPLODE([1, 2]) q AS m;
SELECT SELECT
"z"."a" AS "a", "z"."a" AS "a",
@ -6,11 +8,13 @@ FROM "z" AS "z"
LATERAL VIEW LATERAL VIEW
EXPLODE(ARRAY(1, 2)) q AS "m"; EXPLODE(ARRAY(1, 2)) q AS "m";
# title: unnest
SELECT x FROM UNNEST([1, 2]) AS q(x, y); SELECT x FROM UNNEST([1, 2]) AS q(x, y);
SELECT SELECT
"q"."x" AS "x" "q"."x" AS "x"
FROM UNNEST(ARRAY(1, 2)) AS "q"("x", "y"); FROM UNNEST(ARRAY(1, 2)) AS "q"("x", "y");
# title: Union in CTE
WITH cte AS ( WITH cte AS (
( (
SELECT SELECT
@ -21,7 +25,7 @@ WITH cte AS (
UNION ALL UNION ALL
( (
SELECT SELECT
a b AS a
FROM FROM
y y
) )
@ -39,7 +43,7 @@ WITH "cte" AS (
UNION ALL UNION ALL
( (
SELECT SELECT
"y"."a" AS "a" "y"."b" AS "a"
FROM "y" AS "y" FROM "y" AS "y"
) )
) )
@ -47,6 +51,7 @@ SELECT
"cte"."a" AS "a" "cte"."a" AS "a"
FROM "cte"; FROM "cte";
# title: Chained CTEs
WITH cte1 AS ( WITH cte1 AS (
SELECT a SELECT a
FROM x FROM x
@ -74,30 +79,31 @@ SELECT
"cte1"."a" + 1 AS "a" "cte1"."a" + 1 AS "a"
FROM "cte1"; FROM "cte1";
SELECT a, SUM(b) # title: Correlated subquery
SELECT a, SUM(b) AS sum_b
FROM ( FROM (
SELECT x.a, y.b SELECT x.a, y.b
FROM x, y FROM x, y
WHERE (SELECT max(b) FROM y WHERE x.a = y.a) >= 0 AND x.a = y.a WHERE (SELECT max(b) FROM y WHERE x.b = y.b) >= 0 AND x.b = y.b
) d ) d
WHERE (TRUE AND TRUE OR 'a' = 'b') AND a > 1 WHERE (TRUE AND TRUE OR 'a' = 'b') AND a > 1
GROUP BY a; GROUP BY a;
WITH "_u_0" AS ( WITH "_u_0" AS (
SELECT SELECT
MAX("y"."b") AS "_col_0", MAX("y"."b") AS "_col_0",
"y"."a" AS "_u_1" "y"."b" AS "_u_1"
FROM "y" AS "y" FROM "y" AS "y"
GROUP BY GROUP BY
"y"."a" "y"."b"
) )
SELECT SELECT
"x"."a" AS "a", "x"."a" AS "a",
SUM("y"."b") AS "_col_1" SUM("y"."b") AS "sum_b"
FROM "x" AS "x" FROM "x" AS "x"
LEFT JOIN "_u_0" AS "_u_0" LEFT JOIN "_u_0" AS "_u_0"
ON "x"."a" = "_u_0"."_u_1" ON "x"."b" = "_u_0"."_u_1"
JOIN "y" AS "y" JOIN "y" AS "y"
ON "x"."a" = "y"."a" ON "x"."b" = "y"."b"
WHERE WHERE
"_u_0"."_col_0" >= 0 "_u_0"."_col_0" >= 0
AND "x"."a" > 1 AND "x"."a" > 1
@ -105,6 +111,7 @@ WHERE
GROUP BY GROUP BY
"x"."a"; "x"."a";
# title: Root subquery
(SELECT a FROM x) LIMIT 1; (SELECT a FROM x) LIMIT 1;
( (
SELECT SELECT
@ -113,6 +120,7 @@ GROUP BY
) )
LIMIT 1; LIMIT 1;
# title: Root subquery is union
(SELECT b FROM x UNION SELECT b FROM y) LIMIT 1; (SELECT b FROM x UNION SELECT b FROM y) LIMIT 1;
( (
SELECT SELECT
@ -125,6 +133,7 @@ LIMIT 1;
) )
LIMIT 1; LIMIT 1;
# title: broadcast
# dialect: spark # dialect: spark
SELECT /*+ BROADCAST(y) */ x.b FROM x JOIN y ON x.b = y.b; SELECT /*+ BROADCAST(y) */ x.b FROM x JOIN y ON x.b = y.b;
SELECT /*+ BROADCAST(`y`) */ SELECT /*+ BROADCAST(`y`) */
@ -133,11 +142,14 @@ FROM `x` AS `x`
JOIN `y` AS `y` JOIN `y` AS `y`
ON `x`.`b` = `y`.`b`; ON `x`.`b` = `y`.`b`;
# title: aggregate
# execute: false
SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x; SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
SELECT SELECT
AGGREGATE(ARRAY("x"."a", "x"."b"), 0, ("x", "acc") -> "x" + "acc" + "x"."a") AS "sum_agg" AGGREGATE(ARRAY("x"."a", "x"."b"), 0, ("x", "acc") -> "x" + "acc" + "x"."a") AS "sum_agg"
FROM "x" AS "x"; FROM "x" AS "x";
# title: values
SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb); SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb);
SELECT SELECT
"tab"."cola" AS "cola", "tab"."cola" AS "cola",
@ -146,6 +158,7 @@ FROM (VALUES
(1, 'test'), (1, 'test'),
(2, 'test2')) AS "tab"("cola", "colb"); (2, 'test2')) AS "tab"("cola", "colb");
# title: spark values
# dialect: spark # dialect: spark
SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb); SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb);
SELECT SELECT
@ -154,3 +167,112 @@ SELECT
FROM VALUES FROM VALUES
(1, 'test'), (1, 'test'),
(2, 'test2') AS `tab`(`cola`, `colb`); (2, 'test2') AS `tab`(`cola`, `colb`);
# title: complex CTE dependencies
WITH m AS (
SELECT a, b FROM (VALUES (1, 2)) AS a1(a, b)
), n AS (
SELECT a, b FROM m WHERE m.a = 1
), o AS (
SELECT a, b FROM m WHERE m.a = 2
) SELECT
n.a,
n.b,
o.b
FROM n
FULL OUTER JOIN o ON n.a = o.a
CROSS JOIN n AS n2
WHERE o.b > 0 AND n.a = n2.a;
WITH "m" AS (
SELECT
"a1"."a" AS "a",
"a1"."b" AS "b"
FROM (VALUES
(1, 2)) AS "a1"("a", "b")
), "n" AS (
SELECT
"m"."a" AS "a",
"m"."b" AS "b"
FROM "m"
WHERE
"m"."a" = 1
), "o" AS (
SELECT
"m"."a" AS "a",
"m"."b" AS "b"
FROM "m"
WHERE
"m"."a" = 2
)
SELECT
"n"."a" AS "a",
"n"."b" AS "b",
"o"."b" AS "b"
FROM "n"
FULL JOIN "o"
ON "n"."a" = "o"."a"
JOIN "n" AS "n2"
ON "n"."a" = "n2"."a"
WHERE
"o"."b" > 0;
# title: Broadcast hint
# dialect: spark
WITH m AS (
SELECT
x.a,
x.b
FROM x
), n AS (
SELECT
y.b,
y.c
FROM y
), joined as (
SELECT /*+ BROADCAST(n) */
m.a,
n.c
FROM m JOIN n ON m.b = n.b
)
SELECT
joined.a,
joined.c
FROM joined;
SELECT /*+ BROADCAST(`y`) */
`x`.`a` AS `a`,
`y`.`c` AS `c`
FROM `x` AS `x`
JOIN `y` AS `y`
ON `x`.`b` = `y`.`b`;
# title: Mix Table and Column Hints
# dialect: spark
WITH m AS (
SELECT
x.a,
x.b
FROM x
), n AS (
SELECT
y.b,
y.c
FROM y
), joined as (
SELECT /*+ BROADCAST(m), MERGE(m, n) */
m.a,
n.c
FROM m JOIN n ON m.b = n.b
)
SELECT
/*+ COALESCE(3) */
joined.a,
joined.c
FROM joined;
SELECT /*+ COALESCE(3),
BROADCAST(`x`),
MERGE(`x`, `y`) */
`x`.`a` AS `a`,
`y`.`c` AS `c`
FROM `x` AS `x`
JOIN `y` AS `y`
ON `x`.`b` = `y`.`b`;

View file

@ -19,38 +19,49 @@ SELECT x.a AS a FROM x AS x;
SELECT a AS b FROM x; SELECT a AS b FROM x;
SELECT x.a AS b FROM x AS x; SELECT x.a AS b FROM x AS x;
# execute: false
SELECT 1, 2 FROM x; SELECT 1, 2 FROM x;
SELECT 1 AS "_col_0", 2 AS "_col_1" FROM x AS x; SELECT 1 AS "_col_0", 2 AS "_col_1" FROM x AS x;
# execute: false
SELECT a + b FROM x; SELECT a + b FROM x;
SELECT x.a + x.b AS "_col_0" FROM x AS x; SELECT x.a + x.b AS "_col_0" FROM x AS x;
SELECT a + b FROM x; # execute: false
SELECT x.a + x.b AS "_col_0" FROM x AS x;
SELECT a, SUM(b) FROM x WHERE a > 1 AND b > 1 GROUP BY a; SELECT a, SUM(b) FROM x WHERE a > 1 AND b > 1 GROUP BY a;
SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 AND x.b > 1 GROUP BY x.a; SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 AND x.b > 1 GROUP BY x.a;
SELECT a AS j, b FROM x ORDER BY j; SELECT a AS j, b FROM x ORDER BY j;
SELECT x.a AS j, x.b AS b FROM x AS x ORDER BY j; SELECT x.a AS j, x.b AS b FROM x AS x ORDER BY j;
SELECT a AS j, b FROM x GROUP BY j; SELECT a AS j, b AS a FROM x ORDER BY 1;
SELECT x.a AS j, x.b AS b FROM x AS x GROUP BY x.a; SELECT x.a AS j, x.b AS a FROM x AS x ORDER BY x.a;
SELECT SUM(a) AS c, SUM(b) AS d FROM x ORDER BY 1, 2;
SELECT SUM(x.a) AS c, SUM(x.b) AS d FROM x AS x ORDER BY SUM(x.a), SUM(x.b);
# execute: false
SELECT SUM(a), SUM(b) AS c FROM x ORDER BY 1, 2;
SELECT SUM(x.a) AS "_col_0", SUM(x.b) AS c FROM x AS x ORDER BY SUM(x.a), SUM(x.b);
SELECT a AS j, b FROM x GROUP BY j, b;
SELECT x.a AS j, x.b AS b FROM x AS x GROUP BY x.a, x.b;
SELECT a, b FROM x GROUP BY 1, 2; SELECT a, b FROM x GROUP BY 1, 2;
SELECT x.a AS a, x.b AS b FROM x AS x GROUP BY x.a, x.b; SELECT x.a AS a, x.b AS b FROM x AS x GROUP BY x.a, x.b;
SELECT a, b FROM x ORDER BY 1, 2; SELECT a, b FROM x ORDER BY 1, 2;
SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a, b; SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a, x.b;
# execute: false
SELECT DATE(a), DATE(b) AS c FROM x GROUP BY 1, 2; SELECT DATE(a), DATE(b) AS c FROM x GROUP BY 1, 2;
SELECT DATE(x.a) AS "_col_0", DATE(x.b) AS c FROM x AS x GROUP BY DATE(x.a), DATE(x.b); SELECT DATE(x.a) AS "_col_0", DATE(x.b) AS c FROM x AS x GROUP BY DATE(x.a), DATE(x.b);
SELECT x.a AS c FROM x JOIN y ON x.b = y.b GROUP BY c; SELECT SUM(x.a) AS c FROM x JOIN y ON x.b = y.b GROUP BY c;
SELECT x.a AS c FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY y.c; SELECT SUM(x.a) AS c FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY y.c;
SELECT DATE(x.a) AS d FROM x JOIN y ON x.b = y.b GROUP BY d; SELECT COALESCE(x.a) AS d FROM x JOIN y ON x.b = y.b GROUP BY d;
SELECT DATE(x.a) AS d FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY DATE(x.a); SELECT COALESCE(x.a) AS d FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY COALESCE(x.a);
SELECT a AS a, b FROM x ORDER BY a; SELECT a AS a, b FROM x ORDER BY a;
SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a; SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a;
@ -69,6 +80,7 @@ SELECT ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.b) AS row_num FROM x AS x
SELECT x.b, x.a FROM x LEFT JOIN y ON x.b = y.b QUALIFY ROW_NUMBER() OVER(PARTITION BY x.b ORDER BY x.a DESC) = 1; SELECT x.b, x.a FROM x LEFT JOIN y ON x.b = y.b QUALIFY ROW_NUMBER() OVER(PARTITION BY x.b ORDER BY x.a DESC) = 1;
SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_NUMBER() OVER (PARTITION BY x.b ORDER BY x.a DESC) = 1; SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_NUMBER() OVER (PARTITION BY x.b ORDER BY x.a DESC) = 1;
# execute: false
SELECT AGGREGATE(ARRAY(a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x; SELECT AGGREGATE(ARRAY(a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + x.a) AS sum_agg FROM x AS x; SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + x.a) AS sum_agg FROM x AS x;
@ -93,8 +105,8 @@ SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0";
SELECT a FROM (SELECT a FROM (SELECT a FROM x)); SELECT a FROM (SELECT a FROM (SELECT a FROM x));
SELECT "_q_1".a AS a FROM (SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0") AS "_q_1"; SELECT "_q_1".a AS a FROM (SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0") AS "_q_1";
SELECT x.a FROM x AS x JOIN (SELECT * FROM x); SELECT x.a FROM x AS x JOIN (SELECT * FROM x) AS y ON x.a = y.a;
SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0"; SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a, x.b AS b FROM x AS x) AS y ON x.a = y.a;
-------------------------------------- --------------------------------------
-- Joins -- Joins
@ -123,6 +135,7 @@ SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FRO
SELECT a FROM x WHERE b IN (SELECT c FROM y); SELECT a FROM x WHERE b IN (SELECT c FROM y);
SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT y.c AS c FROM y AS y); SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT y.c AS c FROM y AS y);
# execute: false
SELECT (SELECT c FROM y) FROM x; SELECT (SELECT c FROM y) FROM x;
SELECT (SELECT y.c AS c FROM y AS y) AS "_col_0" FROM x AS x; SELECT (SELECT y.c AS c FROM y AS y) AS "_col_0" FROM x AS x;
@ -144,10 +157,12 @@ SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT x.b AS b FROM y AS x);
SELECT a FROM x AS i WHERE b IN (SELECT b FROM y AS j WHERE j.b IN (SELECT c FROM y AS k WHERE k.b = j.b)); SELECT a FROM x AS i WHERE b IN (SELECT b FROM y AS j WHERE j.b IN (SELECT c FROM y AS k WHERE k.b = j.b));
SELECT i.a AS a FROM x AS i WHERE i.b IN (SELECT j.b AS b FROM y AS j WHERE j.b IN (SELECT k.c AS c FROM y AS k WHERE k.b = j.b)); SELECT i.a AS a FROM x AS i WHERE i.b IN (SELECT j.b AS b FROM y AS j WHERE j.b IN (SELECT k.c AS c FROM y AS k WHERE k.b = j.b));
# execute: false
# dialect: bigquery # dialect: bigquery
SELECT aa FROM x, UNNEST(a) AS aa; SELECT aa FROM x, UNNEST(a) AS aa;
SELECT aa AS aa FROM x AS x, UNNEST(x.a) AS aa; SELECT aa AS aa FROM x AS x, UNNEST(x.a) AS aa;
# execute: false
SELECT aa FROM x, UNNEST(a) AS t(aa); SELECT aa FROM x, UNNEST(a) AS t(aa);
SELECT t.aa AS aa FROM x AS x, UNNEST(x.a) AS t(aa); SELECT t.aa AS aa FROM x AS x, UNNEST(x.a) AS t(aa);
@ -205,15 +220,19 @@ WITH z AS ((SELECT x.b AS b FROM x AS x UNION ALL SELECT y.b AS b FROM y AS y) O
-------------------------------------- --------------------------------------
-- Except and Replace -- Except and Replace
-------------------------------------- --------------------------------------
# execute: false
SELECT * REPLACE(a AS d) FROM x; SELECT * REPLACE(a AS d) FROM x;
SELECT x.a AS d, x.b AS b FROM x AS x; SELECT x.a AS d, x.b AS b FROM x AS x;
# execute: false
SELECT * EXCEPT(b) REPLACE(a AS d) FROM x; SELECT * EXCEPT(b) REPLACE(a AS d) FROM x;
SELECT x.a AS d FROM x AS x; SELECT x.a AS d FROM x AS x;
# execute: false
SELECT x.* EXCEPT(a), y.* FROM x, y; SELECT x.* EXCEPT(a), y.* FROM x, y;
SELECT x.b AS b, y.b AS b, y.c AS c FROM x AS x, y AS y; SELECT x.b AS b, y.b AS b, y.c AS c FROM x AS x, y AS y;
# execute: false
SELECT * EXCEPT(a) FROM x; SELECT * EXCEPT(a) FROM x;
SELECT x.b AS b FROM x AS x; SELECT x.b AS b FROM x AS x;

View file

@ -0,0 +1,35 @@
--------------------------------------
-- Qualify columns
--------------------------------------
SELECT a FROM x;
SELECT x.a AS a FROM x AS x;
SELECT b FROM x;
SELECT x.b AS b FROM x AS x;
--------------------------------------
-- Derived tables
--------------------------------------
SELECT x.a FROM x AS x JOIN (SELECT * FROM x);
SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a FROM x AS x) AS "_q_0";
SELECT x.b FROM x AS x JOIN (SELECT b FROM x);
SELECT x.b AS b FROM x AS x JOIN (SELECT x.b AS b FROM x AS x) AS "_q_0";
--------------------------------------
-- Expand *
--------------------------------------
SELECT * FROM x;
SELECT x.a AS a FROM x AS x;
SELECT * FROM y JOIN z ON y.b = z.b;
SELECT y.b AS b, z.b AS b FROM y AS y JOIN z AS z ON y.b = z.b;
SELECT * FROM y JOIN z ON y.c = z.c;
SELECT y.b AS b, z.b AS b FROM y AS y JOIN z AS z ON y.c = z.c;
SELECT a FROM (SELECT * FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0";
SELECT * FROM (SELECT a FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0";

View file

@ -52,6 +52,9 @@ TRUE;
NULL AND TRUE; NULL AND TRUE;
NULL; NULL;
NULL AND FALSE;
FALSE;
NULL AND NULL; NULL AND NULL;
NULL; NULL;
@ -70,6 +73,9 @@ FALSE;
NOT FALSE; NOT FALSE;
TRUE; TRUE;
NOT NULL;
NULL;
NULL = NULL; NULL = NULL;
NULL; NULL;

View file

@ -769,13 +769,20 @@ group by
order by order by
custdist desc, custdist desc,
c_count desc; c_count desc;
WITH "c_orders" AS ( WITH "orders_2" AS (
SELECT
"orders"."o_orderkey" AS "o_orderkey",
"orders"."o_custkey" AS "o_custkey",
"orders"."o_comment" AS "o_comment"
FROM "orders" AS "orders"
WHERE
NOT "orders"."o_comment" LIKE '%special%requests%'
), "c_orders" AS (
SELECT SELECT
COUNT("orders"."o_orderkey") AS "c_count" COUNT("orders"."o_orderkey") AS "c_count"
FROM "customer" AS "customer" FROM "customer" AS "customer"
LEFT JOIN "orders" AS "orders" LEFT JOIN "orders_2" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey" ON "customer"."c_custkey" = "orders"."o_custkey"
AND NOT "orders"."o_comment" LIKE '%special%requests%'
GROUP BY GROUP BY
"customer"."c_custkey" "customer"."c_custkey"
) )

View file

@ -45,6 +45,14 @@ def load_sql_fixture_pairs(filename):
yield meta, sql, expected yield meta, sql, expected
def string_to_bool(string):
if string is None:
return False
if string in (True, False):
return string
return string and string.lower() in ("true", "1")
TPCH_SCHEMA = { TPCH_SCHEMA = {
"lineitem": { "lineitem": {
"l_orderkey": "uint64", "l_orderkey": "uint64",

View file

@ -1,6 +1,19 @@
import unittest import unittest
from sqlglot import and_, condition, exp, from_, not_, or_, parse_one, select from sqlglot import (
alias,
and_,
condition,
except_,
exp,
from_,
intersect,
not_,
or_,
parse_one,
select,
union,
)
class TestBuild(unittest.TestCase): class TestBuild(unittest.TestCase):
@ -320,6 +333,54 @@ class TestBuild(unittest.TestCase):
lambda: exp.update("tbl", {"x": 1}, from_="tbl2"), lambda: exp.update("tbl", {"x": 1}, from_="tbl2"),
"UPDATE tbl SET x = 1 FROM tbl2", "UPDATE tbl SET x = 1 FROM tbl2",
), ),
(
lambda: union("SELECT * FROM foo", "SELECT * FROM bla"),
"SELECT * FROM foo UNION SELECT * FROM bla",
),
(
lambda: parse_one("SELECT * FROM foo").union("SELECT * FROM bla"),
"SELECT * FROM foo UNION SELECT * FROM bla",
),
(
lambda: intersect("SELECT * FROM foo", "SELECT * FROM bla"),
"SELECT * FROM foo INTERSECT SELECT * FROM bla",
),
(
lambda: parse_one("SELECT * FROM foo").intersect("SELECT * FROM bla"),
"SELECT * FROM foo INTERSECT SELECT * FROM bla",
),
(
lambda: except_("SELECT * FROM foo", "SELECT * FROM bla"),
"SELECT * FROM foo EXCEPT SELECT * FROM bla",
),
(
lambda: parse_one("SELECT * FROM foo").except_("SELECT * FROM bla"),
"SELECT * FROM foo EXCEPT SELECT * FROM bla",
),
(
lambda: parse_one("(SELECT * FROM foo)").union("SELECT * FROM bla"),
"(SELECT * FROM foo) UNION SELECT * FROM bla",
),
(
lambda: parse_one("(SELECT * FROM foo)").union("SELECT * FROM bla", distinct=False),
"(SELECT * FROM foo) UNION ALL SELECT * FROM bla",
),
(
lambda: alias(parse_one("LAG(x) OVER (PARTITION BY y)"), "a"),
"LAG(x) OVER (PARTITION BY y) AS a",
),
(
lambda: alias(parse_one("LAG(x) OVER (ORDER BY z)"), "a"),
"LAG(x) OVER (ORDER BY z) AS a",
),
(
lambda: alias(parse_one("LAG(x) OVER (PARTITION BY y ORDER BY z)"), "a"),
"LAG(x) OVER (PARTITION BY y ORDER BY z) AS a",
),
(
lambda: alias(parse_one("LAG(x) OVER ()"), "a"),
"LAG(x) OVER () AS a",
),
]: ]:
with self.subTest(sql): with self.subTest(sql):
self.assertEqual(expression().sql(dialect[0] if dialect else None), sql) self.assertEqual(expression().sql(dialect[0] if dialect else None), sql)

View file

@ -115,6 +115,21 @@ class TestExpressions(unittest.TestCase):
["first", "second", "third"], ["first", "second", "third"],
) )
def test_table_name(self):
self.assertEqual(exp.table_name(parse_one("a", into=exp.Table)), "a")
self.assertEqual(exp.table_name(parse_one("a.b", into=exp.Table)), "a.b")
self.assertEqual(exp.table_name(parse_one("a.b.c", into=exp.Table)), "a.b.c")
self.assertEqual(exp.table_name("a.b.c"), "a.b.c")
def test_replace_tables(self):
self.assertEqual(
exp.replace_tables(
parse_one("select * from a join b join c.a join d.a join e.a"),
{"a": "a1", "b": "b.a", "c.a": "c.a2", "d.a": "d2"},
).sql(),
'SELECT * FROM "a1" JOIN "b"."a" JOIN "c"."a2" JOIN "d2" JOIN e.a',
)
def test_named_selects(self): def test_named_selects(self):
expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz") expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"]) self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"])
@ -474,3 +489,10 @@ class TestExpressions(unittest.TestCase):
]: ]:
with self.subTest(value): with self.subTest(value):
self.assertEqual(exp.convert(value).sql(), expected) self.assertEqual(exp.convert(value).sql(), expected)
def test_annotation_alias(self):
expression = parse_one("SELECT a, b AS B, c #comment, d AS D #another_comment FROM foo")
self.assertEqual(
[e.alias_or_name for e in expression.expressions],
["a", "B", "c", "D"],
)

View file

@ -1,17 +1,55 @@
import unittest import unittest
from functools import partial from functools import partial
import duckdb
from pandas.testing import assert_frame_equal
import sqlglot
from sqlglot import exp, optimizer, parse_one, table from sqlglot import exp, optimizer, parse_one, table
from sqlglot.errors import OptimizeError from sqlglot.errors import OptimizeError
from sqlglot.optimizer.annotate_types import annotate_types from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.schema import MappingSchema, ensure_schema from sqlglot.optimizer.schema import MappingSchema, ensure_schema
from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope
from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures from tests.helpers import (
TPCH_SCHEMA,
load_sql_fixture_pairs,
load_sql_fixtures,
string_to_bool,
)
class TestOptimizer(unittest.TestCase): class TestOptimizer(unittest.TestCase):
maxDiff = None maxDiff = None
@classmethod
def setUpClass(cls):
cls.conn = duckdb.connect()
cls.conn.execute(
"""
CREATE TABLE x (a INT, b INT);
CREATE TABLE y (b INT, c INT);
CREATE TABLE z (b INT, c INT);
INSERT INTO x VALUES (1, 1);
INSERT INTO x VALUES (2, 2);
INSERT INTO x VALUES (2, 2);
INSERT INTO x VALUES (3, 3);
INSERT INTO x VALUES (null, null);
INSERT INTO y VALUES (2, 2);
INSERT INTO y VALUES (2, 2);
INSERT INTO y VALUES (3, 3);
INSERT INTO y VALUES (4, 4);
INSERT INTO y VALUES (null, null);
INSERT INTO y VALUES (3, 3);
INSERT INTO y VALUES (3, 3);
INSERT INTO y VALUES (4, 4);
INSERT INTO y VALUES (5, 5);
INSERT INTO y VALUES (null, null);
"""
)
def setUp(self): def setUp(self):
self.schema = { self.schema = {
"x": { "x": {
@ -28,29 +66,42 @@ class TestOptimizer(unittest.TestCase):
}, },
} }
def check_file(self, file, func, pretty=False, **kwargs): def check_file(self, file, func, pretty=False, execute=False, **kwargs):
for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1): for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1):
title = meta.get("title") or f"{i}, {sql}"
dialect = meta.get("dialect") dialect = meta.get("dialect")
leave_tables_isolated = meta.get("leave_tables_isolated") leave_tables_isolated = meta.get("leave_tables_isolated")
func_kwargs = {**kwargs} func_kwargs = {**kwargs}
if leave_tables_isolated is not None: if leave_tables_isolated is not None:
func_kwargs["leave_tables_isolated"] = leave_tables_isolated.lower() in ("true", "1") func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated)
with self.subTest(f"{i}, {sql}"): optimized = func(parse_one(sql, read=dialect), **func_kwargs)
with self.subTest(title):
self.assertEqual( self.assertEqual(
func(parse_one(sql, read=dialect), **func_kwargs).sql(pretty=pretty, dialect=dialect), optimized.sql(pretty=pretty, dialect=dialect),
expected, expected,
) )
should_execute = meta.get("execute")
if should_execute is None:
should_execute = execute
if string_to_bool(should_execute):
with self.subTest(f"(execute) {title}"):
df1 = self.conn.execute(sqlglot.transpile(sql, read=dialect, write="duckdb")[0]).df()
df2 = self.conn.execute(optimized.sql(pretty=pretty, dialect="duckdb")).df()
assert_frame_equal(df1, df2)
def test_optimize(self): def test_optimize(self):
schema = { schema = {
"x": {"a": "INT", "b": "INT"}, "x": {"a": "INT", "b": "INT"},
"y": {"a": "INT", "b": "INT"}, "y": {"b": "INT", "c": "INT"},
"z": {"a": "INT", "c": "INT"}, "z": {"a": "INT", "c": "INT"},
} }
self.check_file("optimizer", optimizer.optimize, pretty=True, schema=schema) self.check_file("optimizer", optimizer.optimize, pretty=True, execute=True, schema=schema)
def test_isolate_table_selects(self): def test_isolate_table_selects(self):
self.check_file( self.check_file(
@ -86,7 +137,16 @@ class TestOptimizer(unittest.TestCase):
expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
return expression return expression
self.check_file("qualify_columns", qualify_columns, schema=self.schema) self.check_file("qualify_columns", qualify_columns, execute=True, schema=self.schema)
def test_qualify_columns__with_invisible(self):
def qualify_columns(expression, **kwargs):
expression = optimizer.qualify_tables.qualify_tables(expression)
expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
return expression
schema = MappingSchema(self.schema, {"x": {"a"}, "y": {"b"}, "z": {"b"}})
self.check_file("qualify_columns__with_invisible", qualify_columns, schema=schema)
def test_qualify_columns__invalid(self): def test_qualify_columns__invalid(self):
for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"): for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"):
@ -141,7 +201,7 @@ class TestOptimizer(unittest.TestCase):
], ],
) )
self.check_file("merge_subqueries", optimize, schema=self.schema) self.check_file("merge_subqueries", optimize, execute=True, schema=self.schema)
def test_eliminate_subqueries(self): def test_eliminate_subqueries(self):
self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries) self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries)
@ -301,10 +361,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
} }
for sql, target_type in tests.items(): for sql, target_type in tests.items():
expression = parse_one(sql) expression = annotate_types(parse_one(sql))
annotated_expression = annotate_types(expression) self.assertEqual(expression.find(exp.Literal).type, target_type)
self.assertEqual(annotated_expression.find(exp.Literal).type, target_type)
def test_boolean_type_annotation(self): def test_boolean_type_annotation(self):
tests = { tests = {
@ -313,14 +371,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
} }
for sql, target_type in tests.items(): for sql, target_type in tests.items():
expression = parse_one(sql) expression = annotate_types(parse_one(sql))
annotated_expression = annotate_types(expression) self.assertEqual(expression.find(exp.Boolean).type, target_type)
self.assertEqual(annotated_expression.find(exp.Boolean).type, target_type)
def test_cast_type_annotation(self): def test_cast_type_annotation(self):
expression = parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))") expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))"))
annotate_types(expression)
self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ) self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ)
self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR) self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR)
@ -328,16 +383,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT) self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT)
def test_cache_annotation(self): def test_cache_annotation(self):
expression = parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1") expression = annotate_types(parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1"))
annotated_expression = annotate_types(expression) self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT)
self.assertEqual(annotated_expression.expression.expressions[0].type, exp.DataType.Type.INT)
def test_binary_annotation(self): def test_binary_annotation(self):
expression = parse_one("SELECT 0.0 + (2 + 3)") expression = annotate_types(parse_one("SELECT 0.0 + (2 + 3)")).expressions[0]
annotate_types(expression)
expression = expression.expressions[0]
self.assertEqual(expression.type, exp.DataType.Type.DOUBLE) self.assertEqual(expression.type, exp.DataType.Type.DOUBLE)
self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE) self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE)
@ -345,3 +395,124 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expression.right.this.type, exp.DataType.Type.INT) self.assertEqual(expression.right.this.type, exp.DataType.Type.INT)
self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT) self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT)
self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT) self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT)
def test_derived_tables_column_annotation(self):
schema = {"x": {"cola": "INT"}, "y": {"cola": "FLOAT"}}
sql = """
SELECT a.cola AS cola
FROM (
SELECT x.cola + y.cola AS cola
FROM (
SELECT x.cola AS cola
FROM x AS x
) AS x
JOIN (
SELECT y.cola AS cola
FROM y AS y
) AS y
) AS a
"""
expression = annotate_types(parse_one(sql), schema=schema)
self.assertEqual(expression.expressions[0].type, exp.DataType.Type.FLOAT) # a.cola AS cola
addition_alias = expression.args["from"].expressions[0].this.expressions[0]
self.assertEqual(addition_alias.type, exp.DataType.Type.FLOAT) # x.cola + y.cola AS cola
addition = addition_alias.this
self.assertEqual(addition.type, exp.DataType.Type.FLOAT)
self.assertEqual(addition.this.type, exp.DataType.Type.INT)
self.assertEqual(addition.expression.type, exp.DataType.Type.FLOAT)
def test_cte_column_annotation(self):
schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT"}}
sql = """
WITH tbl AS (
SELECT x.cola + 'bla' AS cola, y.colb AS colb
FROM (
SELECT x.cola AS cola
FROM x AS x
) AS x
JOIN (
SELECT y.colb AS colb
FROM y AS y
) AS y
)
SELECT tbl.cola + tbl.colb + 'foo' AS col
FROM tbl AS tbl
"""
expression = annotate_types(parse_one(sql), schema=schema)
self.assertEqual(expression.expressions[0].type, exp.DataType.Type.TEXT) # tbl.cola + tbl.colb + 'foo' AS col
outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo'
self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT)
self.assertEqual(outer_addition.left.type, exp.DataType.Type.TEXT)
self.assertEqual(outer_addition.right.type, exp.DataType.Type.VARCHAR)
inner_addition = expression.expressions[0].this.left # tbl.cola + tbl.colb
self.assertEqual(inner_addition.left.type, exp.DataType.Type.VARCHAR)
self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT)
cte_select = expression.args["with"].expressions[0].this
self.assertEqual(cte_select.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola + 'bla' AS cola
self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb
cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla'
self.assertEqual(cte_select_addition.type, exp.DataType.Type.VARCHAR)
self.assertEqual(cte_select_addition.left.type, exp.DataType.Type.CHAR)
self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR)
# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
for d, t in zip(cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]):
self.assertEqual(d.this.expressions[0].this.type, t)
def test_function_annotation(self):
schema = {"x": {"cola": "VARCHAR", "colb": "CHAR"}}
sql = "SELECT x.cola || TRIM(x.colb) AS col FROM x AS x"
concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
self.assertEqual(concat_expr_alias.type, exp.DataType.Type.VARCHAR)
concat_expr = concat_expr_alias.this
self.assertEqual(concat_expr.type, exp.DataType.Type.VARCHAR)
self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb
def test_unknown_annotation(self):
schema = {"x": {"cola": "VARCHAR"}}
sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x"
concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
self.assertEqual(concat_expr_alias.type, exp.DataType.Type.UNKNOWN)
concat_expr = concat_expr_alias.this
self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN)
self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) # SOME_ANONYMOUS_FUNC(x.cola)
self.assertEqual(concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola (arg)
def test_null_annotation(self):
expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this
self.assertEqual(expression.left.type, exp.DataType.Type.NULL)
self.assertEqual(expression.right.type, exp.DataType.Type.INT)
# NULL <op> UNKNOWN should yield NULL
sql = "SELECT NULL || SOME_ANONYMOUS_FUNC() AS result"
concat_expr_alias = annotate_types(parse_one(sql)).expressions[0]
self.assertEqual(concat_expr_alias.type, exp.DataType.Type.NULL)
concat_expr = concat_expr_alias.this
self.assertEqual(concat_expr.type, exp.DataType.Type.NULL)
self.assertEqual(concat_expr.left.type, exp.DataType.Type.NULL)
self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN)
def test_nullable_annotation(self):
nullable = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
expression = annotate_types(parse_one("NULL AND FALSE"))
self.assertEqual(expression.type, nullable)
self.assertEqual(expression.left.type, exp.DataType.Type.NULL)
self.assertEqual(expression.right.type, exp.DataType.Type.BOOLEAN)

View file

@ -338,7 +338,7 @@ class TestTranspile(unittest.TestCase):
unsupported_level=level, unsupported_level=level,
) )
error = "Cannot convert array columns into map use SparkSQL instead." error = "Cannot convert array columns into map."
unsupported(ErrorLevel.WARN) unsupported(ErrorLevel.WARN)
assert_logger_contains("\n".join([error] * 4), logger, level="warning") assert_logger_contains("\n".join([error] * 4), logger, level="warning")