1
0
Fork 0

Adding upstream version 6.2.6.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:37:25 +01:00
parent 8425a9678d
commit d62bab68ae
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
49 changed files with 1741 additions and 566 deletions

View file

@ -1,6 +1,8 @@
#!/bin/bash -e
python -m autoflake -i -r \
[[ -z "${GITHUB_ACTIONS}" ]] && RETURN_ERROR_CODE='' || RETURN_ERROR_CODE='--check'
python -m autoflake -i -r ${RETURN_ERROR_CODE} \
--expand-star-imports \
--remove-all-unused-imports \
--ignore-init-module-imports \
@ -8,5 +10,5 @@ python -m autoflake -i -r \
--remove-unused-variables \
sqlglot/ tests/
python -m isort --profile black sqlglot/ tests/
python -m black --line-length 120 sqlglot/ tests/
python -m black ${RETURN_ERROR_CODE} --line-length 120 sqlglot/ tests/
python -m unittest

View file

@ -20,7 +20,7 @@ from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "6.2.1"
__version__ = "6.2.6"
pretty = False

View file

@ -33,6 +33,49 @@ def _date_add_sql(data_type, kind):
return func
def _subquery_to_unnest_if_values(self, expression):
if not isinstance(expression.this, exp.Values):
return self.subquery_sql(expression)
rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.this.find_all(exp.Tuple)]
structs = []
for row in rows:
aliases = [
exp.alias_(value, column_name) for value, column_name in zip(row, expression.args["alias"].args["columns"])
]
structs.append(exp.Struct(expressions=aliases))
unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)])
return self.unnest_sql(unnest_exp)
def _returnsproperty_sql(self, expression):
value = expression.args.get("value")
if isinstance(value, exp.Schema):
value = f"{value.this} <{self.expressions(value)}>"
else:
value = self.sql(value)
return f"RETURNS {value}"
def _create_sql(self, expression):
kind = expression.args.get("kind")
returns = expression.find(exp.ReturnsProperty)
if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"):
expression = expression.copy()
expression.set("kind", "TABLE FUNCTION")
if isinstance(
expression.expression,
(
exp.Subquery,
exp.Literal,
),
):
expression.set("expression", expression.expression.this)
return self.create_sql(expression)
return self.create_sql(expression)
class BigQuery(Dialect):
unnest_column_only = True
@ -77,8 +120,14 @@ class BigQuery(Dialect):
TokenType.CURRENT_TIME: exp.CurrentTime,
}
NESTED_TYPE_TOKENS = {
*Parser.NESTED_TYPE_TOKENS,
TokenType.TABLE,
}
class Generator(Generator):
TRANSFORMS = {
**Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
@ -91,6 +140,9 @@ class BigQuery(Dialect):
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.VariancePop: rename_func("VAR_POP"),
exp.Subquery: _subquery_to_unnest_if_values,
exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql,
}
TYPE_MAPPING = {

View file

@ -245,6 +245,11 @@ def no_tablesample_sql(self, expression):
return self.sql(expression.this)
def no_pivot_sql(self, expression):
self.unsupported("PIVOT unsupported")
return self.sql(expression)
def no_trycast_sql(self, expression):
return self.cast_sql(expression)
@ -282,3 +287,30 @@ def format_time_lambda(exp_class, dialect, default=None):
)
return _format_time
def create_with_partitions_sql(self, expression):
"""
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
columns are removed from the create statement.
"""
has_schema = isinstance(expression.this, exp.Schema)
is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
if has_schema and is_partitionable:
expression = expression.copy()
prop = expression.find(exp.PartitionedByProperty)
value = prop and prop.args.get("value")
if prop and not isinstance(value, exp.Schema):
schema = expression.this
columns = {v.name.upper() for v in value.expressions}
partitions = [col for col in schema.expressions if col.name.upper() in columns]
schema.set(
"expressions",
[e for e in schema.expressions if e not in partitions],
)
prop.replace(exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions)))
expression.set("this", schema)
return self.create_sql(expression)

View file

@ -5,6 +5,7 @@ from sqlglot.dialects.dialect import (
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
format_time_lambda,
no_pivot_sql,
no_safe_divide_sql,
no_tablesample_sql,
rename_func,
@ -122,6 +123,7 @@ class DuckDB(Dialect):
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.Pivot: no_pivot_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"),
exp.SafeDivide: no_safe_divide_sql,

View file

@ -2,6 +2,7 @@ from sqlglot import exp, transforms
from sqlglot.dialects.dialect import (
Dialect,
approx_count_distinct_sql,
create_with_partitions_sql,
format_time_lambda,
if_sql,
no_ilike_sql,
@ -53,7 +54,7 @@ def _array_sort(self, expression):
def _property_sql(self, expression):
key = expression.name
value = self.sql(expression, "value")
return f"'{key}' = {value}"
return f"'{key}'={value}"
def _str_to_unix(self, expression):
@ -218,15 +219,6 @@ class Hive(Dialect):
}
class Generator(Generator):
ROOT_PROPERTIES = [
exp.PartitionedByProperty,
exp.FileFormatProperty,
exp.SchemaCommentProperty,
exp.LocationProperty,
exp.TableFormatProperty,
]
WITH_PROPERTIES = [exp.AnonymousProperty]
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
@ -255,13 +247,13 @@ class Hive(Dialect):
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.Map: _map_sql,
HiveMap: _map_sql,
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e.args['value'])}",
exp.Create: create_with_partitions_sql,
exp.Quantile: rename_func("PERCENTILE"),
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
exp.RegexpSplit: rename_func("SPLIT"),
exp.SafeDivide: no_safe_divide_sql,
exp.SchemaCommentProperty: lambda self, e: f"COMMENT {self.sql(e.args['value'])}",
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.SetAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
exp.StrPosition: lambda self, e: f"LOCATE({csv(self.sql(e, 'substr'), self.sql(e, 'this'), self.sql(e, 'position'))})",
@ -282,6 +274,17 @@ class Hive(Dialect):
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({csv(self.sql(e, 'this'), _time_format(self, e))})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'value')}",
}
WITH_PROPERTIES = {exp.AnonymousProperty}
ROOT_PROPERTIES = {
exp.PartitionedByProperty,
exp.FileFormatProperty,
exp.SchemaCommentProperty,
exp.LocationProperty,
exp.TableFormatProperty,
}
def with_properties(self, properties):

View file

@ -172,6 +172,11 @@ class MySQL(Dialect):
),
}
PROPERTY_PARSERS = {
**Parser.PROPERTY_PARSERS,
TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty),
}
class Generator(Generator):
NULL_ORDERING_SUPPORTED = False
@ -190,3 +195,13 @@ class MySQL(Dialect):
exp.StrToTime: _str_to_date_sql,
exp.Trim: _trim_sql,
}
ROOT_PROPERTIES = {
exp.EngineProperty,
exp.AutoIncrementProperty,
exp.CharacterSetProperty,
exp.CollateProperty,
exp.SchemaCommentProperty,
}
WITH_PROPERTIES = {}

View file

@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import (
no_paren_current_date_sql,
no_tablesample_sql,
no_trycast_sql,
str_position_sql,
)
from sqlglot.generator import Generator
from sqlglot.parser import Parser
@ -158,7 +159,6 @@ class Postgres(Dialect):
"ALWAYS": TokenType.ALWAYS,
"BY DEFAULT": TokenType.BY_DEFAULT,
"IDENTITY": TokenType.IDENTITY,
"FOR": TokenType.FOR,
"GENERATED": TokenType.GENERATED,
"DOUBLE PRECISION": TokenType.DOUBLE,
"BIGSERIAL": TokenType.BIGSERIAL,
@ -204,6 +204,7 @@ class Postgres(Dialect):
exp.DateAdd: _date_add_sql("+"),
exp.DateSub: _date_add_sql("-"),
exp.Lateral: _lateral_sql,
exp.StrPosition: str_position_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Substring: _substring_sql,
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",

View file

@ -146,13 +146,16 @@ class Presto(Dialect):
STRUCT_DELIMITER = ("(", ")")
WITH_PROPERTIES = [
ROOT_PROPERTIES = {
exp.SchemaCommentProperty,
}
WITH_PROPERTIES = {
exp.PartitionedByProperty,
exp.FileFormatProperty,
exp.SchemaCommentProperty,
exp.AnonymousProperty,
exp.TableFormatProperty,
]
}
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
@ -184,13 +187,11 @@ class Presto(Dialect):
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.FileFormatProperty: lambda self, e: self.property_sql(e),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
exp.Lateral: _explode_to_unnest_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY = {self.sql(e.args['value'])}",
exp.Quantile: _quantile_sql,
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,

View file

@ -1,5 +1,10 @@
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, format_time_lambda, rename_func
from sqlglot.dialects.dialect import (
Dialect,
format_time_lambda,
inline_array_sql,
rename_func,
)
from sqlglot.expressions import Literal
from sqlglot.generator import Generator
from sqlglot.helper import list_get
@ -104,6 +109,8 @@ class Snowflake(Dialect):
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"IFF": exp.If.from_arg_list,
"TO_TIMESTAMP": _snowflake_to_timestamp,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
}
FUNCTION_PARSERS = {
@ -111,6 +118,11 @@ class Snowflake(Dialect):
"DATE_PART": lambda self: self._parse_extract(),
}
FUNC_TOKENS = {
*Parser.FUNC_TOKENS,
TokenType.RLIKE,
}
COLUMN_OPERATORS = {
**Parser.COLUMN_OPERATORS,
TokenType.COLON: lambda self, this, path: self.expression(
@ -120,6 +132,11 @@ class Snowflake(Dialect):
),
}
PROPERTY_PARSERS = {
**Parser.PROPERTY_PARSERS,
TokenType.PARTITION_BY: lambda self: self._parse_partitioned_by(),
}
class Tokenizer(Tokenizer):
QUOTES = ["'", "$$"]
ESCAPE = "\\"
@ -137,6 +154,7 @@ class Snowflake(Dialect):
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPNTZ": TokenType.TIMESTAMP,
"SAMPLE": TokenType.TABLE_SAMPLE,
}
class Generator(Generator):
@ -145,6 +163,8 @@ class Snowflake(Dialect):
exp.If: rename_func("IFF"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time,
exp.Array: inline_array_sql,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
}
TYPE_MAPPING = {
@ -152,6 +172,13 @@ class Snowflake(Dialect):
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}
ROOT_PROPERTIES = {
exp.PartitionedByProperty,
exp.ReturnsProperty,
exp.LanguageProperty,
exp.SchemaCommentProperty,
}
def except_op(self, expression):
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT with All is not supported in Snowflake")

View file

@ -1,5 +1,9 @@
from sqlglot import exp
from sqlglot.dialects.dialect import no_ilike_sql, rename_func
from sqlglot.dialects.dialect import (
create_with_partitions_sql,
no_ilike_sql,
rename_func,
)
from sqlglot.dialects.hive import Hive, HiveMap
from sqlglot.helper import list_get
@ -10,7 +14,7 @@ def _create_sql(self, e):
if kind.upper() == "TABLE" and temporary is True:
return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
return self.create_sql(e)
return create_with_partitions_sql(self, e)
def _map_sql(self, expression):
@ -73,6 +77,7 @@ class Spark(Hive):
}
class Generator(Hive.Generator):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "BYTE",

View file

@ -1,4 +1,5 @@
from sqlglot import exp
from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.mysql import MySQL
@ -10,3 +11,12 @@ class StarRocks(MySQL):
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
}
TRANSFORMS = {
**MySQL.Generator.TRANSFORMS,
exp.DateDiff: rename_func("DATEDIFF"),
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
}

View file

@ -1,6 +1,7 @@
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
@ -17,6 +18,7 @@ class TSQL(Dialect):
"REAL": TokenType.FLOAT,
"NTEXT": TokenType.TEXT,
"SMALLDATETIME": TokenType.DATETIME,
"DATETIME2": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"TIME": TokenType.TIMESTAMP,
"VARBINARY": TokenType.BINARY,
@ -24,15 +26,24 @@ class TSQL(Dialect):
"MONEY": TokenType.MONEY,
"SMALLMONEY": TokenType.SMALLMONEY,
"ROWVERSION": TokenType.ROWVERSION,
"SQL_VARIANT": TokenType.SQL_VARIANT,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"XML": TokenType.XML,
"SQL_VARIANT": TokenType.VARIANT,
}
class Parser(Parser):
def _parse_convert(self):
to = self._parse_types()
self._match(TokenType.COMMA)
this = self._parse_field()
return self.expression(exp.Cast, this=this, to=to)
class Generator(Generator):
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "BIT",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC",
exp.DataType.Type.DATETIME: "DATETIME2",
exp.DataType.Type.VARIANT: "SQL_VARIANT",
}

View file

@ -3,17 +3,11 @@ import time
from sqlglot import parse_one
from sqlglot.executor.python import PythonExecutor
from sqlglot.optimizer import RULES, optimize
from sqlglot.optimizer.merge_derived_tables import merge_derived_tables
from sqlglot.optimizer import optimize
from sqlglot.planner import Plan
logger = logging.getLogger("sqlglot")
OPTIMIZER_RULES = list(RULES)
# The executor needs isolated table selects
OPTIMIZER_RULES.remove(merge_derived_tables)
def execute(sql, schema, read=None):
"""
@ -34,7 +28,7 @@ def execute(sql, schema, read=None):
"""
expression = parse_one(sql, read=read)
now = time.time()
expression = optimize(expression, schema, rules=OPTIMIZER_RULES)
expression = optimize(expression, schema, leave_tables_isolated=True)
logger.debug("Optimization finished: %f", time.time() - now)
logger.debug("Optimized SQL: %s", expression.sql(pretty=True))
plan = Plan(expression)

View file

