1
0
Fork 0

Merging upstream version 6.2.6.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:40:43 +01:00
parent 0f5b9ddee1
commit 66e2d714bf
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
49 changed files with 1741 additions and 566 deletions

View file

@ -1,6 +1,8 @@
#!/bin/bash -e #!/bin/bash -e
python -m autoflake -i -r \ [[ -z "${GITHUB_ACTIONS}" ]] && RETURN_ERROR_CODE='' || RETURN_ERROR_CODE='--check'
python -m autoflake -i -r ${RETURN_ERROR_CODE} \
--expand-star-imports \ --expand-star-imports \
--remove-all-unused-imports \ --remove-all-unused-imports \
--ignore-init-module-imports \ --ignore-init-module-imports \
@ -8,5 +10,5 @@ python -m autoflake -i -r \
--remove-unused-variables \ --remove-unused-variables \
sqlglot/ tests/ sqlglot/ tests/
python -m isort --profile black sqlglot/ tests/ python -m isort --profile black sqlglot/ tests/
python -m black --line-length 120 sqlglot/ tests/ python -m black ${RETURN_ERROR_CODE} --line-length 120 sqlglot/ tests/
python -m unittest python -m unittest

View file

@ -20,7 +20,7 @@ from sqlglot.generator import Generator
from sqlglot.parser import Parser from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType from sqlglot.tokens import Tokenizer, TokenType
__version__ = "6.2.1" __version__ = "6.2.6"
pretty = False pretty = False

View file

@ -33,6 +33,49 @@ def _date_add_sql(data_type, kind):
return func return func
def _subquery_to_unnest_if_values(self, expression):
if not isinstance(expression.this, exp.Values):
return self.subquery_sql(expression)
rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.this.find_all(exp.Tuple)]
structs = []
for row in rows:
aliases = [
exp.alias_(value, column_name) for value, column_name in zip(row, expression.args["alias"].args["columns"])
]
structs.append(exp.Struct(expressions=aliases))
unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)])
return self.unnest_sql(unnest_exp)
def _returnsproperty_sql(self, expression):
value = expression.args.get("value")
if isinstance(value, exp.Schema):
value = f"{value.this} <{self.expressions(value)}>"
else:
value = self.sql(value)
return f"RETURNS {value}"
def _create_sql(self, expression):
kind = expression.args.get("kind")
returns = expression.find(exp.ReturnsProperty)
if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"):
expression = expression.copy()
expression.set("kind", "TABLE FUNCTION")
if isinstance(
expression.expression,
(
exp.Subquery,
exp.Literal,
),
):
expression.set("expression", expression.expression.this)
return self.create_sql(expression)
return self.create_sql(expression)
class BigQuery(Dialect): class BigQuery(Dialect):
unnest_column_only = True unnest_column_only = True
@ -77,8 +120,14 @@ class BigQuery(Dialect):
TokenType.CURRENT_TIME: exp.CurrentTime, TokenType.CURRENT_TIME: exp.CurrentTime,
} }
NESTED_TYPE_TOKENS = {
*Parser.NESTED_TYPE_TOKENS,
TokenType.TABLE,
}
class Generator(Generator): class Generator(Generator):
TRANSFORMS = { TRANSFORMS = {
**Generator.TRANSFORMS,
exp.Array: inline_array_sql, exp.Array: inline_array_sql,
exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.DateAdd: _date_add_sql("DATE", "ADD"), exp.DateAdd: _date_add_sql("DATE", "ADD"),
@ -91,6 +140,9 @@ class BigQuery(Dialect):
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.VariancePop: rename_func("VAR_POP"), exp.VariancePop: rename_func("VAR_POP"),
exp.Subquery: _subquery_to_unnest_if_values,
exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql,
} }
TYPE_MAPPING = { TYPE_MAPPING = {

View file

@ -245,6 +245,11 @@ def no_tablesample_sql(self, expression):
return self.sql(expression.this) return self.sql(expression.this)
def no_pivot_sql(self, expression):
self.unsupported("PIVOT unsupported")
return self.sql(expression)
def no_trycast_sql(self, expression): def no_trycast_sql(self, expression):
return self.cast_sql(expression) return self.cast_sql(expression)
@ -282,3 +287,30 @@ def format_time_lambda(exp_class, dialect, default=None):
) )
return _format_time return _format_time
def create_with_partitions_sql(self, expression):
"""
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
columns are removed from the create statement.
"""
has_schema = isinstance(expression.this, exp.Schema)
is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
if has_schema and is_partitionable:
expression = expression.copy()
prop = expression.find(exp.PartitionedByProperty)
value = prop and prop.args.get("value")
if prop and not isinstance(value, exp.Schema):
schema = expression.this
columns = {v.name.upper() for v in value.expressions}
partitions = [col for col in schema.expressions if col.name.upper() in columns]
schema.set(
"expressions",
[e for e in schema.expressions if e not in partitions],
)
prop.replace(exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions)))
expression.set("this", schema)
return self.create_sql(expression)

View file

@ -5,6 +5,7 @@ from sqlglot.dialects.dialect import (
arrow_json_extract_scalar_sql, arrow_json_extract_scalar_sql,
arrow_json_extract_sql, arrow_json_extract_sql,
format_time_lambda, format_time_lambda,
no_pivot_sql,
no_safe_divide_sql, no_safe_divide_sql,
no_tablesample_sql, no_tablesample_sql,
rename_func, rename_func,
@ -122,6 +123,7 @@ class DuckDB(Dialect):
exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: arrow_json_extract_sql, exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.Pivot: no_pivot_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"), exp.RegexpLike: rename_func("REGEXP_MATCHES"),
exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"), exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"),
exp.SafeDivide: no_safe_divide_sql, exp.SafeDivide: no_safe_divide_sql,

View file

@ -2,6 +2,7 @@ from sqlglot import exp, transforms
from sqlglot.dialects.dialect import ( from sqlglot.dialects.dialect import (
Dialect, Dialect,
approx_count_distinct_sql, approx_count_distinct_sql,
create_with_partitions_sql,
format_time_lambda, format_time_lambda,
if_sql, if_sql,
no_ilike_sql, no_ilike_sql,
@ -218,15 +219,6 @@ class Hive(Dialect):
} }
class Generator(Generator): class Generator(Generator):
ROOT_PROPERTIES = [
exp.PartitionedByProperty,
exp.FileFormatProperty,
exp.SchemaCommentProperty,
exp.LocationProperty,
exp.TableFormatProperty,
]
WITH_PROPERTIES = [exp.AnonymousProperty]
TYPE_MAPPING = { TYPE_MAPPING = {
**Generator.TYPE_MAPPING, **Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING", exp.DataType.Type.TEXT: "STRING",
@ -255,13 +247,13 @@ class Hive(Dialect):
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.Map: _map_sql, exp.Map: _map_sql,
HiveMap: _map_sql, HiveMap: _map_sql,
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e.args['value'])}", exp.Create: create_with_partitions_sql,
exp.Quantile: rename_func("PERCENTILE"), exp.Quantile: rename_func("PERCENTILE"),
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"), exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"), exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
exp.RegexpSplit: rename_func("SPLIT"), exp.RegexpSplit: rename_func("SPLIT"),
exp.SafeDivide: no_safe_divide_sql, exp.SafeDivide: no_safe_divide_sql,
exp.SchemaCommentProperty: lambda self, e: f"COMMENT {self.sql(e.args['value'])}", exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.SetAgg: rename_func("COLLECT_SET"), exp.SetAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
exp.StrPosition: lambda self, e: f"LOCATE({csv(self.sql(e, 'substr'), self.sql(e, 'this'), self.sql(e, 'position'))})", exp.StrPosition: lambda self, e: f"LOCATE({csv(self.sql(e, 'substr'), self.sql(e, 'this'), self.sql(e, 'position'))})",
@ -282,6 +274,17 @@ class Hive(Dialect):
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({csv(self.sql(e, 'this'), _time_format(self, e))})", exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({csv(self.sql(e, 'this'), _time_format(self, e))})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'value')}",
}
WITH_PROPERTIES = {exp.AnonymousProperty}
ROOT_PROPERTIES = {
exp.PartitionedByProperty,
exp.FileFormatProperty,
exp.SchemaCommentProperty,
exp.LocationProperty,
exp.TableFormatProperty,
} }
def with_properties(self, properties): def with_properties(self, properties):

View file

@ -172,6 +172,11 @@ class MySQL(Dialect):
), ),
} }
PROPERTY_PARSERS = {
**Parser.PROPERTY_PARSERS,
TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty),
}
class Generator(Generator): class Generator(Generator):
NULL_ORDERING_SUPPORTED = False NULL_ORDERING_SUPPORTED = False
@ -190,3 +195,13 @@ class MySQL(Dialect):
exp.StrToTime: _str_to_date_sql, exp.StrToTime: _str_to_date_sql,
exp.Trim: _trim_sql, exp.Trim: _trim_sql,
} }
ROOT_PROPERTIES = {
exp.EngineProperty,
exp.AutoIncrementProperty,
exp.CharacterSetProperty,
exp.CollateProperty,
exp.SchemaCommentProperty,
}
WITH_PROPERTIES = {}

View file

@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import (
no_paren_current_date_sql, no_paren_current_date_sql,
no_tablesample_sql, no_tablesample_sql,
no_trycast_sql, no_trycast_sql,
str_position_sql,
) )
from sqlglot.generator import Generator from sqlglot.generator import Generator
from sqlglot.parser import Parser from sqlglot.parser import Parser
@ -158,7 +159,6 @@ class Postgres(Dialect):
"ALWAYS": TokenType.ALWAYS, "ALWAYS": TokenType.ALWAYS,
"BY DEFAULT": TokenType.BY_DEFAULT, "BY DEFAULT": TokenType.BY_DEFAULT,
"IDENTITY": TokenType.IDENTITY, "IDENTITY": TokenType.IDENTITY,
"FOR": TokenType.FOR,
"GENERATED": TokenType.GENERATED, "GENERATED": TokenType.GENERATED,
"DOUBLE PRECISION": TokenType.DOUBLE, "DOUBLE PRECISION": TokenType.DOUBLE,
"BIGSERIAL": TokenType.BIGSERIAL, "BIGSERIAL": TokenType.BIGSERIAL,
@ -204,6 +204,7 @@ class Postgres(Dialect):
exp.DateAdd: _date_add_sql("+"), exp.DateAdd: _date_add_sql("+"),
exp.DateSub: _date_add_sql("-"), exp.DateSub: _date_add_sql("-"),
exp.Lateral: _lateral_sql, exp.Lateral: _lateral_sql,
exp.StrPosition: str_position_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Substring: _substring_sql, exp.Substring: _substring_sql,
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",

View file

@ -146,13 +146,16 @@ class Presto(Dialect):
STRUCT_DELIMITER = ("(", ")") STRUCT_DELIMITER = ("(", ")")
WITH_PROPERTIES = [ ROOT_PROPERTIES = {
exp.SchemaCommentProperty,
}
WITH_PROPERTIES = {
exp.PartitionedByProperty, exp.PartitionedByProperty,
exp.FileFormatProperty, exp.FileFormatProperty,
exp.SchemaCommentProperty,
exp.AnonymousProperty, exp.AnonymousProperty,
exp.TableFormatProperty, exp.TableFormatProperty,
] }
TYPE_MAPPING = { TYPE_MAPPING = {
**Generator.TYPE_MAPPING, **Generator.TYPE_MAPPING,
@ -184,13 +187,11 @@ class Presto(Dialect):
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)", exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)", exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)", exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.FileFormatProperty: lambda self, e: self.property_sql(e),
exp.If: if_sql, exp.If: if_sql,
exp.ILike: no_ilike_sql, exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql, exp.Initcap: _initcap_sql,
exp.Lateral: _explode_to_unnest_sql, exp.Lateral: _explode_to_unnest_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY = {self.sql(e.args['value'])}",
exp.Quantile: _quantile_sql, exp.Quantile: _quantile_sql,
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql, exp.SafeDivide: no_safe_divide_sql,

View file

@ -1,5 +1,10 @@
from sqlglot import exp from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, format_time_lambda, rename_func from sqlglot.dialects.dialect import (
Dialect,
format_time_lambda,
inline_array_sql,
rename_func,
)
from sqlglot.expressions import Literal from sqlglot.expressions import Literal
from sqlglot.generator import Generator from sqlglot.generator import Generator
from sqlglot.helper import list_get from sqlglot.helper import list_get
@ -104,6 +109,8 @@ class Snowflake(Dialect):
"ARRAYAGG": exp.ArrayAgg.from_arg_list, "ARRAYAGG": exp.ArrayAgg.from_arg_list,
"IFF": exp.If.from_arg_list, "IFF": exp.If.from_arg_list,
"TO_TIMESTAMP": _snowflake_to_timestamp, "TO_TIMESTAMP": _snowflake_to_timestamp,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
} }
FUNCTION_PARSERS = { FUNCTION_PARSERS = {
@ -111,6 +118,11 @@ class Snowflake(Dialect):
"DATE_PART": lambda self: self._parse_extract(), "DATE_PART": lambda self: self._parse_extract(),
} }
FUNC_TOKENS = {
*Parser.FUNC_TOKENS,
TokenType.RLIKE,
}
COLUMN_OPERATORS = { COLUMN_OPERATORS = {
**Parser.COLUMN_OPERATORS, **Parser.COLUMN_OPERATORS,
TokenType.COLON: lambda self, this, path: self.expression( TokenType.COLON: lambda self, this, path: self.expression(
@ -120,6 +132,11 @@ class Snowflake(Dialect):
), ),
} }
PROPERTY_PARSERS = {
**Parser.PROPERTY_PARSERS,
TokenType.PARTITION_BY: lambda self: self._parse_partitioned_by(),
}
class Tokenizer(Tokenizer): class Tokenizer(Tokenizer):
QUOTES = ["'", "$$"] QUOTES = ["'", "$$"]
ESCAPE = "\\" ESCAPE = "\\"
@ -137,6 +154,7 @@ class Snowflake(Dialect):
"TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPNTZ": TokenType.TIMESTAMP, "TIMESTAMPNTZ": TokenType.TIMESTAMP,
"SAMPLE": TokenType.TABLE_SAMPLE,
} }
class Generator(Generator): class Generator(Generator):
@ -145,6 +163,8 @@ class Snowflake(Dialect):
exp.If: rename_func("IFF"), exp.If: rename_func("IFF"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time, exp.UnixToTime: _unix_to_time,
exp.Array: inline_array_sql,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
} }
TYPE_MAPPING = { TYPE_MAPPING = {
@ -152,6 +172,13 @@ class Snowflake(Dialect):
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
} }
ROOT_PROPERTIES = {
exp.PartitionedByProperty,
exp.ReturnsProperty,
exp.LanguageProperty,
exp.SchemaCommentProperty,
}
def except_op(self, expression): def except_op(self, expression):
if not expression.args.get("distinct", False): if not expression.args.get("distinct", False):
self.unsupported("EXCEPT with All is not supported in Snowflake") self.unsupported("EXCEPT with All is not supported in Snowflake")

View file

@ -1,5 +1,9 @@
from sqlglot import exp from sqlglot import exp
from sqlglot.dialects.dialect import no_ilike_sql, rename_func from sqlglot.dialects.dialect import (
create_with_partitions_sql,
no_ilike_sql,
rename_func,
)
from sqlglot.dialects.hive import Hive, HiveMap from sqlglot.dialects.hive import Hive, HiveMap
from sqlglot.helper import list_get from sqlglot.helper import list_get
@ -10,7 +14,7 @@ def _create_sql(self, e):
if kind.upper() == "TABLE" and temporary is True: if kind.upper() == "TABLE" and temporary is True:
return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}" return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
return self.create_sql(e) return create_with_partitions_sql(self, e)
def _map_sql(self, expression): def _map_sql(self, expression):
@ -73,6 +77,7 @@ class Spark(Hive):
} }
class Generator(Hive.Generator): class Generator(Hive.Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING, **Hive.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "BYTE", exp.DataType.Type.TINYINT: "BYTE",

View file

@ -1,4 +1,5 @@
from sqlglot import exp from sqlglot import exp
from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.mysql import MySQL from sqlglot.dialects.mysql import MySQL
@ -10,3 +11,12 @@ class StarRocks(MySQL):
exp.DataType.Type.TIMESTAMP: "DATETIME", exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "DATETIME", exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
} }
TRANSFORMS = {
**MySQL.Generator.TRANSFORMS,
exp.DateDiff: rename_func("DATEDIFF"),
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
}

View file

@ -1,6 +1,7 @@
from sqlglot import exp from sqlglot import exp
from sqlglot.dialects.dialect import Dialect from sqlglot.dialects.dialect import Dialect
from sqlglot.generator import Generator from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType from sqlglot.tokens import Tokenizer, TokenType
@ -17,6 +18,7 @@ class TSQL(Dialect):
"REAL": TokenType.FLOAT, "REAL": TokenType.FLOAT,
"NTEXT": TokenType.TEXT, "NTEXT": TokenType.TEXT,
"SMALLDATETIME": TokenType.DATETIME, "SMALLDATETIME": TokenType.DATETIME,
"DATETIME2": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ, "DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"TIME": TokenType.TIMESTAMP, "TIME": TokenType.TIMESTAMP,
"VARBINARY": TokenType.BINARY, "VARBINARY": TokenType.BINARY,
@ -24,15 +26,24 @@ class TSQL(Dialect):
"MONEY": TokenType.MONEY, "MONEY": TokenType.MONEY,
"SMALLMONEY": TokenType.SMALLMONEY, "SMALLMONEY": TokenType.SMALLMONEY,
"ROWVERSION": TokenType.ROWVERSION, "ROWVERSION": TokenType.ROWVERSION,
"SQL_VARIANT": TokenType.SQL_VARIANT,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"XML": TokenType.XML, "XML": TokenType.XML,
"SQL_VARIANT": TokenType.VARIANT,
} }
class Parser(Parser):
def _parse_convert(self):
to = self._parse_types()
self._match(TokenType.COMMA)
this = self._parse_field()
return self.expression(exp.Cast, this=this, to=to)
class Generator(Generator): class Generator(Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**Generator.TYPE_MAPPING, **Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "BIT", exp.DataType.Type.BOOLEAN: "BIT",
exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC", exp.DataType.Type.DECIMAL: "NUMERIC",
exp.DataType.Type.DATETIME: "DATETIME2",
exp.DataType.Type.VARIANT: "SQL_VARIANT",
} }

View file

@ -3,17 +3,11 @@ import time
from sqlglot import parse_one from sqlglot import parse_one
from sqlglot.executor.python import PythonExecutor from sqlglot.executor.python import PythonExecutor
from sqlglot.optimizer import RULES, optimize from sqlglot.optimizer import optimize
from sqlglot.optimizer.merge_derived_tables import merge_derived_tables
from sqlglot.planner import Plan from sqlglot.planner import Plan
logger = logging.getLogger("sqlglot") logger = logging.getLogger("sqlglot")
OPTIMIZER_RULES = list(RULES)
# The executor needs isolated table selects
OPTIMIZER_RULES.remove(merge_derived_tables)
def execute(sql, schema, read=None): def execute(sql, schema, read=None):
""" """
@ -34,7 +28,7 @@ def execute(sql, schema, read=None):
""" """
expression = parse_one(sql, read=read) expression = parse_one(sql, read=read)
now = time.time() now = time.time()
expression = optimize(expression, schema, rules=OPTIMIZER_RULES) expression = optimize(expression, schema, leave_tables_isolated=True)
logger.debug("Optimization finished: %f", time.time() - now) logger.debug("Optimization finished: %f", time.time() - now)
logger.debug("Optimized SQL: %s", expression.sql(pretty=True)) logger.debug("Optimized SQL: %s", expression.sql(pretty=True))
plan = Plan(expression) plan = Plan(expression)

View file