@ -1,13 +1,17 @@
import inspect
import numbers
import re
import sys
from collections import deque
from copy import deepcopy
from enum import auto
from sqlglot.errors import ParseError
from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list, list_get
from sqlglot.helper import (
AutoName,
camel_to_snake_case,
ensure_list,
list_get,
subclasses,
)
class _Expression(type):
@ -31,12 +35,13 @@ class Expression(metaclass=_Expression):
key = None
arg_types = {"this": True}
__slots__ = ("args", "parent", "arg_key")
__slots__ = ("args", "parent", "arg_key", "type")
def __init__(self, **args):
self.args = args
self.parent = None
self.arg_key = None
self.type = None
for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
@ -384,7 +389,7 @@ class Expression(metaclass=_Expression):
'SELECT y FROM tbl'
Args:
expression (Expression): new node
expression (Expression|None): new node
Returns :
the new expression or expressions
@ -398,6 +403,12 @@ class Expression(metaclass=_Expression):
replace_children(parent, lambda child: expression if child is self else child)
return expression
def pop(self):
"""
Remove this expression from its AST.
"""
self.replace(None)
def assert_is(self, type_):
"""
Assert that this `Expression` is an instance of `type_`.
@ -527,9 +538,18 @@ class Create(Expression):
"temporary": False,
"replace": False,
"unique": False,
"materialized": False,
}
class UserDefinedFunction(Expression):
arg_types = {"this": True, "expressions": False}
class UserDefinedFunctionKwarg(Expression):
arg_types = {"this": True, "kind": True, "default": False}
class CharacterSet(Expression):
arg_types = {"this": True, "default": False}
@ -887,6 +907,14 @@ class AnonymousProperty(Property):
pass
class ReturnsProperty(Property):
arg_types = {"this": True, "value": True, "is_table": False}
class LanguageProperty(Property):
pass
class Properties(Expression):
arg_types = {"expressions": True}
@ -907,25 +935,9 @@ class Properties(Expression):
expressions = []
for key, value in properties_dict.items():
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
expressions.append(property_cls(this=Literal.string(key), value=cls._convert_value(value)))
expressions.append(property_cls(this=Literal.string(key), value=convert(value)))
return cls(expressions=expressions)
@staticmethod
def _convert_value(value):
if value is None:
return NULL
if isinstance(value, Expression):
return value
if isinstance(value, bool):
return Boolean(this=value)
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, numbers.Number):
return Literal.number(value)
if isinstance(value, list):
return Tuple(expressions=[Properties._convert_value(v) for v in value])
raise ValueError(f"Unsupported type '{type(value)}' for value '{value}'")
class Qualify(Expression):
pass
@ -1030,6 +1042,7 @@ class Subqueryable:
QUERY_MODIFIERS = {
"laterals": False,
"joins": False,
"pivots": False,
"where": False,
"group": False,
"having": False,
@ -1051,6 +1064,7 @@ class Table(Expression):
"catalog": False,
"laterals": False,
"joins": False,
"pivots": False,
}
@ -1643,6 +1657,16 @@ class TableSample(Expression):
"percent": False,
"rows": False,
"size": False,
"seed": False,
}
class Pivot(Expression):
arg_types = {
"this": False,
"expressions": True,
"field": True,
"unpivot": True,
}
@ -1741,7 +1765,8 @@ class DataType(Expression):
SMALLMONEY = auto()
ROWVERSION = auto()
IMAGE = auto()
SQL_VARIANT = auto()
VARIANT = auto()
OBJECT = auto()
@classmethod
def build(cls, dtype, **kwargs):
@ -2124,6 +2149,7 @@ class TryCast(Cast):
class Ceil(Func):
arg_types = {"this": True, "decimals": False}
_sql_names = ["CEIL", "CEILING"]
@ -2254,7 +2280,7 @@ class Explode(Func):
class Floor(Func):
pass
arg_types = {"this": True, "decimals": False}
class Greatest(Func):
@ -2371,7 +2397,7 @@ class Reduce(Func):
class RegexpLike(Func):
arg_types = {"this": True, "expression": True}
arg_types = {"this": True, "expression": True, "flag": False}
class RegexpSplit(Func):
@ -2540,6 +2566,8 @@ def _norm_args(expression):
for k, arg in expression.args.items():
if isinstance(arg, list):
arg = [_norm_arg(a) for a in arg]
if not arg:
arg = None
else:
arg = _norm_arg(arg)
@ -2553,17 +2581,7 @@ def _norm_arg(arg):
return arg.lower() if isinstance(arg, str) else arg
def _all_functions():
return [
obj
for _, obj in inspect.getmembers(
sys.modules[__name__],
lambda obj: inspect.isclass(obj) and issubclass(obj, Func) and obj not in (AggFunc, Anonymous, Func),
)
]
ALL_FUNCTIONS = _all_functions()
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
def maybe_parse(
@ -2793,6 +2811,37 @@ def from_(*expressions, dialect=None, **opts):
return Select().from_(*expressions, dialect=dialect, **opts)
def update(table, properties, where=None, from_=None, dialect=None, **opts):
"""
Creates an update statement.
Example:
>>> update("my_table", {"x": 1, "y": "2", "z": None}, from_="baz", where="id > 1").sql()
"UPDATE my_table SET x = 1, y = '2', z = NULL FROM baz WHERE id > 1"
Args:
*properties (Dict[str, Any]): dictionary of properties to set which are
auto converted to sql objects eg None -> NULL
where (str): sql conditional parsed into a WHERE statement
from_ (str): sql statement parsed into a FROM statement
dialect (str): the dialect used to parse the input expressions.
**opts: other options to use to parse the input expressions.
Returns:
Update: the syntax tree for the UPDATE statement.
"""
update = Update(this=maybe_parse(table, into=Table, dialect=dialect))
update.set(
"expressions",
[EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) for k, v in properties.items()],
)
if from_:
update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts))
if where:
update.set("where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts))
return update
def condition(expression, dialect=None, **opts):
"""
Initialize a logical condition expression.
@ -2980,12 +3029,13 @@ def column(col, table=None, quoted=None):
def table_(table, db=None, catalog=None, quoted=None):
"""
Build a Table.
"""Build a Table.
Args:
table (str or Expression): column name
db (str or Expression): db name
catalog (str or Expression): catalog name
Returns:
Table: table instance
"""
@ -2996,6 +3046,39 @@ def table_(table, db=None, catalog=None, quoted=None):
)
def convert(value):
"""Convert a python value into an expression object.
Raises an error if a conversion is not possible.
Args:
value (Any): a python object
Returns:
Expression: the equivalent expression object
"""
if isinstance(value, Expression):
return value
if value is None:
return NULL
if isinstance(value, bool):
return Boolean(this=value)
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, numbers.Number):
return Literal.number(value)
if isinstance(value, tuple):
return Tuple(expressions=[convert(v) for v in value])
if isinstance(value, list):
return Array(expressions=[convert(v) for v in value])
if isinstance(value, dict):
return Map(
keys=[convert(k) for k in value.keys()],
values=[convert(v) for v in value.values()],
)
raise ValueError(f"Cannot convert {value}")
def replace_children(expression, fun):
"""
Replace children of an expression with the result of a lambda fun(child) -> exp.

View file

@ -46,18 +46,12 @@ class Generator:
"""
TRANSFORMS = {
exp.AnonymousProperty: lambda self, e: self.property_sql(e),
exp.AutoIncrementProperty: lambda self, e: f"AUTO_INCREMENT={self.sql(e, 'value')}",
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
exp.CollateProperty: lambda self, e: f"COLLATE={self.sql(e, 'value')}",
exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.EngineProperty: lambda self, e: f"ENGINE={self.sql(e, 'value')}",
exp.FileFormatProperty: lambda self, e: f"FORMAT={self.sql(e, 'value')}",
exp.LocationProperty: lambda self, e: f"LOCATION {self.sql(e, 'value')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY={self.sql(e.args['value'])}",
exp.SchemaCommentProperty: lambda self, e: f"COMMENT={self.sql(e, 'value')}",
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT={self.sql(e, 'value')}",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
}
@ -72,19 +66,17 @@ class Generator:
STRUCT_DELIMITER = ("<", ">")
ROOT_PROPERTIES = [
exp.AutoIncrementProperty,
exp.CharacterSetProperty,
exp.CollateProperty,
exp.EngineProperty,
exp.SchemaCommentProperty,
]
WITH_PROPERTIES = [
ROOT_PROPERTIES = {
exp.ReturnsProperty,
exp.LanguageProperty,
}
WITH_PROPERTIES = {
exp.AnonymousProperty,
exp.FileFormatProperty,
exp.PartitionedByProperty,
exp.TableFormatProperty,
]
}
__slots__ = (
"time_mapping",
@ -188,6 +180,7 @@ class Generator:
return sql
def unsupported(self, message):
if self.unsupported_level == ErrorLevel.IMMEDIATE:
raise UnsupportedError(message)
self.unsupported_messages.append(message)
@ -261,6 +254,9 @@ class Generator:
if isinstance(expression, exp.Func):
return self.function_fallback_sql(expression)
if isinstance(expression, exp.Property):
return self.property_sql(expression)
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
def annotation_sql(self, expression):
@ -352,9 +348,12 @@ class Generator:
replace = " OR REPLACE" if expression.args.get("replace") else ""
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
properties = self.sql(expression, "properties")
expression_sql = f"CREATE{replace}{temporary}{unique} {kind}{exists_sql} {this}{properties} {expression_sql}"
expression_sql = (
f"CREATE{replace}{temporary}{unique}{materialized} {kind}{exists_sql} {this}{properties} {expression_sql}"
)
return self.prepend_ctes(expression, expression_sql)
def prepend_ctes(self, expression, sql):
@ -461,10 +460,10 @@ class Generator:
for p in expression.expressions:
p_class = p.__class__
if p_class in self.ROOT_PROPERTIES:
root_properties.append(p)
elif p_class in self.WITH_PROPERTIES:
if p_class in self.WITH_PROPERTIES:
with_properties.append(p)
elif p_class in self.ROOT_PROPERTIES:
root_properties.append(p)
return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties(
exp.Properties(expressions=with_properties)
@ -496,9 +495,12 @@ class Generator:
)
def property_sql(self, expression):
key = expression.name
if isinstance(expression.this, exp.Literal):
key = expression.this.this
else:
key = expression.name
value = self.sql(expression, "value")
return f"{key} = {value}"
return f"{key}={value}"
def insert_sql(self, expression):
kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO"
@ -535,7 +537,8 @@ class Generator:
laterals = self.expressions(expression, key="laterals", sep="")
joins = self.expressions(expression, key="joins", sep="")
return f"{table}{laterals}{joins}"
pivots = self.expressions(expression, key="pivots", sep="")
return f"{table}{laterals}{joins}{pivots}"
def tablesample_sql(self, expression):
if self.alias_post_tablesample and isinstance(expression.this, exp.Alias):
@ -556,7 +559,17 @@ class Generator:
rows = self.sql(expression, "rows")
rows = f"{rows} ROWS" if rows else ""
size = self.sql(expression, "size")
return f"{this} TABLESAMPLE{method}({bucket}{percent}{rows}{size}){alias}"
seed = self.sql(expression, "seed")
seed = f" SEED ({seed})" if seed else ""
return f"{this} TABLESAMPLE{method}({bucket}{percent}{rows}{size}){seed}{alias}"
def pivot_sql(self, expression):
this = self.sql(expression, "this")
unpivot = expression.args.get("unpivot")
direction = "UNPIVOT" if unpivot else "PIVOT"
expressions = self.expressions(expression, key="expressions")
field = self.sql(expression, "field")
return f"{this} {direction}({expressions} FOR {field})"
def tuple_sql(self, expression):
return f"({self.expressions(expression, flat=True)})"
@ -681,6 +694,7 @@ class Generator:
def ordered_sql(self, expression):
desc = expression.args.get("desc")
asc = not desc
nulls_first = expression.args.get("nulls_first")
nulls_last = not nulls_first
nulls_are_large = self.null_ordering == "nulls_are_large"
@ -760,6 +774,7 @@ class Generator:
return self.query_modifiers(
expression,
self.wrap(expression),
self.expressions(expression, key="pivots", sep=" "),
f" AS {alias}" if alias else "",
)
@ -1129,6 +1144,9 @@ class Generator:
return f"{op} {expressions_sql}"
return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}"
def naked_property(self, expression):
return f"{expression.name} {self.sql(expression, 'value')}"
def set_operation(self, expression, op):
this = self.sql(expression, "this")
op = self.seg(op)
@ -1136,3 +1154,13 @@ class Generator:
def token_sql(self, token_type):
return self.TOKEN_MAPPING.get(token_type, token_type.name)
def userdefinedfunction_sql(self, expression):
this = self.sql(expression, "this")
expressions = self.no_identify(lambda: self.expressions(expression))
return f"{this}({expressions})"
def userdefinedfunctionkwarg_sql(self, expression):
this = self.sql(expression, "this")
kind = self.sql(expression, "kind")
return f"{this} {kind}"

View file

@ -1,5 +1,7 @@
import inspect
import logging
import re
import sys
from contextlib import contextmanager
from enum import Enum
@ -29,6 +31,26 @@ def csv(*args, sep=", "):
return sep.join(arg for arg in args if arg)
def subclasses(module_name, classes, exclude=()):
"""
Returns a list of all subclasses for a specified class set, posibly excluding some of them.
Args:
module_name (str): The name of the module to search for subclasses in.
classes (type|tuple[type]): Class(es) we want to find the subclasses of.
exclude (type|tuple[type]): Class(es) we want to exclude from the returned list.
Returns:
A list of all the target subclasses.
"""
return [
obj
for _, obj in inspect.getmembers(
sys.modules[module_name],
lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
)
]
def apply_index_offset(expressions, offset):
if not offset or len(expressions) != 1:
return expressions
@ -100,7 +122,7 @@ def csv_reader(table):
Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...])
Args:
expression (Expression): An anonymous function READ_CSV
table (exp.Table): A table expression with an anonymous function READ_CSV in it
Returns:
A python csv reader.
@ -121,3 +143,22 @@ def csv_reader(table):
yield csv_.reader(file, delimiter=delimiter)
finally:
file.close()
def find_new_name(taken, base):
"""
Searches for a new name.
Args:
taken (Sequence[str]): set of taken names
base (str): base name to alter
"""
if base not in taken:
return base
i = 2
new = f"{base}_{i}"
while new in taken:
i += 1
new = f"{base}_{i}"
return new

View file

@ -0,0 +1,162 @@
from sqlglot import exp
from sqlglot.helper import ensure_list, subclasses
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
"""
Recursively infer & annotate types in an expression syntax tree against a schema.
(TODO -- replace this with a better example after adding some functionality)
Example:
>>> import sqlglot
>>> annotated_expression = annotate_types(sqlglot.parse_one('5 + 5.3'))
>>> annotated_expression.type
<Type.DOUBLE: 'DOUBLE'>
Args:
expression (sqlglot.Expression): Expression to annotate.
schema (dict|sqlglot.optimizer.Schema): Database schema.
annotators (dict): Maps expression type to corresponding annotation function.
coerces_to (dict): Maps expression type to set of types that it can be coerced into.
Returns:
sqlglot.Expression: expression annotated with types
"""
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
class TypeAnnotator:
ANNOTATORS = {
**{
expr_type: lambda self, expr: self._annotate_unary(expr)
for expr_type in subclasses(exp.__name__, exp.Unary)
},
**{
expr_type: lambda self, expr: self._annotate_binary(expr)
for expr_type in subclasses(exp.__name__, exp.Binary)
},
exp.Cast: lambda self, expr: self._annotate_cast(expr),
exp.DataType: lambda self, expr: self._annotate_data_type(expr),
exp.Literal: lambda self, expr: self._annotate_literal(expr),
exp.Boolean: lambda self, expr: self._annotate_boolean(expr),
}
# Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
COERCES_TO = {
# CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
exp.DataType.Type.TEXT: set(),
exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
exp.DataType.Type.NCHAR: {exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
exp.DataType.Type.CHAR: {
exp.DataType.Type.NCHAR,
exp.DataType.Type.VARCHAR,
exp.DataType.Type.NVARCHAR,
exp.DataType.Type.TEXT,
},
# TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
exp.DataType.Type.DOUBLE: set(),
exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
exp.DataType.Type.BIGINT: {exp.DataType.Type.DECIMAL, exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
exp.DataType.Type.INT: {
exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL,
exp.DataType.Type.FLOAT,
exp.DataType.Type.DOUBLE,
},
exp.DataType.Type.SMALLINT: {
exp.DataType.Type.INT,
exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL,
exp.DataType.Type.FLOAT,
exp.DataType.Type.DOUBLE,
},
exp.DataType.Type.TINYINT: {
exp.DataType.Type.SMALLINT,
exp.DataType.Type.INT,
exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL,
exp.DataType.Type.FLOAT,
exp.DataType.Type.DOUBLE,
},
# DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
exp.DataType.Type.TIMESTAMPLTZ: set(),
exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
exp.DataType.Type.TIMESTAMP: {exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPLTZ},
exp.DataType.Type.DATETIME: {
exp.DataType.Type.TIMESTAMP,
exp.DataType.Type.TIMESTAMPTZ,
exp.DataType.Type.TIMESTAMPLTZ,
},
exp.DataType.Type.DATE: {
exp.DataType.Type.DATETIME,
exp.DataType.Type.TIMESTAMP,
exp.DataType.Type.TIMESTAMPTZ,
exp.DataType.Type.TIMESTAMPLTZ,
},
}
def __init__(self, schema=None, annotators=None, coerces_to=None):
self.schema = schema
self.annotators = annotators or self.ANNOTATORS
self.coerces_to = coerces_to or self.COERCES_TO
def annotate(self, expression):
if not isinstance(expression, exp.Expression):
return None
annotator = self.annotators.get(expression.__class__)
return annotator(self, expression) if annotator else self._annotate_args(expression)
def _annotate_args(self, expression):
for value in expression.args.values():
for v in ensure_list(value):
self.annotate(v)
return expression
def _annotate_cast(self, expression):
expression.type = expression.args["to"].this
return self._annotate_args(expression)
def _annotate_data_type(self, expression):
expression.type = expression.this
return self._annotate_args(expression)
def _maybe_coerce(self, type1, type2):
return type2 if type2 in self.coerces_to[type1] else type1
def _annotate_binary(self, expression):
self._annotate_args(expression)
if isinstance(expression, (exp.Condition, exp.Predicate)):
expression.type = exp.DataType.Type.BOOLEAN
else:
expression.type = self._maybe_coerce(expression.left.type, expression.right.type)
return expression
def _annotate_unary(self, expression):
self._annotate_args(expression)
if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
expression.type = exp.DataType.Type.BOOLEAN
else:
expression.type = expression.this.type
return expression
def _annotate_literal(self, expression):
if expression.is_string:
expression.type = exp.DataType.Type.VARCHAR
elif expression.is_int:
expression.type = exp.DataType.Type.INT
else:
expression.type = exp.DataType.Type.DOUBLE
return expression
def _annotate_boolean(self, expression):
expression.type = exp.DataType.Type.BOOLEAN
return expression

View file

@ -1,48 +1,144 @@
import itertools
from sqlglot import alias, exp, select, table
from sqlglot.optimizer.scope import traverse_scope
from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import build_scope
from sqlglot.optimizer.simplify import simplify
def eliminate_subqueries(expression):
"""
Rewrite duplicate subqueries from sqlglot AST.
Rewrite subqueries as CTES, deduplicating if possible.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT 1 AS x, 2 AS y UNION ALL SELECT 1 AS x, 2 AS y")
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
>>> eliminate_subqueries(expression).sql()
'WITH _e_0 AS (SELECT 1 AS x, 2 AS y) SELECT * FROM _e_0 UNION ALL SELECT * FROM _e_0'
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
This also deduplicates common subqueries:
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z")
>>> eliminate_subqueries(expression).sql()
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z'
Args:
expression (sqlglot.Expression): expression to qualify
schema (dict|sqlglot.optimizer.Schema): Database schema
expression (sqlglot.Expression): expression
Returns:
sqlglot.Expression: qualified expression
sqlglot.Expression: expression
"""
if isinstance(expression, exp.Subquery):
# It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
eliminate_subqueries(expression.this)
return expression
expression = simplify(expression)
queries = {}
root = build_scope(expression)
for scope in traverse_scope(expression):
query = scope.expression
queries[query] = queries.get(query, []) + [query]
# Map of alias->Scope|Table
# These are all aliases that are already used in the expression.
# We don't want to create new CTEs that conflict with these names.
taken = {}
sequence = itertools.count()
# All CTE aliases in the root scope are taken
for scope in root.cte_scopes:
taken[scope.expression.parent.alias] = scope
for query, duplicates in queries.items():
if len(duplicates) == 1:
continue
# All table names are taken
for scope in root.traverse():
taken.update({source.name: source for _, source in scope.sources.items() if isinstance(source, exp.Table)})
alias_ = f"_e_{next(sequence)}"
# Map of Expression->alias
# Existing CTES in the root expression. We'll use this for deduplication.
existing_ctes = {}
for dup in duplicates:
parent = dup.parent
if isinstance(parent, exp.Subquery):
parent.replace(alias(table(alias_), parent.alias_or_name, table=True))
elif isinstance(parent, exp.Union):
dup.replace(select("*").from_(alias_))
with_ = root.expression.args.get("with")
if with_:
for cte in with_.expressions:
existing_ctes[cte.this] = cte.alias
new_ctes = []
expression.with_(alias_, as_=query, copy=False)
# We're adding more CTEs, but we want to maintain the DAG order.
# Derived tables within an existing CTE need to come before the existing CTE.
for cte_scope in root.cte_scopes:
# Append all the new CTEs from this existing CTE
for scope in cte_scope.traverse():
new_cte = _eliminate(scope, existing_ctes, taken)
if new_cte:
new_ctes.append(new_cte)
# Append the existing CTE itself
new_ctes.append(cte_scope.expression.parent)
# Now append the rest
for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.derived_table_scopes):
for child_scope in scope.traverse():
new_cte = _eliminate(child_scope, existing_ctes, taken)
if new_cte:
new_ctes.append(new_cte)
if new_ctes:
expression.set("with", exp.With(expressions=new_ctes))
return expression
def _eliminate(scope, existing_ctes, taken):
if scope.is_union:
return _eliminate_union(scope, existing_ctes, taken)
if scope.is_derived_table and not isinstance(scope.expression, (exp.Unnest, exp.Lateral)):
return _eliminate_derived_table(scope, existing_ctes, taken)
def _eliminate_union(scope, existing_ctes, taken):
duplicate_cte_alias = existing_ctes.get(scope.expression)
alias = duplicate_cte_alias or find_new_name(taken=taken, base="cte")
taken[alias] = scope
# Try to maintain the selections
expressions = scope.expression.args.get("expressions")
selects = [
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name)
for e in expressions
if e.alias_or_name
]
# If not all selections have an alias, just select *
if len(selects) != len(expressions):
selects = ["*"]
scope.expression.replace(exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias)))
if not duplicate_cte_alias:
existing_ctes[scope.expression] = alias
return exp.CTE(
this=scope.expression,
alias=exp.TableAlias(this=exp.to_identifier(alias)),
)
def _eliminate_derived_table(scope, existing_ctes, taken):
duplicate_cte_alias = existing_ctes.get(scope.expression)
parent = scope.expression.parent
name = alias = parent.alias
if not alias:
name = alias = find_new_name(taken=taken, base="cte")
if duplicate_cte_alias:
name = duplicate_cte_alias
elif taken.get(alias):
name = find_new_name(taken=taken, base=alias)
taken[name] = scope
table = exp.alias_(exp.table_(name), alias=alias)
parent.replace(table)
if not duplicate_cte_alias:
existing_ctes[scope.expression] = name
return exp.CTE(
this=scope.expression,
alias=exp.TableAlias(this=exp.to_identifier(name)),
)

View file

@ -1,45 +1,39 @@
from collections import defaultdict
from sqlglot import expressions as exp
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.optimizer.simplify import simplify
def merge_derived_tables(expression):
def merge_subqueries(expression, leave_tables_isolated=False):
"""
Rewrite sqlglot AST to merge derived tables into the outer query.
This also merges CTEs if they are selected from only once.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x)")
>>> merge_derived_tables(expression).sql()
'SELECT x.a FROM x'
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
>>> merge_subqueries(expression).sql()
'SELECT x.a FROM x JOIN y'
If `leave_tables_isolated` is True, this will not merge inner queries into outer
queries if it would result in multiple table selects in a single query:
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
>>> merge_subqueries(expression, leave_tables_isolated=True).sql()
'SELECT a FROM (SELECT x.a FROM x) JOIN y'
Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
Args:
expression (sqlglot.Expression): expression to optimize
leave_tables_isolated (bool):
Returns:
sqlglot.Expression: optimized expression
"""
for outer_scope in traverse_scope(expression):
for subquery in outer_scope.derived_tables:
inner_select = subquery.unnest()
if (
isinstance(outer_scope.expression, exp.Select)
and isinstance(inner_select, exp.Select)
and _mergeable(inner_select)
):
alias = subquery.alias_or_name
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
inner_scope = outer_scope.sources[alias]
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, subquery)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
merge_ctes(expression, leave_tables_isolated)
merge_derived_tables(expression, leave_tables_isolated)
return expression
@ -53,20 +47,81 @@ UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
}
def _mergeable(inner_select):
def merge_ctes(expression, leave_tables_isolated=False):
scopes = traverse_scope(expression)
# All places where we select from CTEs.
# We key on the CTE scope so we can detect CTES that are selected from multiple times.
cte_selections = defaultdict(list)
for outer_scope in scopes:
for table, inner_scope in outer_scope.selected_sources.values():
if isinstance(inner_scope, Scope) and inner_scope.is_cte:
cte_selections[id(inner_scope)].append(
(
outer_scope,
inner_scope,
table,
)
)
singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
for outer_scope, inner_scope, table in singular_cte_selections:
inner_select = inner_scope.expression.unnest()
if _mergeable(outer_scope, inner_select, leave_tables_isolated):
from_or_join = table.find_ancestor(exp.From, exp.Join)
node_to_replace = table
if isinstance(node_to_replace.parent, exp.Alias):
node_to_replace = node_to_replace.parent
alias = node_to_replace.alias
else:
alias = table.name
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, node_to_replace, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
_pop_cte(inner_scope)
def merge_derived_tables(expression, leave_tables_isolated=False):
for outer_scope in traverse_scope(expression):
for subquery in outer_scope.derived_tables:
inner_select = subquery.unnest()
if _mergeable(outer_scope, inner_select, leave_tables_isolated):
alias = subquery.alias_or_name
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
inner_scope = outer_scope.sources[alias]
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, subquery, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
def _mergeable(outer_scope, inner_select, leave_tables_isolated):
"""
Return True if `inner_select` can be merged into outer query.
Args:
outer_scope (Scope)
inner_select (exp.Select)
leave_tables_isolated (bool)
Returns:
bool: True if can be merged
"""
return (
isinstance(inner_select, exp.Select)
isinstance(outer_scope.expression, exp.Select)
and isinstance(inner_select, exp.Select)
and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
)
@ -84,7 +139,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
conflicts = conflicts - {alias}
for conflict in conflicts:
new_name = _find_new_name(taken, conflict)
new_name = find_new_name(taken, conflict)
source, _ = inner_scope.selected_sources[conflict]
new_alias = exp.to_identifier(new_name)
@ -102,34 +157,19 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
inner_scope.rename_source(conflict, new_name)
def _find_new_name(taken, base):
"""
Searches for a new source name.
Args:
taken (set[str]): set of taken names
base (str): base name to alter
"""
i = 2
new = f"{base}_{i}"
while new in taken:
i += 1
new = f"{base}_{i}"
return new
def _merge_from(outer_scope, inner_scope, subquery):
def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
"""
Merge FROM clause of inner query into outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
subquery (exp.Subquery)
node_to_replace (exp.Subquery|exp.Table)
alias (str)
"""
new_subquery = inner_scope.expression.args.get("from").expressions[0]
subquery.replace(new_subquery)
outer_scope.remove_source(subquery.alias_or_name)
node_to_replace.replace(new_subquery)
outer_scope.remove_source(alias)
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
@ -176,7 +216,7 @@ def _merge_expressions(outer_scope, inner_scope, alias):
inner_scope (sqlglot.optimizer.scope.Scope)
alias (str)
"""
# Collect all columns that for the alias of the inner query
# Collect all columns that reference the alias of the inner query
outer_columns = defaultdict(list)
for column in outer_scope.columns:
if column.table == alias:
@ -205,7 +245,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
if not where or not where.this:
return
if isinstance(from_or_join, exp.Join) and from_or_join.side:
if isinstance(from_or_join, exp.Join):
# Merge predicates from an outer join to the ON clause
from_or_join.on(where.this, copy=False)
from_or_join.set("on", simplify(from_or_join.args.get("on")))
@ -230,3 +270,18 @@ def _merge_order(outer_scope, inner_scope):
return
outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
def _pop_cte(inner_scope):
"""
Remove CTE from the AST.
Args:
inner_scope (sqlglot.optimizer.scope.Scope)
"""
cte = inner_scope.expression.parent
with_ = cte.parent
if len(with_.expressions) == 1:
with_.pop()
else:
cte.pop()

View file

@ -1,7 +1,7 @@
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.merge_derived_tables import merge_derived_tables
from sqlglot.optimizer.merge_subqueries import merge_subqueries
from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
@ -22,7 +22,7 @@ RULES = (
pushdown_predicates,
optimize_joins,
eliminate_subqueries,
merge_derived_tables,
merge_subqueries,
quote_identities,
)

View file

@ -37,7 +37,7 @@ def pushdown_projections(expression):
parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.Union):
left, right = scope.union
left, right = scope.union_scopes
referenced_columns[left] = parent_selections
referenced_columns[right] = parent_selections

View file

@ -69,7 +69,7 @@ def ensure_schema(schema):
def fs_get(table):
name = table.this.name.upper()
name = table.this.name
if name.upper() == "READ_CSV":
with csv_reader(table) as reader:

View file

@ -1,3 +1,4 @@
import itertools
from copy import copy
from enum import Enum, auto
@ -32,10 +33,11 @@ class Scope:
The inner query would have `["col1", "col2"]` for its `outer_column_list`
parent (Scope): Parent scope
scope_type (ScopeType): Type of this scope, relative to it's parent
subquery_scopes (list[Scope]): List of all child scopes for subqueries.
This does not include derived tables or CTEs.
union (tuple[Scope, Scope]): If this Scope is for a Union expression, this will be
a tuple of the left and right child scopes.
subquery_scopes (list[Scope]): List of all child scopes for subqueries
cte_scopes = (list[Scope]) List of all child scopes for CTEs
derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables
union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
a list of the left and right child scopes.
"""
def __init__(
@ -52,7 +54,9 @@ class Scope:
self.parent = parent
self.scope_type = scope_type
self.subquery_scopes = []
self.union = None
self.derived_table_scopes = []
self.cte_scopes = []
self.union_scopes = []
self.clear_cache()
def clear_cache(self):
@ -197,11 +201,16 @@ class Scope:
named_outputs = {e.alias_or_name for e in self.expression.expressions}
self._columns = [
c
for c in columns + external_columns
if not (c.find_ancestor(exp.Qualify, exp.Order) and not c.table and c.name in named_outputs)
]
self._columns = []
for column in columns + external_columns:
ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Hint)
if (
not ancestor
or column.table
or (column.name not in named_outputs and not isinstance(ancestor, exp.Hint))
):
self._columns.append(column)
return self._columns
@property
@ -283,6 +292,26 @@ class Scope:
"""Determine if this scope is a subquery"""
return self.scope_type == ScopeType.SUBQUERY
@property
def is_derived_table(self):
"""Determine if this scope is a derived table"""
return self.scope_type == ScopeType.DERIVED_TABLE
@property
def is_union(self):
"""Determine if this scope is a union"""
return self.scope_type == ScopeType.UNION
@property
def is_cte(self):
"""Determine if this scope is a common table expression"""
return self.scope_type == ScopeType.CTE
@property
def is_root(self):
"""Determine if this is the root scope"""
return self.scope_type == ScopeType.ROOT
@property
def is_unnest(self):
"""Determine if this scope is an unnest"""
@ -308,6 +337,22 @@ class Scope:
self.sources.pop(name, None)
self.clear_cache()
def __repr__(self):
return f"Scope<{self.expression.sql()}>"
def traverse(self):
"""
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
"""
for child_scope in itertools.chain(
self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes
):
yield from child_scope.traverse()
yield self
def traverse_scope(expression):
"""
@ -337,6 +382,18 @@ def traverse_scope(expression):
return list(_traverse_scope(Scope(expression)))
def build_scope(expression):
"""
Build a scope tree.
Args:
expression (exp.Expression): expression to build the scope tree for
Returns:
Scope: root scope
"""
return traverse_scope(expression)[-1]
def _traverse_scope(scope):
if isinstance(scope.expression, exp.Select):
yield from _traverse_select(scope)
@ -370,13 +427,14 @@ def _traverse_union(scope):
for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
yield right
scope.union = (left, right)
scope.union_scopes = [left, right]
def _traverse_derived_tables(derived_tables, scope, scope_type):
sources = {}
for derived_table in derived_tables:
top = None
for child_scope in _traverse_scope(
scope.branch(
derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
@ -386,11 +444,16 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
)
):
yield child_scope
top = child_scope
# Tables without aliases will be set as ""
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
sources[derived_table.alias] = child_scope
if scope_type == ScopeType.CTE:
scope.cte_scopes.append(top)
else:
scope.derived_table_scopes.append(top)
scope.sources.update(sources)
@ -407,8 +470,6 @@ def _add_table_sources(scope):
if table_name in scope.sources:
# This is a reference to a parent source (e.g. a CTE), not an actual table.
scope.sources[source_name] = scope.sources[table_name]
elif source_name in scope.sources:
raise OptimizeError(f"Duplicate table name: {source_name}")
else:
sources[source_name] = table

View file