@ -1,13 +1,17 @@
import inspect
import numbers import numbers
import re import re
import sys
from collections import deque from collections import deque
from copy import deepcopy from copy import deepcopy
from enum import auto from enum import auto
from sqlglot.errors import ParseError from sqlglot.errors import ParseError
from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list, list_get from sqlglot.helper import (
AutoName,
camel_to_snake_case,
ensure_list,
list_get,
subclasses,
)
class _Expression(type): class _Expression(type):
@ -31,12 +35,13 @@ class Expression(metaclass=_Expression):
key = None key = None
arg_types = {"this": True} arg_types = {"this": True}
__slots__ = ("args", "parent", "arg_key") __slots__ = ("args", "parent", "arg_key", "type")
def __init__(self, **args): def __init__(self, **args):
self.args = args self.args = args
self.parent = None self.parent = None
self.arg_key = None self.arg_key = None
self.type = None
for arg_key, value in self.args.items(): for arg_key, value in self.args.items():
self._set_parent(arg_key, value) self._set_parent(arg_key, value)
@ -384,7 +389,7 @@ class Expression(metaclass=_Expression):
'SELECT y FROM tbl' 'SELECT y FROM tbl'
Args: Args:
expression (Expression): new node expression (Expression|None): new node
Returns : Returns :
the new expression or expressions the new expression or expressions
@ -398,6 +403,12 @@ class Expression(metaclass=_Expression):
replace_children(parent, lambda child: expression if child is self else child) replace_children(parent, lambda child: expression if child is self else child)
return expression return expression
def pop(self):
"""
Remove this expression from its AST.
"""
self.replace(None)
def assert_is(self, type_): def assert_is(self, type_):
""" """
Assert that this `Expression` is an instance of `type_`. Assert that this `Expression` is an instance of `type_`.
@ -527,9 +538,18 @@ class Create(Expression):
"temporary": False, "temporary": False,
"replace": False, "replace": False,
"unique": False, "unique": False,
"materialized": False,
} }
class UserDefinedFunction(Expression):
arg_types = {"this": True, "expressions": False}
class UserDefinedFunctionKwarg(Expression):
arg_types = {"this": True, "kind": True, "default": False}
class CharacterSet(Expression): class CharacterSet(Expression):
arg_types = {"this": True, "default": False} arg_types = {"this": True, "default": False}
@ -887,6 +907,14 @@ class AnonymousProperty(Property):
pass pass
class ReturnsProperty(Property):
arg_types = {"this": True, "value": True, "is_table": False}
class LanguageProperty(Property):
pass
class Properties(Expression): class Properties(Expression):
arg_types = {"expressions": True} arg_types = {"expressions": True}
@ -907,25 +935,9 @@ class Properties(Expression):
expressions = [] expressions = []
for key, value in properties_dict.items(): for key, value in properties_dict.items():
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty) property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
expressions.append(property_cls(this=Literal.string(key), value=cls._convert_value(value))) expressions.append(property_cls(this=Literal.string(key), value=convert(value)))
return cls(expressions=expressions) return cls(expressions=expressions)
@staticmethod
def _convert_value(value):
if value is None:
return NULL
if isinstance(value, Expression):
return value
if isinstance(value, bool):
return Boolean(this=value)
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, numbers.Number):
return Literal.number(value)
if isinstance(value, list):
return Tuple(expressions=[Properties._convert_value(v) for v in value])
raise ValueError(f"Unsupported type '{type(value)}' for value '{value}'")
class Qualify(Expression): class Qualify(Expression):
pass pass
@ -1030,6 +1042,7 @@ class Subqueryable:
QUERY_MODIFIERS = { QUERY_MODIFIERS = {
"laterals": False, "laterals": False,
"joins": False, "joins": False,
"pivots": False,
"where": False, "where": False,
"group": False, "group": False,
"having": False, "having": False,
@ -1051,6 +1064,7 @@ class Table(Expression):
"catalog": False, "catalog": False,
"laterals": False, "laterals": False,
"joins": False, "joins": False,
"pivots": False,
} }
@ -1643,6 +1657,16 @@ class TableSample(Expression):
"percent": False, "percent": False,
"rows": False, "rows": False,
"size": False, "size": False,
"seed": False,
}
class Pivot(Expression):
arg_types = {
"this": False,
"expressions": True,
"field": True,
"unpivot": True,
} }
@ -1741,7 +1765,8 @@ class DataType(Expression):
SMALLMONEY = auto() SMALLMONEY = auto()
ROWVERSION = auto() ROWVERSION = auto()
IMAGE = auto() IMAGE = auto()
SQL_VARIANT = auto() VARIANT = auto()
OBJECT = auto()
@classmethod @classmethod
def build(cls, dtype, **kwargs): def build(cls, dtype, **kwargs):
@ -2124,6 +2149,7 @@ class TryCast(Cast):
class Ceil(Func): class Ceil(Func):
arg_types = {"this": True, "decimals": False}
_sql_names = ["CEIL", "CEILING"] _sql_names = ["CEIL", "CEILING"]
@ -2254,7 +2280,7 @@ class Explode(Func):
class Floor(Func): class Floor(Func):
pass arg_types = {"this": True, "decimals": False}
class Greatest(Func): class Greatest(Func):
@ -2371,7 +2397,7 @@ class Reduce(Func):
class RegexpLike(Func): class RegexpLike(Func):
arg_types = {"this": True, "expression": True} arg_types = {"this": True, "expression": True, "flag": False}
class RegexpSplit(Func): class RegexpSplit(Func):
@ -2540,6 +2566,8 @@ def _norm_args(expression):
for k, arg in expression.args.items(): for k, arg in expression.args.items():
if isinstance(arg, list): if isinstance(arg, list):
arg = [_norm_arg(a) for a in arg] arg = [_norm_arg(a) for a in arg]
if not arg:
arg = None
else: else:
arg = _norm_arg(arg) arg = _norm_arg(arg)
@ -2553,17 +2581,7 @@ def _norm_arg(arg):
return arg.lower() if isinstance(arg, str) else arg return arg.lower() if isinstance(arg, str) else arg
def _all_functions(): ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
return [
obj
for _, obj in inspect.getmembers(
sys.modules[__name__],
lambda obj: inspect.isclass(obj) and issubclass(obj, Func) and obj not in (AggFunc, Anonymous, Func),
)
]
ALL_FUNCTIONS = _all_functions()
def maybe_parse( def maybe_parse(
@ -2793,6 +2811,37 @@ def from_(*expressions, dialect=None, **opts):
return Select().from_(*expressions, dialect=dialect, **opts) return Select().from_(*expressions, dialect=dialect, **opts)
def update(table, properties, where=None, from_=None, dialect=None, **opts):
"""
Creates an update statement.
Example:
>>> update("my_table", {"x": 1, "y": "2", "z": None}, from_="baz", where="id > 1").sql()
"UPDATE my_table SET x = 1, y = '2', z = NULL FROM baz WHERE id > 1"
Args:
*properties (Dict[str, Any]): dictionary of properties to set which are
auto converted to sql objects eg None -> NULL
where (str): sql conditional parsed into a WHERE statement
from_ (str): sql statement parsed into a FROM statement
dialect (str): the dialect used to parse the input expressions.
**opts: other options to use to parse the input expressions.
Returns:
Update: the syntax tree for the UPDATE statement.
"""
update = Update(this=maybe_parse(table, into=Table, dialect=dialect))
update.set(
"expressions",
[EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) for k, v in properties.items()],
)
if from_:
update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts))
if where:
update.set("where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts))
return update
def condition(expression, dialect=None, **opts): def condition(expression, dialect=None, **opts):
""" """
Initialize a logical condition expression. Initialize a logical condition expression.
@ -2980,12 +3029,13 @@ def column(col, table=None, quoted=None):
def table_(table, db=None, catalog=None, quoted=None): def table_(table, db=None, catalog=None, quoted=None):
""" """Build a Table.
Build a Table.
Args: Args:
table (str or Expression): column name table (str or Expression): column name
db (str or Expression): db name db (str or Expression): db name
catalog (str or Expression): catalog name catalog (str or Expression): catalog name
Returns: Returns:
Table: table instance Table: table instance
""" """
@ -2996,6 +3046,39 @@ def table_(table, db=None, catalog=None, quoted=None):
) )
def convert(value):
"""Convert a python value into an expression object.
Raises an error if a conversion is not possible.
Args:
value (Any): a python object
Returns:
Expression: the equivalent expression object
"""
if isinstance(value, Expression):
return value
if value is None:
return NULL
if isinstance(value, bool):
return Boolean(this=value)
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, numbers.Number):
return Literal.number(value)
if isinstance(value, tuple):
return Tuple(expressions=[convert(v) for v in value])
if isinstance(value, list):
return Array(expressions=[convert(v) for v in value])
if isinstance(value, dict):
return Map(
keys=[convert(k) for k in value.keys()],
values=[convert(v) for v in value.values()],
)
raise ValueError(f"Cannot convert {value}")
def replace_children(expression, fun): def replace_children(expression, fun):
""" """
Replace children of an expression with the result of a lambda fun(child) -> exp. Replace children of an expression with the result of a lambda fun(child) -> exp.

View file

@ -46,18 +46,12 @@ class Generator:
""" """
TRANSFORMS = { TRANSFORMS = {
exp.AnonymousProperty: lambda self, e: self.property_sql(e),
exp.AutoIncrementProperty: lambda self, e: f"AUTO_INCREMENT={self.sql(e, 'value')}",
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}", exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
exp.CollateProperty: lambda self, e: f"COLLATE={self.sql(e, 'value')}",
exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.EngineProperty: lambda self, e: f"ENGINE={self.sql(e, 'value')}", exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.FileFormatProperty: lambda self, e: f"FORMAT={self.sql(e, 'value')}", exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: f"LOCATION {self.sql(e, 'value')}", exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY={self.sql(e.args['value'])}",
exp.SchemaCommentProperty: lambda self, e: f"COMMENT={self.sql(e, 'value')}",
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT={self.sql(e, 'value')}",
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
} }
@ -72,19 +66,17 @@ class Generator:
STRUCT_DELIMITER = ("<", ">") STRUCT_DELIMITER = ("<", ">")
ROOT_PROPERTIES = [ ROOT_PROPERTIES = {
exp.AutoIncrementProperty, exp.ReturnsProperty,
exp.CharacterSetProperty, exp.LanguageProperty,
exp.CollateProperty, }
exp.EngineProperty,
exp.SchemaCommentProperty, WITH_PROPERTIES = {
]
WITH_PROPERTIES = [
exp.AnonymousProperty, exp.AnonymousProperty,
exp.FileFormatProperty, exp.FileFormatProperty,
exp.PartitionedByProperty, exp.PartitionedByProperty,
exp.TableFormatProperty, exp.TableFormatProperty,
] }
__slots__ = ( __slots__ = (
"time_mapping", "time_mapping",
@ -188,6 +180,7 @@ class Generator:
return sql return sql
def unsupported(self, message): def unsupported(self, message):
if self.unsupported_level == ErrorLevel.IMMEDIATE: if self.unsupported_level == ErrorLevel.IMMEDIATE:
raise UnsupportedError(message) raise UnsupportedError(message)
self.unsupported_messages.append(message) self.unsupported_messages.append(message)
@ -261,6 +254,9 @@ class Generator:
if isinstance(expression, exp.Func): if isinstance(expression, exp.Func):
return self.function_fallback_sql(expression) return self.function_fallback_sql(expression)
if isinstance(expression, exp.Property):
return self.property_sql(expression)
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}") raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
def annotation_sql(self, expression): def annotation_sql(self, expression):
@ -352,9 +348,12 @@ class Generator:
replace = " OR REPLACE" if expression.args.get("replace") else "" replace = " OR REPLACE" if expression.args.get("replace") else ""
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
unique = " UNIQUE" if expression.args.get("unique") else "" unique = " UNIQUE" if expression.args.get("unique") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
properties = self.sql(expression, "properties") properties = self.sql(expression, "properties")
expression_sql = f"CREATE{replace}{temporary}{unique} {kind}{exists_sql} {this}{properties} {expression_sql}" expression_sql = (
f"CREATE{replace}{temporary}{unique}{materialized} {kind}{exists_sql} {this}{properties} {expression_sql}"
)
return self.prepend_ctes(expression, expression_sql) return self.prepend_ctes(expression, expression_sql)
def prepend_ctes(self, expression, sql): def prepend_ctes(self, expression, sql):
@ -461,10 +460,10 @@ class Generator:
for p in expression.expressions: for p in expression.expressions:
p_class = p.__class__ p_class = p.__class__
if p_class in self.ROOT_PROPERTIES: if p_class in self.WITH_PROPERTIES:
root_properties.append(p)
elif p_class in self.WITH_PROPERTIES:
with_properties.append(p) with_properties.append(p)
elif p_class in self.ROOT_PROPERTIES:
root_properties.append(p)
return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties( return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties(
exp.Properties(expressions=with_properties) exp.Properties(expressions=with_properties)
@ -496,6 +495,9 @@ class Generator:
) )
def property_sql(self, expression): def property_sql(self, expression):
if isinstance(expression.this, exp.Literal):
key = expression.this.this
else:
key = expression.name key = expression.name
value = self.sql(expression, "value") value = self.sql(expression, "value")
return f"{key}={value}" return f"{key}={value}"
@ -535,7 +537,8 @@ class Generator:
laterals = self.expressions(expression, key="laterals", sep="") laterals = self.expressions(expression, key="laterals", sep="")
joins = self.expressions(expression, key="joins", sep="") joins = self.expressions(expression, key="joins", sep="")
return f"{table}{laterals}{joins}" pivots = self.expressions(expression, key="pivots", sep="")
return f"{table}{laterals}{joins}{pivots}"
def tablesample_sql(self, expression): def tablesample_sql(self, expression):
if self.alias_post_tablesample and isinstance(expression.this, exp.Alias): if self.alias_post_tablesample and isinstance(expression.this, exp.Alias):
@ -556,7 +559,17 @@ class Generator:
rows = self.sql(expression, "rows") rows = self.sql(expression, "rows")
rows = f"{rows} ROWS" if rows else "" rows = f"{rows} ROWS" if rows else ""
size = self.sql(expression, "size") size = self.sql(expression, "size")
return f"{this} TABLESAMPLE{method}({bucket}{percent}{rows}{size}){alias}" seed = self.sql(expression, "seed")
seed = f" SEED ({seed})" if seed else ""
return f"{this} TABLESAMPLE{method}({bucket}{percent}{rows}{size}){seed}{alias}"
def pivot_sql(self, expression):
this = self.sql(expression, "this")
unpivot = expression.args.get("unpivot")
direction = "UNPIVOT" if unpivot else "PIVOT"
expressions = self.expressions(expression, key="expressions")
field = self.sql(expression, "field")
return f"{this} {direction}({expressions} FOR {field})"
def tuple_sql(self, expression): def tuple_sql(self, expression):
return f"({self.expressions(expression, flat=True)})" return f"({self.expressions(expression, flat=True)})"
@ -681,6 +694,7 @@ class Generator:
def ordered_sql(self, expression): def ordered_sql(self, expression):
desc = expression.args.get("desc") desc = expression.args.get("desc")
asc = not desc asc = not desc
nulls_first = expression.args.get("nulls_first") nulls_first = expression.args.get("nulls_first")
nulls_last = not nulls_first nulls_last = not nulls_first
nulls_are_large = self.null_ordering == "nulls_are_large" nulls_are_large = self.null_ordering == "nulls_are_large"
@ -760,6 +774,7 @@ class Generator:
return self.query_modifiers( return self.query_modifiers(
expression, expression,
self.wrap(expression), self.wrap(expression),
self.expressions(expression, key="pivots", sep=" "),
f" AS {alias}" if alias else "", f" AS {alias}" if alias else "",
) )
@ -1129,6 +1144,9 @@ class Generator:
return f"{op} {expressions_sql}" return f"{op} {expressions_sql}"
return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}" return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}"
def naked_property(self, expression):
return f"{expression.name} {self.sql(expression, 'value')}"
def set_operation(self, expression, op): def set_operation(self, expression, op):
this = self.sql(expression, "this") this = self.sql(expression, "this")
op = self.seg(op) op = self.seg(op)
@ -1136,3 +1154,13 @@ class Generator:
def token_sql(self, token_type): def token_sql(self, token_type):
return self.TOKEN_MAPPING.get(token_type, token_type.name) return self.TOKEN_MAPPING.get(token_type, token_type.name)
def userdefinedfunction_sql(self, expression):
this = self.sql(expression, "this")
expressions = self.no_identify(lambda: self.expressions(expression))
return f"{this}({expressions})"
def userdefinedfunctionkwarg_sql(self, expression):
this = self.sql(expression, "this")
kind = self.sql(expression, "kind")
return f"{this} {kind}"

View file

@ -1,5 +1,7 @@
import inspect
import logging import logging
import re import re
import sys
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum from enum import Enum
@ -29,6 +31,26 @@ def csv(*args, sep=", "):
return sep.join(arg for arg in args if arg) return sep.join(arg for arg in args if arg)
def subclasses(module_name, classes, exclude=()):
"""
Returns a list of all subclasses for a specified class set, posibly excluding some of them.
Args:
module_name (str): The name of the module to search for subclasses in.
classes (type|tuple[type]): Class(es) we want to find the subclasses of.
exclude (type|tuple[type]): Class(es) we want to exclude from the returned list.
Returns:
A list of all the target subclasses.
"""
return [
obj
for _, obj in inspect.getmembers(
sys.modules[module_name],
lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
)
]
def apply_index_offset(expressions, offset): def apply_index_offset(expressions, offset):
if not offset or len(expressions) != 1: if not offset or len(expressions) != 1:
return expressions return expressions
@ -100,7 +122,7 @@ def csv_reader(table):
Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]) Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...])
Args: Args:
expression (Expression): An anonymous function READ_CSV table (exp.Table): A table expression with an anonymous function READ_CSV in it
Returns: Returns:
A python csv reader. A python csv reader.
@ -121,3 +143,22 @@ def csv_reader(table):
yield csv_.reader(file, delimiter=delimiter) yield csv_.reader(file, delimiter=delimiter)
finally: finally:
file.close() file.close()
def find_new_name(taken, base):
"""
Searches for a new name.
Args:
taken (Sequence[str]): set of taken names
base (str): base name to alter
"""
if base not in taken:
return base
i = 2
new = f"{base}_{i}"
while new in taken:
i += 1
new = f"{base}_{i}"
return new

View file