@ -99,7 +99,8 @@ class Parser:
TokenType.SMALLMONEY,
TokenType.ROWVERSION,
TokenType.IMAGE,
TokenType.SQL_VARIANT,
TokenType.VARIANT,
TokenType.OBJECT,
*NESTED_TYPE_TOKENS,
}
@ -131,7 +132,6 @@ class Parser:
TokenType.FALSE,
TokenType.FIRST,
TokenType.FOLLOWING,
TokenType.FOR,
TokenType.FORMAT,
TokenType.FUNCTION,
TokenType.GENERATED,
@ -141,20 +141,26 @@ class Parser:
TokenType.ISNULL,
TokenType.INTERVAL,
TokenType.LAZY,
TokenType.LANGUAGE,
TokenType.LEADING,
TokenType.LOCATION,
TokenType.MATERIALIZED,
TokenType.NATURAL,
TokenType.NEXT,
TokenType.ONLY,
TokenType.OPTIMIZE,
TokenType.OPTIONS,
TokenType.ORDINALITY,
TokenType.PARTITIONED_BY,
TokenType.PERCENT,
TokenType.PIVOT,
TokenType.PRECEDING,
TokenType.RANGE,
TokenType.REFERENCES,
TokenType.RETURNS,
TokenType.ROWS,
TokenType.SCHEMA_COMMENT,
TokenType.SEED,
TokenType.SET,
TokenType.SHOW,
TokenType.STORED,
@ -167,6 +173,7 @@ class Parser:
TokenType.TRUE,
TokenType.UNBOUNDED,
TokenType.UNIQUE,
TokenType.UNPIVOT,
TokenType.PROPERTIES,
*SUBQUERY_PREDICATES,
*TYPE_TOKENS,
@ -303,6 +310,8 @@ class Parser:
exp.Condition: lambda self: self._parse_conjunction(),
exp.Expression: lambda self: self._parse_statement(),
exp.Properties: lambda self: self._parse_properties(),
exp.Where: lambda self: self._parse_where(),
exp.Ordered: lambda self: self._parse_ordered(),
"JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
}
@ -355,23 +364,21 @@ class Parser:
PROPERTY_PARSERS = {
TokenType.AUTO_INCREMENT: lambda self: self._parse_auto_increment(),
TokenType.CHARACTER_SET: lambda self: self._parse_character_set(),
TokenType.COLLATE: lambda self: self._parse_collate(),
TokenType.ENGINE: lambda self: self._parse_engine(),
TokenType.FORMAT: lambda self: self._parse_format(),
TokenType.LOCATION: lambda self: self.expression(
exp.LocationProperty,
this=exp.Literal.string("LOCATION"),
value=self._parse_string(),
),
TokenType.PARTITIONED_BY: lambda self: self.expression(
exp.PartitionedByProperty,
this=exp.Literal.string("PARTITIONED_BY"),
value=self._parse_schema(),
),
TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(),
TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(),
TokenType.STORED: lambda self: self._parse_stored(),
TokenType.TABLE_FORMAT: lambda self: self._parse_table_format(),
TokenType.USING: lambda self: self._parse_table_format(),
TokenType.RETURNS: lambda self: self._parse_returns(),
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
}
CONSTRAINT_PARSERS = {
@ -388,6 +395,7 @@ class Parser:
FUNCTION_PARSERS = {
"CONVERT": lambda self: self._parse_convert(),
"EXTRACT": lambda self: self._parse_extract(),
"POSITION": lambda self: self._parse_position(),
"SUBSTRING": lambda self: self._parse_substring(),
"TRIM": lambda self: self._parse_trim(),
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
@ -628,6 +636,10 @@ class Parser:
replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
temporary = self._match(TokenType.TEMPORARY)
unique = self._match(TokenType.UNIQUE)
materialized = self._match(TokenType.MATERIALIZED)
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
self._match(TokenType.TABLE)
create_token = self._match_set(self.CREATABLES) and self._prev
@ -640,14 +652,15 @@ class Parser:
properties = None
if create_token.token_type == TokenType.FUNCTION:
this = self._parse_var()
this = self._parse_user_defined_function()
properties = self._parse_properties()
if self._match(TokenType.ALIAS):
expression = self._parse_string()
expression = self._parse_select_or_expression()
elif create_token.token_type == TokenType.INDEX:
this = self._parse_index()
elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW):
this = self._parse_table(schema=True)
properties = self._parse_properties(this if isinstance(this, exp.Schema) else None)
properties = self._parse_properties()
if self._match(TokenType.ALIAS):
expression = self._parse_select(nested=True)
@ -661,9 +674,10 @@ class Parser:
temporary=temporary,
replace=replace,
unique=unique,
materialized=materialized,
)
def _parse_property(self, schema):
def _parse_property(self):
if self._match_set(self.PROPERTY_PARSERS):
return self.PROPERTY_PARSERS[self._prev.token_type](self)
if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET):
@ -673,31 +687,27 @@ class Parser:
key = self._parse_var().this
self._match(TokenType.EQ)
if key.upper() == "PARTITIONED_BY":
expression = exp.PartitionedByProperty
value = self._parse_schema() or self._parse_bracket(self._parse_field())
if schema and not isinstance(value, exp.Schema):
columns = {v.name.upper() for v in value.expressions}
partitions = [
expression for expression in schema.expressions if expression.this.name.upper() in columns
]
schema.set(
"expressions",
[e for e in schema.expressions if e not in partitions],
)
value = self.expression(exp.Schema, expressions=partitions)
else:
value = self._parse_column()
expression = exp.AnonymousProperty
return self.expression(
expression,
exp.AnonymousProperty,
this=exp.Literal.string(key),
value=value,
value=self._parse_column(),
)
return None
def _parse_property_assignment(self, exp_class):
prop = self._prev.text
self._match(TokenType.EQ)
return self.expression(exp_class, this=prop, value=self._parse_var_or_string())
def _parse_partitioned_by(self):
self._match(TokenType.EQ)
return self.expression(
exp.PartitionedByProperty,
this=exp.Literal.string("PARTITIONED_BY"),
value=self._parse_schema() or self._parse_bracket(self._parse_field()),
)
def _parse_stored(self):
self._match(TokenType.ALIAS)
self._match(TokenType.EQ)
@ -707,22 +717,6 @@ class Parser:
value=exp.Literal.string(self._parse_var().name),
)
def _parse_format(self):
self._match(TokenType.EQ)
return self.expression(
exp.FileFormatProperty,
this=exp.Literal.string("FORMAT"),
value=self._parse_string() or self._parse_var(),
)
def _parse_engine(self):
self._match(TokenType.EQ)
return self.expression(
exp.EngineProperty,
this=exp.Literal.string("ENGINE"),
value=self._parse_var_or_string(),
)
def _parse_auto_increment(self):
self._match(TokenType.EQ)
return self.expression(
@ -731,14 +725,6 @@ class Parser:
value=self._parse_var() or self._parse_number(),
)
def _parse_collate(self):
self._match(TokenType.EQ)
return self.expression(
exp.CollateProperty,
this=exp.Literal.string("COLLATE"),
value=self._parse_var_or_string(),
)
def _parse_schema_comment(self):
self._match(TokenType.EQ)
return self.expression(
@ -756,26 +742,34 @@ class Parser:
default=default,
)
def _parse_table_format(self):
self._match(TokenType.EQ)
def _parse_returns(self):
is_table = self._match(TokenType.TABLE)
if is_table:
if self._match(TokenType.LT):
value = self.expression(
exp.Schema, this="TABLE", expressions=self._parse_csv(self._parse_struct_kwargs)
)
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
else:
value = self._parse_schema("TABLE")
else:
value = self._parse_types()
return self.expression(
exp.TableFormatProperty,
this=exp.Literal.string("TABLE_FORMAT"),
value=self._parse_var_or_string(),
exp.ReturnsProperty,
this=exp.Literal.string("RETURNS"),
value=value,
is_table=is_table,
)
def _parse_properties(self, schema=None):
"""
Schema is included since if the table schema is defined and we later get a partition by expression
then we will define those columns in the partition by section and not in with the rest of the
columns
"""
def _parse_properties(self):
properties = []
while True:
if self._match(TokenType.WITH):
self._match_l_paren()
properties.extend(self._parse_csv(lambda: self._parse_property(schema)))
properties.extend(self._parse_csv(lambda: self._parse_property()))
self._match_r_paren()
elif self._match(TokenType.PROPERTIES):
self._match_l_paren()
@ -790,7 +784,7 @@ class Parser:
)
self._match_r_paren()
else:
identified_property = self._parse_property(schema)
identified_property = self._parse_property()
if not identified_property:
break
properties.append(identified_property)
@ -1003,7 +997,7 @@ class Parser:
)
def _parse_subquery(self, this):
return self.expression(exp.Subquery, this=this, alias=self._parse_table_alias())
return self.expression(exp.Subquery, this=this, pivots=self._parse_pivots(), alias=self._parse_table_alias())
def _parse_query_modifiers(self, this):
if not isinstance(this, self.MODIFIABLES):
@ -1134,14 +1128,18 @@ class Parser:
table = (not schema and self._parse_function()) or self._parse_id_var(False)
while self._match(TokenType.DOT):
catalog = db
db = table
table = self._parse_id_var()
if catalog:
# This allows nesting the table in arbitrarily many dot expressions if needed
table = self.expression(exp.Dot, this=table, expression=self._parse_id_var())
else:
catalog = db
db = table
table = self._parse_id_var()
if not table:
self.raise_error("Expected table name")
this = self.expression(exp.Table, this=table, db=db, catalog=catalog)
this = self.expression(exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots())
if schema:
return self._parse_schema(this=this)
@ -1199,6 +1197,7 @@ class Parser:
percent = None
rows = None
size = None
seed = None
self._match_l_paren()
@ -1220,6 +1219,11 @@ class Parser:
self._match_r_paren()
if self._match(TokenType.SEED):
self._match_l_paren()
seed = self._parse_number()
self._match_r_paren()
return self.expression(
exp.TableSample,
method=method,
@ -1229,6 +1233,51 @@ class Parser:
percent=percent,
rows=rows,
size=size,
seed=seed,
)
def _parse_pivots(self):
return list(iter(self._parse_pivot, None))
def _parse_pivot(self):
index = self._index
if self._match(TokenType.PIVOT):
unpivot = False
elif self._match(TokenType.UNPIVOT):
unpivot = True
else:
return None
expressions = []
field = None
if not self._match(TokenType.L_PAREN):
self._retreat(index)
return None
if unpivot:
expressions = self._parse_csv(self._parse_column)
else:
expressions = self._parse_csv(lambda: self._parse_alias(self._parse_function()))
if not self._match(TokenType.FOR):
self.raise_error("Expecting FOR")
value = self._parse_column()
if not self._match(TokenType.IN):
self.raise_error("Expecting IN")
field = self._parse_in(value)
self._match_r_paren()
return self.expression(
exp.Pivot,
expressions=expressions,
field=field,
unpivot=unpivot,
)
def _parse_where(self):
@ -1384,7 +1433,7 @@ class Parser:
this = self.expression(exp.In, this=this, unnest=unnest)
else:
self._match_l_paren()
expressions = self._parse_csv(lambda: self._parse_select() or self._parse_expression())
expressions = self._parse_csv(self._parse_select_or_expression)
if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable):
this = self.expression(exp.In, this=this, query=expressions[0])
@ -1577,6 +1626,9 @@ class Parser:
if self._match_set(self.PRIMARY_PARSERS):
return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev)
if self._match_pair(TokenType.DOT, TokenType.NUMBER):
return exp.Literal.number(f"0.{self._prev.text}")
if self._match(TokenType.L_PAREN):
query = self._parse_select()
@ -1647,6 +1699,23 @@ class Parser:
self._match_r_paren()
return self._parse_window(this)
def _parse_user_defined_function(self):
this = self._parse_var()
if not self._match(TokenType.L_PAREN):
return this
expressions = self._parse_csv(self._parse_udf_kwarg)
self._match_r_paren()
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
def _parse_udf_kwarg(self):
this = self._parse_id_var()
kind = self._parse_types()
if not kind:
return this
return self.expression(exp.UserDefinedFunctionKwarg, this=this, kind=kind)
def _parse_lambda(self):
index = self._index
@ -1672,9 +1741,10 @@ class Parser:
return self._parse_alias(self._parse_limit(self._parse_order(this)))
conjunction = self._parse_conjunction().transform(self._replace_lambda, {node.name for node in expressions})
return self.expression(
exp.Lambda,
this=self._parse_conjunction(),
this=conjunction,
expressions=expressions,
)
@ -1896,6 +1966,12 @@ class Parser:
to = None
return self.expression(exp.Cast, this=this, to=to)
def _parse_position(self):
substr = self._parse_bitwise()
if self._match(TokenType.IN):
string = self._parse_bitwise()
return self.expression(exp.StrPosition, this=string, substr=substr)
def _parse_substring(self):
# Postgres supports the form: substring(string [from int] [for int])
# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
@ -2155,6 +2231,9 @@ class Parser:
self._match_r_paren()
return expressions
def _parse_select_or_expression(self):
return self._parse_select() or self._parse_expression()
def _match(self, token_type):
if not self._curr:
return None
@ -2208,3 +2287,9 @@ class Parser:
elif isinstance(this, exp.Identifier):
this = self.expression(exp.Var, this=this.name)
return this
def _replace_lambda(self, node, lambda_variables):
if isinstance(node, exp.Column):
if node.name in lambda_variables:
return node.this
return node

View file