@ -0,0 +1,162 @@
from sqlglot import exp
from sqlglot.helper import ensure_list, subclasses
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
"""
Recursively infer & annotate types in an expression syntax tree against a schema.
(TODO -- replace this with a better example after adding some functionality)
Example:
>>> import sqlglot
>>> annotated_expression = annotate_types(sqlglot.parse_one('5 + 5.3'))
>>> annotated_expression.type
<Type.DOUBLE: 'DOUBLE'>
Args:
expression (sqlglot.Expression): Expression to annotate.
schema (dict|sqlglot.optimizer.Schema): Database schema.
annotators (dict): Maps expression type to corresponding annotation function.
coerces_to (dict): Maps expression type to set of types that it can be coerced into.
Returns:
sqlglot.Expression: expression annotated with types
"""
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
class TypeAnnotator:
ANNOTATORS = {
**{
expr_type: lambda self, expr: self._annotate_unary(expr)
for expr_type in subclasses(exp.__name__, exp.Unary)
},
**{
expr_type: lambda self, expr: self._annotate_binary(expr)
for expr_type in subclasses(exp.__name__, exp.Binary)
},
exp.Cast: lambda self, expr: self._annotate_cast(expr),
exp.DataType: lambda self, expr: self._annotate_data_type(expr),
exp.Literal: lambda self, expr: self._annotate_literal(expr),
exp.Boolean: lambda self, expr: self._annotate_boolean(expr),
}
# Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
COERCES_TO = {
# CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
exp.DataType.Type.TEXT: set(),
exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
exp.DataType.Type.NCHAR: {exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
exp.DataType.Type.CHAR: {
exp.DataType.Type.NCHAR,
exp.DataType.Type.VARCHAR,
exp.DataType.Type.NVARCHAR,
exp.DataType.Type.TEXT,
},
# TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
exp.DataType.Type.DOUBLE: set(),
exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
exp.DataType.Type.BIGINT: {exp.DataType.Type.DECIMAL, exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
exp.DataType.Type.INT: {
exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL,
exp.DataType.Type.FLOAT,
exp.DataType.Type.DOUBLE,
},
exp.DataType.Type.SMALLINT: {
exp.DataType.Type.INT,
exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL,
exp.DataType.Type.FLOAT,
exp.DataType.Type.DOUBLE,
},
exp.DataType.Type.TINYINT: {
exp.DataType.Type.SMALLINT,
exp.DataType.Type.INT,
exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL,
exp.DataType.Type.FLOAT,
exp.DataType.Type.DOUBLE,
},
# DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
exp.DataType.Type.TIMESTAMPLTZ: set(),
exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
exp.DataType.Type.TIMESTAMP: {exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPLTZ},
exp.DataType.Type.DATETIME: {
exp.DataType.Type.TIMESTAMP,
exp.DataType.Type.TIMESTAMPTZ,
exp.DataType.Type.TIMESTAMPLTZ,
},
exp.DataType.Type.DATE: {
exp.DataType.Type.DATETIME,
exp.DataType.Type.TIMESTAMP,
exp.DataType.Type.TIMESTAMPTZ,
exp.DataType.Type.TIMESTAMPLTZ,
},
}
def __init__(self, schema=None, annotators=None, coerces_to=None):
self.schema = schema
self.annotators = annotators or self.ANNOTATORS
self.coerces_to = coerces_to or self.COERCES_TO
def annotate(self, expression):
if not isinstance(expression, exp.Expression):
return None
annotator = self.annotators.get(expression.__class__)
return annotator(self, expression) if annotator else self._annotate_args(expression)
def _annotate_args(self, expression):
for value in expression.args.values():
for v in ensure_list(value):
self.annotate(v)
return expression
def _annotate_cast(self, expression):
expression.type = expression.args["to"].this
return self._annotate_args(expression)
def _annotate_data_type(self, expression):
expression.type = expression.this
return self._annotate_args(expression)
def _maybe_coerce(self, type1, type2):
return type2 if type2 in self.coerces_to[type1] else type1
def _annotate_binary(self, expression):
self._annotate_args(expression)
if isinstance(expression, (exp.Condition, exp.Predicate)):
expression.type = exp.DataType.Type.BOOLEAN
else:
expression.type = self._maybe_coerce(expression.left.type, expression.right.type)
return expression
def _annotate_unary(self, expression):
self._annotate_args(expression)
if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
expression.type = exp.DataType.Type.BOOLEAN
else:
expression.type = expression.this.type
return expression
def _annotate_literal(self, expression):
if expression.is_string:
expression.type = exp.DataType.Type.VARCHAR
elif expression.is_int:
expression.type = exp.DataType.Type.INT
else:
expression.type = exp.DataType.Type.DOUBLE
return expression
def _annotate_boolean(self, expression):
expression.type = exp.DataType.Type.BOOLEAN
return expression

View file

@ -1,48 +1,144 @@
import itertools import itertools
from sqlglot import alias, exp, select, table from sqlglot import expressions as exp
from sqlglot.optimizer.scope import traverse_scope from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import build_scope
from sqlglot.optimizer.simplify import simplify from sqlglot.optimizer.simplify import simplify
def eliminate_subqueries(expression): def eliminate_subqueries(expression):
""" """
Rewrite duplicate subqueries from sqlglot AST. Rewrite subqueries as CTES, deduplicating if possible.
Example: Example:
>>> import sqlglot >>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT 1 AS x, 2 AS y UNION ALL SELECT 1 AS x, 2 AS y") >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
>>> eliminate_subqueries(expression).sql() >>> eliminate_subqueries(expression).sql()
'WITH _e_0 AS (SELECT 1 AS x, 2 AS y) SELECT * FROM _e_0 UNION ALL SELECT * FROM _e_0' 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
This also deduplicates common subqueries:
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z")
>>> eliminate_subqueries(expression).sql()
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z'
Args: Args:
expression (sqlglot.Expression): expression to qualify expression (sqlglot.Expression): expression
schema (dict|sqlglot.optimizer.Schema): Database schema
Returns: Returns:
sqlglot.Expression: qualified expression sqlglot.Expression: expression
""" """
if isinstance(expression, exp.Subquery):
# It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
eliminate_subqueries(expression.this)
return expression
expression = simplify(expression) expression = simplify(expression)
queries = {} root = build_scope(expression)
for scope in traverse_scope(expression): # Map of alias->Scope|Table
query = scope.expression # These are all aliases that are already used in the expression.
queries[query] = queries.get(query, []) + [query] # We don't want to create new CTEs that conflict with these names.
taken = {}
sequence = itertools.count() # All CTE aliases in the root scope are taken
for scope in root.cte_scopes:
taken[scope.expression.parent.alias] = scope
for query, duplicates in queries.items(): # All table names are taken
if len(duplicates) == 1: for scope in root.traverse():
continue taken.update({source.name: source for _, source in scope.sources.items() if isinstance(source, exp.Table)})
alias_ = f"_e_{next(sequence)}" # Map of Expression->alias
# Existing CTES in the root expression. We'll use this for deduplication.
existing_ctes = {}
for dup in duplicates: with_ = root.expression.args.get("with")
parent = dup.parent if with_:
if isinstance(parent, exp.Subquery): for cte in with_.expressions:
parent.replace(alias(table(alias_), parent.alias_or_name, table=True)) existing_ctes[cte.this] = cte.alias
elif isinstance(parent, exp.Union): new_ctes = []
dup.replace(select("*").from_(alias_))
expression.with_(alias_, as_=query, copy=False) # We're adding more CTEs, but we want to maintain the DAG order.
# Derived tables within an existing CTE need to come before the existing CTE.
for cte_scope in root.cte_scopes:
# Append all the new CTEs from this existing CTE
for scope in cte_scope.traverse():
new_cte = _eliminate(scope, existing_ctes, taken)
if new_cte:
new_ctes.append(new_cte)
# Append the existing CTE itself
new_ctes.append(cte_scope.expression.parent)
# Now append the rest
for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.derived_table_scopes):
for child_scope in scope.traverse():
new_cte = _eliminate(child_scope, existing_ctes, taken)
if new_cte:
new_ctes.append(new_cte)
if new_ctes:
expression.set("with", exp.With(expressions=new_ctes))
return expression return expression
def _eliminate(scope, existing_ctes, taken):
if scope.is_union:
return _eliminate_union(scope, existing_ctes, taken)
if scope.is_derived_table and not isinstance(scope.expression, (exp.Unnest, exp.Lateral)):
return _eliminate_derived_table(scope, existing_ctes, taken)
def _eliminate_union(scope, existing_ctes, taken):
duplicate_cte_alias = existing_ctes.get(scope.expression)
alias = duplicate_cte_alias or find_new_name(taken=taken, base="cte")
taken[alias] = scope
# Try to maintain the selections
expressions = scope.expression.args.get("expressions")
selects = [
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name)
for e in expressions
if e.alias_or_name
]
# If not all selections have an alias, just select *
if len(selects) != len(expressions):
selects = ["*"]
scope.expression.replace(exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias)))
if not duplicate_cte_alias:
existing_ctes[scope.expression] = alias
return exp.CTE(
this=scope.expression,
alias=exp.TableAlias(this=exp.to_identifier(alias)),
)
def _eliminate_derived_table(scope, existing_ctes, taken):
duplicate_cte_alias = existing_ctes.get(scope.expression)
parent = scope.expression.parent
name = alias = parent.alias
if not alias:
name = alias = find_new_name(taken=taken, base="cte")
if duplicate_cte_alias:
name = duplicate_cte_alias
elif taken.get(alias):
name = find_new_name(taken=taken, base=alias)
taken[name] = scope
table = exp.alias_(exp.table_(name), alias=alias)
parent.replace(table)
if not duplicate_cte_alias:
existing_ctes[scope.expression] = name
return exp.CTE(
this=scope.expression,
alias=exp.TableAlias(this=exp.to_identifier(name)),
)

View file

@ -1,45 +1,39 @@
from collections import defaultdict from collections import defaultdict
from sqlglot import expressions as exp from sqlglot import expressions as exp
from sqlglot.optimizer.scope import traverse_scope from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.optimizer.simplify import simplify from sqlglot.optimizer.simplify import simplify
def merge_derived_tables(expression): def merge_subqueries(expression, leave_tables_isolated=False):
""" """
Rewrite sqlglot AST to merge derived tables into the outer query. Rewrite sqlglot AST to merge derived tables into the outer query.
This also merges CTEs if they are selected from only once.
Example: Example:
>>> import sqlglot >>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x)") >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
>>> merge_derived_tables(expression).sql() >>> merge_subqueries(expression).sql()
'SELECT x.a FROM x' 'SELECT x.a FROM x JOIN y'
If `leave_tables_isolated` is True, this will not merge inner queries into outer
queries if it would result in multiple table selects in a single query:
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
>>> merge_subqueries(expression, leave_tables_isolated=True).sql()
'SELECT a FROM (SELECT x.a FROM x) JOIN y'
Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
Args: Args:
expression (sqlglot.Expression): expression to optimize expression (sqlglot.Expression): expression to optimize
leave_tables_isolated (bool):
Returns: Returns:
sqlglot.Expression: optimized expression sqlglot.Expression: optimized expression
""" """
for outer_scope in traverse_scope(expression): merge_ctes(expression, leave_tables_isolated)
for subquery in outer_scope.derived_tables: merge_derived_tables(expression, leave_tables_isolated)
inner_select = subquery.unnest()
if (
isinstance(outer_scope.expression, exp.Select)
and isinstance(inner_select, exp.Select)
and _mergeable(inner_select)
):
alias = subquery.alias_or_name
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
inner_scope = outer_scope.sources[alias]
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, subquery)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
return expression return expression
@ -53,20 +47,81 @@ UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
} }
def _mergeable(inner_select): def merge_ctes(expression, leave_tables_isolated=False):
scopes = traverse_scope(expression)
# All places where we select from CTEs.
# We key on the CTE scope so we can detect CTES that are selected from multiple times.
cte_selections = defaultdict(list)
for outer_scope in scopes:
for table, inner_scope in outer_scope.selected_sources.values():
if isinstance(inner_scope, Scope) and inner_scope.is_cte:
cte_selections[id(inner_scope)].append(
(
outer_scope,
inner_scope,
table,
)
)
singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
for outer_scope, inner_scope, table in singular_cte_selections:
inner_select = inner_scope.expression.unnest()
if _mergeable(outer_scope, inner_select, leave_tables_isolated):
from_or_join = table.find_ancestor(exp.From, exp.Join)
node_to_replace = table
if isinstance(node_to_replace.parent, exp.Alias):
node_to_replace = node_to_replace.parent
alias = node_to_replace.alias
else:
alias = table.name
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, node_to_replace, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
_pop_cte(inner_scope)
def merge_derived_tables(expression, leave_tables_isolated=False):
for outer_scope in traverse_scope(expression):
for subquery in outer_scope.derived_tables:
inner_select = subquery.unnest()
if _mergeable(outer_scope, inner_select, leave_tables_isolated):
alias = subquery.alias_or_name
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
inner_scope = outer_scope.sources[alias]
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, subquery, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
def _mergeable(outer_scope, inner_select, leave_tables_isolated):
""" """
Return True if `inner_select` can be merged into outer query. Return True if `inner_select` can be merged into outer query.
Args: Args:
outer_scope (Scope)
inner_select (exp.Select) inner_select (exp.Select)
leave_tables_isolated (bool)
Returns: Returns:
bool: True if can be merged bool: True if can be merged
""" """
return ( return (
isinstance(inner_select, exp.Select) isinstance(outer_scope.expression, exp.Select)
and isinstance(inner_select, exp.Select)
and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from") and inner_select.args.get("from")
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
) )
@ -84,7 +139,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
conflicts = conflicts - {alias} conflicts = conflicts - {alias}
for conflict in conflicts: for conflict in conflicts:
new_name = _find_new_name(taken, conflict) new_name = find_new_name(taken, conflict)
source, _ = inner_scope.selected_sources[conflict] source, _ = inner_scope.selected_sources[conflict]
new_alias = exp.to_identifier(new_name) new_alias = exp.to_identifier(new_name)
@ -102,34 +157,19 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
inner_scope.rename_source(conflict, new_name) inner_scope.rename_source(conflict, new_name)
def _find_new_name(taken, base): def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
"""
Searches for a new source name.
Args:
taken (set[str]): set of taken names
base (str): base name to alter
"""
i = 2
new = f"{base}_{i}"
while new in taken:
i += 1
new = f"{base}_{i}"
return new
def _merge_from(outer_scope, inner_scope, subquery):
""" """
Merge FROM clause of inner query into outer query. Merge FROM clause of inner query into outer query.
Args: Args:
outer_scope (sqlglot.optimizer.scope.Scope) outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope) inner_scope (sqlglot.optimizer.scope.Scope)
subquery (exp.Subquery) node_to_replace (exp.Subquery|exp.Table)
alias (str)
""" """
new_subquery = inner_scope.expression.args.get("from").expressions[0] new_subquery = inner_scope.expression.args.get("from").expressions[0]
subquery.replace(new_subquery) node_to_replace.replace(new_subquery)
outer_scope.remove_source(subquery.alias_or_name) outer_scope.remove_source(alias)
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]) outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
@ -176,7 +216,7 @@ def _merge_expressions(outer_scope, inner_scope, alias):
inner_scope (sqlglot.optimizer.scope.Scope) inner_scope (sqlglot.optimizer.scope.Scope)
alias (str) alias (str)
""" """
# Collect all columns that for the alias of the inner query # Collect all columns that reference the alias of the inner query
outer_columns = defaultdict(list) outer_columns = defaultdict(list)
for column in outer_scope.columns: for column in outer_scope.columns:
if column.table == alias: if column.table == alias:
@ -205,7 +245,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
if not where or not where.this: if not where or not where.this:
return return
if isinstance(from_or_join, exp.Join) and from_or_join.side: if isinstance(from_or_join, exp.Join):
# Merge predicates from an outer join to the ON clause # Merge predicates from an outer join to the ON clause
from_or_join.on(where.this, copy=False) from_or_join.on(where.this, copy=False)
from_or_join.set("on", simplify(from_or_join.args.get("on"))) from_or_join.set("on", simplify(from_or_join.args.get("on")))
@ -230,3 +270,18 @@ def _merge_order(outer_scope, inner_scope):
return return
outer_scope.expression.set("order", inner_scope.expression.args.get("order")) outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
def _pop_cte(inner_scope):
"""
Remove CTE from the AST.
Args:
inner_scope (sqlglot.optimizer.scope.Scope)
"""
cte = inner_scope.expression.parent
with_ = cte.parent
if len(with_.expressions) == 1:
with_.pop()
else:
cte.pop()

View file

@ -1,7 +1,7 @@
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.merge_derived_tables import merge_derived_tables from sqlglot.optimizer.merge_subqueries import merge_subqueries
from sqlglot.optimizer.normalize import normalize from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins from sqlglot.optimizer.optimize_joins import optimize_joins
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
@ -22,7 +22,7 @@ RULES = (
pushdown_predicates, pushdown_predicates,
optimize_joins, optimize_joins,
eliminate_subqueries, eliminate_subqueries,
merge_derived_tables, merge_subqueries,
quote_identities, quote_identities,
) )

View file

@ -37,7 +37,7 @@ def pushdown_projections(expression):
parent_selections = {SELECT_ALL} parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.Union): if isinstance(scope.expression, exp.Union):
left, right = scope.union left, right = scope.union_scopes
referenced_columns[left] = parent_selections referenced_columns[left] = parent_selections
referenced_columns[right] = parent_selections referenced_columns[right] = parent_selections

View file

@ -69,7 +69,7 @@ def ensure_schema(schema):
def fs_get(table): def fs_get(table):
name = table.this.name.upper() name = table.this.name
if name.upper() == "READ_CSV": if name.upper() == "READ_CSV":
with csv_reader(table) as reader: with csv_reader(table) as reader:

View file