@ -94,7 +94,8 @@ class TokenType(AutoName):
SMALLMONEY = auto()
ROWVERSION = auto()
IMAGE = auto()
SQL_VARIANT = auto()
VARIANT = auto()
OBJECT = auto()
# keywords
ADD_FILE = auto()
@ -177,6 +178,7 @@ class TokenType(AutoName):
IS = auto()
ISNULL = auto()
JOIN = auto()
LANGUAGE = auto()
LATERAL = auto()
LAZY = auto()
LEADING = auto()
@ -185,6 +187,7 @@ class TokenType(AutoName):
LIMIT = auto()
LOCATION = auto()
MAP = auto()
MATERIALIZED = auto()
MOD = auto()
NATURAL = auto()
NEXT = auto()
@ -208,6 +211,7 @@ class TokenType(AutoName):
PARTITION_BY = auto()
PARTITIONED_BY = auto()
PERCENT = auto()
PIVOT = auto()
PLACEHOLDER = auto()
PRECEDING = auto()
PRIMARY_KEY = auto()
@ -219,12 +223,14 @@ class TokenType(AutoName):
REPLACE = auto()
RESPECT_NULLS = auto()
REFERENCES = auto()
RETURNS = auto()
RIGHT = auto()
RLIKE = auto()
ROLLUP = auto()
ROW = auto()
ROWS = auto()
SCHEMA_COMMENT = auto()
SEED = auto()
SELECT = auto()
SEPARATOR = auto()
SET = auto()
@ -246,6 +252,7 @@ class TokenType(AutoName):
UNCACHE = auto()
UNION = auto()
UNNEST = auto()
UNPIVOT = auto()
UPDATE = auto()
USE = auto()
USING = auto()
@ -440,6 +447,7 @@ class Tokenizer(metaclass=_Tokenizer):
"FULL": TokenType.FULL,
"FUNCTION": TokenType.FUNCTION,
"FOLLOWING": TokenType.FOLLOWING,
"FOR": TokenType.FOR,
"FOREIGN KEY": TokenType.FOREIGN_KEY,
"FORMAT": TokenType.FORMAT,
"FROM": TokenType.FROM,
@ -459,6 +467,7 @@ class Tokenizer(metaclass=_Tokenizer):
"IS": TokenType.IS,
"ISNULL": TokenType.ISNULL,
"JOIN": TokenType.JOIN,
"LANGUAGE": TokenType.LANGUAGE,
"LATERAL": TokenType.LATERAL,
"LAZY": TokenType.LAZY,
"LEADING": TokenType.LEADING,
@ -466,6 +475,7 @@ class Tokenizer(metaclass=_Tokenizer):
"LIKE": TokenType.LIKE,
"LIMIT": TokenType.LIMIT,
"LOCATION": TokenType.LOCATION,
"MATERIALIZED": TokenType.MATERIALIZED,
"NATURAL": TokenType.NATURAL,
"NEXT": TokenType.NEXT,
"NO ACTION": TokenType.NO_ACTION,
@ -473,6 +483,7 @@ class Tokenizer(metaclass=_Tokenizer):
"NULL": TokenType.NULL,
"NULLS FIRST": TokenType.NULLS_FIRST,
"NULLS LAST": TokenType.NULLS_LAST,
"OBJECT": TokenType.OBJECT,
"OFFSET": TokenType.OFFSET,
"ON": TokenType.ON,
"ONLY": TokenType.ONLY,
@ -488,7 +499,9 @@ class Tokenizer(metaclass=_Tokenizer):
"PARTITION": TokenType.PARTITION,
"PARTITION BY": TokenType.PARTITION_BY,
"PARTITIONED BY": TokenType.PARTITIONED_BY,
"PARTITIONED_BY": TokenType.PARTITIONED_BY,
"PERCENT": TokenType.PERCENT,
"PIVOT": TokenType.PIVOT,
"PRECEDING": TokenType.PRECEDING,
"PRIMARY KEY": TokenType.PRIMARY_KEY,
"RANGE": TokenType.RANGE,
@ -497,11 +510,13 @@ class Tokenizer(metaclass=_Tokenizer):
"REPLACE": TokenType.REPLACE,
"RESPECT NULLS": TokenType.RESPECT_NULLS,
"REFERENCES": TokenType.REFERENCES,
"RETURNS": TokenType.RETURNS,
"RIGHT": TokenType.RIGHT,
"RLIKE": TokenType.RLIKE,
"ROLLUP": TokenType.ROLLUP,
"ROW": TokenType.ROW,
"ROWS": TokenType.ROWS,
"SEED": TokenType.SEED,
"SELECT": TokenType.SELECT,
"SET": TokenType.SET,
"SHOW": TokenType.SHOW,
@ -520,6 +535,7 @@ class Tokenizer(metaclass=_Tokenizer):
"TRUNCATE": TokenType.TRUNCATE,
"UNBOUNDED": TokenType.UNBOUNDED,
"UNION": TokenType.UNION,
"UNPIVOT": TokenType.UNPIVOT,
"UNNEST": TokenType.UNNEST,
"UPDATE": TokenType.UPDATE,
"USE": TokenType.USE,
@ -577,6 +593,7 @@ class Tokenizer(metaclass=_Tokenizer):
"DATETIME": TokenType.DATETIME,
"UNIQUE": TokenType.UNIQUE,
"STRUCT": TokenType.STRUCT,
"VARIANT": TokenType.VARIANT,
}
WHITE_SPACE = {

View file

@ -12,15 +12,20 @@ def unalias_group(expression):
"""
if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
aliased_selects = {
e.alias: i for i, e in enumerate(expression.parent.expressions, start=1) if isinstance(e, exp.Alias)
e.alias: (i, e.this)
for i, e in enumerate(expression.parent.expressions, start=1)
if isinstance(e, exp.Alias)
}
expression = expression.copy()
for col in expression.find_all(exp.Column):
alias_index = aliased_selects.get(col.name)
if not col.table and alias_index:
col.replace(exp.Literal.number(alias_index))
top_level_expression = None
for item, parent, _ in expression.walk(bfs=False):
top_level_expression = item if isinstance(parent, exp.Group) else top_level_expression
if isinstance(item, exp.Column) and not item.table:
alias_index, col_expression = aliased_selects.get(item.name, (None, None))
if alias_index and top_level_expression != col_expression:
item.replace(exp.Literal.number(alias_index))
return expression

View file

@ -236,3 +236,24 @@ class TestBigQuery(Validator):
"snowflake": "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a NULLS FIRST LIMIT 10",
},
)
self.validate_all(
"SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
write={
"spark": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
"bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb)])",
"snowflake": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
},
)
self.validate_all(
"SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) d, COUNT(*) e FOR c IN ('x', 'y'))",
write={
"bigquery": "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))",
},
)
def test_user_defined_functions(self):
self.validate_identity(
"CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 LANGUAGE js AS 'return x*y;'"
)
self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)")
self.validate_identity("CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t")

View file

@ -13,9 +13,6 @@ from sqlglot import (
class Validator(unittest.TestCase):
dialect = None
def validate(self, sql, target, **kwargs):
self.assertEqual(transpile(sql, **kwargs)[0], target)
def validate_identity(self, sql):
self.assertEqual(transpile(sql, read=self.dialect, write=self.dialect)[0], sql)
@ -258,6 +255,7 @@ class TestDialect(Validator):
"duckdb": "EPOCH(STRPTIME('2020-01-01', '%Y-%M-%d'))",
"hive": "UNIX_TIMESTAMP('2020-01-01', 'yyyy-mm-dd')",
"presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%i-%d'))",
"starrocks": "UNIX_TIMESTAMP('2020-01-01', '%Y-%i-%d')",
},
)
self.validate_all(
@ -266,6 +264,7 @@ class TestDialect(Validator):
"duckdb": "CAST('2020-01-01' AS DATE)",
"hive": "TO_DATE('2020-01-01')",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')",
"starrocks": "TO_DATE('2020-01-01')",
},
)
self.validate_all(
@ -341,6 +340,7 @@ class TestDialect(Validator):
"duckdb": "STRFTIME(TO_TIMESTAMP(CAST(x AS BIGINT)), y)",
"hive": "FROM_UNIXTIME(x, y)",
"presto": "DATE_FORMAT(FROM_UNIXTIME(x), y)",
"starrocks": "FROM_UNIXTIME(x, y)",
},
)
self.validate_all(
@ -349,6 +349,7 @@ class TestDialect(Validator):
"duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))",
"hive": "FROM_UNIXTIME(x)",
"presto": "FROM_UNIXTIME(x)",
"starrocks": "FROM_UNIXTIME(x)",
},
)
self.validate_all(
@ -840,10 +841,20 @@ class TestDialect(Validator):
"starrocks": UnsupportedError,
},
)
self.validate_all(
"POSITION(' ' in x)",
write={
"duckdb": "STRPOS(x, ' ')",
"postgres": "STRPOS(x, ' ')",
"presto": "STRPOS(x, ' ')",
"spark": "LOCATE(' ', x)",
},
)
self.validate_all(
"STR_POSITION(x, 'a')",
write={
"duckdb": "STRPOS(x, 'a')",
"postgres": "STRPOS(x, 'a')",
"presto": "STRPOS(x, 'a')",
"spark": "LOCATE('a', x)",
},

View file

@ -1,3 +1,4 @@
from sqlglot import ErrorLevel, UnsupportedError, transpile
from tests.dialects.test_dialect import Validator
@ -250,3 +251,10 @@ class TestDuckDB(Validator):
"spark": "MONTH('2021-03-01')",
},
)
with self.assertRaises(UnsupportedError):
transpile(
"SELECT a FROM b PIVOT(SUM(x) FOR y IN ('z', 'q'))",
read="duckdb",
unsupported_level=ErrorLevel.IMMEDIATE,
)

View file

@ -127,17 +127,17 @@ class TestHive(Validator):
def test_ddl(self):
self.validate_all(
"CREATE TABLE test STORED AS parquet TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1",
"CREATE TABLE test STORED AS parquet TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1",
write={
"presto": "CREATE TABLE test WITH (FORMAT = 'parquet', x = '1', Z = '2') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1",
"spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1",
"presto": "CREATE TABLE test WITH (FORMAT='parquet', x='1', Z='2') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1",
"spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1",
},
)
self.validate_all(
"CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)",
write={
"presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY = ARRAY['y', 'z'])",
"presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY=ARRAY['y', 'z'])",
"hive": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)",
"spark": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)",
},

View file

@ -119,3 +119,39 @@ class TestMySQL(Validator):
"sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC, '')",
},
)
self.validate_identity(
"CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'"
)
self.validate_identity(
"CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'"
)
self.validate_identity(
"CREATE TABLE z (a INT DEFAULT NULL, PRIMARY KEY(a)) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'"
)
self.validate_all(
"""
CREATE TABLE `t_customer_account` (
"id" int(11) NOT NULL AUTO_INCREMENT,
"customer_id" int(11) DEFAULT NULL COMMENT '客户id',
"bank" varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别',
"account_no" varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号',
PRIMARY KEY ("id")
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='客户账户表'
""",
write={
"mysql": """CREATE TABLE `t_customer_account` (
'id' INT(11) NOT NULL AUTO_INCREMENT,
'customer_id' INT(11) DEFAULT NULL COMMENT '客户id',
'bank' VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别',
'account_no' VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号',
PRIMARY KEY('id')
)
ENGINE=InnoDB
AUTO_INCREMENT=1
DEFAULT CHARACTER SET=utf8
COLLATE=utf8_bin
COMMENT='客户账户表'"""
},
pretty=True,
)

View file

@ -171,7 +171,7 @@ class TestPresto(Validator):
self.validate_all(
"CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1",
write={
"presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1",
"presto": "CREATE TABLE test WITH (FORMAT='PARQUET') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
"spark": "CREATE TABLE test USING PARQUET AS SELECT 1",
},
@ -179,15 +179,15 @@ class TestPresto(Validator):
self.validate_all(
"CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1",
write={
"presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1",
"spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1",
"presto": "CREATE TABLE test WITH (FORMAT='PARQUET', X='1', Z='2') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X'='1', 'Z'='2') AS SELECT 1",
"spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('X'='1', 'Z'='2') AS SELECT 1",
},
)
self.validate_all(
"CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY = ARRAY['y', 'z'])",
"CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY=ARRAY['y', 'z'])",
write={
"presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY = ARRAY['y', 'z'])",
"presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY=ARRAY['y', 'z'])",
"hive": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)",
"spark": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)",
},
@ -195,9 +195,9 @@ class TestPresto(Validator):
self.validate_all(
"CREATE TABLE x WITH (bucket_by = ARRAY['y'], bucket_count = 64) AS SELECT 1 AS y",
write={
"presto": "CREATE TABLE x WITH (bucket_by = ARRAY['y'], bucket_count = 64) AS SELECT 1 AS y",
"hive": "CREATE TABLE x TBLPROPERTIES ('bucket_by' = ARRAY('y'), 'bucket_count' = 64) AS SELECT 1 AS y",
"spark": "CREATE TABLE x TBLPROPERTIES ('bucket_by' = ARRAY('y'), 'bucket_count' = 64) AS SELECT 1 AS y",
"presto": "CREATE TABLE x WITH (bucket_by=ARRAY['y'], bucket_count=64) AS SELECT 1 AS y",
"hive": "CREATE TABLE x TBLPROPERTIES ('bucket_by'=ARRAY('y'), 'bucket_count'=64) AS SELECT 1 AS y",
"spark": "CREATE TABLE x TBLPROPERTIES ('bucket_by'=ARRAY('y'), 'bucket_count'=64) AS SELECT 1 AS y",
},
)
self.validate_all(
@ -217,11 +217,12 @@ class TestPresto(Validator):
},
)
self.validate(
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname",
read="presto",
write="presto",
write={
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
},
)
def test_quotes(self):

View file

@ -143,6 +143,31 @@ class TestSnowflake(Validator):
"snowflake": r"SELECT 'a \' \\ \\t \\x21 z $ '",
},
)
self.validate_identity("SELECT REGEXP_LIKE(a, b, c)")
self.validate_all(
"SELECT RLIKE(a, b)",
write={
"snowflake": "SELECT REGEXP_LIKE(a, b)",
},
)
self.validate_all(
"SELECT a FROM test SAMPLE BLOCK (0.5) SEED (42)",
write={
"snowflake": "SELECT a FROM test TABLESAMPLE BLOCK (0.5) SEED (42)",
},
)
self.validate_all(
"SELECT a FROM test pivot",
write={
"snowflake": "SELECT a FROM test AS pivot",
},
)
self.validate_all(
"SELECT a FROM test unpivot",
write={
"snowflake": "SELECT a FROM test AS unpivot",
},
)
def test_null_treatment(self):
self.validate_all(
@ -220,3 +245,51 @@ class TestSnowflake(Validator):
"snowflake": "SELECT EXTRACT(month FROM CAST(a AS DATETIME))",
},
)
def test_semi_structured_types(self):
self.validate_identity("SELECT CAST(a AS VARIANT)")
self.validate_all(
"SELECT a::VARIANT",
write={
"snowflake": "SELECT CAST(a AS VARIANT)",
"tsql": "SELECT CAST(a AS SQL_VARIANT)",
},
)
self.validate_identity("SELECT CAST(a AS ARRAY)")
self.validate_all(
"ARRAY_CONSTRUCT(0, 1, 2)",
write={
"snowflake": "[0, 1, 2]",
"bigquery": "[0, 1, 2]",
"duckdb": "LIST_VALUE(0, 1, 2)",
"presto": "ARRAY[0, 1, 2]",
"spark": "ARRAY(0, 1, 2)",
},
)
self.validate_all(
"SELECT a::OBJECT",
write={
"snowflake": "SELECT CAST(a AS OBJECT)",
},
)
def test_ddl(self):
self.validate_identity(
"CREATE TABLE a (x DATE, y BIGINT) WITH (PARTITION BY (x), integration='q', auto_refresh=TRUE, file_format=(type = parquet))"
)
self.validate_identity("CREATE MATERIALIZED VIEW a COMMENT='...' AS SELECT 1 FROM x")
def test_user_defined_functions(self):
self.validate_all(
"CREATE FUNCTION a(x DATE, y BIGINT) RETURNS ARRAY LANGUAGE JAVASCRIPT AS $$ SELECT 1 $$",
write={
"snowflake": "CREATE FUNCTION a(x DATE, y BIGINT) RETURNS ARRAY LANGUAGE JAVASCRIPT AS ' SELECT 1 '",
},
)
self.validate_all(
"CREATE FUNCTION a() RETURNS TABLE (b INT) AS 'SELECT 1'",
write={
"snowflake": "CREATE FUNCTION a() RETURNS TABLE (b INT) AS 'SELECT 1'",
"bigquery": "CREATE TABLE FUNCTION a() RETURNS TABLE <b INT64> AS SELECT 1",
},
)

View file

@ -34,7 +34,7 @@ class TestSpark(Validator):
self.validate_all(
"CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'",
write={
"presto": "CREATE TABLE x WITH (TABLE_FORMAT = 'ICEBERG', PARTITIONED_BY = ARRAY['MONTHS'])",
"presto": "CREATE TABLE x WITH (TABLE_FORMAT = 'ICEBERG', PARTITIONED_BY=ARRAY['MONTHS'])",
"hive": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'",
"spark": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'",
},
@ -42,7 +42,7 @@ class TestSpark(Validator):
self.validate_all(
"CREATE TABLE test STORED AS PARQUET AS SELECT 1",
write={
"presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1",
"presto": "CREATE TABLE test WITH (FORMAT='PARQUET') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
"spark": "CREATE TABLE test USING PARQUET AS SELECT 1",
},
@ -56,9 +56,9 @@ class TestSpark(Validator):
)
COMMENT='Test comment: blah'
WITH (
PARTITIONED_BY = ARRAY['date'],
FORMAT = 'ICEBERG',
x = '1'
PARTITIONED_BY=ARRAY['date'],
FORMAT='ICEBERG',
x='1'
)""",
"hive": """CREATE TABLE blah (
col_a INT
@ -69,7 +69,7 @@ PARTITIONED BY (
)
STORED AS ICEBERG
TBLPROPERTIES (
'x' = '1'
'x'='1'
)""",
"spark": """CREATE TABLE blah (
col_a INT
@ -80,7 +80,7 @@ PARTITIONED BY (
)
USING ICEBERG
TBLPROPERTIES (
'x' = '1'
'x'='1'
)""",
},
pretty=True,

View file

@ -15,6 +15,14 @@ class TestTSQL(Validator):
},
)
self.validate_all(
"CONVERT(INT, CONVERT(NUMERIC, '444.75'))",
write={
"mysql": "CAST(CAST('444.75' AS DECIMAL) AS INT)",
"tsql": "CAST(CAST('444.75' AS NUMERIC) AS INTEGER)",
},
)
def test_types(self):
self.validate_identity("CAST(x AS XML)")
self.validate_identity("CAST(x AS UNIQUEIDENTIFIER)")
@ -24,3 +32,13 @@ class TestTSQL(Validator):
self.validate_identity("CAST(x AS IMAGE)")
self.validate_identity("CAST(x AS SQL_VARIANT)")
self.validate_identity("CAST(x AS BIT)")
self.validate_all(
"CAST(x AS DATETIME2)",
read={
"": "CAST(x AS DATETIME)",
},
write={
"mysql": "CAST(x AS DATETIME)",
"tsql": "CAST(x AS DATETIME2)",
},
)

View file