@ -1,3 +1,4 @@
import itertools
from copy import copy from copy import copy
from enum import Enum, auto from enum import Enum, auto
@ -32,10 +33,11 @@ class Scope:
The inner query would have `["col1", "col2"]` for its `outer_column_list` The inner query would have `["col1", "col2"]` for its `outer_column_list`
parent (Scope): Parent scope parent (Scope): Parent scope
scope_type (ScopeType): Type of this scope, relative to it's parent scope_type (ScopeType): Type of this scope, relative to it's parent
subquery_scopes (list[Scope]): List of all child scopes for subqueries. subquery_scopes (list[Scope]): List of all child scopes for subqueries
This does not include derived tables or CTEs. cte_scopes = (list[Scope]) List of all child scopes for CTEs
union (tuple[Scope, Scope]): If this Scope is for a Union expression, this will be derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables
a tuple of the left and right child scopes. union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
a list of the left and right child scopes.
""" """
def __init__( def __init__(
@ -52,7 +54,9 @@ class Scope:
self.parent = parent self.parent = parent
self.scope_type = scope_type self.scope_type = scope_type
self.subquery_scopes = [] self.subquery_scopes = []
self.union = None self.derived_table_scopes = []
self.cte_scopes = []
self.union_scopes = []
self.clear_cache() self.clear_cache()
def clear_cache(self): def clear_cache(self):
@ -197,11 +201,16 @@ class Scope:
named_outputs = {e.alias_or_name for e in self.expression.expressions} named_outputs = {e.alias_or_name for e in self.expression.expressions}
self._columns = [ self._columns = []
c for column in columns + external_columns:
for c in columns + external_columns ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Hint)
if not (c.find_ancestor(exp.Qualify, exp.Order) and not c.table and c.name in named_outputs) if (
] not ancestor
or column.table
or (column.name not in named_outputs and not isinstance(ancestor, exp.Hint))
):
self._columns.append(column)
return self._columns return self._columns
@property @property
@ -283,6 +292,26 @@ class Scope:
"""Determine if this scope is a subquery""" """Determine if this scope is a subquery"""
return self.scope_type == ScopeType.SUBQUERY return self.scope_type == ScopeType.SUBQUERY
@property
def is_derived_table(self):
"""Determine if this scope is a derived table"""
return self.scope_type == ScopeType.DERIVED_TABLE
@property
def is_union(self):
"""Determine if this scope is a union"""
return self.scope_type == ScopeType.UNION
@property
def is_cte(self):
"""Determine if this scope is a common table expression"""
return self.scope_type == ScopeType.CTE
@property
def is_root(self):
"""Determine if this is the root scope"""
return self.scope_type == ScopeType.ROOT
@property @property
def is_unnest(self): def is_unnest(self):
"""Determine if this scope is an unnest""" """Determine if this scope is an unnest"""
@ -308,6 +337,22 @@ class Scope:
self.sources.pop(name, None) self.sources.pop(name, None)
self.clear_cache() self.clear_cache()
def __repr__(self):
return f"Scope<{self.expression.sql()}>"
def traverse(self):
"""
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
"""
for child_scope in itertools.chain(
self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes
):
yield from child_scope.traverse()
yield self
def traverse_scope(expression): def traverse_scope(expression):
""" """
@ -337,6 +382,18 @@ def traverse_scope(expression):
return list(_traverse_scope(Scope(expression))) return list(_traverse_scope(Scope(expression)))
def build_scope(expression):
"""
Build a scope tree.
Args:
expression (exp.Expression): expression to build the scope tree for
Returns:
Scope: root scope
"""
return traverse_scope(expression)[-1]
def _traverse_scope(scope): def _traverse_scope(scope):
if isinstance(scope.expression, exp.Select): if isinstance(scope.expression, exp.Select):
yield from _traverse_select(scope) yield from _traverse_select(scope)
@ -370,13 +427,14 @@ def _traverse_union(scope):
for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
yield right yield right
scope.union = (left, right) scope.union_scopes = [left, right]
def _traverse_derived_tables(derived_tables, scope, scope_type): def _traverse_derived_tables(derived_tables, scope, scope_type):
sources = {} sources = {}
for derived_table in derived_tables: for derived_table in derived_tables:
top = None
for child_scope in _traverse_scope( for child_scope in _traverse_scope(
scope.branch( scope.branch(
derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this, derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
@ -386,11 +444,16 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
) )
): ):
yield child_scope yield child_scope
top = child_scope
# Tables without aliases will be set as "" # Tables without aliases will be set as ""
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather, # Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins. # the latest one wins.
sources[derived_table.alias] = child_scope sources[derived_table.alias] = child_scope
if scope_type == ScopeType.CTE:
scope.cte_scopes.append(top)
else:
scope.derived_table_scopes.append(top)
scope.sources.update(sources) scope.sources.update(sources)
@ -407,8 +470,6 @@ def _add_table_sources(scope):
if table_name in scope.sources: if table_name in scope.sources:
# This is a reference to a parent source (e.g. a CTE), not an actual table. # This is a reference to a parent source (e.g. a CTE), not an actual table.
scope.sources[source_name] = scope.sources[table_name] scope.sources[source_name] = scope.sources[table_name]
elif source_name in scope.sources:
raise OptimizeError(f"Duplicate table name: {source_name}")
else: else:
sources[source_name] = table sources[source_name] = table

View file

@ -99,7 +99,8 @@ class Parser:
TokenType.SMALLMONEY, TokenType.SMALLMONEY,
TokenType.ROWVERSION, TokenType.ROWVERSION,
TokenType.IMAGE, TokenType.IMAGE,
TokenType.SQL_VARIANT, TokenType.VARIANT,
TokenType.OBJECT,
*NESTED_TYPE_TOKENS, *NESTED_TYPE_TOKENS,
} }
@ -131,7 +132,6 @@ class Parser:
TokenType.FALSE, TokenType.FALSE,
TokenType.FIRST, TokenType.FIRST,
TokenType.FOLLOWING, TokenType.FOLLOWING,
TokenType.FOR,
TokenType.FORMAT, TokenType.FORMAT,
TokenType.FUNCTION, TokenType.FUNCTION,
TokenType.GENERATED, TokenType.GENERATED,
@ -141,20 +141,26 @@ class Parser:
TokenType.ISNULL, TokenType.ISNULL,
TokenType.INTERVAL, TokenType.INTERVAL,
TokenType.LAZY, TokenType.LAZY,
TokenType.LANGUAGE,
TokenType.LEADING, TokenType.LEADING,
TokenType.LOCATION, TokenType.LOCATION,
TokenType.MATERIALIZED,
TokenType.NATURAL, TokenType.NATURAL,
TokenType.NEXT, TokenType.NEXT,
TokenType.ONLY, TokenType.ONLY,
TokenType.OPTIMIZE, TokenType.OPTIMIZE,
TokenType.OPTIONS, TokenType.OPTIONS,
TokenType.ORDINALITY, TokenType.ORDINALITY,
TokenType.PARTITIONED_BY,
TokenType.PERCENT, TokenType.PERCENT,
TokenType.PIVOT,
TokenType.PRECEDING, TokenType.PRECEDING,
TokenType.RANGE, TokenType.RANGE,
TokenType.REFERENCES, TokenType.REFERENCES,
TokenType.RETURNS,
TokenType.ROWS, TokenType.ROWS,
TokenType.SCHEMA_COMMENT, TokenType.SCHEMA_COMMENT,
TokenType.SEED,
TokenType.SET, TokenType.SET,
TokenType.SHOW, TokenType.SHOW,
TokenType.STORED, TokenType.STORED,
@ -167,6 +173,7 @@ class Parser:
TokenType.TRUE, TokenType.TRUE,
TokenType.UNBOUNDED, TokenType.UNBOUNDED,
TokenType.UNIQUE, TokenType.UNIQUE,
TokenType.UNPIVOT,
TokenType.PROPERTIES, TokenType.PROPERTIES,
*SUBQUERY_PREDICATES, *SUBQUERY_PREDICATES,
*TYPE_TOKENS, *TYPE_TOKENS,
@ -303,6 +310,8 @@ class Parser:
exp.Condition: lambda self: self._parse_conjunction(), exp.Condition: lambda self: self._parse_conjunction(),
exp.Expression: lambda self: self._parse_statement(), exp.Expression: lambda self: self._parse_statement(),
exp.Properties: lambda self: self._parse_properties(), exp.Properties: lambda self: self._parse_properties(),
exp.Where: lambda self: self._parse_where(),
exp.Ordered: lambda self: self._parse_ordered(),
"JOIN_TYPE": lambda self: self._parse_join_side_and_kind(), "JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
} }
@ -355,23 +364,21 @@ class Parser:
PROPERTY_PARSERS = { PROPERTY_PARSERS = {
TokenType.AUTO_INCREMENT: lambda self: self._parse_auto_increment(), TokenType.AUTO_INCREMENT: lambda self: self._parse_auto_increment(),
TokenType.CHARACTER_SET: lambda self: self._parse_character_set(), TokenType.CHARACTER_SET: lambda self: self._parse_character_set(),
TokenType.COLLATE: lambda self: self._parse_collate(),
TokenType.ENGINE: lambda self: self._parse_engine(),
TokenType.FORMAT: lambda self: self._parse_format(),
TokenType.LOCATION: lambda self: self.expression( TokenType.LOCATION: lambda self: self.expression(
exp.LocationProperty, exp.LocationProperty,
this=exp.Literal.string("LOCATION"), this=exp.Literal.string("LOCATION"),
value=self._parse_string(), value=self._parse_string(),
), ),
TokenType.PARTITIONED_BY: lambda self: self.expression( TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(),
exp.PartitionedByProperty,
this=exp.Literal.string("PARTITIONED_BY"),
value=self._parse_schema(),
),
TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(), TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(),
TokenType.STORED: lambda self: self._parse_stored(), TokenType.STORED: lambda self: self._parse_stored(),
TokenType.TABLE_FORMAT: lambda self: self._parse_table_format(), TokenType.RETURNS: lambda self: self._parse_returns(),
TokenType.USING: lambda self: self._parse_table_format(), TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
} }
CONSTRAINT_PARSERS = { CONSTRAINT_PARSERS = {
@ -388,6 +395,7 @@ class Parser:
FUNCTION_PARSERS = { FUNCTION_PARSERS = {
"CONVERT": lambda self: self._parse_convert(), "CONVERT": lambda self: self._parse_convert(),
"EXTRACT": lambda self: self._parse_extract(), "EXTRACT": lambda self: self._parse_extract(),
"POSITION": lambda self: self._parse_position(),
"SUBSTRING": lambda self: self._parse_substring(), "SUBSTRING": lambda self: self._parse_substring(),
"TRIM": lambda self: self._parse_trim(), "TRIM": lambda self: self._parse_trim(),
"CAST": lambda self: self._parse_cast(self.STRICT_CAST), "CAST": lambda self: self._parse_cast(self.STRICT_CAST),
@ -628,6 +636,10 @@ class Parser:
replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE) replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
temporary = self._match(TokenType.TEMPORARY) temporary = self._match(TokenType.TEMPORARY)
unique = self._match(TokenType.UNIQUE) unique = self._match(TokenType.UNIQUE)
materialized = self._match(TokenType.MATERIALIZED)
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
self._match(TokenType.TABLE)
create_token = self._match_set(self.CREATABLES) and self._prev create_token = self._match_set(self.CREATABLES) and self._prev
@ -640,14 +652,15 @@ class Parser:
properties = None properties = None
if create_token.token_type == TokenType.FUNCTION: if create_token.token_type == TokenType.FUNCTION:
this = self._parse_var() this = self._parse_user_defined_function()
properties = self._parse_properties()
if self._match(TokenType.ALIAS): if self._match(TokenType.ALIAS):
expression = self._parse_string() expression = self._parse_select_or_expression()
elif create_token.token_type == TokenType.INDEX: elif create_token.token_type == TokenType.INDEX:
this = self._parse_index() this = self._parse_index()
elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW): elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW):
this = self._parse_table(schema=True) this = self._parse_table(schema=True)
properties = self._parse_properties(this if isinstance(this, exp.Schema) else None) properties = self._parse_properties()
if self._match(TokenType.ALIAS): if self._match(TokenType.ALIAS):
expression = self._parse_select(nested=True) expression = self._parse_select(nested=True)
@ -661,9 +674,10 @@ class Parser:
temporary=temporary, temporary=temporary,
replace=replace, replace=replace,
unique=unique, unique=unique,
materialized=materialized,
) )
def _parse_property(self, schema): def _parse_property(self):
if self._match_set(self.PROPERTY_PARSERS): if self._match_set(self.PROPERTY_PARSERS):
return self.PROPERTY_PARSERS[self._prev.token_type](self) return self.PROPERTY_PARSERS[self._prev.token_type](self)
if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET): if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET):
@ -673,31 +687,27 @@ class Parser:
key = self._parse_var().this key = self._parse_var().this
self._match(TokenType.EQ) self._match(TokenType.EQ)
if key.upper() == "PARTITIONED_BY":
expression = exp.PartitionedByProperty
value = self._parse_schema() or self._parse_bracket(self._parse_field())
if schema and not isinstance(value, exp.Schema):
columns = {v.name.upper() for v in value.expressions}
partitions = [
expression for expression in schema.expressions if expression.this.name.upper() in columns
]
schema.set(
"expressions",
[e for e in schema.expressions if e not in partitions],
)
value = self.expression(exp.Schema, expressions=partitions)
else:
value = self._parse_column()
expression = exp.AnonymousProperty
return self.expression( return self.expression(
expression, exp.AnonymousProperty,
this=exp.Literal.string(key), this=exp.Literal.string(key),
value=value, value=self._parse_column(),
) )
return None return None
def _parse_property_assignment(self, exp_class):
prop = self._prev.text
self._match(TokenType.EQ)
return self.expression(exp_class, this=prop, value=self._parse_var_or_string())
def _parse_partitioned_by(self):
self._match(TokenType.EQ)
return self.expression(
exp.PartitionedByProperty,
this=exp.Literal.string("PARTITIONED_BY"),
value=self._parse_schema() or self._parse_bracket(self._parse_field()),
)
def _parse_stored(self): def _parse_stored(self):
self._match(TokenType.ALIAS) self._match(TokenType.ALIAS)
self._match(TokenType.EQ) self._match(TokenType.EQ)
@ -707,22 +717,6 @@ class Parser:
value=exp.Literal.string(self._parse_var().name), value=exp.Literal.string(self._parse_var().name),
) )
def _parse_format(self):
self._match(TokenType.EQ)
return self.expression(
exp.FileFormatProperty,
this=exp.Literal.string("FORMAT"),
value=self._parse_string() or self._parse_var(),
)
def _parse_engine(self):
self._match(TokenType.EQ)
return self.expression(
exp.EngineProperty,
this=exp.Literal.string("ENGINE"),
value=self._parse_var_or_string(),
)
def _parse_auto_increment(self): def _parse_auto_increment(self):
self._match(TokenType.EQ) self._match(TokenType.EQ)
return self.expression( return self.expression(
@ -731,14 +725,6 @@ class Parser:
value=self._parse_var() or self._parse_number(), value=self._parse_var() or self._parse_number(),
) )
def _parse_collate(self):
self._match(TokenType.EQ)
return self.expression(
exp.CollateProperty,
this=exp.Literal.string("COLLATE"),
value=self._parse_var_or_string(),
)
def _parse_schema_comment(self): def _parse_schema_comment(self):
self._match(TokenType.EQ) self._match(TokenType.EQ)
return self.expression( return self.expression(
@ -756,26 +742,34 @@ class Parser:
default=default, default=default,
) )
def _parse_table_format(self): def _parse_returns(self):
self._match(TokenType.EQ) is_table = self._match(TokenType.TABLE)
if is_table:
if self._match(TokenType.LT):
value = self.expression(
exp.Schema, this="TABLE", expressions=self._parse_csv(self._parse_struct_kwargs)
)
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
else:
value = self._parse_schema("TABLE")
else:
value = self._parse_types()
return self.expression( return self.expression(
exp.TableFormatProperty, exp.ReturnsProperty,
this=exp.Literal.string("TABLE_FORMAT"), this=exp.Literal.string("RETURNS"),
value=self._parse_var_or_string(), value=value,
is_table=is_table,
) )
def _parse_properties(self, schema=None): def _parse_properties(self):
"""
Schema is included since if the table schema is defined and we later get a partition by expression
then we will define those columns in the partition by section and not in with the rest of the
columns
"""
properties = [] properties = []
while True: while True:
if self._match(TokenType.WITH): if self._match(TokenType.WITH):
self._match_l_paren() self._match_l_paren()
properties.extend(self._parse_csv(lambda: self._parse_property(schema))) properties.extend(self._parse_csv(lambda: self._parse_property()))
self._match_r_paren() self._match_r_paren()
elif self._match(TokenType.PROPERTIES): elif self._match(TokenType.PROPERTIES):
self._match_l_paren() self._match_l_paren()
@ -790,7 +784,7 @@ class Parser:
) )
self._match_r_paren() self._match_r_paren()
else: else:
identified_property = self._parse_property(schema) identified_property = self._parse_property()
if not identified_property: if not identified_property:
break break
properties.append(identified_property) properties.append(identified_property)
@ -1003,7 +997,7 @@ class Parser:
) )
def _parse_subquery(self, this): def _parse_subquery(self, this):
return self.expression(exp.Subquery, this=this, alias=self._parse_table_alias()) return self.expression(exp.Subquery, this=this, pivots=self._parse_pivots(), alias=self._parse_table_alias())
def _parse_query_modifiers(self, this): def _parse_query_modifiers(self, this):
if not isinstance(this, self.MODIFIABLES): if not isinstance(this, self.MODIFIABLES):
@ -1134,6 +1128,10 @@ class Parser:
table = (not schema and self._parse_function()) or self._parse_id_var(False) table = (not schema and self._parse_function()) or self._parse_id_var(False)
while self._match(TokenType.DOT): while self._match(TokenType.DOT):
if catalog:
# This allows nesting the table in arbitrarily many dot expressions if needed
table = self.expression(exp.Dot, this=table, expression=self._parse_id_var())
else:
catalog = db catalog = db
db = table db = table
table = self._parse_id_var() table = self._parse_id_var()
@ -1141,7 +1139,7 @@ class Parser:
if not table: if not table:
self.raise_error("Expected table name") self.raise_error("Expected table name")
this = self.expression(exp.Table, this=table, db=db, catalog=catalog) this = self.expression(exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots())
if schema: if schema:
return self._parse_schema(this=this) return self._parse_schema(this=this)
@ -1199,6 +1197,7 @@ class Parser:
percent = None percent = None
rows = None rows = None
size = None size = None
seed = None
self._match_l_paren() self._match_l_paren()
@ -1220,6 +1219,11 @@ class Parser:
self._match_r_paren() self._match_r_paren()
if self._match(TokenType.SEED):
self._match_l_paren()
seed = self._parse_number()
self._match_r_paren()
return self.expression( return self.expression(
exp.TableSample, exp.TableSample,
method=method, method=method,
@ -1229,6 +1233,51 @@ class Parser:
percent=percent, percent=percent,
rows=rows, rows=rows,
size=size, size=size,
seed=seed,
)
def _parse_pivots(self):
return list(iter(self._parse_pivot, None))
def _parse_pivot(self):
index = self._index
if self._match(TokenType.PIVOT):
unpivot = False
elif self._match(TokenType.UNPIVOT):
unpivot = True
else:
return None
expressions = []
field = None
if not self._match(TokenType.L_PAREN):
self._retreat(index)
return None
if unpivot:
expressions = self._parse_csv(self._parse_column)
else:
expressions = self._parse_csv(lambda: self._parse_alias(self._parse_function()))
if not self._match(TokenType.FOR):
self.raise_error("Expecting FOR")
value = self._parse_column()
if not self._match(TokenType.IN):
self.raise_error("Expecting IN")
field = self._parse_in(value)
self._match_r_paren()
return self.expression(
exp.Pivot,
expressions=expressions,
field=field,
unpivot=unpivot,
) )
def _parse_where(self): def _parse_where(self):
@ -1384,7 +1433,7 @@ class Parser:
this = self.expression(exp.In, this=this, unnest=unnest) this = self.expression(exp.In, this=this, unnest=unnest)
else: else:
self._match_l_paren() self._match_l_paren()
expressions = self._parse_csv(lambda: self._parse_select() or self._parse_expression()) expressions = self._parse_csv(self._parse_select_or_expression)
if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable): if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable):
this = self.expression(exp.In, this=this, query=expressions[0]) this = self.expression(exp.In, this=this, query=expressions[0])
@ -1577,6 +1626,9 @@ class Parser:
if self._match_set(self.PRIMARY_PARSERS): if self._match_set(self.PRIMARY_PARSERS):
return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev) return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev)
if self._match_pair(TokenType.DOT, TokenType.NUMBER):
return exp.Literal.number(f"0.{self._prev.text}")
if self._match(TokenType.L_PAREN): if self._match(TokenType.L_PAREN):
query = self._parse_select() query = self._parse_select()
@ -1647,6 +1699,23 @@ class Parser:
self._match_r_paren() self._match_r_paren()
return self._parse_window(this) return self._parse_window(this)
def _parse_user_defined_function(self):
this = self._parse_var()
if not self._match(TokenType.L_PAREN):
return this
expressions = self._parse_csv(self._parse_udf_kwarg)
self._match_r_paren()
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
def _parse_udf_kwarg(self):
this = self._parse_id_var()
kind = self._parse_types()
if not kind:
return this
return self.expression(exp.UserDefinedFunctionKwarg, this=this, kind=kind)
def _parse_lambda(self): def _parse_lambda(self):
index = self._index index = self._index
@ -1672,9 +1741,10 @@ class Parser:
return self._parse_alias(self._parse_limit(self._parse_order(this))) return self._parse_alias(self._parse_limit(self._parse_order(this)))
conjunction = self._parse_conjunction().transform(self._replace_lambda, {node.name for node in expressions})
return self.expression( return self.expression(
exp.Lambda, exp.Lambda,
this=self._parse_conjunction(), this=conjunction,
expressions=expressions, expressions=expressions,
) )
@ -1896,6 +1966,12 @@ class Parser:
to = None to = None
return self.expression(exp.Cast, this=this, to=to) return self.expression(exp.Cast, this=this, to=to)
def _parse_position(self):
substr = self._parse_bitwise()
if self._match(TokenType.IN):
string = self._parse_bitwise()
return self.expression(exp.StrPosition, this=string, substr=substr)
def _parse_substring(self): def _parse_substring(self):
# Postgres supports the form: substring(string [from int] [for int]) # Postgres supports the form: substring(string [from int] [for int])
# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6 # https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
@ -2155,6 +2231,9 @@ class Parser:
self._match_r_paren() self._match_r_paren()
return expressions return expressions
def _parse_select_or_expression(self):
return self._parse_select() or self._parse_expression()
def _match(self, token_type): def _match(self, token_type):
if not self._curr: if not self._curr:
return None return None
@ -2208,3 +2287,9 @@ class Parser:
elif isinstance(this, exp.Identifier): elif isinstance(this, exp.Identifier):
this = self.expression(exp.Var, this=this.name) this = self.expression(exp.Var, this=this.name)
return this return this
def _replace_lambda(self, node, lambda_variables):
if isinstance(node, exp.Column):
if node.name in lambda_variables:
return node.this
return node

View file

@ -94,7 +94,8 @@ class TokenType(AutoName):
SMALLMONEY = auto() SMALLMONEY = auto()
ROWVERSION = auto() ROWVERSION = auto()
IMAGE = auto() IMAGE = auto()
SQL_VARIANT = auto() VARIANT = auto()
OBJECT = auto()
# keywords # keywords
ADD_FILE = auto() ADD_FILE = auto()
@ -177,6 +178,7 @@ class TokenType(AutoName):
IS = auto() IS = auto()
ISNULL = auto() ISNULL = auto()
JOIN = auto() JOIN = auto()
LANGUAGE = auto()
LATERAL = auto() LATERAL = auto()
LAZY = auto() LAZY = auto()
LEADING = auto() LEADING = auto()
@ -185,6 +187,7 @@ class TokenType(AutoName):
LIMIT = auto() LIMIT = auto()
LOCATION = auto() LOCATION = auto()
MAP = auto() MAP = auto()
MATERIALIZED = auto()
MOD = auto() MOD = auto()
NATURAL = auto() NATURAL = auto()
NEXT = auto() NEXT = auto()
@ -208,6 +211,7 @@ class TokenType(AutoName):
PARTITION_BY = auto() PARTITION_BY = auto()
PARTITIONED_BY = auto() PARTITIONED_BY = auto()
PERCENT = auto() PERCENT = auto()
PIVOT = auto()
PLACEHOLDER = auto() PLACEHOLDER = auto()
PRECEDING = auto() PRECEDING = auto()
PRIMARY_KEY = auto() PRIMARY_KEY = auto()
@ -219,12 +223,14 @@ class TokenType(AutoName):
REPLACE = auto() REPLACE = auto()
RESPECT_NULLS = auto() RESPECT_NULLS = auto()
REFERENCES = auto() REFERENCES = auto()
RETURNS = auto()
RIGHT = auto() RIGHT = auto()
RLIKE = auto() RLIKE = auto()
ROLLUP = auto() ROLLUP = auto()
ROW = auto() ROW = auto()
ROWS = auto() ROWS = auto()
SCHEMA_COMMENT = auto() SCHEMA_COMMENT = auto()
SEED = auto()
SELECT = auto() SELECT = auto()
SEPARATOR = auto() SEPARATOR = auto()
SET = auto() SET = auto()
@ -246,6 +252,7 @@ class TokenType(AutoName):
UNCACHE = auto() UNCACHE = auto()
UNION = auto() UNION = auto()
UNNEST = auto() UNNEST = auto()
UNPIVOT = auto()
UPDATE = auto() UPDATE = auto()
USE = auto() USE = auto()
USING = auto() USING = auto()
@ -440,6 +447,7 @@ class Tokenizer(metaclass=_Tokenizer):
"FULL": TokenType.FULL, "FULL": TokenType.FULL,
"FUNCTION": TokenType.FUNCTION, "FUNCTION": TokenType.FUNCTION,
"FOLLOWING": TokenType.FOLLOWING, "FOLLOWING": TokenType.FOLLOWING,
"FOR": TokenType.FOR,
"FOREIGN KEY": TokenType.FOREIGN_KEY, "FOREIGN KEY": TokenType.FOREIGN_KEY,
"FORMAT": TokenType.FORMAT, "FORMAT": TokenType.FORMAT,
"FROM": TokenType.FROM, "FROM": TokenType.FROM,
@ -459,6 +467,7 @@ class Tokenizer(metaclass=_Tokenizer):
"IS": TokenType.IS, "IS": TokenType.IS,
"ISNULL": TokenType.ISNULL, "ISNULL": TokenType.ISNULL,
"JOIN": TokenType.JOIN, "JOIN": TokenType.JOIN,
"LANGUAGE": TokenType.LANGUAGE,
"LATERAL": TokenType.LATERAL, "LATERAL": TokenType.LATERAL,
"LAZY": TokenType.LAZY, "LAZY": TokenType.LAZY,
"LEADING": TokenType.LEADING, "LEADING": TokenType.LEADING,
@ -466,6 +475,7 @@ class Tokenizer(metaclass=_Tokenizer):
"LIKE": TokenType.LIKE, "LIKE": TokenType.LIKE,
"LIMIT": TokenType.LIMIT, "LIMIT": TokenType.LIMIT,
"LOCATION": TokenType.LOCATION, "LOCATION": TokenType.LOCATION,
"MATERIALIZED": TokenType.MATERIALIZED,
"NATURAL": TokenType.NATURAL, "NATURAL": TokenType.NATURAL,
"NEXT": TokenType.NEXT, "NEXT": TokenType.NEXT,
"NO ACTION": TokenType.NO_ACTION, "NO ACTION": TokenType.NO_ACTION,
@ -473,6 +483,7 @@ class Tokenizer(metaclass=_Tokenizer):
"NULL": TokenType.NULL, "NULL": TokenType.NULL,
"NULLS FIRST": TokenType.NULLS_FIRST, "NULLS FIRST": TokenType.NULLS_FIRST,
"NULLS LAST": TokenType.NULLS_LAST, "NULLS LAST": TokenType.NULLS_LAST,
"OBJECT": TokenType.OBJECT,
"OFFSET": TokenType.OFFSET, "OFFSET": TokenType.OFFSET,
"ON": TokenType.ON, "ON": TokenType.ON,
"ONLY": TokenType.ONLY, "ONLY": TokenType.ONLY,
@ -488,7 +499,9 @@ class Tokenizer(metaclass=_Tokenizer):
"PARTITION": TokenType.PARTITION, "PARTITION": TokenType.PARTITION,
"PARTITION BY": TokenType.PARTITION_BY, "PARTITION BY": TokenType.PARTITION_BY,
"PARTITIONED BY": TokenType.PARTITIONED_BY, "PARTITIONED BY": TokenType.PARTITIONED_BY,
"PARTITIONED_BY": TokenType.PARTITIONED_BY,
"PERCENT": TokenType.PERCENT, "PERCENT": TokenType.PERCENT,
"PIVOT": TokenType.PIVOT,
"PRECEDING": TokenType.PRECEDING, "PRECEDING": TokenType.PRECEDING,
"PRIMARY KEY": TokenType.PRIMARY_KEY, "PRIMARY KEY": TokenType.PRIMARY_KEY,
"RANGE": TokenType.RANGE, "RANGE": TokenType.RANGE,
@ -497,11 +510,13 @@ class Tokenizer(metaclass=_Tokenizer):
"REPLACE": TokenType.REPLACE, "REPLACE": TokenType.REPLACE,
"RESPECT NULLS": TokenType.RESPECT_NULLS, "RESPECT NULLS": TokenType.RESPECT_NULLS,
"REFERENCES": TokenType.REFERENCES, "REFERENCES": TokenType.REFERENCES,
"RETURNS": TokenType.RETURNS,
"RIGHT": TokenType.RIGHT, "RIGHT": TokenType.RIGHT,
"RLIKE": TokenType.RLIKE, "RLIKE": TokenType.RLIKE,
"ROLLUP": TokenType.ROLLUP, "ROLLUP": TokenType.ROLLUP,
"ROW": TokenType.ROW, "ROW": TokenType.ROW,
"ROWS": TokenType.ROWS, "ROWS": TokenType.ROWS,
"SEED": TokenType.SEED,
"SELECT": TokenType.SELECT, "SELECT": TokenType.SELECT,
"SET": TokenType.SET, "SET": TokenType.SET,
"SHOW": TokenType.SHOW, "SHOW": TokenType.SHOW,
@ -520,6 +535,7 @@ class Tokenizer(metaclass=_Tokenizer):
"TRUNCATE": TokenType.TRUNCATE, "TRUNCATE": TokenType.TRUNCATE,
"UNBOUNDED": TokenType.UNBOUNDED, "UNBOUNDED": TokenType.UNBOUNDED,
"UNION": TokenType.UNION, "UNION": TokenType.UNION,
"UNPIVOT": TokenType.UNPIVOT,
"UNNEST": TokenType.UNNEST, "UNNEST": TokenType.UNNEST,
"UPDATE": TokenType.UPDATE, "UPDATE": TokenType.UPDATE,
"USE": TokenType.USE, "USE": TokenType.USE,
@ -577,6 +593,7 @@ class Tokenizer(metaclass=_Tokenizer):
"DATETIME": TokenType.DATETIME, "DATETIME": TokenType.DATETIME,
"UNIQUE": TokenType.UNIQUE, "UNIQUE": TokenType.UNIQUE,
"STRUCT": TokenType.STRUCT, "STRUCT": TokenType.STRUCT,
"VARIANT": TokenType.VARIANT,
} }
WHITE_SPACE = { WHITE_SPACE = {

View file

@ -12,15 +12,20 @@ def unalias_group(expression):
""" """
if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
aliased_selects = { aliased_selects = {
e.alias: i for i, e in enumerate(expression.parent.expressions, start=1) if isinstance(e, exp.Alias) e.alias: (i, e.this)
for i, e in enumerate(expression.parent.expressions, start=1)
if isinstance(e, exp.Alias)
} }
expression = expression.copy() expression = expression.copy()
for col in expression.find_all(exp.Column): top_level_expression = None
alias_index = aliased_selects.get(col.name) for item, parent, _ in expression.walk(bfs=False):
if not col.table and alias_index: top_level_expression = item if isinstance(parent, exp.Group) else top_level_expression
col.replace(exp.Literal.number(alias_index)) if isinstance(item, exp.Column) and not item.table:
alias_index, col_expression = aliased_selects.get(item.name, (None, None))
if alias_index and top_level_expression != col_expression:
item.replace(exp.Literal.number(alias_index))
return expression return expression

View file

@ -236,3 +236,24 @@ class TestBigQuery(Validator):
"snowflake": "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a NULLS FIRST LIMIT 10", "snowflake": "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a NULLS FIRST LIMIT 10",
}, },
) )
self.validate_all(
"SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
write={
"spark": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
"bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb)])",
"snowflake": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
},
)
self.validate_all(
"SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) d, COUNT(*) e FOR c IN ('x', 'y'))",
write={
"bigquery": "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))",
},
)
def test_user_defined_functions(self):
self.validate_identity(
"CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 LANGUAGE js AS 'return x*y;'"
)
self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)")
self.validate_identity("CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t")

View file

@ -13,9 +13,6 @@ from sqlglot import (
class Validator(unittest.TestCase): class Validator(unittest.TestCase):
dialect = None dialect = None
def validate(self, sql, target, **kwargs):
self.assertEqual(transpile(sql, **kwargs)[0], target)
def validate_identity(self, sql): def validate_identity(self, sql):
self.assertEqual(transpile(sql, read=self.dialect, write=self.dialect)[0], sql) self.assertEqual(transpile(sql, read=self.dialect, write=self.dialect)[0], sql)
@ -258,6 +255,7 @@ class TestDialect(Validator):
"duckdb": "EPOCH(STRPTIME('2020-01-01', '%Y-%M-%d'))", "duckdb": "EPOCH(STRPTIME('2020-01-01', '%Y-%M-%d'))",
"hive": "UNIX_TIMESTAMP('2020-01-01', 'yyyy-mm-dd')", "hive": "UNIX_TIMESTAMP('2020-01-01', 'yyyy-mm-dd')",
"presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%i-%d'))", "presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%i-%d'))",
"starrocks": "UNIX_TIMESTAMP('2020-01-01', '%Y-%i-%d')",
}, },
) )
self.validate_all( self.validate_all(
@ -266,6 +264,7 @@ class TestDialect(Validator):
"duckdb": "CAST('2020-01-01' AS DATE)", "duckdb": "CAST('2020-01-01' AS DATE)",
"hive": "TO_DATE('2020-01-01')", "hive": "TO_DATE('2020-01-01')",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')", "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')",
"starrocks": "TO_DATE('2020-01-01')",
}, },
) )
self.validate_all( self.validate_all(
@ -341,6 +340,7 @@ class TestDialect(Validator):
"duckdb": "STRFTIME(TO_TIMESTAMP(CAST(x AS BIGINT)), y)", "duckdb": "STRFTIME(TO_TIMESTAMP(CAST(x AS BIGINT)), y)",
"hive": "FROM_UNIXTIME(x, y)", "hive": "FROM_UNIXTIME(x, y)",
"presto": "DATE_FORMAT(FROM_UNIXTIME(x), y)", "presto": "DATE_FORMAT(FROM_UNIXTIME(x), y)",
"starrocks": "FROM_UNIXTIME(x, y)",
}, },
) )
self.validate_all( self.validate_all(
@ -349,6 +349,7 @@ class TestDialect(Validator):
"duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))", "duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))",
"hive": "FROM_UNIXTIME(x)", "hive": "FROM_UNIXTIME(x)",
"presto": "FROM_UNIXTIME(x)", "presto": "FROM_UNIXTIME(x)",
"starrocks": "FROM_UNIXTIME(x)",
}, },
) )
self.validate_all( self.validate_all(
@ -840,10 +841,20 @@ class TestDialect(Validator):
"starrocks": UnsupportedError, "starrocks": UnsupportedError,
}, },
) )
self.validate_all(
"POSITION(' ' in x)",
write={
"duckdb": "STRPOS(x, ' ')",
"postgres": "STRPOS(x, ' ')",
"presto": "STRPOS(x, ' ')",
"spark": "LOCATE(' ', x)",
},
)
self.validate_all( self.validate_all(
"STR_POSITION(x, 'a')", "STR_POSITION(x, 'a')",
write={ write={
"duckdb": "STRPOS(x, 'a')", "duckdb": "STRPOS(x, 'a')",
"postgres": "STRPOS(x, 'a')",
"presto": "STRPOS(x, 'a')", "presto": "STRPOS(x, 'a')",
"spark": "LOCATE('a', x)", "spark": "LOCATE('a', x)",
}, },

View file

@ -1,3 +1,4 @@
from sqlglot import ErrorLevel, UnsupportedError, transpile
from tests.dialects.test_dialect import Validator from tests.dialects.test_dialect import Validator
@ -250,3 +251,10 @@ class TestDuckDB(Validator):
"spark": "MONTH('2021-03-01')", "spark": "MONTH('2021-03-01')",
}, },
) )
with self.assertRaises(UnsupportedError):
transpile(
"SELECT a FROM b PIVOT(SUM(x) FOR y IN ('z', 'q'))",
read="duckdb",
unsupported_level=ErrorLevel.IMMEDIATE,
)

View file

@ -119,3 +119,39 @@ class TestMySQL(Validator):
"sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC, '')", "sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC, '')",
}, },
) )
self.validate_identity(
"CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'"
)
self.validate_identity(
"CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'"
)
self.validate_identity(
"CREATE TABLE z (a INT DEFAULT NULL, PRIMARY KEY(a)) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'"
)
self.validate_all(
"""
CREATE TABLE `t_customer_account` (
"id" int(11) NOT NULL AUTO_INCREMENT,
"customer_id" int(11) DEFAULT NULL COMMENT '客户id',
"bank" varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别',
"account_no" varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号',
PRIMARY KEY ("id")
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='客户账户表'
""",
write={
"mysql": """CREATE TABLE `t_customer_account` (
'id' INT(11) NOT NULL AUTO_INCREMENT,
'customer_id' INT(11) DEFAULT NULL COMMENT '客户id',
'bank' VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别',
'account_no' VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号',
PRIMARY KEY('id')
)
ENGINE=InnoDB
AUTO_INCREMENT=1
DEFAULT CHARACTER SET=utf8
COLLATE=utf8_bin
COMMENT='客户账户表'"""
},
pretty=True,
)

View file

@ -217,11 +217,12 @@ class TestPresto(Validator):
}, },
) )
self.validate( self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", write={
read="presto", "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname",
write="presto", "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
},
) )
def test_quotes(self): def test_quotes(self):

View file

@ -143,6 +143,31 @@ class TestSnowflake(Validator):
"snowflake": r"SELECT 'a \' \\ \\t \\x21 z $ '", "snowflake": r"SELECT 'a \' \\ \\t \\x21 z $ '",
}, },
) )
self.validate_identity("SELECT REGEXP_LIKE(a, b, c)")
self.validate_all(
"SELECT RLIKE(a, b)",
write={
"snowflake": "SELECT REGEXP_LIKE(a, b)",
},
)
self.validate_all(
"SELECT a FROM test SAMPLE BLOCK (0.5) SEED (42)",
write={
"snowflake": "SELECT a FROM test TABLESAMPLE BLOCK (0.5) SEED (42)",
},
)
self.validate_all(
"SELECT a FROM test pivot",
write={
"snowflake": "SELECT a FROM test AS pivot",
},
)
self.validate_all(
"SELECT a FROM test unpivot",
write={
"snowflake": "SELECT a FROM test AS unpivot",
},
)
def test_null_treatment(self): def test_null_treatment(self):
self.validate_all( self.validate_all(
@ -220,3 +245,51 @@ class TestSnowflake(Validator):
"snowflake": "SELECT EXTRACT(month FROM CAST(a AS DATETIME))", "snowflake": "SELECT EXTRACT(month FROM CAST(a AS DATETIME))",
}, },
) )
def test_semi_structured_types(self):
self.validate_identity("SELECT CAST(a AS VARIANT)")
self.validate_all(
"SELECT a::VARIANT",
write={
"snowflake": "SELECT CAST(a AS VARIANT)",
"tsql": "SELECT CAST(a AS SQL_VARIANT)",
},
)
self.validate_identity("SELECT CAST(a AS ARRAY)")
self.validate_all(
"ARRAY_CONSTRUCT(0, 1, 2)",
write={
"snowflake": "[0, 1, 2]",
"bigquery": "[0, 1, 2]",
"duckdb": "LIST_VALUE(0, 1, 2)",
"presto": "ARRAY[0, 1, 2]",
"spark": "ARRAY(0, 1, 2)",
},
)
self.validate_all(
"SELECT a::OBJECT",
write={
"snowflake": "SELECT CAST(a AS OBJECT)",
},
)
def test_ddl(self):
self.validate_identity(
"CREATE TABLE a (x DATE, y BIGINT) WITH (PARTITION BY (x), integration='q', auto_refresh=TRUE, file_format=(type = parquet))"
)
self.validate_identity("CREATE MATERIALIZED VIEW a COMMENT='...' AS SELECT 1 FROM x")
def test_user_defined_functions(self):
self.validate_all(
"CREATE FUNCTION a(x DATE, y BIGINT) RETURNS ARRAY LANGUAGE JAVASCRIPT AS $$ SELECT 1 $$",
write={
"snowflake": "CREATE FUNCTION a(x DATE, y BIGINT) RETURNS ARRAY LANGUAGE JAVASCRIPT AS ' SELECT 1 '",
},
)
self.validate_all(
"CREATE FUNCTION a() RETURNS TABLE (b INT) AS 'SELECT 1'",
write={
"snowflake": "CREATE FUNCTION a() RETURNS TABLE (b INT) AS 'SELECT 1'",
"bigquery": "CREATE TABLE FUNCTION a() RETURNS TABLE <b INT64> AS SELECT 1",
},
)

View file

@ -15,6 +15,14 @@ class TestTSQL(Validator):
}, },
) )
self.validate_all(
"CONVERT(INT, CONVERT(NUMERIC, '444.75'))",
write={
"mysql": "CAST(CAST('444.75' AS DECIMAL) AS INT)",
"tsql": "CAST(CAST('444.75' AS NUMERIC) AS INTEGER)",
},
)
def test_types(self): def test_types(self):
self.validate_identity("CAST(x AS XML)") self.validate_identity("CAST(x AS XML)")
self.validate_identity("CAST(x AS UNIQUEIDENTIFIER)") self.validate_identity("CAST(x AS UNIQUEIDENTIFIER)")
@ -24,3 +32,13 @@ class TestTSQL(Validator):
self.validate_identity("CAST(x AS IMAGE)") self.validate_identity("CAST(x AS IMAGE)")
self.validate_identity("CAST(x AS SQL_VARIANT)") self.validate_identity("CAST(x AS SQL_VARIANT)")
self.validate_identity("CAST(x AS BIT)") self.validate_identity("CAST(x AS BIT)")
self.validate_all(
"CAST(x AS DATETIME2)",
read={
"": "CAST(x AS DATETIME)",
},
write={
"mysql": "CAST(x AS DATETIME)",
"tsql": "CAST(x AS DATETIME2)",
},
)

View file