@ -8,6 +8,7 @@ SUM(CASE WHEN x > 1 THEN 1 ELSE 0 END) / y
1.1E10
1.12e-10
-11.023E7 * 3
0.2
(1 * 2) / (3 - 5)
((TRUE))
''
@ -167,7 +168,7 @@ SELECT LEAD(a) OVER (ORDER BY b) AS a
SELECT LEAD(a, 1) OVER (PARTITION BY a ORDER BY a) AS x
SELECT LEAD(a, 1, b) OVER (PARTITION BY a ORDER BY a) AS x
SELECT X((a, b) -> a + b, z -> z) AS x
SELECT X(a -> "a" + ("z" - 1))
SELECT X(a -> a + ("z" - 1))
SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)
SELECT test.* FROM test
SELECT a AS b FROM test
@ -258,15 +259,24 @@ SELECT a FROM test TABLESAMPLE(100)
SELECT a FROM test TABLESAMPLE(100 ROWS)
SELECT a FROM test TABLESAMPLE BERNOULLI (50)
SELECT a FROM test TABLESAMPLE SYSTEM (75)
SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q'))
SELECT a FROM test PIVOT(SOMEAGG(x, y, z) FOR q IN (1))
SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) PIVOT(MAX(b) FOR c IN ('d'))
SELECT a FROM (SELECT a, b FROM test) PIVOT(SUM(x) FOR y IN ('z', 'q'))
SELECT a FROM test UNPIVOT(x FOR y IN (z, q)) AS x
SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) AS x TABLESAMPLE(0.1)
SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) UNPIVOT(x FOR y IN (z, q)) AS x
SELECT ABS(a) FROM test
SELECT AVG(a) FROM test
SELECT CEIL(a) FROM test
SELECT CEIL(a, b) FROM test
SELECT COUNT(a) FROM test
SELECT COUNT(1) FROM test
SELECT COUNT(*) FROM test
SELECT COUNT(DISTINCT a) FROM test
SELECT EXP(a) FROM test
SELECT FLOOR(a) FROM test
SELECT FLOOR(a, b) FROM test
SELECT FIRST(a) FROM test
SELECT GREATEST(a, b, c) FROM test
SELECT LAST(a) FROM test
@ -299,6 +309,7 @@ SELECT CAST(a AS MAP<INT, INT>) FROM test
SELECT CAST(a AS TIMESTAMP) FROM test
SELECT CAST(a AS DATE) FROM test
SELECT CAST(a AS ARRAY<INT>) FROM test
SELECT CAST(a AS VARIANT) FROM test
SELECT TRY_CAST(a AS INT) FROM test
SELECT COALESCE(a, b, c) FROM test
SELECT IFNULL(a, b) FROM test
@ -442,13 +453,10 @@ CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id')
CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1)
CREATE TABLE z (a INT(11) NOT NULL COLLATE utf8_bin AUTO_INCREMENT)
CREATE TABLE z (a INT, PRIMARY KEY(a))
CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'
CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'
CREATE TABLE z (a INT DEFAULT NULL, PRIMARY KEY(a)) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'
CREATE TABLE z WITH (FORMAT='parquet') AS SELECT 1
CREATE TABLE z WITH (FORMAT='ORC', x = '2') AS SELECT 1
CREATE TABLE z WITH (FORMAT='ORC', x='2') AS SELECT 1
CREATE TABLE z WITH (TABLE_FORMAT='iceberg', FORMAT='parquet') AS SELECT 1
CREATE TABLE z WITH (TABLE_FORMAT='iceberg', FORMAT='ORC', x = '2') AS SELECT 1
CREATE TABLE z WITH (TABLE_FORMAT='iceberg', FORMAT='ORC', x='2') AS SELECT 1
CREATE TABLE z (z INT) WITH (PARTITIONED_BY=(x INT, y INT))
CREATE TABLE z (z INT) WITH (PARTITIONED_BY=(x INT)) AS SELECT 1
CREATE TABLE z AS (WITH cte AS (SELECT 1) SELECT * FROM cte)
@ -460,6 +468,9 @@ CREATE TEMPORARY FUNCTION f
CREATE TEMPORARY FUNCTION f AS 'g'
CREATE FUNCTION f
CREATE FUNCTION f AS 'g'
CREATE FUNCTION a(b INT, c VARCHAR) AS 'SELECT 1'
CREATE FUNCTION a() LANGUAGE sql
CREATE FUNCTION a() LANGUAGE sql RETURNS INT
CREATE INDEX abc ON t (a)
CREATE INDEX abc ON t (a, b, b)
CREATE UNIQUE INDEX abc ON t (a, b, b)
@ -519,3 +530,4 @@ WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS
WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a
SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z
SELECT ((SELECT 1) + 1)
SELECT * FROM project.dataset.INFORMATION_SCHEMA.TABLES

View file

@ -1,42 +1,79 @@
SELECT 1 AS x, 2 AS y
UNION ALL
SELECT 1 AS x, 2 AS y;
WITH _e_0 AS (
SELECT
1 AS x,
2 AS y
)
SELECT
*
FROM _e_0
UNION ALL
SELECT
*
FROM _e_0;
-- No derived tables
SELECT * FROM x;
SELECT * FROM x;
SELECT x.id
FROM (
SELECT *
FROM x AS x
JOIN y AS y
ON x.id = y.id
) AS x
JOIN (
SELECT *
FROM x AS x
JOIN y AS y
ON x.id = y.id
) AS y
ON x.id = y.id;
WITH _e_0 AS (
SELECT
*
FROM x AS x
JOIN y AS y
ON x.id = y.id
)
SELECT
x.id
FROM "_e_0" AS x
JOIN "_e_0" AS y
ON x.id = y.id;
-- Unaliased derived tables
SELECT a FROM (SELECT b FROM (SELECT c FROM x));
WITH cte AS (SELECT c FROM x), cte_2 AS (SELECT b FROM cte AS cte) SELECT a FROM cte_2 AS cte_2;
-- Joined derived table inside nested derived table
SELECT b FROM (SELECT b FROM (SELECT b FROM x JOIN (SELECT b FROM y) AS y ON x.b = y.b));
WITH y_2 AS (SELECT b FROM y), cte AS (SELECT b FROM x JOIN y_2 AS y ON x.b = y.b), cte_2 AS (SELECT b FROM cte AS cte) SELECT b FROM cte_2 AS cte_2;
-- Aliased derived tables
SELECT a FROM (SELECT b FROM (SELECT c FROM x) AS y) AS z;
WITH y AS (SELECT c FROM x), z AS (SELECT b FROM y AS y) SELECT a FROM z AS z;
-- Existing CTEs
WITH q AS (SELECT c FROM x) SELECT a FROM (SELECT b FROM q AS y) AS z;
WITH q AS (SELECT c FROM x), z AS (SELECT b FROM q AS y) SELECT a FROM z AS z;
-- Derived table inside CTE
WITH x AS (SELECT a FROM (SELECT a FROM x) AS y) SELECT a FROM x;
WITH y AS (SELECT a FROM x), x AS (SELECT a FROM y AS y) SELECT a FROM x;
-- Name conflicts with existing outer derived table
SELECT a FROM (SELECT b FROM (SELECT c FROM x) AS y) AS y;
WITH y AS (SELECT c FROM x), y_2 AS (SELECT b FROM y AS y) SELECT a FROM y_2 AS y;
-- Name conflicts with outer join
SELECT a, b FROM (SELECT c FROM (SELECT d FROM x) AS x) AS y JOIN x ON x.a = y.a;
WITH x_2 AS (SELECT d FROM x), y AS (SELECT c FROM x_2 AS x) SELECT a, b FROM y AS y JOIN x ON x.a = y.a;
-- Name conflicts with table name that is selected in another branch
SELECT * FROM (SELECT * FROM (SELECT a FROM x) AS x) AS y JOIN (SELECT * FROM x) AS z ON x.a = y.a;
WITH x_2 AS (SELECT a FROM x), y AS (SELECT * FROM x_2 AS x), z AS (SELECT * FROM x) SELECT * FROM y AS y JOIN z AS z ON x.a = y.a;
-- Name conflicts with table alias
SELECT a FROM (SELECT a FROM (SELECT a FROM x) AS y) AS z JOIN q AS y;
WITH y AS (SELECT a FROM x), z AS (SELECT a FROM y AS y) SELECT a FROM z AS z JOIN q AS y;
-- Name conflicts with existing CTE
WITH y AS (SELECT a FROM (SELECT a FROM x) AS y) SELECT a FROM y;
WITH y_2 AS (SELECT a FROM x), y AS (SELECT a FROM y_2 AS y) SELECT a FROM y;
-- Union
SELECT 1 AS x, 2 AS y UNION ALL SELECT 1 AS x, 2 AS y;
WITH cte AS (SELECT 1 AS x, 2 AS y) SELECT cte.x AS x, cte.y AS y FROM cte AS cte UNION ALL SELECT cte.x AS x, cte.y AS y FROM cte AS cte;
-- Union of selects with derived tables
(SELECT a FROM (SELECT b FROM x)) UNION (SELECT a FROM (SELECT b FROM y));
WITH cte AS (SELECT b FROM x), cte_2 AS (SELECT a FROM cte AS cte), cte_3 AS (SELECT b FROM y), cte_4 AS (SELECT a FROM cte_3 AS cte_3) (SELECT cte_2.a AS a FROM cte_2 AS cte_2) UNION (SELECT cte_4.a AS a FROM cte_4 AS cte_4);
-- Subquery
SELECT a FROM x WHERE b = (SELECT y.c FROM y);
SELECT a FROM x WHERE b = (SELECT y.c FROM y);
-- Correlated subquery
SELECT a FROM x WHERE b = (SELECT c FROM y WHERE y.a = x.a);
SELECT a FROM x WHERE b = (SELECT c FROM y WHERE y.a = x.a);
-- Duplicate CTE
SELECT a FROM (SELECT b FROM x) AS y JOIN (SELECT b FROM x) AS z;
WITH y AS (SELECT b FROM x) SELECT a FROM y AS y JOIN y AS z;
-- Doubly duplicate CTE
SELECT * FROM (SELECT * FROM x JOIN (SELECT * FROM x) AS y) AS z JOIN (SELECT * FROM x JOIN (SELECT * FROM x) AS y) AS q;
WITH y AS (SELECT * FROM x), z AS (SELECT * FROM x JOIN y AS y) SELECT * FROM z AS z JOIN z AS q;
-- Another duplicate...
SELECT x.id FROM (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) AS x JOIN (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) AS y ON x.id = y.id;
WITH x_2 AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT x.id FROM x_2 AS x JOIN x_2 AS y ON x.id = y.id;
-- Root subquery
(SELECT * FROM (SELECT * FROM x)) LIMIT 1;
(WITH cte AS (SELECT * FROM x) SELECT * FROM cte AS cte) LIMIT 1;
-- Existing duplicate CTE
WITH y AS (SELECT a FROM x) SELECT a FROM (SELECT a FROM x) AS y JOIN y AS z;
WITH y AS (SELECT a FROM x) SELECT a FROM y AS y JOIN y AS z;

View file

@ -18,6 +18,14 @@ SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 GROUP BY x.a;
SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1;
-- Outer query has join
SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1;
# leave_tables_isolated: true
SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM (SELECT x.a AS a, x.b AS b FROM x AS x WHERE x.a > 1) AS x JOIN y AS y ON x.b = y.b;
-- Join on derived table
SELECT a, c FROM x JOIN (SELECT b, c FROM y) AS y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
@ -42,13 +50,9 @@ SELECT q_2.a AS a, q.c AS c, r.c AS c FROM x AS q_2 JOIN y AS r_2 ON q_2.b = r_2
SELECT r.b FROM (SELECT b FROM x AS x) AS q JOIN (SELECT b FROM x) AS r ON q.b = r.b;
SELECT x_2.b AS b FROM x AS x JOIN x AS x_2 ON x.b = x_2.b;
-- WHERE clause in joined derived table is merged
-- WHERE clause in joined derived table is merged to ON clause
SELECT x.a, y.c FROM x JOIN (SELECT b, c FROM y WHERE c > 1) AS y;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y WHERE y.c > 1;
-- WHERE clause in outer joined derived table is merged to ON clause
SELECT x.a, y.c FROM x LEFT JOIN (SELECT b, c FROM y WHERE c > 1) AS y;
SELECT x.a AS a, y.c AS c FROM x AS x LEFT JOIN y AS y ON y.c > 1;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON y.c > 1;
-- Comma JOIN in outer query
SELECT x.a, y.c FROM (SELECT a FROM x) AS x, (SELECT c FROM y) AS y;
@ -61,3 +65,35 @@ SELECT x.a AS a, z.c AS c FROM x AS x CROSS JOIN y AS z;
-- (Regression) Column in ORDER BY
SELECT * FROM (SELECT * FROM (SELECT * FROM x)) ORDER BY a LIMIT 1;
SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a LIMIT 1;
-- CTE
WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x;
SELECT x.a AS a, x.b AS b FROM x AS x;
-- CTE with outer table alias
WITH y AS (SELECT a, b FROM x) SELECT a, b FROM y AS z;
SELECT x.a AS a, x.b AS b FROM x AS x;
-- Nested CTE
WITH x AS (SELECT a FROM x), x2 AS (SELECT a FROM x) SELECT a FROM x2;
SELECT x.a AS a FROM x AS x;
-- CTE WHERE clause is merged
WITH x AS (SELECT a, b FROM x WHERE a > 1) SELECT a, SUM(b) FROM x GROUP BY a;
SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 GROUP BY x.a;
-- CTE Outer query has join
WITH x AS (SELECT a, b FROM x WHERE a > 1) SELECT a, c FROM x AS x JOIN y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1;
-- CTE with inner table alias
WITH y AS (SELECT a, b FROM x AS q) SELECT a, b FROM y AS z;
SELECT q.a AS a, q.b AS b FROM x AS q;
-- Duplicate queries to CTE
WITH x AS (SELECT a, b FROM x) SELECT x.a, y.b FROM x JOIN x AS y;
WITH x AS (SELECT x.a AS a, x.b AS b FROM x AS x) SELECT x.a AS a, y.b AS b FROM x JOIN x AS y;
-- Nested CTE
SELECT * FROM (WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x);
SELECT x.a AS a, x.b AS b FROM x AS x;

View file

@ -65,18 +65,14 @@ WITH "cte1" AS (
SELECT
"x"."a" AS "a"
FROM "x" AS "x"
), "cte2" AS (
SELECT
"cte1"."a" + 1 AS "a"
FROM "cte1"
)
SELECT
"cte1"."a" AS "a"
FROM "cte1"
UNION ALL
SELECT
"cte2"."a" AS "a"
FROM "cte2";
"cte1"."a" + 1 AS "a"
FROM "cte1";
SELECT a, SUM(b)
FROM (
@ -86,18 +82,19 @@ FROM (
) d
WHERE (TRUE AND TRUE OR 'a' = 'b') AND a > 1
GROUP BY a;
SELECT
"x"."a" AS "a",
SUM("y"."b") AS "_col_1"
FROM "x" AS "x"
LEFT JOIN (
WITH "_u_0" AS (
SELECT
MAX("y"."b") AS "_col_0",
"y"."a" AS "_u_1"
FROM "y" AS "y"
GROUP BY
"y"."a"
) AS "_u_0"
)
SELECT
"x"."a" AS "a",
SUM("y"."b") AS "_col_1"
FROM "x" AS "x"
LEFT JOIN "_u_0" AS "_u_0"
ON "x"."a" = "_u_0"."_u_1"
JOIN "y" AS "y"
ON "x"."a" = "y"."a"
@ -127,3 +124,16 @@ LIMIT 1;
FROM "y" AS "y"
)
LIMIT 1;
# dialect: spark
SELECT /*+ BROADCAST(y) */ x.b FROM x JOIN y ON x.b = y.b;
SELECT /*+ BROADCAST(`y`) */
`x`.`b` AS `b`
FROM `x` AS `x`
JOIN `y` AS `y`
ON `x`.`b` = `y`.`b`;
SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
SELECT
AGGREGATE(ARRAY("x"."a", "x"."b"), 0, ("x", "acc") -> "x" + "acc" + "x"."a") AS "sum_agg"
FROM "x" AS "x";

View file

@ -69,6 +69,9 @@ SELECT ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.b) AS row_num FROM x AS x
SELECT x.b, x.a FROM x LEFT JOIN y ON x.b = y.b QUALIFY ROW_NUMBER() OVER(PARTITION BY x.b ORDER BY x.a DESC) = 1;
SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_NUMBER() OVER (PARTITION BY x.b ORDER BY x.a DESC) = 1;
SELECT AGGREGATE(ARRAY(a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + x.a) AS sum_agg FROM x AS x;
--------------------------------------
-- Derived tables
--------------------------------------
@ -231,3 +234,10 @@ SELECT COALESCE(x.b, y.b) AS b FROM x AS x JOIN y AS y ON x.b = y.b WHERE COALES
SELECT b FROM x JOIN y USING (b) JOIN z USING (b);
SELECT COALESCE(x.b, y.b, z.b) AS b FROM x AS x JOIN y AS y ON x.b = y.b JOIN z AS z ON x.b = z.b;
--------------------------------------
-- Hint with table reference
--------------------------------------
# dialect: spark
SELECT /*+ BROADCAST(y) */ x.b FROM x JOIN y ON x.b = y.b;
SELECT /*+ BROADCAST(y) */ x.b AS b FROM x AS x JOIN y AS y ON x.b = y.b;

View file

@ -5,7 +5,6 @@ SELECT z.* FROM x;
SELECT x FROM x;
INSERT INTO x VALUES (1, 2);
SELECT a FROM x AS z JOIN y AS z;
WITH z AS (SELECT * FROM x) SELECT * FROM x AS z;
SELECT a FROM x JOIN (SELECT b FROM y WHERE y.b = x.c);
SELECT a FROM x AS y JOIN (SELECT a FROM y) AS q ON y.a = q.a;
SELECT q.a FROM (SELECT x.b FROM x) AS z JOIN (SELECT a FROM z) AS q ON z.b = q.a;

View file

@ -97,19 +97,32 @@ order by
p_partkey
limit
100;
WITH "_e_0" AS (
WITH "partsupp_2" AS (
SELECT
"partsupp"."ps_partkey" AS "ps_partkey",
"partsupp"."ps_suppkey" AS "ps_suppkey",
"partsupp"."ps_supplycost" AS "ps_supplycost"
FROM "partsupp" AS "partsupp"
), "_e_1" AS (
), "region_2" AS (
SELECT
"region"."r_regionkey" AS "r_regionkey",
"region"."r_name" AS "r_name"
FROM "region" AS "region"
WHERE
"region"."r_name" = 'EUROPE'
), "_u_0" AS (
SELECT
MIN("partsupp"."ps_supplycost") AS "_col_0",
"partsupp"."ps_partkey" AS "_u_1"
FROM "partsupp_2" AS "partsupp"
CROSS JOIN "region_2" AS "region"
JOIN "nation" AS "nation"
ON "nation"."n_regionkey" = "region"."r_regionkey"
JOIN "supplier" AS "supplier"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey"
GROUP BY
"partsupp"."ps_partkey"
)
SELECT
"supplier"."s_acctbal" AS "s_acctbal",
@ -121,25 +134,12 @@ SELECT
"supplier"."s_phone" AS "s_phone",
"supplier"."s_comment" AS "s_comment"
FROM "part" AS "part"
LEFT JOIN (
SELECT
MIN("partsupp"."ps_supplycost") AS "_col_0",
"partsupp"."ps_partkey" AS "_u_1"
FROM "_e_0" AS "partsupp"
CROSS JOIN "_e_1" AS "region"
JOIN "nation" AS "nation"
ON "nation"."n_regionkey" = "region"."r_regionkey"
JOIN "supplier" AS "supplier"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey"
GROUP BY
"partsupp"."ps_partkey"
) AS "_u_0"
LEFT JOIN "_u_0" AS "_u_0"
ON "part"."p_partkey" = "_u_0"."_u_1"
CROSS JOIN "_e_1" AS "region"
CROSS JOIN "region_2" AS "region"
JOIN "nation" AS "nation"
ON "nation"."n_regionkey" = "region"."r_regionkey"
JOIN "_e_0" AS "partsupp"
JOIN "partsupp_2" AS "partsupp"
ON "part"."p_partkey" = "partsupp"."ps_partkey"
JOIN "supplier" AS "supplier"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
@ -193,12 +193,12 @@ SELECT
FROM "customer" AS "customer"
JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey"
AND "orders"."o_orderdate" < '1995-03-15'
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
AND "lineitem"."l_shipdate" > '1995-03-15'
WHERE
"customer"."c_mktsegment" = 'BUILDING'
AND "lineitem"."l_shipdate" > '1995-03-15'
AND "orders"."o_orderdate" < '1995-03-15'
GROUP BY
"lineitem"."l_orderkey",
"orders"."o_orderdate",
@ -232,11 +232,7 @@ group by
o_orderpriority
order by
o_orderpriority;
SELECT
"orders"."o_orderpriority" AS "o_orderpriority",
COUNT(*) AS "order_count"
FROM "orders" AS "orders"
LEFT JOIN (
WITH "_u_0" AS (
SELECT
"lineitem"."l_orderkey" AS "l_orderkey"
FROM "lineitem" AS "lineitem"
@ -244,7 +240,12 @@ LEFT JOIN (
"lineitem"."l_commitdate" < "lineitem"."l_receiptdate"
GROUP BY
"lineitem"."l_orderkey"
) AS "_u_0"
)
SELECT
"orders"."o_orderpriority" AS "o_orderpriority",
COUNT(*) AS "order_count"
FROM "orders" AS "orders"
LEFT JOIN "_u_0" AS "_u_0"
ON "_u_0"."l_orderkey" = "orders"."o_orderkey"
WHERE
"orders"."o_orderdate" < CAST('1993-10-01' AS DATE)
@ -290,7 +291,10 @@ SELECT
FROM "customer" AS "customer"
JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey"
CROSS JOIN "region" AS "region"
AND "orders"."o_orderdate" < CAST('1995-01-01' AS DATE)
AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE)
JOIN "region" AS "region"
ON "region"."r_name" = 'ASIA'
JOIN "nation" AS "nation"
ON "nation"."n_regionkey" = "region"."r_regionkey"
JOIN "supplier" AS "supplier"
@ -299,10 +303,6 @@ JOIN "supplier" AS "supplier"
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
AND "lineitem"."l_suppkey" = "supplier"."s_suppkey"
WHERE
"orders"."o_orderdate" < CAST('1995-01-01' AS DATE)
AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE)
AND "region"."r_name" = 'ASIA'
GROUP BY
"nation"."n_name"
ORDER BY
@ -371,7 +371,7 @@ order by
supp_nation,
cust_nation,
l_year;
WITH "_e_0" AS (
WITH "n1" AS (
SELECT
"nation"."n_nationkey" AS "n_nationkey",
"nation"."n_name" AS "n_name"
@ -389,14 +389,15 @@ SELECT
)) AS "revenue"
FROM "supplier" AS "supplier"
JOIN "lineitem" AS "lineitem"
ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
ON "lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
AND "supplier"."s_suppkey" = "lineitem"."l_suppkey"
JOIN "orders" AS "orders"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
JOIN "customer" AS "customer"
ON "customer"."c_custkey" = "orders"."o_custkey"
JOIN "_e_0" AS "n1"
JOIN "n1" AS "n1"
ON "supplier"."s_nationkey" = "n1"."n_nationkey"
JOIN "_e_0" AS "n2"
JOIN "n1" AS "n2"
ON "customer"."c_nationkey" = "n2"."n_nationkey"
AND (
"n1"."n_name" = 'FRANCE'
@ -406,8 +407,6 @@ JOIN "_e_0" AS "n2"
"n1"."n_name" = 'GERMANY'
OR "n2"."n_name" = 'GERMANY'
)
WHERE
"lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
GROUP BY
"n1"."n_name",
"n2"."n_name",
@ -469,13 +468,15 @@ SELECT
1 - "lineitem"."l_discount"
)) AS "mkt_share"
FROM "part" AS "part"
CROSS JOIN "region" AS "region"
JOIN "region" AS "region"
ON "region"."r_name" = 'AMERICA'
JOIN "nation" AS "nation"
ON "nation"."n_regionkey" = "region"."r_regionkey"
JOIN "customer" AS "customer"
ON "customer"."c_nationkey" = "nation"."n_nationkey"
JOIN "orders" AS "orders"
ON "orders"."o_custkey" = "customer"."c_custkey"
AND "orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
AND "part"."p_partkey" = "lineitem"."l_partkey"
@ -484,9 +485,7 @@ JOIN "supplier" AS "supplier"
JOIN "nation" AS "nation_2"
ON "supplier"."s_nationkey" = "nation_2"."n_nationkey"
WHERE
"orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
AND "part"."p_type" = 'ECONOMY ANODIZED STEEL'
AND "region"."r_name" = 'AMERICA'
"part"."p_type" = 'ECONOMY ANODIZED STEEL'
GROUP BY
EXTRACT(year FROM "orders"."o_orderdate")
ORDER BY
@ -604,14 +603,13 @@ SELECT
FROM "customer" AS "customer"
JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey"
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
JOIN "nation" AS "nation"
ON "customer"."c_nationkey" = "nation"."n_nationkey"
WHERE
"lineitem"."l_returnflag" = 'R'
AND "orders"."o_orderdate" < CAST('1994-01-01' AS DATE)
AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE)
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
AND "lineitem"."l_returnflag" = 'R'
JOIN "nation" AS "nation"
ON "customer"."c_nationkey" = "nation"."n_nationkey"
GROUP BY
"customer"."c_custkey",
"customer"."c_name",
@ -654,12 +652,12 @@ group by
)
order by
value desc;
WITH "_e_0" AS (
WITH "supplier_2" AS (
SELECT
"supplier"."s_suppkey" AS "s_suppkey",
"supplier"."s_nationkey" AS "s_nationkey"
FROM "supplier" AS "supplier"
), "_e_1" AS (
), "nation_2" AS (
SELECT
"nation"."n_nationkey" AS "n_nationkey",
"nation"."n_name" AS "n_name"
@ -671,9 +669,9 @@ SELECT
"partsupp"."ps_partkey" AS "ps_partkey",
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value"
FROM "partsupp" AS "partsupp"
JOIN "_e_0" AS "supplier"
JOIN "supplier_2" AS "supplier"
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
JOIN "_e_1" AS "nation"
JOIN "nation_2" AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
GROUP BY
"partsupp"."ps_partkey"
@ -682,9 +680,9 @@ HAVING
SELECT
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0"
FROM "partsupp" AS "partsupp"
JOIN "_e_0" AS "supplier"
JOIN "supplier_2" AS "supplier"
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
JOIN "_e_1" AS "nation"
JOIN "nation_2" AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
)
ORDER BY
@ -737,13 +735,12 @@ SELECT
END) AS "low_line_count"
FROM "orders" AS "orders"
JOIN "lineitem" AS "lineitem"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
WHERE
"lineitem"."l_commitdate" < "lineitem"."l_receiptdate"
ON "lineitem"."l_commitdate" < "lineitem"."l_receiptdate"
AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE)
AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE)
AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate"
AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP')
AND "orders"."o_orderkey" = "lineitem"."l_orderkey"
GROUP BY
"lineitem"."l_shipmode"
ORDER BY
@ -772,10 +769,7 @@ group by
order by
custdist desc,
c_count desc;
SELECT
"c_orders"."c_count" AS "c_count",
COUNT(*) AS "custdist"
FROM (
WITH "c_orders" AS (
SELECT
COUNT("orders"."o_orderkey") AS "c_count"
FROM "customer" AS "customer"
@ -784,7 +778,11 @@ FROM (
AND NOT "orders"."o_comment" LIKE '%special%requests%'
GROUP BY
"customer"."c_custkey"
) AS "c_orders"
)
SELECT
"c_orders"."c_count" AS "c_count",
COUNT(*) AS "custdist"
FROM "c_orders" AS "c_orders"
GROUP BY
"c_orders"."c_count"
ORDER BY
@ -920,13 +918,7 @@ order by
p_brand,
p_type,
p_size;
SELECT
"part"."p_brand" AS "p_brand",
"part"."p_type" AS "p_type",
"part"."p_size" AS "p_size",
COUNT(DISTINCT "partsupp"."ps_suppkey") AS "supplier_cnt"
FROM "partsupp" AS "partsupp"
LEFT JOIN (
WITH "_u_0" AS (
SELECT
"supplier"."s_suppkey" AS "s_suppkey"
FROM "supplier" AS "supplier"
@ -934,15 +926,22 @@ LEFT JOIN (
"supplier"."s_comment" LIKE '%Customer%Complaints%'
GROUP BY
"supplier"."s_suppkey"
) AS "_u_0"
)
SELECT
"part"."p_brand" AS "p_brand",
"part"."p_type" AS "p_type",
"part"."p_size" AS "p_size",
COUNT(DISTINCT "partsupp"."ps_suppkey") AS "supplier_cnt"
FROM "partsupp" AS "partsupp"
LEFT JOIN "_u_0" AS "_u_0"
ON "partsupp"."ps_suppkey" = "_u_0"."s_suppkey"
JOIN "part" AS "part"
ON "part"."p_partkey" = "partsupp"."ps_partkey"
WHERE
"_u_0"."s_suppkey" IS NULL
AND "part"."p_brand" <> 'Brand#45'
ON "part"."p_brand" <> 'Brand#45'
AND "part"."p_partkey" = "partsupp"."ps_partkey"
AND "part"."p_size" IN (49, 14, 23, 45, 19, 3, 36, 9)
AND NOT "part"."p_type" LIKE 'MEDIUM POLISHED%'
WHERE
"_u_0"."s_suppkey" IS NULL
GROUP BY
"part"."p_brand",
"part"."p_type",
@ -973,24 +972,25 @@ where
where
l_partkey = p_partkey
);
SELECT
SUM("lineitem"."l_extendedprice") / 7.0 AS "avg_yearly"
FROM "lineitem" AS "lineitem"
JOIN "part" AS "part"
ON "part"."p_partkey" = "lineitem"."l_partkey"
LEFT JOIN (
WITH "_u_0" AS (
SELECT
0.2 * AVG("lineitem"."l_quantity") AS "_col_0",
"lineitem"."l_partkey" AS "_u_1"
FROM "lineitem" AS "lineitem"
GROUP BY
"lineitem"."l_partkey"
) AS "_u_0"
)
SELECT
SUM("lineitem"."l_extendedprice") / 7.0 AS "avg_yearly"
FROM "lineitem" AS "lineitem"
JOIN "part" AS "part"
ON "part"."p_brand" = 'Brand#23'
AND "part"."p_container" = 'MED BOX'
AND "part"."p_partkey" = "lineitem"."l_partkey"
LEFT JOIN "_u_0" AS "_u_0"
ON "_u_0"."_u_1" = "part"."p_partkey"
WHERE
"lineitem"."l_quantity" < "_u_0"."_col_0"
AND "part"."p_brand" = 'Brand#23'
AND "part"."p_container" = 'MED BOX'
AND NOT "_u_0"."_u_1" IS NULL;
--------------------------------------
@ -1030,6 +1030,16 @@ order by
o_orderdate
limit
100;
WITH "_u_0" AS (
SELECT
"lineitem"."l_orderkey" AS "l_orderkey"
FROM "lineitem" AS "lineitem"
GROUP BY
"lineitem"."l_orderkey",
"lineitem"."l_orderkey"
HAVING
SUM("lineitem"."l_quantity") > 300
)
SELECT
"customer"."c_name" AS "c_name",
"customer"."c_custkey" AS "c_custkey",
@ -1040,16 +1050,7 @@ SELECT
FROM "customer" AS "customer"
JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey"
LEFT JOIN (
SELECT
"lineitem"."l_orderkey" AS "l_orderkey"
FROM "lineitem" AS "lineitem"
GROUP BY
"lineitem"."l_orderkey",
"lineitem"."l_orderkey"
HAVING
SUM("lineitem"."l_quantity") > 300
) AS "_u_0"
LEFT JOIN "_u_0" AS "_u_0"
ON "orders"."o_orderkey" = "_u_0"."l_orderkey"
JOIN "lineitem" AS "lineitem"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
@ -1200,38 +1201,34 @@ where
and n_name = 'CANADA'
order by
s_name;
SELECT
"supplier"."s_name" AS "s_name",
"supplier"."s_address" AS "s_address"
FROM "supplier" AS "supplier"
LEFT JOIN (
WITH "_u_0" AS (
SELECT
0.5 * SUM("lineitem"."l_quantity") AS "_col_0",
"lineitem"."l_partkey" AS "_u_1",
"lineitem"."l_suppkey" AS "_u_2"
FROM "lineitem" AS "lineitem"
WHERE
"lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE)
AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE)
GROUP BY
"lineitem"."l_partkey",
"lineitem"."l_suppkey"
), "_u_3" AS (
SELECT
"part"."p_partkey" AS "p_partkey"
FROM "part" AS "part"
WHERE
"part"."p_name" LIKE 'forest%'
GROUP BY
"part"."p_partkey"
), "_u_4" AS (
SELECT
"partsupp"."ps_suppkey" AS "ps_suppkey"
FROM "partsupp" AS "partsupp"
LEFT JOIN (
SELECT
0.5 * SUM("lineitem"."l_quantity") AS "_col_0",
"lineitem"."l_partkey" AS "_u_1",
"lineitem"."l_suppkey" AS "_u_2"
FROM "lineitem" AS "lineitem"
WHERE
"lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE)
AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE)
GROUP BY
"lineitem"."l_partkey",
"lineitem"."l_suppkey"
) AS "_u_0"
LEFT JOIN "_u_0" AS "_u_0"
ON "_u_0"."_u_1" = "partsupp"."ps_partkey"
AND "_u_0"."_u_2" = "partsupp"."ps_suppkey"
LEFT JOIN (
SELECT
"part"."p_partkey" AS "p_partkey"
FROM "part" AS "part"
WHERE
"part"."p_name" LIKE 'forest%'
GROUP BY
"part"."p_partkey"
) AS "_u_3"
LEFT JOIN "_u_3" AS "_u_3"
ON "partsupp"."ps_partkey" = "_u_3"."p_partkey"
WHERE
"partsupp"."ps_availqty" > "_u_0"."_col_0"
@ -1240,13 +1237,18 @@ LEFT JOIN (
AND NOT "_u_3"."p_partkey" IS NULL
GROUP BY
"partsupp"."ps_suppkey"
) AS "_u_4"
)
SELECT
"supplier"."s_name" AS "s_name",
"supplier"."s_address" AS "s_address"
FROM "supplier" AS "supplier"
LEFT JOIN "_u_4" AS "_u_4"
ON "supplier"."s_suppkey" = "_u_4"."ps_suppkey"
JOIN "nation" AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
ON "nation"."n_name" = 'CANADA'
AND "supplier"."s_nationkey" = "nation"."n_nationkey"
WHERE
"nation"."n_name" = 'CANADA'
AND NOT "_u_4"."ps_suppkey" IS NULL
NOT "_u_4"."ps_suppkey" IS NULL
ORDER BY
"s_name";
@ -1294,22 +1296,14 @@ order by
s_name
limit
100;
SELECT
"supplier"."s_name" AS "s_name",
COUNT(*) AS "numwait"
FROM "supplier" AS "supplier"
JOIN "lineitem" AS "lineitem"
ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
LEFT JOIN (
WITH "_u_0" AS (
SELECT
"l2"."l_orderkey" AS "l_orderkey",
ARRAY_AGG("l2"."l_suppkey") AS "_u_1"
FROM "lineitem" AS "l2"
GROUP BY
"l2"."l_orderkey"
) AS "_u_0"
ON "_u_0"."l_orderkey" = "lineitem"."l_orderkey"
LEFT JOIN (
), "_u_2" AS (
SELECT
"l3"."l_orderkey" AS "l_orderkey",
ARRAY_AGG("l3"."l_suppkey") AS "_u_3"
@ -1318,20 +1312,29 @@ LEFT JOIN (
"l3"."l_receiptdate" > "l3"."l_commitdate"
GROUP BY
"l3"."l_orderkey"
) AS "_u_2"
)
SELECT
"supplier"."s_name" AS "s_name",
COUNT(*) AS "numwait"
FROM "supplier" AS "supplier"
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_receiptdate" > "lineitem"."l_commitdate"
AND "supplier"."s_suppkey" = "lineitem"."l_suppkey"
LEFT JOIN "_u_0" AS "_u_0"
ON "_u_0"."l_orderkey" = "lineitem"."l_orderkey"
LEFT JOIN "_u_2" AS "_u_2"
ON "_u_2"."l_orderkey" = "lineitem"."l_orderkey"
JOIN "orders" AS "orders"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
AND "orders"."o_orderstatus" = 'F'
JOIN "nation" AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
ON "nation"."n_name" = 'SAUDI ARABIA'
AND "supplier"."s_nationkey" = "nation"."n_nationkey"
WHERE
(
"_u_2"."l_orderkey" IS NULL
OR NOT ARRAY_ANY("_u_2"."_u_3", "_x" -> "_x" <> "lineitem"."l_suppkey")
)
AND "lineitem"."l_receiptdate" > "lineitem"."l_commitdate"
AND "nation"."n_name" = 'SAUDI ARABIA'
AND "orders"."o_orderstatus" = 'F'
AND ARRAY_ANY("_u_0"."_u_1", "_x" -> "_x" <> "lineitem"."l_suppkey")
AND NOT "_u_0"."l_orderkey" IS NULL
GROUP BY
@ -1381,18 +1384,19 @@ group by
cntrycode
order by
cntrycode;
SELECT
SUBSTRING("customer"."c_phone", 1, 2) AS "cntrycode",
COUNT(*) AS "numcust",
SUM("customer"."c_acctbal") AS "totacctbal"
FROM "customer" AS "customer"
LEFT JOIN (
WITH "_u_0" AS (
SELECT
"orders"."o_custkey" AS "_u_1"
FROM "orders" AS "orders"
GROUP BY
"orders"."o_custkey"
) AS "_u_0"
)
SELECT
SUBSTRING("customer"."c_phone", 1, 2) AS "cntrycode",
COUNT(*) AS "numcust",
SUM("customer"."c_acctbal") AS "totacctbal"
FROM "customer" AS "customer"
LEFT JOIN "_u_0" AS "_u_0"
ON "_u_0"."_u_1" = "customer"."c_custkey"
WHERE
"_u_0"."_u_1" IS NULL

View file

@ -264,22 +264,3 @@ CREATE TABLE "t_customer_account" (
"account_no" VARCHAR(100)
);
CREATE TABLE "t_customer_account" (
"id" int(11) NOT NULL AUTO_INCREMENT,
"customer_id" int(11) DEFAULT NULL COMMENT '客户id',
"bank" varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别',
"account_no" varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号',
PRIMARY KEY ("id")
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='客户账户表';
CREATE TABLE "t_customer_account" (
"id" INT(11) NOT NULL AUTO_INCREMENT,
"customer_id" INT(11) DEFAULT NULL COMMENT '客户id',
"bank" VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别',
"account_no" VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号',
PRIMARY KEY("id")
)
ENGINE=InnoDB
AUTO_INCREMENT=1
DEFAULT CHARACTER SET=utf8
COLLATE=utf8_bin
COMMENT='客户账户表';

View file

@ -270,7 +270,7 @@ class TestBuild(unittest.TestCase):
lambda: parse_one("SELECT * FROM y")
.assert_is(exp.Select)
.ctas("foo.x", properties={"format": "parquet", "y": "2"}),
"CREATE TABLE foo.x STORED AS PARQUET TBLPROPERTIES ('y' = '2') AS SELECT * FROM y",
"CREATE TABLE foo.x STORED AS PARQUET TBLPROPERTIES ('y'='2') AS SELECT * FROM y",
"hive",
),
(lambda: and_("x=1", "y=1"), "x = 1 AND y = 1"),
@ -308,6 +308,18 @@ class TestBuild(unittest.TestCase):
lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select("x"),
"SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned",
),
(
lambda: exp.update("tbl", {"x": None, "y": {"x": 1}}),
"UPDATE tbl SET x = NULL, y = MAP('x', 1)",
),
(
lambda: exp.update("tbl", {"x": 1}, where="y > 0"),
"UPDATE tbl SET x = 1 WHERE y > 0",
),
(
lambda: exp.update("tbl", {"x": 1}, from_="tbl2"),
"UPDATE tbl SET x = 1 FROM tbl2",
),
]:
with self.subTest(sql):
self.assertEqual(expression().sql(dialect[0] if dialect else None), sql)