@ -8,6 +8,7 @@ SUM(CASE WHEN x > 1 THEN 1 ELSE 0 END) / y
1.1E10 1.1E10
1.12e-10 1.12e-10
-11.023E7 * 3 -11.023E7 * 3
0.2
(1 * 2) / (3 - 5) (1 * 2) / (3 - 5)
((TRUE)) ((TRUE))
'' ''
@ -167,7 +168,7 @@ SELECT LEAD(a) OVER (ORDER BY b) AS a
SELECT LEAD(a, 1) OVER (PARTITION BY a ORDER BY a) AS x SELECT LEAD(a, 1) OVER (PARTITION BY a ORDER BY a) AS x
SELECT LEAD(a, 1, b) OVER (PARTITION BY a ORDER BY a) AS x SELECT LEAD(a, 1, b) OVER (PARTITION BY a ORDER BY a) AS x
SELECT X((a, b) -> a + b, z -> z) AS x SELECT X((a, b) -> a + b, z -> z) AS x
SELECT X(a -> "a" + ("z" - 1)) SELECT X(a -> a + ("z" - 1))
SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0) SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)
SELECT test.* FROM test SELECT test.* FROM test
SELECT a AS b FROM test SELECT a AS b FROM test
@ -258,15 +259,24 @@ SELECT a FROM test TABLESAMPLE(100)
SELECT a FROM test TABLESAMPLE(100 ROWS) SELECT a FROM test TABLESAMPLE(100 ROWS)
SELECT a FROM test TABLESAMPLE BERNOULLI (50) SELECT a FROM test TABLESAMPLE BERNOULLI (50)
SELECT a FROM test TABLESAMPLE SYSTEM (75) SELECT a FROM test TABLESAMPLE SYSTEM (75)
SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q'))
SELECT a FROM test PIVOT(SOMEAGG(x, y, z) FOR q IN (1))
SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) PIVOT(MAX(b) FOR c IN ('d'))
SELECT a FROM (SELECT a, b FROM test) PIVOT(SUM(x) FOR y IN ('z', 'q'))
SELECT a FROM test UNPIVOT(x FOR y IN (z, q)) AS x
SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) AS x TABLESAMPLE(0.1)
SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) UNPIVOT(x FOR y IN (z, q)) AS x
SELECT ABS(a) FROM test SELECT ABS(a) FROM test
SELECT AVG(a) FROM test SELECT AVG(a) FROM test
SELECT CEIL(a) FROM test SELECT CEIL(a) FROM test
SELECT CEIL(a, b) FROM test
SELECT COUNT(a) FROM test SELECT COUNT(a) FROM test
SELECT COUNT(1) FROM test SELECT COUNT(1) FROM test
SELECT COUNT(*) FROM test SELECT COUNT(*) FROM test
SELECT COUNT(DISTINCT a) FROM test SELECT COUNT(DISTINCT a) FROM test
SELECT EXP(a) FROM test SELECT EXP(a) FROM test
SELECT FLOOR(a) FROM test SELECT FLOOR(a) FROM test
SELECT FLOOR(a, b) FROM test
SELECT FIRST(a) FROM test SELECT FIRST(a) FROM test
SELECT GREATEST(a, b, c) FROM test SELECT GREATEST(a, b, c) FROM test
SELECT LAST(a) FROM test SELECT LAST(a) FROM test
@ -299,6 +309,7 @@ SELECT CAST(a AS MAP<INT, INT>) FROM test
SELECT CAST(a AS TIMESTAMP) FROM test SELECT CAST(a AS TIMESTAMP) FROM test
SELECT CAST(a AS DATE) FROM test SELECT CAST(a AS DATE) FROM test
SELECT CAST(a AS ARRAY<INT>) FROM test SELECT CAST(a AS ARRAY<INT>) FROM test
SELECT CAST(a AS VARIANT) FROM test
SELECT TRY_CAST(a AS INT) FROM test SELECT TRY_CAST(a AS INT) FROM test
SELECT COALESCE(a, b, c) FROM test SELECT COALESCE(a, b, c) FROM test
SELECT IFNULL(a, b) FROM test SELECT IFNULL(a, b) FROM test
@ -442,9 +453,6 @@ CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id')
CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1) CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1)
CREATE TABLE z (a INT(11) NOT NULL COLLATE utf8_bin AUTO_INCREMENT) CREATE TABLE z (a INT(11) NOT NULL COLLATE utf8_bin AUTO_INCREMENT)
CREATE TABLE z (a INT, PRIMARY KEY(a)) CREATE TABLE z (a INT, PRIMARY KEY(a))
CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'
CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'
CREATE TABLE z (a INT DEFAULT NULL, PRIMARY KEY(a)) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'
CREATE TABLE z WITH (FORMAT='parquet') AS SELECT 1 CREATE TABLE z WITH (FORMAT='parquet') AS SELECT 1
CREATE TABLE z WITH (FORMAT='ORC', x='2') AS SELECT 1 CREATE TABLE z WITH (FORMAT='ORC', x='2') AS SELECT 1
CREATE TABLE z WITH (TABLE_FORMAT='iceberg', FORMAT='parquet') AS SELECT 1 CREATE TABLE z WITH (TABLE_FORMAT='iceberg', FORMAT='parquet') AS SELECT 1
@ -460,6 +468,9 @@ CREATE TEMPORARY FUNCTION f
CREATE TEMPORARY FUNCTION f AS 'g' CREATE TEMPORARY FUNCTION f AS 'g'
CREATE FUNCTION f CREATE FUNCTION f
CREATE FUNCTION f AS 'g' CREATE FUNCTION f AS 'g'
CREATE FUNCTION a(b INT, c VARCHAR) AS 'SELECT 1'
CREATE FUNCTION a() LANGUAGE sql
CREATE FUNCTION a() LANGUAGE sql RETURNS INT
CREATE INDEX abc ON t (a) CREATE INDEX abc ON t (a)
CREATE INDEX abc ON t (a, b, b) CREATE INDEX abc ON t (a, b, b)
CREATE UNIQUE INDEX abc ON t (a, b, b) CREATE UNIQUE INDEX abc ON t (a, b, b)
@ -519,3 +530,4 @@ WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS
WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a
SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z
SELECT ((SELECT 1) + 1) SELECT ((SELECT 1) + 1)
SELECT * FROM project.dataset.INFORMATION_SCHEMA.TABLES

View file

@ -1,42 +1,79 @@
SELECT 1 AS x, 2 AS y -- No derived tables
UNION ALL SELECT * FROM x;
SELECT 1 AS x, 2 AS y; SELECT * FROM x;
WITH _e_0 AS (
SELECT
1 AS x,
2 AS y
)
SELECT
*
FROM _e_0
UNION ALL
SELECT
*
FROM _e_0;
SELECT x.id -- Unaliased derived tables
FROM ( SELECT a FROM (SELECT b FROM (SELECT c FROM x));
SELECT * WITH cte AS (SELECT c FROM x), cte_2 AS (SELECT b FROM cte AS cte) SELECT a FROM cte_2 AS cte_2;
FROM x AS x
JOIN y AS y -- Joined derived table inside nested derived table
ON x.id = y.id SELECT b FROM (SELECT b FROM (SELECT b FROM x JOIN (SELECT b FROM y) AS y ON x.b = y.b));
) AS x WITH y_2 AS (SELECT b FROM y), cte AS (SELECT b FROM x JOIN y_2 AS y ON x.b = y.b), cte_2 AS (SELECT b FROM cte AS cte) SELECT b FROM cte_2 AS cte_2;
JOIN (
SELECT * -- Aliased derived tables
FROM x AS x SELECT a FROM (SELECT b FROM (SELECT c FROM x) AS y) AS z;
JOIN y AS y WITH y AS (SELECT c FROM x), z AS (SELECT b FROM y AS y) SELECT a FROM z AS z;
ON x.id = y.id
) AS y -- Existing CTEs
ON x.id = y.id; WITH q AS (SELECT c FROM x) SELECT a FROM (SELECT b FROM q AS y) AS z;
WITH _e_0 AS ( WITH q AS (SELECT c FROM x), z AS (SELECT b FROM q AS y) SELECT a FROM z AS z;
SELECT
* -- Derived table inside CTE
FROM x AS x WITH x AS (SELECT a FROM (SELECT a FROM x) AS y) SELECT a FROM x;
JOIN y AS y WITH y AS (SELECT a FROM x), x AS (SELECT a FROM y AS y) SELECT a FROM x;
ON x.id = y.id
) -- Name conflicts with existing outer derived table
SELECT SELECT a FROM (SELECT b FROM (SELECT c FROM x) AS y) AS y;
x.id WITH y AS (SELECT c FROM x), y_2 AS (SELECT b FROM y AS y) SELECT a FROM y_2 AS y;
FROM "_e_0" AS x
JOIN "_e_0" AS y -- Name conflicts with outer join
ON x.id = y.id; SELECT a, b FROM (SELECT c FROM (SELECT d FROM x) AS x) AS y JOIN x ON x.a = y.a;
WITH x_2 AS (SELECT d FROM x), y AS (SELECT c FROM x_2 AS x) SELECT a, b FROM y AS y JOIN x ON x.a = y.a;
-- Name conflicts with table name that is selected in another branch
SELECT * FROM (SELECT * FROM (SELECT a FROM x) AS x) AS y JOIN (SELECT * FROM x) AS z ON x.a = y.a;
WITH x_2 AS (SELECT a FROM x), y AS (SELECT * FROM x_2 AS x), z AS (SELECT * FROM x) SELECT * FROM y AS y JOIN z AS z ON x.a = y.a;
-- Name conflicts with table alias
SELECT a FROM (SELECT a FROM (SELECT a FROM x) AS y) AS z JOIN q AS y;
WITH y AS (SELECT a FROM x), z AS (SELECT a FROM y AS y) SELECT a FROM z AS z JOIN q AS y;
-- Name conflicts with existing CTE
WITH y AS (SELECT a FROM (SELECT a FROM x) AS y) SELECT a FROM y;
WITH y_2 AS (SELECT a FROM x), y AS (SELECT a FROM y_2 AS y) SELECT a FROM y;
-- Union
SELECT 1 AS x, 2 AS y UNION ALL SELECT 1 AS x, 2 AS y;
WITH cte AS (SELECT 1 AS x, 2 AS y) SELECT cte.x AS x, cte.y AS y FROM cte AS cte UNION ALL SELECT cte.x AS x, cte.y AS y FROM cte AS cte;
-- Union of selects with derived tables
(SELECT a FROM (SELECT b FROM x)) UNION (SELECT a FROM (SELECT b FROM y));
WITH cte AS (SELECT b FROM x), cte_2 AS (SELECT a FROM cte AS cte), cte_3 AS (SELECT b FROM y), cte_4 AS (SELECT a FROM cte_3 AS cte_3) (SELECT cte_2.a AS a FROM cte_2 AS cte_2) UNION (SELECT cte_4.a AS a FROM cte_4 AS cte_4);
-- Subquery
SELECT a FROM x WHERE b = (SELECT y.c FROM y);
SELECT a FROM x WHERE b = (SELECT y.c FROM y);
-- Correlated subquery
SELECT a FROM x WHERE b = (SELECT c FROM y WHERE y.a = x.a);
SELECT a FROM x WHERE b = (SELECT c FROM y WHERE y.a = x.a);
-- Duplicate CTE
SELECT a FROM (SELECT b FROM x) AS y JOIN (SELECT b FROM x) AS z;
WITH y AS (SELECT b FROM x) SELECT a FROM y AS y JOIN y AS z;
-- Doubly duplicate CTE
SELECT * FROM (SELECT * FROM x JOIN (SELECT * FROM x) AS y) AS z JOIN (SELECT * FROM x JOIN (SELECT * FROM x) AS y) AS q;
WITH y AS (SELECT * FROM x), z AS (SELECT * FROM x JOIN y AS y) SELECT * FROM z AS z JOIN z AS q;
-- Another duplicate...
SELECT x.id FROM (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) AS x JOIN (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) AS y ON x.id = y.id;
WITH x_2 AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT x.id FROM x_2 AS x JOIN x_2 AS y ON x.id = y.id;
-- Root subquery
(SELECT * FROM (SELECT * FROM x)) LIMIT 1;
(WITH cte AS (SELECT * FROM x) SELECT * FROM cte AS cte) LIMIT 1;
-- Existing duplicate CTE
WITH y AS (SELECT a FROM x) SELECT a FROM (SELECT a FROM x) AS y JOIN y AS z;
WITH y AS (SELECT a FROM x) SELECT a FROM y AS y JOIN y AS z;

View file

@ -18,6 +18,14 @@ SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 GROUP BY x.a;
SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b; SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1; SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1;
-- Outer query has join
SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1;
# leave_tables_isolated: true
SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM (SELECT x.a AS a, x.b AS b FROM x AS x WHERE x.a > 1) AS x JOIN y AS y ON x.b = y.b;
-- Join on derived table -- Join on derived table
SELECT a, c FROM x JOIN (SELECT b, c FROM y) AS y ON x.b = y.b; SELECT a, c FROM x JOIN (SELECT b, c FROM y) AS y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
@ -42,13 +50,9 @@ SELECT q_2.a AS a, q.c AS c, r.c AS c FROM x AS q_2 JOIN y AS r_2 ON q_2.b = r_2
SELECT r.b FROM (SELECT b FROM x AS x) AS q JOIN (SELECT b FROM x) AS r ON q.b = r.b; SELECT r.b FROM (SELECT b FROM x AS x) AS q JOIN (SELECT b FROM x) AS r ON q.b = r.b;
SELECT x_2.b AS b FROM x AS x JOIN x AS x_2 ON x.b = x_2.b; SELECT x_2.b AS b FROM x AS x JOIN x AS x_2 ON x.b = x_2.b;
-- WHERE clause in joined derived table is merged -- WHERE clause in joined derived table is merged to ON clause
SELECT x.a, y.c FROM x JOIN (SELECT b, c FROM y WHERE c > 1) AS y; SELECT x.a, y.c FROM x JOIN (SELECT b, c FROM y WHERE c > 1) AS y;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y WHERE y.c > 1; SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON y.c > 1;
-- WHERE clause in outer joined derived table is merged to ON clause
SELECT x.a, y.c FROM x LEFT JOIN (SELECT b, c FROM y WHERE c > 1) AS y;
SELECT x.a AS a, y.c AS c FROM x AS x LEFT JOIN y AS y ON y.c > 1;
-- Comma JOIN in outer query -- Comma JOIN in outer query
SELECT x.a, y.c FROM (SELECT a FROM x) AS x, (SELECT c FROM y) AS y; SELECT x.a, y.c FROM (SELECT a FROM x) AS x, (SELECT c FROM y) AS y;
@ -61,3 +65,35 @@ SELECT x.a AS a, z.c AS c FROM x AS x CROSS JOIN y AS z;
-- (Regression) Column in ORDER BY -- (Regression) Column in ORDER BY
SELECT * FROM (SELECT * FROM (SELECT * FROM x)) ORDER BY a LIMIT 1; SELECT * FROM (SELECT * FROM (SELECT * FROM x)) ORDER BY a LIMIT 1;
SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a LIMIT 1; SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a LIMIT 1;
-- CTE
WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x;
SELECT x.a AS a, x.b AS b FROM x AS x;
-- CTE with outer table alias
WITH y AS (SELECT a, b FROM x) SELECT a, b FROM y AS z;
SELECT x.a AS a, x.b AS b FROM x AS x;
-- Nested CTE
WITH x AS (SELECT a FROM x), x2 AS (SELECT a FROM x) SELECT a FROM x2;
SELECT x.a AS a FROM x AS x;
-- CTE WHERE clause is merged
WITH x AS (SELECT a, b FROM x WHERE a > 1) SELECT a, SUM(b) FROM x GROUP BY a;
SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 GROUP BY x.a;
-- CTE Outer query has join
WITH x AS (SELECT a, b FROM x WHERE a > 1) SELECT a, c FROM x AS x JOIN y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1;
-- CTE with inner table alias
WITH y AS (SELECT a, b FROM x AS q) SELECT a, b FROM y AS z;
SELECT q.a AS a, q.b AS b FROM x AS q;
-- Duplicate queries to CTE
WITH x AS (SELECT a, b FROM x) SELECT x.a, y.b FROM x JOIN x AS y;
WITH x AS (SELECT x.a AS a, x.b AS b FROM x AS x) SELECT x.a AS a, y.b AS b FROM x JOIN x AS y;
-- Nested CTE
SELECT * FROM (WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x);
SELECT x.a AS a, x.b AS b FROM x AS x;

View file

@ -65,18 +65,14 @@ WITH "cte1" AS (
SELECT SELECT
"x"."a" AS "a" "x"."a" AS "a"
FROM "x" AS "x" FROM "x" AS "x"
), "cte2" AS (
SELECT
"cte1"."a" + 1 AS "a"
FROM "cte1"
) )
SELECT SELECT
"cte1"."a" AS "a" "cte1"."a" AS "a"
FROM "cte1" FROM "cte1"
UNION ALL UNION ALL
SELECT SELECT
"cte2"."a" AS "a" "cte1"."a" + 1 AS "a"
FROM "cte2"; FROM "cte1";
SELECT a, SUM(b) SELECT a, SUM(b)
FROM ( FROM (
@ -86,18 +82,19 @@ FROM (
) d ) d
WHERE (TRUE AND TRUE OR 'a' = 'b') AND a > 1 WHERE (TRUE AND TRUE OR 'a' = 'b') AND a > 1
GROUP BY a; GROUP BY a;
SELECT WITH "_u_0" AS (
"x"."a" AS "a",
SUM("y"."b") AS "_col_1"
FROM "x" AS "x"
LEFT JOIN (
SELECT SELECT
MAX("y"."b") AS "_col_0", MAX("y"."b") AS "_col_0",
"y"."a" AS "_u_1" "y"."a" AS "_u_1"
FROM "y" AS "y" FROM "y" AS "y"
GROUP BY GROUP BY
"y"."a" "y"."a"
) AS "_u_0" )
SELECT
"x"."a" AS "a",
SUM("y"."b") AS "_col_1"
FROM "x" AS "x"
LEFT JOIN "_u_0" AS "_u_0"
ON "x"."a" = "_u_0"."_u_1" ON "x"."a" = "_u_0"."_u_1"
JOIN "y" AS "y" JOIN "y" AS "y"
ON "x"."a" = "y"."a" ON "x"."a" = "y"."a"
@ -127,3 +124,16 @@ LIMIT 1;
FROM "y" AS "y" FROM "y" AS "y"
) )
LIMIT 1; LIMIT 1;
# dialect: spark
SELECT /*+ BROADCAST(y) */ x.b FROM x JOIN y ON x.b = y.b;
SELECT /*+ BROADCAST(`y`) */
`x`.`b` AS `b`
FROM `x` AS `x`
JOIN `y` AS `y`
ON `x`.`b` = `y`.`b`;
SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
SELECT
AGGREGATE(ARRAY("x"."a", "x"."b"), 0, ("x", "acc") -> "x" + "acc" + "x"."a") AS "sum_agg"
FROM "x" AS "x";

View file