View file

@ -27,6 +27,8 @@ class TestExpressions(unittest.TestCase):
parse_one("ROW() OVER (partition BY y)"),
)
self.assertEqual(parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)"))
self.assertEqual(exp.Table(pivots=[]), exp.Table())
self.assertNotEqual(exp.Table(pivots=[None]), exp.Table())
def test_find(self):
expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y")
@ -280,6 +282,19 @@ class TestExpressions(unittest.TestCase):
expression.find(exp.Table).replace(parse_one("y"))
self.assertEqual(expression.sql(), "SELECT c, b FROM y")
def test_pop(self):
expression = parse_one("SELECT a, b FROM x")
expression.find(exp.Column).pop()
self.assertEqual(expression.sql(), "SELECT b FROM x")
expression.find(exp.Column).pop()
self.assertEqual(expression.sql(), "SELECT FROM x")
expression.pop()
self.assertEqual(expression.sql(), "SELECT FROM x")
expression = parse_one("WITH x AS (SELECT a FROM x) SELECT * FROM x")
expression.find(exp.With).pop()
self.assertEqual(expression.sql(), "SELECT * FROM x")
def test_walk(self):
expression = parse_one("SELECT * FROM (SELECT * FROM x)")
self.assertEqual(len(list(expression.walk())), 9)
@ -316,6 +331,7 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(parse_one("MAX(a)"), exp.Max)
self.assertIsInstance(parse_one("MIN(a)"), exp.Min)
self.assertIsInstance(parse_one("MONTH(a)"), exp.Month)
self.assertIsInstance(parse_one("POSITION(' ' IN a)"), exp.StrPosition)
self.assertIsInstance(parse_one("POW(a, 2)"), exp.Pow)
self.assertIsInstance(parse_one("POWER(a, 2)"), exp.Pow)
self.assertIsInstance(parse_one("QUANTILE(a, 0.90)"), exp.Quantile)
@ -420,7 +436,7 @@ class TestExpressions(unittest.TestCase):
exp.Properties.from_dict(
{
"FORMAT": "parquet",
"PARTITIONED_BY": [exp.to_identifier("a"), exp.to_identifier("b")],
"PARTITIONED_BY": (exp.to_identifier("a"), exp.to_identifier("b")),
"custom": 1,
"TABLE_FORMAT": exp.to_identifier("test_format"),
"ENGINE": None,
@ -444,4 +460,17 @@ class TestExpressions(unittest.TestCase):
),
)
self.assertRaises(ValueError, exp.Properties.from_dict, {"FORMAT": {"key": "value"}})
self.assertRaises(ValueError, exp.Properties.from_dict, {"FORMAT": object})
def test_convert(self):
for value, expected in [
(1, "1"),
("1", "'1'"),
(None, "NULL"),
(True, "TRUE"),
((1, "2", None), "(1, '2', NULL)"),
([1, "2", None], "ARRAY(1, '2', NULL)"),
({"x": None}, "MAP('x', NULL)"),
]:
with self.subTest(value):
self.assertEqual(exp.convert(value).sql(), expected)

View file

@ -1,9 +1,11 @@
import unittest
from functools import partial
from sqlglot import optimizer, parse_one, table
from sqlglot import exp, optimizer, parse_one, table
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.schema import MappingSchema, ensure_schema
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.optimizer.scope import build_scope, traverse_scope
from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures
@ -27,11 +29,17 @@ class TestOptimizer(unittest.TestCase):
}
def check_file(self, file, func, pretty=False, **kwargs):
for meta, sql, expected in load_sql_fixture_pairs(f"optimizer/{file}.sql"):
for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1):
dialect = meta.get("dialect")
with self.subTest(sql):
leave_tables_isolated = meta.get("leave_tables_isolated")
func_kwargs = {**kwargs}
if leave_tables_isolated is not None:
func_kwargs["leave_tables_isolated"] = leave_tables_isolated.lower() in ("true", "1")
with self.subTest(f"{i}, {sql}"):
self.assertEqual(
func(parse_one(sql, read=dialect), **kwargs).sql(pretty=pretty, dialect=dialect),
func(parse_one(sql, read=dialect), **func_kwargs).sql(pretty=pretty, dialect=dialect),
expected,
)
@ -123,21 +131,20 @@ class TestOptimizer(unittest.TestCase):
optimizer.optimize_joins.optimize_joins,
)
def test_eliminate_subqueries(self):
self.check_file(
"eliminate_subqueries",
optimizer.eliminate_subqueries.eliminate_subqueries,
pretty=True,
def test_merge_subqueries(self):
optimize = partial(
optimizer.optimize,
rules=[
optimizer.qualify_tables.qualify_tables,
optimizer.qualify_columns.qualify_columns,
optimizer.merge_subqueries.merge_subqueries,
],
)
def test_merge_derived_tables(self):
def optimize(expression, **kwargs):
expression = optimizer.qualify_tables.qualify_tables(expression)
expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
expression = optimizer.merge_derived_tables.merge_derived_tables(expression)
return expression
self.check_file("merge_subqueries", optimize, schema=self.schema)
self.check_file("merge_derived_tables", optimize, schema=self.schema)
def test_eliminate_subqueries(self):
self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries)
def test_tpch(self):
self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True)
@ -257,17 +264,73 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
ON s.b = r.b
WHERE s.b > (SELECT MAX(x.a) FROM x WHERE x.b = s.b)
"""
scopes = traverse_scope(parse_one(sql))
self.assertEqual(len(scopes), 5)
self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x")
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y")
self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql())
for scopes in traverse_scope(parse_one(sql)), list(build_scope(parse_one(sql)).traverse()):
self.assertEqual(len(scopes), 5)
self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x")
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y")
self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql())
self.assertEqual(set(scopes[4].sources), {"q", "r", "s"})
self.assertEqual(len(scopes[4].columns), 6)
self.assertEqual(set(c.table for c in scopes[4].columns), {"r", "s"})
self.assertEqual(scopes[4].source_columns("q"), [])
self.assertEqual(len(scopes[4].source_columns("r")), 2)
self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"})
self.assertEqual(set(scopes[4].sources), {"q", "r", "s"})
self.assertEqual(len(scopes[4].columns), 6)
self.assertEqual(set(c.table for c in scopes[4].columns), {"r", "s"})
self.assertEqual(scopes[4].source_columns("q"), [])
self.assertEqual(len(scopes[4].source_columns("r")), 2)
self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"})
def test_literal_type_annotation(self):
tests = {
"SELECT 5": exp.DataType.Type.INT,
"SELECT 5.3": exp.DataType.Type.DOUBLE,
"SELECT 'bla'": exp.DataType.Type.VARCHAR,
"5": exp.DataType.Type.INT,
"5.3": exp.DataType.Type.DOUBLE,
"'bla'": exp.DataType.Type.VARCHAR,
}
for sql, target_type in tests.items():
expression = parse_one(sql)
annotated_expression = annotate_types(expression)
self.assertEqual(annotated_expression.find(exp.Literal).type, target_type)
def test_boolean_type_annotation(self):
tests = {
"SELECT TRUE": exp.DataType.Type.BOOLEAN,
"FALSE": exp.DataType.Type.BOOLEAN,
}
for sql, target_type in tests.items():
expression = parse_one(sql)
annotated_expression = annotate_types(expression)
self.assertEqual(annotated_expression.find(exp.Boolean).type, target_type)
def test_cast_type_annotation(self):
expression = parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))")
annotate_types(expression)
self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ)
self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR)
self.assertEqual(expression.args["to"].type, exp.DataType.Type.TIMESTAMPTZ)
self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT)
def test_cache_annotation(self):
expression = parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")
annotated_expression = annotate_types(expression)
self.assertEqual(annotated_expression.expression.expressions[0].type, exp.DataType.Type.INT)
def test_binary_annotation(self):
expression = parse_one("SELECT 0.0 + (2 + 3)")
annotate_types(expression)
expression = expression.expressions[0]
self.assertEqual(expression.type, exp.DataType.Type.DOUBLE)
self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE)
self.assertEqual(expression.right.type, exp.DataType.Type.INT)
self.assertEqual(expression.right.this.type, exp.DataType.Type.INT)
self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT)
self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT)

View file

@ -21,6 +21,11 @@ class TestParser(unittest.TestCase):
self.assertIsNotNone(parse_one("date").find(exp.Column))
def test_float(self):
self.assertEqual(parse_one(".2"), parse_one("0.2"))
self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)"))
self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)"))
def test_table(self):
tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)]
self.assertEqual(tables, ["a", "b.c", "d"])

View file

@ -6,11 +6,32 @@ from sqlglot.transforms import unalias_group
class TestTime(unittest.TestCase):
def validate(self, transform, sql, target):
self.assertEqual(parse_one(sql).transform(transform).sql(), target)
with self.subTest(sql):
self.assertEqual(parse_one(sql).transform(transform).sql(), target)
def test_unalias_group(self):
self.validate(
unalias_group,
"SELECT a, b AS b, c AS c, 4 FROM x GROUP BY a, b, x.c, 4",
"SELECT a, b AS b, c AS c, 4 FROM x GROUP BY a, 2, x.c, 4",
"SELECT a, b AS b, c AS c, 4 FROM x GROUP BY a, b, x.c, 4",
)
self.validate(
unalias_group,
"SELECT TO_DATE(the_date) AS the_date, CUSTOM_UDF(other_col) AS other_col, last_col AS aliased_last, COUNT(*) AS the_count FROM x GROUP BY TO_DATE(the_date), CUSTOM_UDF(other_col), aliased_last",
"SELECT TO_DATE(the_date) AS the_date, CUSTOM_UDF(other_col) AS other_col, last_col AS aliased_last, COUNT(*) AS the_count FROM x GROUP BY TO_DATE(the_date), CUSTOM_UDF(other_col), 3",
)
self.validate(
unalias_group,
"SELECT SOME_UDF(TO_DATE(the_date)) AS the_date, COUNT(*) AS the_count FROM x GROUP BY SOME_UDF(TO_DATE(the_date))",
"SELECT SOME_UDF(TO_DATE(the_date)) AS the_date, COUNT(*) AS the_count FROM x GROUP BY SOME_UDF(TO_DATE(the_date))",
)
self.validate(
unalias_group,
"SELECT SOME_UDF(TO_DATE(the_date)) AS new_date, COUNT(*) AS the_count FROM x GROUP BY new_date",
"SELECT SOME_UDF(TO_DATE(the_date)) AS new_date, COUNT(*) AS the_count FROM x GROUP BY 1",
)
self.validate(
unalias_group,
"SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date",
"SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date",
)