@ -69,6 +69,9 @@ SELECT ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.b) AS row_num FROM x AS x
SELECT x.b, x.a FROM x LEFT JOIN y ON x.b = y.b QUALIFY ROW_NUMBER() OVER(PARTITION BY x.b ORDER BY x.a DESC) = 1; SELECT x.b, x.a FROM x LEFT JOIN y ON x.b = y.b QUALIFY ROW_NUMBER() OVER(PARTITION BY x.b ORDER BY x.a DESC) = 1;
SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_NUMBER() OVER (PARTITION BY x.b ORDER BY x.a DESC) = 1; SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_NUMBER() OVER (PARTITION BY x.b ORDER BY x.a DESC) = 1;
SELECT AGGREGATE(ARRAY(a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + x.a) AS sum_agg FROM x AS x;
-------------------------------------- --------------------------------------
-- Derived tables -- Derived tables
-------------------------------------- --------------------------------------
@ -231,3 +234,10 @@ SELECT COALESCE(x.b, y.b) AS b FROM x AS x JOIN y AS y ON x.b = y.b WHERE COALES
SELECT b FROM x JOIN y USING (b) JOIN z USING (b); SELECT b FROM x JOIN y USING (b) JOIN z USING (b);
SELECT COALESCE(x.b, y.b, z.b) AS b FROM x AS x JOIN y AS y ON x.b = y.b JOIN z AS z ON x.b = z.b; SELECT COALESCE(x.b, y.b, z.b) AS b FROM x AS x JOIN y AS y ON x.b = y.b JOIN z AS z ON x.b = z.b;
--------------------------------------
-- Hint with table reference
--------------------------------------
# dialect: spark
SELECT /*+ BROADCAST(y) */ x.b FROM x JOIN y ON x.b = y.b;
SELECT /*+ BROADCAST(y) */ x.b AS b FROM x AS x JOIN y AS y ON x.b = y.b;

View file

@ -5,7 +5,6 @@ SELECT z.* FROM x;
SELECT x FROM x; SELECT x FROM x;
INSERT INTO x VALUES (1, 2); INSERT INTO x VALUES (1, 2);
SELECT a FROM x AS z JOIN y AS z; SELECT a FROM x AS z JOIN y AS z;
WITH z AS (SELECT * FROM x) SELECT * FROM x AS z;
SELECT a FROM x JOIN (SELECT b FROM y WHERE y.b = x.c); SELECT a FROM x JOIN (SELECT b FROM y WHERE y.b = x.c);
SELECT a FROM x AS y JOIN (SELECT a FROM y) AS q ON y.a = q.a; SELECT a FROM x AS y JOIN (SELECT a FROM y) AS q ON y.a = q.a;
SELECT q.a FROM (SELECT x.b FROM x) AS z JOIN (SELECT a FROM z) AS q ON z.b = q.a; SELECT q.a FROM (SELECT x.b FROM x) AS z JOIN (SELECT a FROM z) AS q ON z.b = q.a;

View file

@ -97,19 +97,32 @@ order by
p_partkey p_partkey
limit limit
100; 100;
WITH "_e_0" AS ( WITH "partsupp_2" AS (
SELECT SELECT
"partsupp"."ps_partkey" AS "ps_partkey", "partsupp"."ps_partkey" AS "ps_partkey",
"partsupp"."ps_suppkey" AS "ps_suppkey", "partsupp"."ps_suppkey" AS "ps_suppkey",
"partsupp"."ps_supplycost" AS "ps_supplycost" "partsupp"."ps_supplycost" AS "ps_supplycost"
FROM "partsupp" AS "partsupp" FROM "partsupp" AS "partsupp"
), "_e_1" AS ( ), "region_2" AS (
SELECT SELECT
"region"."r_regionkey" AS "r_regionkey", "region"."r_regionkey" AS "r_regionkey",
"region"."r_name" AS "r_name" "region"."r_name" AS "r_name"
FROM "region" AS "region" FROM "region" AS "region"
WHERE WHERE
"region"."r_name" = 'EUROPE' "region"."r_name" = 'EUROPE'
), "_u_0" AS (
SELECT
MIN("partsupp"."ps_supplycost") AS "_col_0",
"partsupp"."ps_partkey" AS "_u_1"
FROM "partsupp_2" AS "partsupp"
CROSS JOIN "region_2" AS "region"
JOIN "nation" AS "nation"
ON "nation"."n_regionkey" = "region"."r_regionkey"
JOIN "supplier" AS "supplier"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey"
GROUP BY
"partsupp"."ps_partkey"
) )
SELECT SELECT
"supplier"."s_acctbal" AS "s_acctbal", "supplier"."s_acctbal" AS "s_acctbal",
@ -121,25 +134,12 @@ SELECT
"supplier"."s_phone" AS "s_phone", "supplier"."s_phone" AS "s_phone",
"supplier"."s_comment" AS "s_comment" "supplier"."s_comment" AS "s_comment"
FROM "part" AS "part" FROM "part" AS "part"
LEFT JOIN ( LEFT JOIN "_u_0" AS "_u_0"
SELECT
MIN("partsupp"."ps_supplycost") AS "_col_0",
"partsupp"."ps_partkey" AS "_u_1"
FROM "_e_0" AS "partsupp"
CROSS JOIN "_e_1" AS "region"
JOIN "nation" AS "nation"
ON "nation"."n_regionkey" = "region"."r_regionkey"
JOIN "supplier" AS "supplier"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey"
GROUP BY
"partsupp"."ps_partkey"
) AS "_u_0"
ON "part"."p_partkey" = "_u_0"."_u_1" ON "part"."p_partkey" = "_u_0"."_u_1"
CROSS JOIN "_e_1" AS "region" CROSS JOIN "region_2" AS "region"
JOIN "nation" AS "nation" JOIN "nation" AS "nation"
ON "nation"."n_regionkey" = "region"."r_regionkey" ON "nation"."n_regionkey" = "region"."r_regionkey"
JOIN "_e_0" AS "partsupp" JOIN "partsupp_2" AS "partsupp"
ON "part"."p_partkey" = "partsupp"."ps_partkey" ON "part"."p_partkey" = "partsupp"."ps_partkey"
JOIN "supplier" AS "supplier" JOIN "supplier" AS "supplier"
ON "supplier"."s_nationkey" = "nation"."n_nationkey" ON "supplier"."s_nationkey" = "nation"."n_nationkey"
@ -193,12 +193,12 @@ SELECT
FROM "customer" AS "customer" FROM "customer" AS "customer"
JOIN "orders" AS "orders" JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey" ON "customer"."c_custkey" = "orders"."o_custkey"
AND "orders"."o_orderdate" < '1995-03-15'
JOIN "lineitem" AS "lineitem" JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey" ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
AND "lineitem"."l_shipdate" > '1995-03-15'
WHERE WHERE
"customer"."c_mktsegment" = 'BUILDING' "customer"."c_mktsegment" = 'BUILDING'
AND "lineitem"."l_shipdate" > '1995-03-15'
AND "orders"."o_orderdate" < '1995-03-15'
GROUP BY GROUP BY
"lineitem"."l_orderkey", "lineitem"."l_orderkey",
"orders"."o_orderdate", "orders"."o_orderdate",
@ -232,11 +232,7 @@ group by
o_orderpriority o_orderpriority
order by order by
o_orderpriority; o_orderpriority;
SELECT WITH "_u_0" AS (
"orders"."o_orderpriority" AS "o_orderpriority",
COUNT(*) AS "order_count"
FROM "orders" AS "orders"
LEFT JOIN (
SELECT SELECT
"lineitem"."l_orderkey" AS "l_orderkey" "lineitem"."l_orderkey" AS "l_orderkey"
FROM "lineitem" AS "lineitem" FROM "lineitem" AS "lineitem"
@ -244,7 +240,12 @@ LEFT JOIN (
"lineitem"."l_commitdate" < "lineitem"."l_receiptdate" "lineitem"."l_commitdate" < "lineitem"."l_receiptdate"
GROUP BY GROUP BY
"lineitem"."l_orderkey" "lineitem"."l_orderkey"
) AS "_u_0" )
SELECT
"orders"."o_orderpriority" AS "o_orderpriority",
COUNT(*) AS "order_count"
FROM "orders" AS "orders"
LEFT JOIN "_u_0" AS "_u_0"
ON "_u_0"."l_orderkey" = "orders"."o_orderkey" ON "_u_0"."l_orderkey" = "orders"."o_orderkey"
WHERE WHERE
"orders"."o_orderdate" < CAST('1993-10-01' AS DATE) "orders"."o_orderdate" < CAST('1993-10-01' AS DATE)
@ -290,7 +291,10 @@ SELECT
FROM "customer" AS "customer" FROM "customer" AS "customer"
JOIN "orders" AS "orders" JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey" ON "customer"."c_custkey" = "orders"."o_custkey"
CROSS JOIN "region" AS "region" AND "orders"."o_orderdate" < CAST('1995-01-01' AS DATE)
AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE)
JOIN "region" AS "region"
ON "region"."r_name" = 'ASIA'
JOIN "nation" AS "nation" JOIN "nation" AS "nation"
ON "nation"."n_regionkey" = "region"."r_regionkey" ON "nation"."n_regionkey" = "region"."r_regionkey"
JOIN "supplier" AS "supplier" JOIN "supplier" AS "supplier"
@ -299,10 +303,6 @@ JOIN "supplier" AS "supplier"
JOIN "lineitem" AS "lineitem" JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey" ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
AND "lineitem"."l_suppkey" = "supplier"."s_suppkey" AND "lineitem"."l_suppkey" = "supplier"."s_suppkey"
WHERE
"orders"."o_orderdate" < CAST('1995-01-01' AS DATE)
AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE)
AND "region"."r_name" = 'ASIA'
GROUP BY GROUP BY
"nation"."n_name" "nation"."n_name"
ORDER BY ORDER BY
@ -371,7 +371,7 @@ order by
supp_nation, supp_nation,
cust_nation, cust_nation,
l_year; l_year;
WITH "_e_0" AS ( WITH "n1" AS (
SELECT SELECT
"nation"."n_nationkey" AS "n_nationkey", "nation"."n_nationkey" AS "n_nationkey",
"nation"."n_name" AS "n_name" "nation"."n_name" AS "n_name"
@ -389,14 +389,15 @@ SELECT
)) AS "revenue" )) AS "revenue"
FROM "supplier" AS "supplier" FROM "supplier" AS "supplier"
JOIN "lineitem" AS "lineitem" JOIN "lineitem" AS "lineitem"
ON "supplier"."s_suppkey" = "lineitem"."l_suppkey" ON "lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
AND "supplier"."s_suppkey" = "lineitem"."l_suppkey"
JOIN "orders" AS "orders" JOIN "orders" AS "orders"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey" ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
JOIN "customer" AS "customer" JOIN "customer" AS "customer"
ON "customer"."c_custkey" = "orders"."o_custkey" ON "customer"."c_custkey" = "orders"."o_custkey"
JOIN "_e_0" AS "n1" JOIN "n1" AS "n1"
ON "supplier"."s_nationkey" = "n1"."n_nationkey" ON "supplier"."s_nationkey" = "n1"."n_nationkey"
JOIN "_e_0" AS "n2" JOIN "n1" AS "n2"
ON "customer"."c_nationkey" = "n2"."n_nationkey" ON "customer"."c_nationkey" = "n2"."n_nationkey"
AND ( AND (
"n1"."n_name" = 'FRANCE' "n1"."n_name" = 'FRANCE'
@ -406,8 +407,6 @@ JOIN "_e_0" AS "n2"
"n1"."n_name" = 'GERMANY' "n1"."n_name" = 'GERMANY'
OR "n2"."n_name" = 'GERMANY' OR "n2"."n_name" = 'GERMANY'
) )
WHERE
"lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
GROUP BY GROUP BY
"n1"."n_name", "n1"."n_name",
"n2"."n_name", "n2"."n_name",
@ -469,13 +468,15 @@ SELECT
1 - "lineitem"."l_discount" 1 - "lineitem"."l_discount"
)) AS "mkt_share" )) AS "mkt_share"
FROM "part" AS "part" FROM "part" AS "part"
CROSS JOIN "region" AS "region" JOIN "region" AS "region"
ON "region"."r_name" = 'AMERICA'
JOIN "nation" AS "nation" JOIN "nation" AS "nation"
ON "nation"."n_regionkey" = "region"."r_regionkey" ON "nation"."n_regionkey" = "region"."r_regionkey"
JOIN "customer" AS "customer" JOIN "customer" AS "customer"
ON "customer"."c_nationkey" = "nation"."n_nationkey" ON "customer"."c_nationkey" = "nation"."n_nationkey"
JOIN "orders" AS "orders" JOIN "orders" AS "orders"
ON "orders"."o_custkey" = "customer"."c_custkey" ON "orders"."o_custkey" = "customer"."c_custkey"
AND "orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
JOIN "lineitem" AS "lineitem" JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey" ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
AND "part"."p_partkey" = "lineitem"."l_partkey" AND "part"."p_partkey" = "lineitem"."l_partkey"
@ -484,9 +485,7 @@ JOIN "supplier" AS "supplier"
JOIN "nation" AS "nation_2" JOIN "nation" AS "nation_2"
ON "supplier"."s_nationkey" = "nation_2"."n_nationkey" ON "supplier"."s_nationkey" = "nation_2"."n_nationkey"
WHERE WHERE
"orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) "part"."p_type" = 'ECONOMY ANODIZED STEEL'
AND "part"."p_type" = 'ECONOMY ANODIZED STEEL'
AND "region"."r_name" = 'AMERICA'
GROUP BY GROUP BY
EXTRACT(year FROM "orders"."o_orderdate") EXTRACT(year FROM "orders"."o_orderdate")
ORDER BY ORDER BY
@ -604,14 +603,13 @@ SELECT
FROM "customer" AS "customer" FROM "customer" AS "customer"
JOIN "orders" AS "orders" JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey" ON "customer"."c_custkey" = "orders"."o_custkey"
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
JOIN "nation" AS "nation"
ON "customer"."c_nationkey" = "nation"."n_nationkey"
WHERE
"lineitem"."l_returnflag" = 'R'
AND "orders"."o_orderdate" < CAST('1994-01-01' AS DATE) AND "orders"."o_orderdate" < CAST('1994-01-01' AS DATE)
AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE) AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE)
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
AND "lineitem"."l_returnflag" = 'R'
JOIN "nation" AS "nation"
ON "customer"."c_nationkey" = "nation"."n_nationkey"
GROUP BY GROUP BY
"customer"."c_custkey", "customer"."c_custkey",
"customer"."c_name", "customer"."c_name",
@ -654,12 +652,12 @@ group by
) )
order by order by
value desc; value desc;
WITH "_e_0" AS ( WITH "supplier_2" AS (
SELECT SELECT
"supplier"."s_suppkey" AS "s_suppkey", "supplier"."s_suppkey" AS "s_suppkey",
"supplier"."s_nationkey" AS "s_nationkey" "supplier"."s_nationkey" AS "s_nationkey"
FROM "supplier" AS "supplier" FROM "supplier" AS "supplier"
), "_e_1" AS ( ), "nation_2" AS (
SELECT SELECT
"nation"."n_nationkey" AS "n_nationkey", "nation"."n_nationkey" AS "n_nationkey",
"nation"."n_name" AS "n_name" "nation"."n_name" AS "n_name"
@ -671,9 +669,9 @@ SELECT
"partsupp"."ps_partkey" AS "ps_partkey", "partsupp"."ps_partkey" AS "ps_partkey",
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value" SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value"
FROM "partsupp" AS "partsupp" FROM "partsupp" AS "partsupp"
JOIN "_e_0" AS "supplier" JOIN "supplier_2" AS "supplier"
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
JOIN "_e_1" AS "nation" JOIN "nation_2" AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey" ON "supplier"."s_nationkey" = "nation"."n_nationkey"
GROUP BY GROUP BY
"partsupp"."ps_partkey" "partsupp"."ps_partkey"
@ -682,9 +680,9 @@ HAVING
SELECT SELECT
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0" SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0"
FROM "partsupp" AS "partsupp" FROM "partsupp" AS "partsupp"
JOIN "_e_0" AS "supplier" JOIN "supplier_2" AS "supplier"
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
JOIN "_e_1" AS "nation" JOIN "nation_2" AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey" ON "supplier"."s_nationkey" = "nation"."n_nationkey"
) )
ORDER BY ORDER BY
@ -737,13 +735,12 @@ SELECT
END) AS "low_line_count" END) AS "low_line_count"
FROM "orders" AS "orders" FROM "orders" AS "orders"
JOIN "lineitem" AS "lineitem" JOIN "lineitem" AS "lineitem"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey" ON "lineitem"."l_commitdate" < "lineitem"."l_receiptdate"
WHERE
"lineitem"."l_commitdate" < "lineitem"."l_receiptdate"
AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE) AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE)
AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE) AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE)
AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate" AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate"
AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP') AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP')
AND "orders"."o_orderkey" = "lineitem"."l_orderkey"
GROUP BY GROUP BY
"lineitem"."l_shipmode" "lineitem"."l_shipmode"
ORDER BY ORDER BY
@ -772,10 +769,7 @@ group by
order by order by
custdist desc, custdist desc,
c_count desc; c_count desc;
SELECT WITH "c_orders" AS (
"c_orders"."c_count" AS "c_count",
COUNT(*) AS "custdist"
FROM (
SELECT SELECT
COUNT("orders"."o_orderkey") AS "c_count" COUNT("orders"."o_orderkey") AS "c_count"
FROM "customer" AS "customer" FROM "customer" AS "customer"
@ -784,7 +778,11 @@ FROM (
AND NOT "orders"."o_comment" LIKE '%special%requests%' AND NOT "orders"."o_comment" LIKE '%special%requests%'
GROUP BY GROUP BY
"customer"."c_custkey" "customer"."c_custkey"
) AS "c_orders" )
SELECT
"c_orders"."c_count" AS "c_count",
COUNT(*) AS "custdist"
FROM "c_orders" AS "c_orders"
GROUP BY GROUP BY
"c_orders"."c_count" "c_orders"."c_count"
ORDER BY ORDER BY
@ -920,13 +918,7 @@ order by
p_brand, p_brand,
p_type, p_type,
p_size; p_size;
SELECT WITH "_u_0" AS (
"part"."p_brand" AS "p_brand",
"part"."p_type" AS "p_type",
"part"."p_size" AS "p_size",
COUNT(DISTINCT "partsupp"."ps_suppkey") AS "supplier_cnt"
FROM "partsupp" AS "partsupp"
LEFT JOIN (
SELECT SELECT
"supplier"."s_suppkey" AS "s_suppkey" "supplier"."s_suppkey" AS "s_suppkey"
FROM "supplier" AS "supplier" FROM "supplier" AS "supplier"
@ -934,15 +926,22 @@ LEFT JOIN (
"supplier"."s_comment" LIKE '%Customer%Complaints%' "supplier"."s_comment" LIKE '%Customer%Complaints%'
GROUP BY GROUP BY
"supplier"."s_suppkey" "supplier"."s_suppkey"
) AS "_u_0" )
SELECT
"part"."p_brand" AS "p_brand",
"part"."p_type" AS "p_type",
"part"."p_size" AS "p_size",
COUNT(DISTINCT "partsupp"."ps_suppkey") AS "supplier_cnt"
FROM "partsupp" AS "partsupp"
LEFT JOIN "_u_0" AS "_u_0"
ON "partsupp"."ps_suppkey" = "_u_0"."s_suppkey" ON "partsupp"."ps_suppkey" = "_u_0"."s_suppkey"
JOIN "part" AS "part" JOIN "part" AS "part"
ON "part"."p_partkey" = "partsupp"."ps_partkey" ON "part"."p_brand" <> 'Brand#45'
WHERE AND "part"."p_partkey" = "partsupp"."ps_partkey"
"_u_0"."s_suppkey" IS NULL
AND "part"."p_brand" <> 'Brand#45'
AND "part"."p_size" IN (49, 14, 23, 45, 19, 3, 36, 9) AND "part"."p_size" IN (49, 14, 23, 45, 19, 3, 36, 9)
AND NOT "part"."p_type" LIKE 'MEDIUM POLISHED%' AND NOT "part"."p_type" LIKE 'MEDIUM POLISHED%'
WHERE
"_u_0"."s_suppkey" IS NULL
GROUP BY GROUP BY
"part"."p_brand", "part"."p_brand",
"part"."p_type", "part"."p_type",
@ -973,24 +972,25 @@ where
where where
l_partkey = p_partkey l_partkey = p_partkey
); );
SELECT WITH "_u_0" AS (
SUM("lineitem"."l_extendedprice") / 7.0 AS "avg_yearly"
FROM "lineitem" AS "lineitem"
JOIN "part" AS "part"
ON "part"."p_partkey" = "lineitem"."l_partkey"
LEFT JOIN (
SELECT SELECT
0.2 * AVG("lineitem"."l_quantity") AS "_col_0", 0.2 * AVG("lineitem"."l_quantity") AS "_col_0",
"lineitem"."l_partkey" AS "_u_1" "lineitem"."l_partkey" AS "_u_1"
FROM "lineitem" AS "lineitem" FROM "lineitem" AS "lineitem"
GROUP BY GROUP BY
"lineitem"."l_partkey" "lineitem"."l_partkey"
) AS "_u_0" )
SELECT
SUM("lineitem"."l_extendedprice") / 7.0 AS "avg_yearly"
FROM "lineitem" AS "lineitem"
JOIN "part" AS "part"
ON "part"."p_brand" = 'Brand#23'
AND "part"."p_container" = 'MED BOX'
AND "part"."p_partkey" = "lineitem"."l_partkey"
LEFT JOIN "_u_0" AS "_u_0"
ON "_u_0"."_u_1" = "part"."p_partkey" ON "_u_0"."_u_1" = "part"."p_partkey"
WHERE WHERE
"lineitem"."l_quantity" < "_u_0"."_col_0" "lineitem"."l_quantity" < "_u_0"."_col_0"
AND "part"."p_brand" = 'Brand#23'
AND "part"."p_container" = 'MED BOX'
AND NOT "_u_0"."_u_1" IS NULL; AND NOT "_u_0"."_u_1" IS NULL;
-------------------------------------- --------------------------------------
@ -1030,6 +1030,16 @@ order by
o_orderdate o_orderdate
limit limit
100; 100;
WITH "_u_0" AS (
SELECT
"lineitem"."l_orderkey" AS "l_orderkey"
FROM "lineitem" AS "lineitem"
GROUP BY
"lineitem"."l_orderkey",
"lineitem"."l_orderkey"
HAVING
SUM("lineitem"."l_quantity") > 300
)
SELECT SELECT
"customer"."c_name" AS "c_name", "customer"."c_name" AS "c_name",
"customer"."c_custkey" AS "c_custkey", "customer"."c_custkey" AS "c_custkey",
@ -1040,16 +1050,7 @@ SELECT
FROM "customer" AS "customer" FROM "customer" AS "customer"
JOIN "orders" AS "orders" JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey" ON "customer"."c_custkey" = "orders"."o_custkey"
LEFT JOIN ( LEFT JOIN "_u_0" AS "_u_0"
SELECT
"lineitem"."l_orderkey" AS "l_orderkey"
FROM "lineitem" AS "lineitem"
GROUP BY
"lineitem"."l_orderkey",
"lineitem"."l_orderkey"
HAVING
SUM("lineitem"."l_quantity") > 300
) AS "_u_0"
ON "orders"."o_orderkey" = "_u_0"."l_orderkey" ON "orders"."o_orderkey" = "_u_0"."l_orderkey"
JOIN "lineitem" AS "lineitem" JOIN "lineitem" AS "lineitem"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey" ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
@ -1200,15 +1201,7 @@ where
and n_name = 'CANADA' and n_name = 'CANADA'
order by order by
s_name; s_name;
SELECT WITH "_u_0" AS (
"supplier"."s_name" AS "s_name",
"supplier"."s_address" AS "s_address"
FROM "supplier" AS "supplier"
LEFT JOIN (
SELECT
"partsupp"."ps_suppkey" AS "ps_suppkey"
FROM "partsupp" AS "partsupp"
LEFT JOIN (
SELECT SELECT
0.5 * SUM("lineitem"."l_quantity") AS "_col_0", 0.5 * SUM("lineitem"."l_quantity") AS "_col_0",
"lineitem"."l_partkey" AS "_u_1", "lineitem"."l_partkey" AS "_u_1",
@ -1220,10 +1213,7 @@ LEFT JOIN (
GROUP BY GROUP BY
"lineitem"."l_partkey", "lineitem"."l_partkey",
"lineitem"."l_suppkey" "lineitem"."l_suppkey"
) AS "_u_0" ), "_u_3" AS (
ON "_u_0"."_u_1" = "partsupp"."ps_partkey"
AND "_u_0"."_u_2" = "partsupp"."ps_suppkey"
LEFT JOIN (
SELECT SELECT
"part"."p_partkey" AS "p_partkey" "part"."p_partkey" AS "p_partkey"
FROM "part" AS "part" FROM "part" AS "part"
@ -1231,7 +1221,14 @@ LEFT JOIN (
"part"."p_name" LIKE 'forest%' "part"."p_name" LIKE 'forest%'
GROUP BY GROUP BY
"part"."p_partkey" "part"."p_partkey"
) AS "_u_3" ), "_u_4" AS (
SELECT
"partsupp"."ps_suppkey" AS "ps_suppkey"
FROM "partsupp" AS "partsupp"
LEFT JOIN "_u_0" AS "_u_0"
ON "_u_0"."_u_1" = "partsupp"."ps_partkey"
AND "_u_0"."_u_2" = "partsupp"."ps_suppkey"
LEFT JOIN "_u_3" AS "_u_3"
ON "partsupp"."ps_partkey" = "_u_3"."p_partkey" ON "partsupp"."ps_partkey" = "_u_3"."p_partkey"
WHERE WHERE
"partsupp"."ps_availqty" > "_u_0"."_col_0" "partsupp"."ps_availqty" > "_u_0"."_col_0"
@ -1240,13 +1237,18 @@ LEFT JOIN (
AND NOT "_u_3"."p_partkey" IS NULL AND NOT "_u_3"."p_partkey" IS NULL
GROUP BY GROUP BY
"partsupp"."ps_suppkey" "partsupp"."ps_suppkey"
) AS "_u_4" )
SELECT
"supplier"."s_name" AS "s_name",
"supplier"."s_address" AS "s_address"
FROM "supplier" AS "supplier"
LEFT JOIN "_u_4" AS "_u_4"
ON "supplier"."s_suppkey" = "_u_4"."ps_suppkey" ON "supplier"."s_suppkey" = "_u_4"."ps_suppkey"
JOIN "nation" AS "nation" JOIN "nation" AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey" ON "nation"."n_name" = 'CANADA'
AND "supplier"."s_nationkey" = "nation"."n_nationkey"
WHERE WHERE
"nation"."n_name" = 'CANADA' NOT "_u_4"."ps_suppkey" IS NULL
AND NOT "_u_4"."ps_suppkey" IS NULL
ORDER BY ORDER BY
"s_name"; "s_name";
@ -1294,22 +1296,14 @@ order by
s_name s_name
limit limit
100; 100;
SELECT WITH "_u_0" AS (
"supplier"."s_name" AS "s_name",
COUNT(*) AS "numwait"
FROM "supplier" AS "supplier"
JOIN "lineitem" AS "lineitem"
ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
LEFT JOIN (
SELECT SELECT
"l2"."l_orderkey" AS "l_orderkey", "l2"."l_orderkey" AS "l_orderkey",
ARRAY_AGG("l2"."l_suppkey") AS "_u_1" ARRAY_AGG("l2"."l_suppkey") AS "_u_1"
FROM "lineitem" AS "l2" FROM "lineitem" AS "l2"
GROUP BY GROUP BY
"l2"."l_orderkey" "l2"."l_orderkey"
) AS "_u_0" ), "_u_2" AS (
ON "_u_0"."l_orderkey" = "lineitem"."l_orderkey"
LEFT JOIN (
SELECT SELECT
"l3"."l_orderkey" AS "l_orderkey", "l3"."l_orderkey" AS "l_orderkey",
ARRAY_AGG("l3"."l_suppkey") AS "_u_3" ARRAY_AGG("l3"."l_suppkey") AS "_u_3"
@ -1318,20 +1312,29 @@ LEFT JOIN (
"l3"."l_receiptdate" > "l3"."l_commitdate" "l3"."l_receiptdate" > "l3"."l_commitdate"
GROUP BY GROUP BY
"l3"."l_orderkey" "l3"."l_orderkey"
) AS "_u_2" )
SELECT
"supplier"."s_name" AS "s_name",
COUNT(*) AS "numwait"
FROM "supplier" AS "supplier"
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_receiptdate" > "lineitem"."l_commitdate"
AND "supplier"."s_suppkey" = "lineitem"."l_suppkey"
LEFT JOIN "_u_0" AS "_u_0"
ON "_u_0"."l_orderkey" = "lineitem"."l_orderkey"
LEFT JOIN "_u_2" AS "_u_2"
ON "_u_2"."l_orderkey" = "lineitem"."l_orderkey" ON "_u_2"."l_orderkey" = "lineitem"."l_orderkey"
JOIN "orders" AS "orders" JOIN "orders" AS "orders"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey" ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
AND "orders"."o_orderstatus" = 'F'
JOIN "nation" AS "nation" JOIN "nation" AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey" ON "nation"."n_name" = 'SAUDI ARABIA'
AND "supplier"."s_nationkey" = "nation"."n_nationkey"
WHERE WHERE
( (
"_u_2"."l_orderkey" IS NULL "_u_2"."l_orderkey" IS NULL
OR NOT ARRAY_ANY("_u_2"."_u_3", "_x" -> "_x" <> "lineitem"."l_suppkey") OR NOT ARRAY_ANY("_u_2"."_u_3", "_x" -> "_x" <> "lineitem"."l_suppkey")
) )
AND "lineitem"."l_receiptdate" > "lineitem"."l_commitdate"
AND "nation"."n_name" = 'SAUDI ARABIA'
AND "orders"."o_orderstatus" = 'F'
AND ARRAY_ANY("_u_0"."_u_1", "_x" -> "_x" <> "lineitem"."l_suppkey") AND ARRAY_ANY("_u_0"."_u_1", "_x" -> "_x" <> "lineitem"."l_suppkey")
AND NOT "_u_0"."l_orderkey" IS NULL AND NOT "_u_0"."l_orderkey" IS NULL
GROUP BY GROUP BY
@ -1381,18 +1384,19 @@ group by
cntrycode cntrycode
order by order by
cntrycode; cntrycode;
SELECT WITH "_u_0" AS (
SUBSTRING("customer"."c_phone", 1, 2) AS "cntrycode",
COUNT(*) AS "numcust",
SUM("customer"."c_acctbal") AS "totacctbal"
FROM "customer" AS "customer"
LEFT JOIN (
SELECT SELECT
"orders"."o_custkey" AS "_u_1" "orders"."o_custkey" AS "_u_1"
FROM "orders" AS "orders" FROM "orders" AS "orders"
GROUP BY GROUP BY
"orders"."o_custkey" "orders"."o_custkey"
) AS "_u_0" )
SELECT
SUBSTRING("customer"."c_phone", 1, 2) AS "cntrycode",
COUNT(*) AS "numcust",
SUM("customer"."c_acctbal") AS "totacctbal"
FROM "customer" AS "customer"
LEFT JOIN "_u_0" AS "_u_0"
ON "_u_0"."_u_1" = "customer"."c_custkey" ON "_u_0"."_u_1" = "customer"."c_custkey"
WHERE WHERE
"_u_0"."_u_1" IS NULL "_u_0"."_u_1" IS NULL

View file

@ -264,22 +264,3 @@ CREATE TABLE "t_customer_account" (
"account_no" VARCHAR(100) "account_no" VARCHAR(100)
); );
CREATE TABLE "t_customer_account" (
"id" int(11) NOT NULL AUTO_INCREMENT,
"customer_id" int(11) DEFAULT NULL COMMENT '客户id',
"bank" varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别',
"account_no" varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号',
PRIMARY KEY ("id")
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='客户账户表';
CREATE TABLE "t_customer_account" (
"id" INT(11) NOT NULL AUTO_INCREMENT,
"customer_id" INT(11) DEFAULT NULL COMMENT '客户id',
"bank" VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别',
"account_no" VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号',
PRIMARY KEY("id")
)
ENGINE=InnoDB
AUTO_INCREMENT=1
DEFAULT CHARACTER SET=utf8
COLLATE=utf8_bin
COMMENT='客户账户表';

View file

@ -308,6 +308,18 @@ class TestBuild(unittest.TestCase):
lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select("x"), lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select("x"),
"SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned", "SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned",
), ),
(
lambda: exp.update("tbl", {"x": None, "y": {"x": 1}}),
"UPDATE tbl SET x = NULL, y = MAP('x', 1)",
),
(
lambda: exp.update("tbl", {"x": 1}, where="y > 0"),
"UPDATE tbl SET x = 1 WHERE y > 0",
),
(
lambda: exp.update("tbl", {"x": 1}, from_="tbl2"),
"UPDATE tbl SET x = 1 FROM tbl2",
),
]: ]:
with self.subTest(sql): with self.subTest(sql):
self.assertEqual(expression().sql(dialect[0] if dialect else None), sql) self.assertEqual(expression().sql(dialect[0] if dialect else None), sql)

View file

@ -27,6 +27,8 @@ class TestExpressions(unittest.TestCase):
parse_one("ROW() OVER (partition BY y)"), parse_one("ROW() OVER (partition BY y)"),
) )
self.assertEqual(parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)")) self.assertEqual(parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)"))
self.assertEqual(exp.Table(pivots=[]), exp.Table())
self.assertNotEqual(exp.Table(pivots=[None]), exp.Table())
def test_find(self): def test_find(self):
expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y") expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y")
@ -280,6 +282,19 @@ class TestExpressions(unittest.TestCase):
expression.find(exp.Table).replace(parse_one("y")) expression.find(exp.Table).replace(parse_one("y"))
self.assertEqual(expression.sql(), "SELECT c, b FROM y") self.assertEqual(expression.sql(), "SELECT c, b FROM y")
def test_pop(self):
expression = parse_one("SELECT a, b FROM x")
expression.find(exp.Column).pop()
self.assertEqual(expression.sql(), "SELECT b FROM x")
expression.find(exp.Column).pop()
self.assertEqual(expression.sql(), "SELECT FROM x")
expression.pop()
self.assertEqual(expression.sql(), "SELECT FROM x")
expression = parse_one("WITH x AS (SELECT a FROM x) SELECT * FROM x")
expression.find(exp.With).pop()
self.assertEqual(expression.sql(), "SELECT * FROM x")
def test_walk(self): def test_walk(self):
expression = parse_one("SELECT * FROM (SELECT * FROM x)") expression = parse_one("SELECT * FROM (SELECT * FROM x)")
self.assertEqual(len(list(expression.walk())), 9) self.assertEqual(len(list(expression.walk())), 9)
@ -316,6 +331,7 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(parse_one("MAX(a)"), exp.Max) self.assertIsInstance(parse_one("MAX(a)"), exp.Max)
self.assertIsInstance(parse_one("MIN(a)"), exp.Min) self.assertIsInstance(parse_one("MIN(a)"), exp.Min)
self.assertIsInstance(parse_one("MONTH(a)"), exp.Month) self.assertIsInstance(parse_one("MONTH(a)"), exp.Month)
self.assertIsInstance(parse_one("POSITION(' ' IN a)"), exp.StrPosition)
self.assertIsInstance(parse_one("POW(a, 2)"), exp.Pow) self.assertIsInstance(parse_one("POW(a, 2)"), exp.Pow)
self.assertIsInstance(parse_one("POWER(a, 2)"), exp.Pow) self.assertIsInstance(parse_one("POWER(a, 2)"), exp.Pow)
self.assertIsInstance(parse_one("QUANTILE(a, 0.90)"), exp.Quantile) self.assertIsInstance(parse_one("QUANTILE(a, 0.90)"), exp.Quantile)
@ -420,7 +436,7 @@ class TestExpressions(unittest.TestCase):
exp.Properties.from_dict( exp.Properties.from_dict(
{ {
"FORMAT": "parquet", "FORMAT": "parquet",
"PARTITIONED_BY": [exp.to_identifier("a"), exp.to_identifier("b")], "PARTITIONED_BY": (exp.to_identifier("a"), exp.to_identifier("b")),
"custom": 1, "custom": 1,
"TABLE_FORMAT": exp.to_identifier("test_format"), "TABLE_FORMAT": exp.to_identifier("test_format"),
"ENGINE": None, "ENGINE": None,
@ -444,4 +460,17 @@ class TestExpressions(unittest.TestCase):
), ),
) )
self.assertRaises(ValueError, exp.Properties.from_dict, {"FORMAT": {"key": "value"}}) self.assertRaises(ValueError, exp.Properties.from_dict, {"FORMAT": object})
def test_convert(self):
for value, expected in [
(1, "1"),
("1", "'1'"),
(None, "NULL"),
(True, "TRUE"),
((1, "2", None), "(1, '2', NULL)"),
([1, "2", None], "ARRAY(1, '2', NULL)"),
({"x": None}, "MAP('x', NULL)"),
]:
with self.subTest(value):
self.assertEqual(exp.convert(value).sql(), expected)

View file

@ -1,9 +1,11 @@
import unittest import unittest
from functools import partial
from sqlglot import optimizer, parse_one, table from sqlglot import exp, optimizer, parse_one, table
from sqlglot.errors import OptimizeError from sqlglot.errors import OptimizeError
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.schema import MappingSchema, ensure_schema from sqlglot.optimizer.schema import MappingSchema, ensure_schema
from sqlglot.optimizer.scope import traverse_scope from sqlglot.optimizer.scope import build_scope, traverse_scope
from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures
@ -27,11 +29,17 @@ class TestOptimizer(unittest.TestCase):
} }
def check_file(self, file, func, pretty=False, **kwargs): def check_file(self, file, func, pretty=False, **kwargs):
for meta, sql, expected in load_sql_fixture_pairs(f"optimizer/{file}.sql"): for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1):
dialect = meta.get("dialect") dialect = meta.get("dialect")
with self.subTest(sql): leave_tables_isolated = meta.get("leave_tables_isolated")
func_kwargs = {**kwargs}
if leave_tables_isolated is not None:
func_kwargs["leave_tables_isolated"] = leave_tables_isolated.lower() in ("true", "1")
with self.subTest(f"{i}, {sql}"):
self.assertEqual( self.assertEqual(
func(parse_one(sql, read=dialect), **kwargs).sql(pretty=pretty, dialect=dialect), func(parse_one(sql, read=dialect), **func_kwargs).sql(pretty=pretty, dialect=dialect),
expected, expected,
) )
@ -123,21 +131,20 @@ class TestOptimizer(unittest.TestCase):
optimizer.optimize_joins.optimize_joins, optimizer.optimize_joins.optimize_joins,
) )
def test_eliminate_subqueries(self): def test_merge_subqueries(self):
self.check_file( optimize = partial(
"eliminate_subqueries", optimizer.optimize,
optimizer.eliminate_subqueries.eliminate_subqueries, rules=[
pretty=True, optimizer.qualify_tables.qualify_tables,
optimizer.qualify_columns.qualify_columns,
optimizer.merge_subqueries.merge_subqueries,
],
) )
def test_merge_derived_tables(self): self.check_file("merge_subqueries", optimize, schema=self.schema)
def optimize(expression, **kwargs):
expression = optimizer.qualify_tables.qualify_tables(expression)
expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
expression = optimizer.merge_derived_tables.merge_derived_tables(expression)
return expression
self.check_file("merge_derived_tables", optimize, schema=self.schema) def test_eliminate_subqueries(self):
self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries)
def test_tpch(self): def test_tpch(self):
self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True) self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True)
@ -257,7 +264,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
ON s.b = r.b ON s.b = r.b
WHERE s.b > (SELECT MAX(x.a) FROM x WHERE x.b = s.b) WHERE s.b > (SELECT MAX(x.a) FROM x WHERE x.b = s.b)
""" """
scopes = traverse_scope(parse_one(sql)) for scopes in traverse_scope(parse_one(sql)), list(build_scope(parse_one(sql)).traverse()):
self.assertEqual(len(scopes), 5) self.assertEqual(len(scopes), 5)
self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x") self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x")
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y") self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
@ -271,3 +278,59 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(scopes[4].source_columns("q"), []) self.assertEqual(scopes[4].source_columns("q"), [])
self.assertEqual(len(scopes[4].source_columns("r")), 2) self.assertEqual(len(scopes[4].source_columns("r")), 2)
self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"}) self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"})
def test_literal_type_annotation(self):
tests = {
"SELECT 5": exp.DataType.Type.INT,
"SELECT 5.3": exp.DataType.Type.DOUBLE,
"SELECT 'bla'": exp.DataType.Type.VARCHAR,
"5": exp.DataType.Type.INT,
"5.3": exp.DataType.Type.DOUBLE,
"'bla'": exp.DataType.Type.VARCHAR,
}
for sql, target_type in tests.items():
expression = parse_one(sql)
annotated_expression = annotate_types(expression)
self.assertEqual(annotated_expression.find(exp.Literal).type, target_type)
def test_boolean_type_annotation(self):
tests = {
"SELECT TRUE": exp.DataType.Type.BOOLEAN,
"FALSE": exp.DataType.Type.BOOLEAN,
}
for sql, target_type in tests.items():
expression = parse_one(sql)
annotated_expression = annotate_types(expression)
self.assertEqual(annotated_expression.find(exp.Boolean).type, target_type)
def test_cast_type_annotation(self):
expression = parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))")
annotate_types(expression)
self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ)
self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR)
self.assertEqual(expression.args["to"].type, exp.DataType.Type.TIMESTAMPTZ)
self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT)
def test_cache_annotation(self):
expression = parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")
annotated_expression = annotate_types(expression)
self.assertEqual(annotated_expression.expression.expressions[0].type, exp.DataType.Type.INT)
def test_binary_annotation(self):
expression = parse_one("SELECT 0.0 + (2 + 3)")
annotate_types(expression)
expression = expression.expressions[0]
self.assertEqual(expression.type, exp.DataType.Type.DOUBLE)
self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE)
self.assertEqual(expression.right.type, exp.DataType.Type.INT)
self.assertEqual(expression.right.this.type, exp.DataType.Type.INT)
self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT)
self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT)

View file

@ -21,6 +21,11 @@ class TestParser(unittest.TestCase):
self.assertIsNotNone(parse_one("date").find(exp.Column)) self.assertIsNotNone(parse_one("date").find(exp.Column))
def test_float(self):
self.assertEqual(parse_one(".2"), parse_one("0.2"))
self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)"))
self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)"))
def test_table(self): def test_table(self):
tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)] tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)]
self.assertEqual(tables, ["a", "b.c", "d"]) self.assertEqual(tables, ["a", "b.c", "d"])

View file

@ -6,11 +6,32 @@ from sqlglot.transforms import unalias_group
class TestTime(unittest.TestCase): class TestTime(unittest.TestCase):
def validate(self, transform, sql, target): def validate(self, transform, sql, target):
with self.subTest(sql):
self.assertEqual(parse_one(sql).transform(transform).sql(), target) self.assertEqual(parse_one(sql).transform(transform).sql(), target)
def test_unalias_group(self): def test_unalias_group(self):
self.validate( self.validate(
unalias_group, unalias_group,
"SELECT a, b AS b, c AS c, 4 FROM x GROUP BY a, b, x.c, 4", "SELECT a, b AS b, c AS c, 4 FROM x GROUP BY a, b, x.c, 4",
"SELECT a, b AS b, c AS c, 4 FROM x GROUP BY a, 2, x.c, 4", "SELECT a, b AS b, c AS c, 4 FROM x GROUP BY a, b, x.c, 4",
)
self.validate(
unalias_group,
"SELECT TO_DATE(the_date) AS the_date, CUSTOM_UDF(other_col) AS other_col, last_col AS aliased_last, COUNT(*) AS the_count FROM x GROUP BY TO_DATE(the_date), CUSTOM_UDF(other_col), aliased_last",
"SELECT TO_DATE(the_date) AS the_date, CUSTOM_UDF(other_col) AS other_col, last_col AS aliased_last, COUNT(*) AS the_count FROM x GROUP BY TO_DATE(the_date), CUSTOM_UDF(other_col), 3",
)
self.validate(
unalias_group,
"SELECT SOME_UDF(TO_DATE(the_date)) AS the_date, COUNT(*) AS the_count FROM x GROUP BY SOME_UDF(TO_DATE(the_date))",
"SELECT SOME_UDF(TO_DATE(the_date)) AS the_date, COUNT(*) AS the_count FROM x GROUP BY SOME_UDF(TO_DATE(the_date))",
)
self.validate(
unalias_group,
"SELECT SOME_UDF(TO_DATE(the_date)) AS new_date, COUNT(*) AS the_count FROM x GROUP BY new_date",
"SELECT SOME_UDF(TO_DATE(the_date)) AS new_date, COUNT(*) AS the_count FROM x GROUP BY 1",
)
self.validate(
unalias_group,
"SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date",
"SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date",
) )