1
0
Fork 0

Merging upstream version 6.1.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 08:04:41 +01:00
parent 3c6d649c90
commit 08ecea3adf
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
61 changed files with 1844 additions and 1555 deletions

11
CHANGELOG.md Normal file
View file

@ -0,0 +1,11 @@
Changelog
=========
v6.1.0
------
Changes:
- New: mysql group\_concat separator [49a4099](https://github.com/tobymao/sqlglot/commit/49a4099adc93780eeffef8204af36559eab50a9f)
- Improvement: Better nested select parsing [45603f](https://github.com/tobymao/sqlglot/commit/45603f14bf9146dc3f8b330b85a0e25b77630b9b)

View file

@ -1,6 +1,6 @@
MIT License MIT License
Copyright (c) 2021 Toby Mao Copyright (c) 2022 Toby Mao
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal

View file

@ -8,5 +8,5 @@ python -m autoflake -i -r \
--remove-unused-variables \ --remove-unused-variables \
sqlglot/ tests/ sqlglot/ tests/
python -m isort --profile black sqlglot/ tests/ python -m isort --profile black sqlglot/ tests/
python -m black sqlglot/ tests/ python -m black --line-length 120 sqlglot/ tests/
python -m unittest python -m unittest

View file

@ -20,7 +20,7 @@ from sqlglot.generator import Generator
from sqlglot.parser import Parser from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType from sqlglot.tokens import Tokenizer, TokenType
__version__ = "6.0.4" __version__ = "6.1.1"
pretty = False pretty = False

View file

@ -49,12 +49,7 @@ args = parser.parse_args()
error_level = sqlglot.ErrorLevel[args.error_level.upper()] error_level = sqlglot.ErrorLevel[args.error_level.upper()]
if args.parse: if args.parse:
sqls = [ sqls = [repr(expression) for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)]
repr(expression)
for expression in sqlglot.parse(
args.sql, read=args.read, error_level=error_level
)
]
else: else:
sqls = sqlglot.transpile( sqls = sqlglot.transpile(
args.sql, args.sql,

View file

@ -7,6 +7,7 @@ from sqlglot.dialects.mysql import MySQL
from sqlglot.dialects.oracle import Oracle from sqlglot.dialects.oracle import Oracle
from sqlglot.dialects.postgres import Postgres from sqlglot.dialects.postgres import Postgres
from sqlglot.dialects.presto import Presto from sqlglot.dialects.presto import Presto
from sqlglot.dialects.redshift import Redshift
from sqlglot.dialects.snowflake import Snowflake from sqlglot.dialects.snowflake import Snowflake
from sqlglot.dialects.spark import Spark from sqlglot.dialects.spark import Spark
from sqlglot.dialects.sqlite import SQLite from sqlglot.dialects.sqlite import SQLite

View file

@ -44,6 +44,7 @@ class BigQuery(Dialect):
] ]
IDENTIFIERS = ["`"] IDENTIFIERS = ["`"]
ESCAPE = "\\" ESCAPE = "\\"
HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = { KEYWORDS = {
**Tokenizer.KEYWORDS, **Tokenizer.KEYWORDS,
@ -120,9 +121,5 @@ class BigQuery(Dialect):
def intersect_op(self, expression): def intersect_op(self, expression):
if not expression.args.get("distinct", False): if not expression.args.get("distinct", False):
self.unsupported( self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery")
"INTERSECT without DISTINCT is not supported in BigQuery" return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
)
return (
f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
)

View file

@ -20,6 +20,7 @@ class Dialects(str, Enum):
ORACLE = "oracle" ORACLE = "oracle"
POSTGRES = "postgres" POSTGRES = "postgres"
PRESTO = "presto" PRESTO = "presto"
REDSHIFT = "redshift"
SNOWFLAKE = "snowflake" SNOWFLAKE = "snowflake"
SPARK = "spark" SPARK = "spark"
SQLITE = "sqlite" SQLITE = "sqlite"
@ -53,12 +54,19 @@ class _Dialect(type):
klass.generator_class = getattr(klass, "Generator", Generator) klass.generator_class = getattr(klass, "Generator", Generator)
klass.tokenizer = klass.tokenizer_class() klass.tokenizer = klass.tokenizer_class()
klass.quote_start, klass.quote_end = list(klass.tokenizer_class.QUOTES.items())[ klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
0 klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0]
]
klass.identifier_start, klass.identifier_end = list( if klass.tokenizer_class._BIT_STRINGS and exp.BitString not in klass.generator_class.TRANSFORMS:
klass.tokenizer_class.IDENTIFIERS.items() bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
)[0] klass.generator_class.TRANSFORMS[
exp.BitString
] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
if klass.tokenizer_class._HEX_STRINGS and exp.HexString not in klass.generator_class.TRANSFORMS:
hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.HexString
] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
return klass return klass
@ -122,9 +130,7 @@ class Dialect(metaclass=_Dialect):
return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql) return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql)
def parse_into(self, expression_type, sql, **opts): def parse_into(self, expression_type, sql, **opts):
return self.parser(**opts).parse_into( return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql)
expression_type, self.tokenizer.tokenize(sql), sql
)
def generate(self, expression, **opts): def generate(self, expression, **opts):
return self.generator(**opts).generate(expression) return self.generator(**opts).generate(expression)
@ -164,9 +170,7 @@ class Dialect(metaclass=_Dialect):
def rename_func(name): def rename_func(name):
return ( return lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})"
lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})"
)
def approx_count_distinct_sql(self, expression): def approx_count_distinct_sql(self, expression):
@ -260,8 +264,7 @@ def format_time_lambda(exp_class, dialect, default=None):
return exp_class( return exp_class(
this=list_get(args, 0), this=list_get(args, 0),
format=Dialect[dialect].format_time( format=Dialect[dialect].format_time(
list_get(args, 1) list_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
or (Dialect[dialect].time_format if default is True else default)
), ),
) )

View file

@ -63,10 +63,7 @@ def _sort_array_reverse(args):
def _struct_pack_sql(self, expression): def _struct_pack_sql(self, expression):
args = [ args = [self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) for e in expression.expressions]
self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e)
for e in expression.expressions
]
return f"STRUCT_PACK({', '.join(args)})" return f"STRUCT_PACK({', '.join(args)})"

View file

@ -109,9 +109,7 @@ def _unnest_to_explode_sql(self, expression):
alias=exp.TableAlias(this=alias.this, columns=[column]), alias=exp.TableAlias(this=alias.this, columns=[column]),
) )
) )
for expression, column in zip( for expression, column in zip(unnest.expressions, alias.columns if alias else [])
unnest.expressions, alias.columns if alias else []
)
) )
return self.join_sql(expression) return self.join_sql(expression)
@ -206,14 +204,11 @@ class Hive(Dialect):
substr=list_get(args, 0), substr=list_get(args, 0),
position=list_get(args, 2), position=list_get(args, 2),
), ),
"LOG": ( "LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)),
lambda args: exp.Log.from_arg_list(args)
if len(args) > 1
else exp.Ln.from_arg_list(args)
),
"MAP": _parse_map, "MAP": _parse_map,
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
"PERCENTILE": exp.Quantile.from_arg_list, "PERCENTILE": exp.Quantile.from_arg_list,
"PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list,
"COLLECT_SET": exp.SetAgg.from_arg_list, "COLLECT_SET": exp.SetAgg.from_arg_list,
"SIZE": exp.ArraySize.from_arg_list, "SIZE": exp.ArraySize.from_arg_list,
"SPLIT": exp.RegexpSplit.from_arg_list, "SPLIT": exp.RegexpSplit.from_arg_list,
@ -262,6 +257,7 @@ class Hive(Dialect):
HiveMap: _map_sql, HiveMap: _map_sql,
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e.args['value'])}", exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e.args['value'])}",
exp.Quantile: rename_func("PERCENTILE"), exp.Quantile: rename_func("PERCENTILE"),
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"), exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
exp.RegexpSplit: rename_func("SPLIT"), exp.RegexpSplit: rename_func("SPLIT"),
exp.SafeDivide: no_safe_divide_sql, exp.SafeDivide: no_safe_divide_sql,
@ -296,8 +292,7 @@ class Hive(Dialect):
def datatype_sql(self, expression): def datatype_sql(self, expression):
if ( if (
expression.this expression.this in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR)
in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR)
and not expression.expressions and not expression.expressions
): ):
expression = exp.DataType.build("text") expression = exp.DataType.build("text")

View file

@ -49,6 +49,21 @@ def _str_to_date_sql(self, expression):
return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})" return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
def _trim_sql(self, expression):
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
remove_chars = self.sql(expression, "expression")
# Use TRIM/LTRIM/RTRIM syntax if the expression isn't mysql-specific
if not remove_chars:
return self.trim_sql(expression)
trim_type = f"{trim_type} " if trim_type else ""
remove_chars = f"{remove_chars} " if remove_chars else ""
from_part = "FROM " if trim_type or remove_chars else ""
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
def _date_add(expression_class): def _date_add(expression_class):
def func(args): def func(args):
interval = list_get(args, 1) interval = list_get(args, 1)
@ -88,9 +103,12 @@ class MySQL(Dialect):
QUOTES = ["'", '"'] QUOTES = ["'", '"']
COMMENTS = ["--", "#", ("/*", "*/")] COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"] IDENTIFIERS = ["`"]
BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")]
KEYWORDS = { KEYWORDS = {
**Tokenizer.KEYWORDS, **Tokenizer.KEYWORDS,
"SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER, "_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER,
"_BIG5": TokenType.INTRODUCER, "_BIG5": TokenType.INTRODUCER,
@ -145,6 +163,15 @@ class MySQL(Dialect):
"STR_TO_DATE": _str_to_date, "STR_TO_DATE": _str_to_date,
} }
FUNCTION_PARSERS = {
**Parser.FUNCTION_PARSERS,
"GROUP_CONCAT": lambda self: self.expression(
exp.GroupConcat,
this=self._parse_lambda(),
separator=self._match(TokenType.SEPARATOR) and self._parse_field(),
),
}
class Generator(Generator): class Generator(Generator):
NULL_ORDERING_SUPPORTED = False NULL_ORDERING_SUPPORTED = False
@ -158,6 +185,8 @@ class MySQL(Dialect):
exp.DateAdd: _date_add_sql("ADD"), exp.DateAdd: _date_add_sql("ADD"),
exp.DateSub: _date_add_sql("SUB"), exp.DateSub: _date_add_sql("SUB"),
exp.DateTrunc: _date_trunc_sql, exp.DateTrunc: _date_trunc_sql,
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.StrToDate: _str_to_date_sql, exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql, exp.StrToTime: _str_to_date_sql,
exp.Trim: _trim_sql,
} }

View file

@ -51,6 +51,14 @@ class Oracle(Dialect):
sep="", sep="",
) )
def alias_sql(self, expression):
if isinstance(expression.this, exp.Table):
to_sql = self.sql(expression, "alias")
# oracle does not allow "AS" between table and alias
to_sql = f" {to_sql}" if to_sql else ""
return f"{self.sql(expression, 'this')}{to_sql}"
return super().alias_sql(expression)
def offset_sql(self, expression): def offset_sql(self, expression):
return f"{super().offset_sql(expression)} ROWS" return f"{super().offset_sql(expression)} ROWS"

View file

@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import (
from sqlglot.generator import Generator from sqlglot.generator import Generator
from sqlglot.parser import Parser from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.transforms import delegate, preprocess
def _date_add_sql(kind): def _date_add_sql(kind):
@ -32,11 +33,96 @@ def _date_add_sql(kind):
return func return func
def _lateral_sql(self, expression):
this = self.sql(expression, "this")
if isinstance(expression.this, exp.Subquery):
return f"LATERAL{self.sep()}{this}"
alias = expression.args["alias"]
table = alias.name
table = f" {table}" if table else table
columns = self.expressions(alias, key="columns", flat=True)
columns = f" AS {columns}" if columns else ""
return f"LATERAL{self.sep()}{this}{table}{columns}"
def _substring_sql(self, expression):
this = self.sql(expression, "this")
start = self.sql(expression, "start")
length = self.sql(expression, "length")
from_part = f" FROM {start}" if start else ""
for_part = f" FOR {length}" if length else ""
return f"SUBSTRING({this}{from_part}{for_part})"
def _trim_sql(self, expression):
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
remove_chars = self.sql(expression, "expression")
collation = self.sql(expression, "collation")
# Use TRIM/LTRIM/RTRIM syntax if the expression isn't postgres-specific
if not remove_chars and not collation:
return self.trim_sql(expression)
trim_type = f"{trim_type} " if trim_type else ""
remove_chars = f"{remove_chars} " if remove_chars else ""
from_part = "FROM " if trim_type or remove_chars else ""
collation = f" COLLATE {collation}" if collation else ""
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def _auto_increment_to_serial(expression):
auto = expression.find(exp.AutoIncrementColumnConstraint)
if auto:
expression = expression.copy()
expression.args["constraints"].remove(auto.parent)
kind = expression.args["kind"]
if kind.this == exp.DataType.Type.INT:
kind.replace(exp.DataType(this=exp.DataType.Type.SERIAL))
elif kind.this == exp.DataType.Type.SMALLINT:
kind.replace(exp.DataType(this=exp.DataType.Type.SMALLSERIAL))
elif kind.this == exp.DataType.Type.BIGINT:
kind.replace(exp.DataType(this=exp.DataType.Type.BIGSERIAL))
return expression
def _serial_to_generated(expression):
kind = expression.args["kind"]
if kind.this == exp.DataType.Type.SERIAL:
data_type = exp.DataType(this=exp.DataType.Type.INT)
elif kind.this == exp.DataType.Type.SMALLSERIAL:
data_type = exp.DataType(this=exp.DataType.Type.SMALLINT)
elif kind.this == exp.DataType.Type.BIGSERIAL:
data_type = exp.DataType(this=exp.DataType.Type.BIGINT)
else:
data_type = None
if data_type:
expression = expression.copy()
expression.args["kind"].replace(data_type)
constraints = expression.args["constraints"]
generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False))
notnull = exp.ColumnConstraint(kind=exp.NotNullColumnConstraint())
if notnull not in constraints:
constraints.insert(0, notnull)
if generated not in constraints:
constraints.insert(0, generated)
return expression
class Postgres(Dialect): class Postgres(Dialect):
null_ordering = "nulls_are_large" null_ordering = "nulls_are_large"
time_format = "'YYYY-MM-DD HH24:MI:SS'" time_format = "'YYYY-MM-DD HH24:MI:SS'"
time_mapping = { time_mapping = {
"AM": "%p", # AM or PM "AM": "%p",
"PM": "%p",
"D": "%w", # 1-based day of week "D": "%w", # 1-based day of week
"DD": "%d", # day of month "DD": "%d", # day of month
"DDD": "%j", # zero padded day of year "DDD": "%j", # zero padded day of year
@ -65,14 +151,25 @@ class Postgres(Dialect):
} }
class Tokenizer(Tokenizer): class Tokenizer(Tokenizer):
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
KEYWORDS = { KEYWORDS = {
**Tokenizer.KEYWORDS, **Tokenizer.KEYWORDS,
"SERIAL": TokenType.AUTO_INCREMENT, "ALWAYS": TokenType.ALWAYS,
"BY DEFAULT": TokenType.BY_DEFAULT,
"IDENTITY": TokenType.IDENTITY,
"FOR": TokenType.FOR,
"GENERATED": TokenType.GENERATED,
"DOUBLE PRECISION": TokenType.DOUBLE,
"BIGSERIAL": TokenType.BIGSERIAL,
"SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
"UUID": TokenType.UUID, "UUID": TokenType.UUID,
} }
class Parser(Parser): class Parser(Parser):
STRICT_CAST = False STRICT_CAST = False
FUNCTIONS = { FUNCTIONS = {
**Parser.FUNCTIONS, **Parser.FUNCTIONS,
"TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"), "TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"),
@ -86,14 +183,18 @@ class Postgres(Dialect):
exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
exp.DataType.Type.BINARY: "BYTEA", exp.DataType.Type.BINARY: "BYTEA",
} exp.DataType.Type.DATETIME: "TIMESTAMP",
TOKEN_MAPPING = {
TokenType.AUTO_INCREMENT: "SERIAL",
} }
TRANSFORMS = { TRANSFORMS = {
**Generator.TRANSFORMS, **Generator.TRANSFORMS,
exp.ColumnDef: preprocess(
[
_auto_increment_to_serial,
_serial_to_generated,
],
delegate("columndef_sql"),
),
exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}", exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}",
@ -102,8 +203,11 @@ class Postgres(Dialect):
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: _date_add_sql("+"), exp.DateAdd: _date_add_sql("+"),
exp.DateSub: _date_add_sql("-"), exp.DateSub: _date_add_sql("-"),
exp.Lateral: _lateral_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Substring: _substring_sql,
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TableSample: no_tablesample_sql, exp.TableSample: no_tablesample_sql,
exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql, exp.TryCast: no_trycast_sql,
} }

View file

@ -96,9 +96,7 @@ def _ts_or_ds_to_date_sql(self, expression):
time_format = self.format_time(expression) time_format = self.format_time(expression)
if time_format and time_format not in (Presto.time_format, Presto.date_format): if time_format and time_format not in (Presto.time_format, Presto.date_format):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
return ( return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
)
def _ts_or_ds_add_sql(self, expression): def _ts_or_ds_add_sql(self, expression):
@ -141,6 +139,7 @@ class Presto(Dialect):
"FROM_UNIXTIME": exp.UnixToTime.from_arg_list, "FROM_UNIXTIME": exp.UnixToTime.from_arg_list,
"STRPOS": exp.StrPosition.from_arg_list, "STRPOS": exp.StrPosition.from_arg_list,
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list, "TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
} }
class Generator(Generator): class Generator(Generator):
@ -193,6 +192,7 @@ class Presto(Dialect):
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY = {self.sql(e.args['value'])}", exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY = {self.sql(e.args['value'])}",
exp.Quantile: _quantile_sql, exp.Quantile: _quantile_sql,
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql, exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql, exp.Schema: _schema_sql,
exp.SortArray: _no_sort_array, exp.SortArray: _no_sort_array,

View file

@ -0,0 +1,34 @@
from sqlglot import exp
from sqlglot.dialects.postgres import Postgres
from sqlglot.tokens import TokenType
class Redshift(Postgres):
time_format = "'YYYY-MM-DD HH:MI:SS'"
time_mapping = {
**Postgres.time_mapping,
"MON": "%b",
"HH": "%H",
}
class Tokenizer(Postgres.Tokenizer):
ESCAPE = "\\"
KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS,
"GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER,
"TIME": TokenType.TIMESTAMP,
"TIMETZ": TokenType.TIMESTAMPTZ,
"VARBYTE": TokenType.BINARY,
"SIMILAR TO": TokenType.SIMILAR_TO,
}
class Generator(Postgres.Generator):
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
exp.DataType.Type.BINARY: "VARBYTE",
exp.DataType.Type.INT: "INTEGER",
}

View file

@ -23,9 +23,7 @@ def _snowflake_to_timestamp(args):
# case: <numeric_expr> [ , <scale> ] # case: <numeric_expr> [ , <scale> ]
if second_arg.name not in ["0", "3", "9"]: if second_arg.name not in ["0", "3", "9"]:
raise ValueError( raise ValueError(f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9")
f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9"
)
if second_arg.name == "0": if second_arg.name == "0":
timescale = exp.UnixToTime.SECONDS timescale = exp.UnixToTime.SECONDS

View file

@ -65,12 +65,11 @@ class Spark(Hive):
this=list_get(args, 0), this=list_get(args, 0),
start=exp.Sub( start=exp.Sub(
this=exp.Length(this=list_get(args, 0)), this=exp.Length(this=list_get(args, 0)),
expression=exp.Add( expression=exp.Add(this=list_get(args, 1), expression=exp.Literal.number(1)),
this=list_get(args, 1), expression=exp.Literal.number(1)
),
), ),
length=list_get(args, 1), length=list_get(args, 1),
), ),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
} }
class Generator(Hive.Generator): class Generator(Hive.Generator):
@ -82,11 +81,7 @@ class Spark(Hive):
} }
TRANSFORMS = { TRANSFORMS = {
**{ **{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort}},
k: v
for k, v in Hive.Generator.TRANSFORMS.items()
if k not in {exp.ArraySort}
},
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
@ -102,5 +97,5 @@ class Spark(Hive):
HiveMap: _map_sql, HiveMap: _map_sql,
} }
def bitstring_sql(self, expression): class Tokenizer(Hive.Tokenizer):
return f"X'{self.sql(expression, 'this')}'" HEX_STRINGS = [("X'", "'")]

View file

@ -16,6 +16,7 @@ from sqlglot.tokens import Tokenizer, TokenType
class SQLite(Dialect): class SQLite(Dialect):
class Tokenizer(Tokenizer): class Tokenizer(Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"] IDENTIFIERS = ['"', ("[", "]"), "`"]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
KEYWORDS = { KEYWORDS = {
**Tokenizer.KEYWORDS, **Tokenizer.KEYWORDS,

View file

@ -8,3 +8,6 @@ class Trino(Presto):
**Presto.Generator.TRANSFORMS, **Presto.Generator.TRANSFORMS,
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
} }
class Tokenizer(Presto.Tokenizer):
HEX_STRINGS = [("X'", "'")]

View file

@ -115,13 +115,8 @@ class ChangeDistiller:
for kept_source_node_id, kept_target_node_id in matching_set: for kept_source_node_id, kept_target_node_id in matching_set:
source_node = self._source_index[kept_source_node_id] source_node = self._source_index[kept_source_node_id]
target_node = self._target_index[kept_target_node_id] target_node = self._target_index[kept_target_node_id]
if ( if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node:
not isinstance(source_node, LEAF_EXPRESSION_TYPES) edit_script.extend(self._generate_move_edits(source_node, target_node, matching_set))
or source_node == target_node
):
edit_script.extend(
self._generate_move_edits(source_node, target_node, matching_set)
)
edit_script.append(Keep(source_node, target_node)) edit_script.append(Keep(source_node, target_node))
else: else:
edit_script.append(Update(source_node, target_node)) edit_script.append(Update(source_node, target_node))
@ -132,9 +127,7 @@ class ChangeDistiller:
source_args = [id(e) for e in _expression_only_args(source)] source_args = [id(e) for e in _expression_only_args(source)]
target_args = [id(e) for e in _expression_only_args(target)] target_args = [id(e) for e in _expression_only_args(target)]
args_lcs = set( args_lcs = set(_lcs(source_args, target_args, lambda l, r: (l, r) in matching_set))
_lcs(source_args, target_args, lambda l, r: (l, r) in matching_set)
)
move_edits = [] move_edits = []
for a in source_args: for a in source_args:
@ -148,14 +141,10 @@ class ChangeDistiller:
matching_set = leaves_matching_set.copy() matching_set = leaves_matching_set.copy()
ordered_unmatched_source_nodes = { ordered_unmatched_source_nodes = {
id(n[0]): None id(n[0]): None for n in self._source.bfs() if id(n[0]) in self._unmatched_source_nodes
for n in self._source.bfs()
if id(n[0]) in self._unmatched_source_nodes
} }
ordered_unmatched_target_nodes = { ordered_unmatched_target_nodes = {
id(n[0]): None id(n[0]): None for n in self._target.bfs() if id(n[0]) in self._unmatched_target_nodes
for n in self._target.bfs()
if id(n[0]) in self._unmatched_target_nodes
} }
for source_node_id in ordered_unmatched_source_nodes: for source_node_id in ordered_unmatched_source_nodes:
@ -169,18 +158,13 @@ class ChangeDistiller:
max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids)) max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids))
if max_leaves_num: if max_leaves_num:
common_leaves_num = sum( common_leaves_num = sum(
1 if s in source_leaf_ids and t in target_leaf_ids else 0 1 if s in source_leaf_ids and t in target_leaf_ids else 0 for s, t in leaves_matching_set
for s, t in leaves_matching_set
) )
leaf_similarity_score = common_leaves_num / max_leaves_num leaf_similarity_score = common_leaves_num / max_leaves_num
else: else:
leaf_similarity_score = 0.0 leaf_similarity_score = 0.0
adjusted_t = ( adjusted_t = self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4
self.t
if min(len(source_leaf_ids), len(target_leaf_ids)) > 4
else 0.4
)
if leaf_similarity_score >= 0.8 or ( if leaf_similarity_score >= 0.8 or (
leaf_similarity_score >= adjusted_t leaf_similarity_score >= adjusted_t
@ -217,10 +201,7 @@ class ChangeDistiller:
matching_set = set() matching_set = set()
while candidate_matchings: while candidate_matchings:
_, _, source_leaf, target_leaf = heappop(candidate_matchings) _, _, source_leaf, target_leaf = heappop(candidate_matchings)
if ( if id(source_leaf) in self._unmatched_source_nodes and id(target_leaf) in self._unmatched_target_nodes:
id(source_leaf) in self._unmatched_source_nodes
and id(target_leaf) in self._unmatched_target_nodes
):
matching_set.add((id(source_leaf), id(target_leaf))) matching_set.add((id(source_leaf), id(target_leaf)))
self._unmatched_source_nodes.remove(id(source_leaf)) self._unmatched_source_nodes.remove(id(source_leaf))
self._unmatched_target_nodes.remove(id(target_leaf)) self._unmatched_target_nodes.remove(id(target_leaf))

View file

@ -3,11 +3,17 @@ import time
from sqlglot import parse_one from sqlglot import parse_one
from sqlglot.executor.python import PythonExecutor from sqlglot.executor.python import PythonExecutor
from sqlglot.optimizer import optimize from sqlglot.optimizer import RULES, optimize
from sqlglot.optimizer.merge_derived_tables import merge_derived_tables
from sqlglot.planner import Plan from sqlglot.planner import Plan
logger = logging.getLogger("sqlglot") logger = logging.getLogger("sqlglot")
OPTIMIZER_RULES = list(RULES)
# The executor needs isolated table selects
OPTIMIZER_RULES.remove(merge_derived_tables)
def execute(sql, schema, read=None): def execute(sql, schema, read=None):
""" """
@ -28,7 +34,7 @@ def execute(sql, schema, read=None):
""" """
expression = parse_one(sql, read=read) expression = parse_one(sql, read=read)
now = time.time() now = time.time()
expression = optimize(expression, schema) expression = optimize(expression, schema, rules=OPTIMIZER_RULES)
logger.debug("Optimization finished: %f", time.time() - now) logger.debug("Optimization finished: %f", time.time() - now)
logger.debug("Optimized SQL: %s", expression.sql(pretty=True)) logger.debug("Optimized SQL: %s", expression.sql(pretty=True))
plan = Plan(expression) plan = Plan(expression)

View file

@ -19,9 +19,7 @@ class Context:
env (Optional[dict]): dictionary of functions within the execution context env (Optional[dict]): dictionary of functions within the execution context
""" """
self.tables = tables self.tables = tables
self.range_readers = { self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
name: table.range_reader for name, table in self.tables.items()
}
self.row_readers = {name: table.reader for name, table in tables.items()} self.row_readers = {name: table.reader for name, table in tables.items()}
self.env = {**(env or {}), "scope": self.row_readers} self.env = {**(env or {}), "scope": self.row_readers}

View file

@ -26,11 +26,7 @@ class PythonExecutor:
while queue: while queue:
node = queue.pop() node = queue.pop()
context = self.context( context = self.context(
{ {name: table for dep in node.dependencies for name, table in contexts[dep].tables.items()}
name: table
for dep in node.dependencies
for name, table in contexts[dep].tables.items()
}
) )
running.add(node) running.add(node)
@ -151,9 +147,7 @@ class PythonExecutor:
return self.context({name: table for name in ctx.tables}) return self.context({name: table for name in ctx.tables})
for name, join in step.joins.items(): for name, join in step.joins.items():
join_context = self.context( join_context = self.context({**join_context.tables, name: context.tables[name]})
{**join_context.tables, name: context.tables[name]}
)
if join.get("source_key"): if join.get("source_key"):
table = self.hash_join(join, source, name, join_context) table = self.hash_join(join, source, name, join_context)
@ -247,9 +241,7 @@ class PythonExecutor:
if step.operands: if step.operands:
source_table = context.tables[source] source_table = context.tables[source]
operand_table = Table( operand_table = Table(source_table.columns + self.table(step.operands).columns)
source_table.columns + self.table(step.operands).columns
)
for reader, ctx in context: for reader, ctx in context:
operand_table.append(reader.row + ctx.eval_tuple(operands)) operand_table.append(reader.row + ctx.eval_tuple(operands))

View file

@ -37,10 +37,7 @@ class Table:
break break
lines.append( lines.append(
" ".join( " ".join(str(row[column]).rjust(widths[column])[0 : widths[column]] for column in self.columns)
str(row[column]).rjust(widths[column])[0 : widths[column]]
for column in self.columns
)
) )
return "\n".join(lines) return "\n".join(lines)

View file

@ -47,10 +47,7 @@ class Expression(metaclass=_Expression):
return hash( return hash(
( (
self.key, self.key,
tuple( tuple((k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()),
(k, tuple(v) if isinstance(v, list) else v)
for k, v in _norm_args(self).items()
),
) )
) )
@ -116,9 +113,22 @@ class Expression(metaclass=_Expression):
item.parent = parent item.parent = parent
return new return new
def append(self, arg_key, value):
"""
Appends value to arg_key if it's a list or sets it as a new list.
Args:
arg_key (str): name of the list expression arg
value (Any): value to append to the list
"""
if not isinstance(self.args.get(arg_key), list):
self.args[arg_key] = []
self.args[arg_key].append(value)
self._set_parent(arg_key, value)
def set(self, arg_key, value): def set(self, arg_key, value):
""" """
Sets `arg` to `value`. Sets `arg_key` to `value`.
Args: Args:
arg_key (str): name of the expression arg arg_key (str): name of the expression arg
@ -267,6 +277,14 @@ class Expression(metaclass=_Expression):
expression = expression.this expression = expression.this
return expression return expression
def unalias(self):
"""
Returns the inner expression if this is an Alias.
"""
if isinstance(self, Alias):
return self.this
return self
def unnest_operands(self): def unnest_operands(self):
""" """
Returns unnested operands as a tuple. Returns unnested operands as a tuple.
@ -279,9 +297,7 @@ class Expression(metaclass=_Expression):
A AND B AND C -> [A, B, C] A AND B AND C -> [A, B, C]
""" """
for node, _, _ in self.dfs( for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not isinstance(n, self.__class__)):
prune=lambda n, p, *_: p and not isinstance(n, self.__class__)
):
if not isinstance(node, self.__class__): if not isinstance(node, self.__class__):
yield node.unnest() if unnest else node yield node.unnest() if unnest else node
@ -314,9 +330,7 @@ class Expression(metaclass=_Expression):
args = { args = {
k: ", ".join( k: ", ".join(
v.to_s(hide_missing=hide_missing, level=level + 1) v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
if hasattr(v, "to_s")
else str(v)
for v in ensure_list(vs) for v in ensure_list(vs)
if v is not None if v is not None
) )
@ -354,9 +368,7 @@ class Expression(metaclass=_Expression):
new_node.parent = node.parent new_node.parent = node.parent
return new_node return new_node
replace_children( replace_children(new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs))
new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs)
)
return new_node return new_node
def replace(self, expression): def replace(self, expression):
@ -546,6 +558,10 @@ class BitString(Condition):
pass pass
class HexString(Condition):
pass
class Column(Condition): class Column(Condition):
arg_types = {"this": True, "table": False} arg_types = {"this": True, "table": False}
@ -566,35 +582,44 @@ class ColumnConstraint(Expression):
arg_types = {"this": False, "kind": True} arg_types = {"this": False, "kind": True}
class AutoIncrementColumnConstraint(Expression): class ColumnConstraintKind(Expression):
pass pass
class CheckColumnConstraint(Expression): class AutoIncrementColumnConstraint(ColumnConstraintKind):
pass pass
class CollateColumnConstraint(Expression): class CheckColumnConstraint(ColumnConstraintKind):
pass pass
class CommentColumnConstraint(Expression): class CollateColumnConstraint(ColumnConstraintKind):
pass pass
class DefaultColumnConstraint(Expression): class CommentColumnConstraint(ColumnConstraintKind):
pass pass
class NotNullColumnConstraint(Expression): class DefaultColumnConstraint(ColumnConstraintKind):
pass pass
class PrimaryKeyColumnConstraint(Expression): class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
# this: True -> ALWAYS, this: False -> BY DEFAULT
arg_types = {"this": True, "expression": False}
class NotNullColumnConstraint(ColumnConstraintKind):
pass pass
class UniqueColumnConstraint(Expression): class PrimaryKeyColumnConstraint(ColumnConstraintKind):
pass
class UniqueColumnConstraint(ColumnConstraintKind):
pass pass
@ -651,9 +676,7 @@ class Identifier(Expression):
return bool(self.args.get("quoted")) return bool(self.args.get("quoted"))
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg( return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(other.this)
other.this
)
def __hash__(self): def __hash__(self):
return hash((self.key, self.this.lower())) return hash((self.key, self.this.lower()))
@ -709,9 +732,7 @@ class Literal(Condition):
def __eq__(self, other): def __eq__(self, other):
return ( return (
isinstance(other, Literal) isinstance(other, Literal) and self.this == other.this and self.args["is_string"] == other.args["is_string"]
and self.this == other.this
and self.args["is_string"] == other.args["is_string"]
) )
def __hash__(self): def __hash__(self):
@ -733,6 +754,7 @@ class Join(Expression):
"side": False, "side": False,
"kind": False, "kind": False,
"using": False, "using": False,
"natural": False,
} }
@property @property
@ -743,6 +765,10 @@ class Join(Expression):
def side(self): def side(self):
return self.text("side").upper() return self.text("side").upper()
@property
def alias_or_name(self):
return self.this.alias_or_name
def on(self, *expressions, append=True, dialect=None, copy=True, **opts): def on(self, *expressions, append=True, dialect=None, copy=True, **opts):
""" """
Append to or set the ON expressions. Append to or set the ON expressions.
@ -873,10 +899,6 @@ class Reference(Expression):
arg_types = {"this": True, "expressions": True} arg_types = {"this": True, "expressions": True}
class Table(Expression):
arg_types = {"this": True, "db": False, "catalog": False}
class Tuple(Expression): class Tuple(Expression):
arg_types = {"expressions": False} arg_types = {"expressions": False}
@ -986,6 +1008,16 @@ QUERY_MODIFIERS = {
} }
class Table(Expression):
arg_types = {
"this": True,
"db": False,
"catalog": False,
"laterals": False,
"joins": False,
}
class Union(Subqueryable, Expression): class Union(Subqueryable, Expression):
arg_types = { arg_types = {
"with": False, "with": False,
@ -1396,7 +1428,9 @@ class Select(Subqueryable, Expression):
join.this.replace(join.this.subquery()) join.this.replace(join.this.subquery())
if join_type: if join_type:
side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args)
if natural:
join.set("natural", True)
if side: if side:
join.set("side", side.text) join.set("side", side.text)
if kind: if kind:
@ -1529,10 +1563,7 @@ class Select(Subqueryable, Expression):
properties_expression = None properties_expression = None
if properties: if properties:
properties_str = " ".join( properties_str = " ".join(
[ [f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in properties.items()]
f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}"
for k, v in properties.items()
]
) )
properties_expression = maybe_parse( properties_expression = maybe_parse(
properties_str, properties_str,
@ -1654,6 +1685,7 @@ class DataType(Expression):
DECIMAL = auto() DECIMAL = auto()
BOOLEAN = auto() BOOLEAN = auto()
JSON = auto() JSON = auto()
INTERVAL = auto()
TIMESTAMP = auto() TIMESTAMP = auto()
TIMESTAMPTZ = auto() TIMESTAMPTZ = auto()
DATE = auto() DATE = auto()
@ -1662,15 +1694,19 @@ class DataType(Expression):
MAP = auto() MAP = auto()
UUID = auto() UUID = auto()
GEOGRAPHY = auto() GEOGRAPHY = auto()
GEOMETRY = auto()
STRUCT = auto() STRUCT = auto()
NULLABLE = auto() NULLABLE = auto()
HLLSKETCH = auto()
SUPER = auto()
SERIAL = auto()
SMALLSERIAL = auto()
BIGSERIAL = auto()
@classmethod @classmethod
def build(cls, dtype, **kwargs): def build(cls, dtype, **kwargs):
return DataType( return DataType(
this=dtype this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()],
if isinstance(dtype, DataType.Type)
else DataType.Type[dtype.upper()],
**kwargs, **kwargs,
) )
@ -1798,6 +1834,14 @@ class Like(Binary, Predicate):
pass pass
class SimilarTo(Binary, Predicate):
pass
class Distance(Binary):
pass
class LT(Binary, Predicate): class LT(Binary, Predicate):
pass pass
@ -1899,6 +1943,10 @@ class IgnoreNulls(Expression):
pass pass
class RespectNulls(Expression):
pass
# Functions # Functions
class Func(Condition): class Func(Condition):
""" """
@ -1924,9 +1972,7 @@ class Func(Condition):
all_arg_keys = list(cls.arg_types) all_arg_keys = list(cls.arg_types)
# If this function supports variable length argument treat the last argument as such. # If this function supports variable length argument treat the last argument as such.
non_var_len_arg_keys = ( non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
)
args_dict = {} args_dict = {}
arg_idx = 0 arg_idx = 0
@ -1944,9 +1990,7 @@ class Func(Condition):
@classmethod @classmethod
def sql_names(cls): def sql_names(cls):
if cls is Func: if cls is Func:
raise NotImplementedError( raise NotImplementedError("SQL name is only supported by concrete function implementations")
"SQL name is only supported by concrete function implementations"
)
if not hasattr(cls, "_sql_names"): if not hasattr(cls, "_sql_names"):
cls._sql_names = [camel_to_snake_case(cls.__name__)] cls._sql_names = [camel_to_snake_case(cls.__name__)]
return cls._sql_names return cls._sql_names
@ -2178,6 +2222,10 @@ class Greatest(Func):
is_var_len_args = True is_var_len_args = True
class GroupConcat(Func):
arg_types = {"this": True, "separator": False}
class If(Func): class If(Func):
arg_types = {"this": True, "true": True, "false": False} arg_types = {"this": True, "true": True, "false": False}
@ -2274,6 +2322,10 @@ class Quantile(AggFunc):
arg_types = {"this": True, "quantile": True} arg_types = {"this": True, "quantile": True}
class ApproxQuantile(Quantile):
pass
class Reduce(Func): class Reduce(Func):
arg_types = {"this": True, "initial": True, "merge": True, "finish": True} arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
@ -2306,8 +2358,10 @@ class Split(Func):
arg_types = {"this": True, "expression": True} arg_types = {"this": True, "expression": True}
# Start may be omitted in the case of postgres
# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
class Substring(Func): class Substring(Func):
arg_types = {"this": True, "start": True, "length": False} arg_types = {"this": True, "start": False, "length": False}
class StrPosition(Func): class StrPosition(Func):
@ -2379,6 +2433,15 @@ class TimeStrToUnix(Func):
pass pass
class Trim(Func):
arg_types = {
"this": True,
"position": False,
"expression": False,
"collation": False,
}
class TsOrDsAdd(Func, TimeUnit): class TsOrDsAdd(Func, TimeUnit):
arg_types = {"this": True, "expression": True, "unit": False} arg_types = {"this": True, "expression": True, "unit": False}
@ -2455,9 +2518,7 @@ def _all_functions():
obj obj
for _, obj in inspect.getmembers( for _, obj in inspect.getmembers(
sys.modules[__name__], sys.modules[__name__],
lambda obj: inspect.isclass(obj) lambda obj: inspect.isclass(obj) and issubclass(obj, Func) and obj not in (AggFunc, Anonymous, Func),
and issubclass(obj, Func)
and obj not in (AggFunc, Anonymous, Func),
) )
] ]
@ -2633,9 +2694,7 @@ def _apply_conjunction_builder(
def _combine(expressions, operator, dialect=None, **opts): def _combine(expressions, operator, dialect=None, **opts):
expressions = [ expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions]
condition(expression, dialect=dialect, **opts) for expression in expressions
]
this = expressions[0] this = expressions[0]
if expressions[1:]: if expressions[1:]:
this = _wrap_operator(this) this = _wrap_operator(this)
@ -2809,9 +2868,7 @@ def to_identifier(alias, quoted=None):
quoted = not re.match(SAFE_IDENTIFIER_RE, alias) quoted = not re.match(SAFE_IDENTIFIER_RE, alias)
identifier = Identifier(this=alias, quoted=quoted) identifier = Identifier(this=alias, quoted=quoted)
else: else:
raise ValueError( raise ValueError(f"Alias needs to be a string or an Identifier, got: {alias.__class__}")
f"Alias needs to be a string or an Identifier, got: {alias.__class__}"
)
return identifier return identifier

View file

@ -41,6 +41,8 @@ class Generator:
max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError. max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError.
This is only relevant if unsupported_level is ErrorLevel.RAISE. This is only relevant if unsupported_level is ErrorLevel.RAISE.
Default: 3 Default: 3
leading_comma (bool): if the the comma is leading or trailing in select statements
Default: False
""" """
TRANSFORMS = { TRANSFORMS = {
@ -108,6 +110,7 @@ class Generator:
"_indent", "_indent",
"_replace_backslash", "_replace_backslash",
"_escaped_quote_end", "_escaped_quote_end",
"_leading_comma",
) )
def __init__( def __init__(
@ -131,6 +134,7 @@ class Generator:
unsupported_level=ErrorLevel.WARN, unsupported_level=ErrorLevel.WARN,
null_ordering=None, null_ordering=None,
max_unsupported=3, max_unsupported=3,
leading_comma=False,
): ):
import sqlglot import sqlglot
@ -157,6 +161,7 @@ class Generator:
self._indent = indent self._indent = indent
self._replace_backslash = self.escape == "\\" self._replace_backslash = self.escape == "\\"
self._escaped_quote_end = self.escape + self.quote_end self._escaped_quote_end = self.escape + self.quote_end
self._leading_comma = leading_comma
def generate(self, expression): def generate(self, expression):
""" """
@ -178,9 +183,7 @@ class Generator:
for msg in self.unsupported_messages: for msg in self.unsupported_messages:
logger.warning(msg) logger.warning(msg)
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages: elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
raise UnsupportedError( raise UnsupportedError(concat_errors(self.unsupported_messages, self.max_unsupported))
concat_errors(self.unsupported_messages, self.max_unsupported)
)
return sql return sql
@ -197,9 +200,7 @@ class Generator:
def wrap(self, expression): def wrap(self, expression):
this_sql = self.indent( this_sql = self.indent(
self.sql(expression) self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) else self.sql(expression, "this"),
if isinstance(expression, (exp.Select, exp.Union))
else self.sql(expression, "this"),
level=1, level=1,
pad=0, pad=0,
) )
@ -251,9 +252,7 @@ class Generator:
return transform return transform
if not isinstance(expression, exp.Expression): if not isinstance(expression, exp.Expression):
raise ValueError( raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
f"Expected an Expression. Received {type(expression)}: {expression}"
)
exp_handler_name = f"{expression.key}_sql" exp_handler_name = f"{expression.key}_sql"
if hasattr(self, exp_handler_name): if hasattr(self, exp_handler_name):
@ -276,11 +275,7 @@ class Generator:
lazy = " LAZY" if expression.args.get("lazy") else "" lazy = " LAZY" if expression.args.get("lazy") else ""
table = self.sql(expression, "this") table = self.sql(expression, "this")
options = expression.args.get("options") options = expression.args.get("options")
options = ( options = f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})" if options else ""
f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})"
if options
else ""
)
sql = self.sql(expression, "expression") sql = self.sql(expression, "expression")
sql = f" AS{self.sep()}{sql}" if sql else "" sql = f" AS{self.sep()}{sql}" if sql else ""
sql = f"CACHE{lazy} TABLE {table}{options}{sql}" sql = f"CACHE{lazy} TABLE {table}{options}{sql}"
@ -306,9 +301,7 @@ class Generator:
def columndef_sql(self, expression): def columndef_sql(self, expression):
column = self.sql(expression, "this") column = self.sql(expression, "this")
kind = self.sql(expression, "kind") kind = self.sql(expression, "kind")
constraints = self.expressions( constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
expression, key="constraints", sep=" ", flat=True
)
if not constraints: if not constraints:
return f"{column} {kind}" return f"{column} {kind}"
@ -338,6 +331,9 @@ class Generator:
default = self.sql(expression, "this") default = self.sql(expression, "this")
return f"DEFAULT {default}" return f"DEFAULT {default}"
def generatedasidentitycolumnconstraint_sql(self, expression):
return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY"
def notnullcolumnconstraint_sql(self, _): def notnullcolumnconstraint_sql(self, _):
return "NOT NULL" return "NOT NULL"
@ -384,7 +380,10 @@ class Generator:
return f"{alias}{columns}" return f"{alias}{columns}"
def bitstring_sql(self, expression): def bitstring_sql(self, expression):
return f"b'{self.sql(expression, 'this')}'" return self.sql(expression, "this")
def hexstring_sql(self, expression):
return self.sql(expression, "this")
def datatype_sql(self, expression): def datatype_sql(self, expression):
type_value = expression.this type_value = expression.this
@ -452,10 +451,7 @@ class Generator:
def partition_sql(self, expression): def partition_sql(self, expression):
keys = csv( keys = csv(
*[ *[f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] for k, v in expression.args.get("this")]
f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"]
for k, v in expression.args.get("this")
]
) )
return f"PARTITION({keys})" return f"PARTITION({keys})"
@ -470,9 +466,9 @@ class Generator:
elif p_class in self.WITH_PROPERTIES: elif p_class in self.WITH_PROPERTIES:
with_properties.append(p) with_properties.append(p)
return self.root_properties( return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties(
exp.Properties(expressions=root_properties) exp.Properties(expressions=with_properties)
) + self.with_properties(exp.Properties(expressions=with_properties)) )
def root_properties(self, properties): def root_properties(self, properties):
if properties.expressions: if properties.expressions:
@ -508,11 +504,7 @@ class Generator:
kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO" kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO"
this = self.sql(expression, "this") this = self.sql(expression, "this")
exists = " IF EXISTS " if expression.args.get("exists") else " " exists = " IF EXISTS " if expression.args.get("exists") else " "
partition_sql = ( partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else ""
self.sql(expression, "partition")
if expression.args.get("partition")
else ""
)
expression_sql = self.sql(expression, "expression") expression_sql = self.sql(expression, "expression")
sep = self.sep() if partition_sql else "" sep = self.sep() if partition_sql else ""
sql = f"INSERT {kind} {this}{exists}{partition_sql}{sep}{expression_sql}" sql = f"INSERT {kind} {this}{exists}{partition_sql}{sep}{expression_sql}"
@ -531,7 +523,7 @@ class Generator:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
def table_sql(self, expression): def table_sql(self, expression):
return ".".join( table = ".".join(
part part
for part in [ for part in [
self.sql(expression, "catalog"), self.sql(expression, "catalog"),
@ -541,6 +533,10 @@ class Generator:
if part if part
) )
laterals = self.expressions(expression, key="laterals", sep="")
joins = self.expressions(expression, key="joins", sep="")
return f"{table}{laterals}{joins}"
def tablesample_sql(self, expression): def tablesample_sql(self, expression):
if self.alias_post_tablesample and isinstance(expression.this, exp.Alias): if self.alias_post_tablesample and isinstance(expression.this, exp.Alias):
this = self.sql(expression.this, "this") this = self.sql(expression.this, "this")
@ -586,11 +582,7 @@ class Generator:
def group_sql(self, expression): def group_sql(self, expression):
group_by = self.op_expressions("GROUP BY", expression) group_by = self.op_expressions("GROUP BY", expression)
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False) grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
grouping_sets = ( grouping_sets = f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}"
if grouping_sets
else ""
)
cube = self.expressions(expression, key="cube", indent=False) cube = self.expressions(expression, key="cube", indent=False)
cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else "" cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
rollup = self.expressions(expression, key="rollup", indent=False) rollup = self.expressions(expression, key="rollup", indent=False)
@ -603,7 +595,16 @@ class Generator:
def join_sql(self, expression): def join_sql(self, expression):
op_sql = self.seg( op_sql = self.seg(
" ".join(op for op in (expression.side, expression.kind, "JOIN") if op) " ".join(
op
for op in (
"NATURAL" if expression.args.get("natural") else None,
expression.side,
expression.kind,
"JOIN",
)
if op
)
) )
on_sql = self.sql(expression, "on") on_sql = self.sql(expression, "on")
using = expression.args.get("using") using = expression.args.get("using")
@ -630,9 +631,9 @@ class Generator:
def lateral_sql(self, expression): def lateral_sql(self, expression):
this = self.sql(expression, "this") this = self.sql(expression, "this")
op_sql = self.seg( if isinstance(expression.this, exp.Subquery):
f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}" return f"LATERAL{self.sep()}{this}"
) op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}")
alias = expression.args["alias"] alias = expression.args["alias"]
table = alias.name table = alias.name
table = f" {table}" if table else table table = f" {table}" if table else table
@ -688,21 +689,13 @@ class Generator:
sort_order = " DESC" if desc else "" sort_order = " DESC" if desc else ""
nulls_sort_change = "" nulls_sort_change = ""
if nulls_first and ( if nulls_first and ((asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last):
(asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last
):
nulls_sort_change = " NULLS FIRST" nulls_sort_change = " NULLS FIRST"
elif ( elif nulls_last and ((asc and nulls_are_small) or (desc and nulls_are_large)) and not nulls_are_last:
nulls_last
and ((asc and nulls_are_small) or (desc and nulls_are_large))
and not nulls_are_last
):
nulls_sort_change = " NULLS LAST" nulls_sort_change = " NULLS LAST"
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED: if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
self.unsupported( self.unsupported("Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect")
"Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect"
)
nulls_sort_change = "" nulls_sort_change = ""
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}" return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
@ -798,14 +791,20 @@ class Generator:
def window_sql(self, expression): def window_sql(self, expression):
this = self.sql(expression, "this") this = self.sql(expression, "this")
partition = self.expressions(expression, key="partition_by", flat=True) partition = self.expressions(expression, key="partition_by", flat=True)
partition = f"PARTITION BY {partition}" if partition else "" partition = f"PARTITION BY {partition}" if partition else ""
order = expression.args.get("order") order = expression.args.get("order")
order_sql = self.order_sql(order, flat=True) if order else "" order_sql = self.order_sql(order, flat=True) if order else ""
partition_sql = partition + " " if partition and order else partition partition_sql = partition + " " if partition and order else partition
spec = expression.args.get("spec") spec = expression.args.get("spec")
spec_sql = " " + self.window_spec_sql(spec) if spec else "" spec_sql = " " + self.window_spec_sql(spec) if spec else ""
alias = self.sql(expression, "alias") alias = self.sql(expression, "alias")
if expression.arg_key == "window": if expression.arg_key == "window":
this = this = f"{self.seg('WINDOW')} {this} AS" this = this = f"{self.seg('WINDOW')} {this} AS"
else: else:
@ -818,13 +817,8 @@ class Generator:
def window_spec_sql(self, expression): def window_spec_sql(self, expression):
kind = self.sql(expression, "kind") kind = self.sql(expression, "kind")
start = csv( start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" " end = csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") or "CURRENT ROW"
)
end = (
csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ")
or "CURRENT ROW"
)
return f"{kind} BETWEEN {start} AND {end}" return f"{kind} BETWEEN {start} AND {end}"
def withingroup_sql(self, expression): def withingroup_sql(self, expression):
@ -879,6 +873,17 @@ class Generator:
expression_sql = self.sql(expression, "expression") expression_sql = self.sql(expression, "expression")
return f"EXTRACT({this} FROM {expression_sql})" return f"EXTRACT({this} FROM {expression_sql})"
def trim_sql(self, expression):
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
if trim_type == "LEADING":
return f"LTRIM({target})"
elif trim_type == "TRAILING":
return f"RTRIM({target})"
else:
return f"TRIM({target})"
def check_sql(self, expression): def check_sql(self, expression):
this = self.sql(expression, key="this") this = self.sql(expression, key="this")
return f"CHECK ({this})" return f"CHECK ({this})"
@ -898,9 +903,7 @@ class Generator:
return f"UNIQUE ({columns})" return f"UNIQUE ({columns})"
def if_sql(self, expression): def if_sql(self, expression):
return self.case_sql( return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
exp.Case(ifs=[expression], default=expression.args.get("false"))
)
def in_sql(self, expression): def in_sql(self, expression):
query = expression.args.get("query") query = expression.args.get("query")
@ -917,7 +920,9 @@ class Generator:
return f"(SELECT {self.sql(unnest)})" return f"(SELECT {self.sql(unnest)})"
def interval_sql(self, expression): def interval_sql(self, expression):
return f"INTERVAL {self.sql(expression, 'this')} {self.sql(expression, 'unit')}" unit = self.sql(expression, "unit")
unit = f" {unit}" if unit else ""
return f"INTERVAL {self.sql(expression, 'this')}{unit}"
def reference_sql(self, expression): def reference_sql(self, expression):
this = self.sql(expression, "this") this = self.sql(expression, "this")
@ -925,9 +930,7 @@ class Generator:
return f"REFERENCES {this}({expressions})" return f"REFERENCES {this}({expressions})"
def anonymous_sql(self, expression): def anonymous_sql(self, expression):
args = self.indent( args = self.indent(self.expressions(expression, flat=True), skip_first=True, skip_last=True)
self.expressions(expression, flat=True), skip_first=True, skip_last=True
)
return f"{self.normalize_func(self.sql(expression, 'this'))}({args})" return f"{self.normalize_func(self.sql(expression, 'this'))}({args})"
def paren_sql(self, expression): def paren_sql(self, expression):
@ -1006,6 +1009,9 @@ class Generator:
def ignorenulls_sql(self, expression): def ignorenulls_sql(self, expression):
return f"{self.sql(expression, 'this')} IGNORE NULLS" return f"{self.sql(expression, 'this')} IGNORE NULLS"
def respectnulls_sql(self, expression):
return f"{self.sql(expression, 'this')} RESPECT NULLS"
def intdiv_sql(self, expression): def intdiv_sql(self, expression):
return self.sql( return self.sql(
exp.Cast( exp.Cast(
@ -1023,6 +1029,9 @@ class Generator:
def div_sql(self, expression): def div_sql(self, expression):
return self.binary(expression, "/") return self.binary(expression, "/")
def distance_sql(self, expression):
return self.binary(expression, "<->")
def dot_sql(self, expression): def dot_sql(self, expression):
return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}" return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}"
@ -1047,6 +1056,9 @@ class Generator:
def like_sql(self, expression): def like_sql(self, expression):
return self.binary(expression, "LIKE") return self.binary(expression, "LIKE")
def similarto_sql(self, expression):
return self.binary(expression, "SIMILAR TO")
def lt_sql(self, expression): def lt_sql(self, expression):
return self.binary(expression, "<") return self.binary(expression, "<")
@ -1069,14 +1081,10 @@ class Generator:
return self.binary(expression, "-") return self.binary(expression, "-")
def trycast_sql(self, expression): def trycast_sql(self, expression):
return ( return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
)
def binary(self, expression, op): def binary(self, expression, op):
return ( return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
)
def function_fallback_sql(self, expression): def function_fallback_sql(self, expression):
args = [] args = []
@ -1089,9 +1097,7 @@ class Generator:
return f"{self.normalize_func(expression.sql_name())}({args_str})" return f"{self.normalize_func(expression.sql_name())}({args_str})"
def format_time(self, expression): def format_time(self, expression):
return format_time( return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie)
self.sql(expression, "format"), self.time_mapping, self.time_trie
)
def expressions(self, expression, key=None, flat=False, indent=True, sep=", "): def expressions(self, expression, key=None, flat=False, indent=True, sep=", "):
expressions = expression.args.get(key or "expressions") expressions = expression.args.get(key or "expressions")
@ -1102,7 +1108,14 @@ class Generator:
if flat: if flat:
return sep.join(self.sql(e) for e in expressions) return sep.join(self.sql(e) for e in expressions)
expressions = self.sep(sep).join(self.sql(e) for e in expressions) sql = (self.sql(e) for e in expressions)
# the only time leading_comma changes the output is if pretty print is enabled
if self._leading_comma and self.pretty:
pad = " " * self.pad
expressions = "\n".join(f"{sep}{s}" if i > 0 else f"{pad}{s}" for i, s in enumerate(sql))
else:
expressions = self.sep(sep).join(sql)
if indent: if indent:
return self.indent(expressions, skip_first=False) return self.indent(expressions, skip_first=False)
return expressions return expressions
@ -1116,9 +1129,7 @@ class Generator:
def set_operation(self, expression, op): def set_operation(self, expression, op):
this = self.sql(expression, "this") this = self.sql(expression, "this")
op = self.seg(op) op = self.seg(op)
return self.query_modifiers( return self.query_modifiers(expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}")
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
)
def token_sql(self, token_type): def token_sql(self, token_type):
return self.TOKEN_MAPPING.get(token_type, token_type.name) return self.TOKEN_MAPPING.get(token_type, token_type.name)

View file

@ -1,2 +1,2 @@
from sqlglot.optimizer.optimizer import optimize from sqlglot.optimizer.optimizer import RULES, optimize
from sqlglot.optimizer.schema import Schema from sqlglot.optimizer.schema import Schema

View file

@ -13,9 +13,7 @@ def isolate_table_selects(expression):
continue continue
if not isinstance(source.parent, exp.Alias): if not isinstance(source.parent, exp.Alias):
raise OptimizeError( raise OptimizeError("Tables require an alias. Run qualify_tables optimization.")
"Tables require an alias. Run qualify_tables optimization."
)
parent = source.parent parent = source.parent

View file

@ -0,0 +1,232 @@
from collections import defaultdict
from sqlglot import expressions as exp
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.optimizer.simplify import simplify
def merge_derived_tables(expression):
"""
Rewrite sqlglot AST to merge derived tables into the outer query.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x)")
>>> merge_derived_tables(expression).sql()
'SELECT x.a FROM x'
Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
Args:
expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression
"""
for outer_scope in traverse_scope(expression):
for subquery in outer_scope.derived_tables:
inner_select = subquery.unnest()
if (
isinstance(outer_scope.expression, exp.Select)
and isinstance(inner_select, exp.Select)
and _mergeable(inner_select)
):
alias = subquery.alias_or_name
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
inner_scope = outer_scope.sources[alias]
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, subquery)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
return expression
# If a derived table has these Select args, it can't be merged
UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
"expressions",
"from",
"joins",
"where",
"order",
}
def _mergeable(inner_select):
"""
Return True if `inner_select` can be merged into outer query.
Args:
inner_select (exp.Select)
Returns:
bool: True if can be merged
"""
return (
isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
)
def _rename_inner_sources(outer_scope, inner_scope, alias):
"""
Renames any sources in the inner query that conflict with names in the outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
alias (str)
"""
taken = set(outer_scope.selected_sources)
conflicts = taken.intersection(set(inner_scope.selected_sources))
conflicts = conflicts - {alias}
for conflict in conflicts:
new_name = _find_new_name(taken, conflict)
source, _ = inner_scope.selected_sources[conflict]
new_alias = exp.to_identifier(new_name)
if isinstance(source, exp.Subquery):
source.set("alias", exp.TableAlias(this=new_alias))
elif isinstance(source, exp.Table) and isinstance(source.parent, exp.Alias):
source.parent.set("alias", new_alias)
elif isinstance(source, exp.Table):
source.replace(exp.alias_(source.copy(), new_alias))
for column in inner_scope.source_columns(conflict):
column.set("table", exp.to_identifier(new_name))
inner_scope.rename_source(conflict, new_name)
def _find_new_name(taken, base):
"""
Searches for a new source name.
Args:
taken (set[str]): set of taken names
base (str): base name to alter
"""
i = 2
new = f"{base}_{i}"
while new in taken:
i += 1
new = f"{base}_{i}"
return new
def _merge_from(outer_scope, inner_scope, subquery):
"""
Merge FROM clause of inner query into outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
subquery (exp.Subquery)
"""
new_subquery = inner_scope.expression.args.get("from").expressions[0]
subquery.replace(new_subquery)
outer_scope.remove_source(subquery.alias_or_name)
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
def _merge_joins(outer_scope, inner_scope, from_or_join):
"""
Merge JOIN clauses of inner query into outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
from_or_join (exp.From|exp.Join)
"""
new_joins = []
comma_joins = inner_scope.expression.args.get("from").expressions[1:]
for subquery in comma_joins:
new_joins.append(exp.Join(this=subquery, kind="CROSS"))
outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name])
joins = inner_scope.expression.args.get("joins") or []
for join in joins:
new_joins.append(join)
outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name])
if new_joins:
outer_joins = outer_scope.expression.args.get("joins", [])
# Maintain the join order
if isinstance(from_or_join, exp.From):
position = 0
else:
position = outer_joins.index(from_or_join) + 1
outer_joins[position:position] = new_joins
outer_scope.expression.set("joins", outer_joins)
def _merge_expressions(outer_scope, inner_scope, alias):
"""
Merge projections of inner query into outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
alias (str)
"""
# Collect all columns that for the alias of the inner query
outer_columns = defaultdict(list)
for column in outer_scope.columns:
if column.table == alias:
outer_columns[column.name].append(column)
# Replace columns with the projection expression in the inner query
for expression in inner_scope.expression.expressions:
projection_name = expression.alias_or_name
if not projection_name:
continue
columns_to_replace = outer_columns.get(projection_name, [])
for column in columns_to_replace:
column.replace(expression.unalias())
def _merge_where(outer_scope, inner_scope, from_or_join):
"""
Merge WHERE clause of inner query into outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
from_or_join (exp.From|exp.Join)
"""
where = inner_scope.expression.args.get("where")
if not where or not where.this:
return
if isinstance(from_or_join, exp.Join) and from_or_join.side:
# Merge predicates from an outer join to the ON clause
from_or_join.on(where.this, copy=False)
from_or_join.set("on", simplify(from_or_join.args.get("on")))
else:
outer_scope.expression.where(where.this, copy=False)
outer_scope.expression.set("where", simplify(outer_scope.expression.args.get("where")))
def _merge_order(outer_scope, inner_scope):
"""
Merge ORDER clause of inner query into outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
"""
if (
any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"])
or len(outer_scope.selected_sources) != 1
or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
):
return
outer_scope.expression.set("order", inner_scope.expression.args.get("order"))

View file

@ -22,18 +22,14 @@ def normalize(expression, dnf=False, max_distance=128):
""" """
expression = simplify(expression) expression = simplify(expression)
expression = while_changing( expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance))
expression, lambda e: distributive_law(e, dnf, max_distance)
)
return simplify(expression) return simplify(expression)
def normalized(expression, dnf=False): def normalized(expression, dnf=False):
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
return not any( return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root))
connector.find_ancestor(ancestor) for connector in expression.find_all(root)
)
def normalization_distance(expression, dnf=False): def normalization_distance(expression, dnf=False):
@ -54,9 +50,7 @@ def normalization_distance(expression, dnf=False):
Returns: Returns:
int: difference int: difference
""" """
return sum(_predicate_lengths(expression, dnf)) - ( return sum(_predicate_lengths(expression, dnf)) - (len(list(expression.find_all(exp.Connector))) + 1)
len(list(expression.find_all(exp.Connector))) + 1
)
def _predicate_lengths(expression, dnf): def _predicate_lengths(expression, dnf):
@ -73,11 +67,7 @@ def _predicate_lengths(expression, dnf):
left, right = expression.args.values() left, right = expression.args.values()
if isinstance(expression, exp.And if dnf else exp.Or): if isinstance(expression, exp.And if dnf else exp.Or):
x = [ x = [a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)]
a + b
for a in _predicate_lengths(left, dnf)
for b in _predicate_lengths(right, dnf)
]
return x return x
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
@ -102,9 +92,7 @@ def distributive_law(expression, dnf, max_distance):
to_func = exp.and_ if to_exp == exp.And else exp.or_ to_func = exp.and_ if to_exp == exp.And else exp.or_
if isinstance(a, to_exp) and isinstance(b, to_exp): if isinstance(a, to_exp) and isinstance(b, to_exp):
if len(tuple(a.find_all(exp.Connector))) > len( if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
tuple(b.find_all(exp.Connector))
):
return _distribute(a, b, from_func, to_func) return _distribute(a, b, from_func, to_func)
return _distribute(b, a, from_func, to_func) return _distribute(b, a, from_func, to_func)
if isinstance(a, to_exp): if isinstance(a, to_exp):

View file

@ -68,8 +68,4 @@ def normalize(expression):
def other_table_names(join, exclude): def other_table_names(join, exclude):
return [ return [name for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) if name != exclude]
name
for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
if name != exclude
]

View file

@ -1,6 +1,7 @@
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.merge_derived_tables import merge_derived_tables
from sqlglot.optimizer.normalize import normalize from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins from sqlglot.optimizer.optimize_joins import optimize_joins
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
@ -10,8 +11,23 @@ from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.quote_identities import quote_identities from sqlglot.optimizer.quote_identities import quote_identities
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
RULES = (
qualify_tables,
isolate_table_selects,
qualify_columns,
pushdown_projections,
normalize,
unnest_subqueries,
expand_multi_table_selects,
pushdown_predicates,
optimize_joins,
eliminate_subqueries,
merge_derived_tables,
quote_identities,
)
def optimize(expression, schema=None, db=None, catalog=None):
def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwargs):
""" """
Rewrite a sqlglot AST into an optimized form. Rewrite a sqlglot AST into an optimized form.
@ -25,19 +41,18 @@ def optimize(expression, schema=None, db=None, catalog=None):
3. {catalog: {db: {table: {col: type}}}} 3. {catalog: {db: {table: {col: type}}}}
db (str): specify the default database, as might be set by a `USE DATABASE db` statement db (str): specify the default database, as might be set by a `USE DATABASE db` statement
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
rules (list): sequence of optimizer rules to use
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
Returns: Returns:
sqlglot.Expression: optimized expression sqlglot.Expression: optimized expression
""" """
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
expression = expression.copy() expression = expression.copy()
expression = qualify_tables(expression, db=db, catalog=catalog) for rule in rules:
expression = isolate_table_selects(expression)
expression = qualify_columns(expression, schema) # Find any additional rule parameters, beyond `expression`
expression = pushdown_projections(expression) rule_params = rule.__code__.co_varnames
expression = normalize(expression) rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs}
expression = unnest_subqueries(expression)
expression = expand_multi_table_selects(expression) expression = rule(expression, **rule_kwargs)
expression = pushdown_predicates(expression)
expression = optimize_joins(expression)
expression = eliminate_subqueries(expression)
expression = quote_identities(expression)
return expression return expression

View file

@ -42,11 +42,7 @@ def pushdown(condition, sources):
condition = condition.replace(simplify(condition)) condition = condition.replace(simplify(condition))
cnf_like = normalized(condition) or not normalized(condition, dnf=True) cnf_like = normalized(condition) or not normalized(condition, dnf=True)
predicates = list( predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition])
condition.flatten()
if isinstance(condition, exp.And if cnf_like else exp.Or)
else [condition]
)
if cnf_like: if cnf_like:
pushdown_cnf(predicates, sources) pushdown_cnf(predicates, sources)
@ -105,17 +101,11 @@ def pushdown_dnf(predicates, scope):
for column in predicate.find_all(exp.Column): for column in predicate.find_all(exp.Column):
if column.table == table: if column.table == table:
condition = column.find_ancestor(exp.Condition) condition = column.find_ancestor(exp.Condition)
predicate_condition = ( predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition
exp.and_(predicate_condition, condition)
if predicate_condition
else condition
)
if predicate_condition: if predicate_condition:
conditions[table] = ( conditions[table] = (
exp.or_(conditions[table], predicate_condition) exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition
if table in conditions
else predicate_condition
) )
for name, node in nodes.items(): for name, node in nodes.items():
@ -133,9 +123,7 @@ def pushdown_dnf(predicates, scope):
def nodes_for_predicate(predicate, sources): def nodes_for_predicate(predicate, sources):
nodes = {} nodes = {}
tables = exp.column_table_names(predicate) tables = exp.column_table_names(predicate)
where_condition = isinstance( where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
predicate.find_ancestor(exp.Join, exp.Where), exp.Where
)
for table in tables: for table in tables:
node, source = sources.get(table) or (None, None) node, source = sources.get(table) or (None, None)

View file

@ -226,9 +226,7 @@ def _expand_stars(scope, resolver):
tables = list(scope.selected_sources) tables = list(scope.selected_sources)
_add_except_columns(expression, tables, except_columns) _add_except_columns(expression, tables, except_columns)
_add_replace_columns(expression, tables, replace_columns) _add_replace_columns(expression, tables, replace_columns)
elif isinstance(expression, exp.Column) and isinstance( elif isinstance(expression, exp.Column) and isinstance(expression.this, exp.Star):
expression.this, exp.Star
):
tables = [expression.table] tables = [expression.table]
_add_except_columns(expression.this, tables, except_columns) _add_except_columns(expression.this, tables, except_columns)
_add_replace_columns(expression.this, tables, replace_columns) _add_replace_columns(expression.this, tables, replace_columns)
@ -245,9 +243,7 @@ def _expand_stars(scope, resolver):
if name not in except_columns.get(table_id, set()): if name not in except_columns.get(table_id, set()):
alias_ = replace_columns.get(table_id, {}).get(name, name) alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table) column = exp.column(name, table)
new_selections.append( new_selections.append(alias(column, alias_) if alias_ != name else column)
alias(column, alias_) if alias_ != name else column
)
scope.expression.set("expressions", new_selections) scope.expression.set("expressions", new_selections)
@ -280,9 +276,7 @@ def _qualify_outputs(scope):
"""Ensure all output columns are aliased""" """Ensure all output columns are aliased"""
new_selections = [] new_selections = []
for i, (selection, aliased_column) in enumerate( for i, (selection, aliased_column) in enumerate(itertools.zip_longest(scope.selects, scope.outer_column_list)):
itertools.zip_longest(scope.selects, scope.outer_column_list)
):
if isinstance(selection, exp.Column): if isinstance(selection, exp.Column):
# convoluted setter because a simple selection.replace(alias) would require a copy # convoluted setter because a simple selection.replace(alias) would require a copy
alias_ = alias(exp.column(""), alias=selection.name) alias_ = alias(exp.column(""), alias=selection.name)
@ -302,11 +296,7 @@ def _qualify_outputs(scope):
def _check_unknown_tables(scope): def _check_unknown_tables(scope):
if ( if scope.external_columns and not scope.is_unnest and not scope.is_correlated_subquery:
scope.external_columns
and not scope.is_unnest
and not scope.is_correlated_subquery
):
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}") raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
@ -334,20 +324,14 @@ class _Resolver:
(str) table name (str) table name
""" """
if self._unambiguous_columns is None: if self._unambiguous_columns is None:
self._unambiguous_columns = self._get_unambiguous_columns( self._unambiguous_columns = self._get_unambiguous_columns(self._get_all_source_columns())
self._get_all_source_columns()
)
return self._unambiguous_columns.get(column_name) return self._unambiguous_columns.get(column_name)
@property @property
def all_columns(self): def all_columns(self):
"""All available columns of all sources in this scope""" """All available columns of all sources in this scope"""
if self._all_columns is None: if self._all_columns is None:
self._all_columns = set( self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns)
column
for columns in self._get_all_source_columns().values()
for column in columns
)
return self._all_columns return self._all_columns
def get_source_columns(self, name): def get_source_columns(self, name):
@ -369,9 +353,7 @@ class _Resolver:
def _get_all_source_columns(self): def _get_all_source_columns(self):
if self._source_columns is None: if self._source_columns is None:
self._source_columns = { self._source_columns = {k: self.get_source_columns(k) for k in self.scope.selected_sources}
k: self.get_source_columns(k) for k in self.scope.selected_sources
}
return self._source_columns return self._source_columns
def _get_unambiguous_columns(self, source_columns): def _get_unambiguous_columns(self, source_columns):
@ -389,9 +371,7 @@ class _Resolver:
source_columns = list(source_columns.items()) source_columns = list(source_columns.items())
first_table, first_columns = source_columns[0] first_table, first_columns = source_columns[0]
unambiguous_columns = { unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
col: first_table for col in self._find_unique_columns(first_columns)
}
all_columns = set(unambiguous_columns) all_columns = set(unambiguous_columns)
for table, columns in source_columns[1:]: for table, columns in source_columns[1:]:

View file

@ -27,9 +27,7 @@ def qualify_tables(expression, db=None, catalog=None):
for derived_table in scope.ctes + scope.derived_tables: for derived_table in scope.ctes + scope.derived_tables:
if not derived_table.args.get("alias"): if not derived_table.args.get("alias"):
alias_ = f"_q_{next(sequence)}" alias_ = f"_q_{next(sequence)}"
derived_table.set( derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
"alias", exp.TableAlias(this=exp.to_identifier(alias_))
)
scope.rename_source(None, alias_) scope.rename_source(None, alias_)
for source in scope.sources.values(): for source in scope.sources.values():

View file

@ -57,9 +57,7 @@ class MappingSchema(Schema):
for forbidden in self.forbidden_args: for forbidden in self.forbidden_args:
if table.text(forbidden): if table.text(forbidden):
raise ValueError( raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
f"Schema doesn't support {forbidden}. Received: {table.sql()}"
)
return list(_nested_get(self.schema, *zip(self.supported_table_args, args))) return list(_nested_get(self.schema, *zip(self.supported_table_args, args)))

View file

@ -104,9 +104,7 @@ class Scope:
elif isinstance(node, exp.CTE): elif isinstance(node, exp.CTE):
self._ctes.append(node) self._ctes.append(node)
prune = True prune = True
elif isinstance(node, exp.Subquery) and isinstance( elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
parent, (exp.From, exp.Join)
):
self._derived_tables.append(node) self._derived_tables.append(node)
prune = True prune = True
elif isinstance(node, exp.Subqueryable): elif isinstance(node, exp.Subqueryable):
@ -195,20 +193,14 @@ class Scope:
self._ensure_collected() self._ensure_collected()
columns = self._raw_columns columns = self._raw_columns
external_columns = [ external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns]
column
for scope in self.subquery_scopes
for column in scope.external_columns
]
named_outputs = {e.alias_or_name for e in self.expression.expressions} named_outputs = {e.alias_or_name for e in self.expression.expressions}
self._columns = [ self._columns = [
c c
for c in columns + external_columns for c in columns + external_columns
if not ( if not (c.find_ancestor(exp.Qualify, exp.Order) and not c.table and c.name in named_outputs)
c.find_ancestor(exp.Qualify, exp.Order) and c.name in named_outputs
)
] ]
return self._columns return self._columns
@ -229,9 +221,7 @@ class Scope:
for table in self.tables: for table in self.tables:
referenced_names.append( referenced_names.append(
( (
table.parent.alias table.parent.alias if isinstance(table.parent, exp.Alias) else table.name,
if isinstance(table.parent, exp.Alias)
else table.name,
table, table,
) )
) )
@ -274,9 +264,7 @@ class Scope:
sources in the current scope. sources in the current scope.
""" """
if self._external_columns is None: if self._external_columns is None:
self._external_columns = [ self._external_columns = [c for c in self.columns if c.table not in self.selected_sources]
c for c in self.columns if c.table not in self.selected_sources
]
return self._external_columns return self._external_columns
def source_columns(self, source_name): def source_columns(self, source_name):
@ -310,6 +298,16 @@ class Scope:
columns = self.sources.pop(old_name or "", []) columns = self.sources.pop(old_name or "", [])
self.sources[new_name] = columns self.sources[new_name] = columns
def add_source(self, name, source):
"""Add a source to this scope"""
self.sources[name] = source
self.clear_cache()
def remove_source(self, name):
"""Remove a source from this scope"""
self.sources.pop(name, None)
self.clear_cache()
def traverse_scope(expression): def traverse_scope(expression):
""" """
@ -334,7 +332,7 @@ def traverse_scope(expression):
Args: Args:
expression (exp.Expression): expression to traverse expression (exp.Expression): expression to traverse
Returns: Returns:
List[Scope]: scope instances list[Scope]: scope instances
""" """
return list(_traverse_scope(Scope(expression))) return list(_traverse_scope(Scope(expression)))
@ -356,9 +354,7 @@ def _traverse_scope(scope):
def _traverse_select(scope): def _traverse_select(scope):
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE) yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
yield from _traverse_subqueries(scope) yield from _traverse_subqueries(scope)
yield from _traverse_derived_tables( yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
scope.derived_tables, scope, ScopeType.DERIVED_TABLE
)
_add_table_sources(scope) _add_table_sources(scope)
@ -367,15 +363,11 @@ def _traverse_union(scope):
# The last scope to be yield should be the top most scope # The last scope to be yield should be the top most scope
left = None left = None
for left in _traverse_scope( for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)):
scope.branch(scope.expression.left, scope_type=ScopeType.UNION)
):
yield left yield left
right = None right = None
for right in _traverse_scope( for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
scope.branch(scope.expression.right, scope_type=ScopeType.UNION)
):
yield right yield right
scope.union = (left, right) scope.union = (left, right)
@ -387,14 +379,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
for derived_table in derived_tables: for derived_table in derived_tables:
for child_scope in _traverse_scope( for child_scope in _traverse_scope(
scope.branch( scope.branch(
derived_table derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
if isinstance(derived_table, (exp.Unnest, exp.Lateral))
else derived_table.this,
add_sources=sources if scope_type == ScopeType.CTE else None, add_sources=sources if scope_type == ScopeType.CTE else None,
outer_column_list=derived_table.alias_column_names, outer_column_list=derived_table.alias_column_names,
scope_type=ScopeType.UNNEST scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type,
if isinstance(derived_table, exp.Unnest)
else scope_type,
) )
): ):
yield child_scope yield child_scope
@ -430,9 +418,7 @@ def _add_table_sources(scope):
def _traverse_subqueries(scope): def _traverse_subqueries(scope):
for subquery in scope.subqueries: for subquery in scope.subqueries:
top = None top = None
for child_scope in _traverse_scope( for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
scope.branch(subquery, scope_type=ScopeType.SUBQUERY)
):
yield child_scope yield child_scope
top = child_scope top = child_scope
scope.subquery_scopes.append(top) scope.subquery_scopes.append(top)

View file

@ -188,9 +188,7 @@ def absorb_and_eliminate(expression):
aa.replace(exp.TRUE if kind == exp.And else exp.FALSE) aa.replace(exp.TRUE if kind == exp.And else exp.FALSE)
elif is_complement(b, ab): elif is_complement(b, ab):
ab.replace(exp.TRUE if kind == exp.And else exp.FALSE) ab.replace(exp.TRUE if kind == exp.And else exp.FALSE)
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set( elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
a.flatten()
):
a.replace(exp.FALSE if kind == exp.And else exp.TRUE) a.replace(exp.FALSE if kind == exp.And else exp.TRUE)
elif isinstance(b, kind): elif isinstance(b, kind):
# eliminate # eliminate
@ -227,9 +225,7 @@ def simplify_literals(expression):
operands.append(a) operands.append(a)
if len(operands) < size: if len(operands) < size:
return functools.reduce( return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
lambda a, b: expression.__class__(this=a, expression=b), operands
)
elif isinstance(expression, exp.Neg): elif isinstance(expression, exp.Neg):
this = expression.this this = expression.this
if this.is_number: if this.is_number:

View file

@ -89,11 +89,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
return return
if isinstance(predicate, exp.Binary): if isinstance(predicate, exp.Binary):
key = ( key = predicate.right if any(node is column for node, *_ in predicate.left.walk()) else predicate.left
predicate.right
if any(node is column for node, *_ in predicate.left.walk())
else predicate.left
)
else: else:
return return
@ -124,9 +120,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
# if the value of the subquery is not an agg or a key, we need to collect it into an array # if the value of the subquery is not an agg or a key, we need to collect it into an array
# so that it can be grouped # so that it can be grouped
if not value.find(exp.AggFunc) and value.this not in group_by: if not value.find(exp.AggFunc) and value.this not in group_by:
select.select( select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False)
f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False
)
# exists queries should not have any selects as it only checks if there are any rows # exists queries should not have any selects as it only checks if there are any rows
# all selects will be added by the optimizer and only used for join keys # all selects will be added by the optimizer and only used for join keys
@ -151,16 +145,12 @@ def decorrelate(select, parent_select, external_columns, sequence):
else: else:
parent_predicate = _replace(parent_predicate, "TRUE") parent_predicate = _replace(parent_predicate, "TRUE")
elif isinstance(parent_predicate, exp.All): elif isinstance(parent_predicate, exp.All):
parent_predicate = _replace( parent_predicate = _replace(parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})")
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
)
elif isinstance(parent_predicate, exp.Any): elif isinstance(parent_predicate, exp.Any):
if value.this in group_by: if value.this in group_by:
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
else: else:
parent_predicate = _replace( parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})")
parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})"
)
elif isinstance(parent_predicate, exp.In): elif isinstance(parent_predicate, exp.In):
if value.this in group_by: if value.this in group_by:
parent_predicate = _replace(parent_predicate, f"{other} = {alias}") parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
@ -178,9 +168,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
if key in group_by: if key in group_by:
key.replace(nested) key.replace(nested)
parent_predicate = _replace( parent_predicate = _replace(parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)")
parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
)
elif isinstance(predicate, exp.EQ): elif isinstance(predicate, exp.EQ):
parent_predicate = _replace( parent_predicate = _replace(
parent_predicate, parent_predicate,

View file

@ -78,6 +78,7 @@ class Parser:
TokenType.TEXT, TokenType.TEXT,
TokenType.BINARY, TokenType.BINARY,
TokenType.JSON, TokenType.JSON,
TokenType.INTERVAL,
TokenType.TIMESTAMP, TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ, TokenType.TIMESTAMPTZ,
TokenType.DATETIME, TokenType.DATETIME,
@ -85,6 +86,12 @@ class Parser:
TokenType.DECIMAL, TokenType.DECIMAL,
TokenType.UUID, TokenType.UUID,
TokenType.GEOGRAPHY, TokenType.GEOGRAPHY,
TokenType.GEOMETRY,
TokenType.HLLSKETCH,
TokenType.SUPER,
TokenType.SERIAL,
TokenType.SMALLSERIAL,
TokenType.BIGSERIAL,
*NESTED_TYPE_TOKENS, *NESTED_TYPE_TOKENS,
} }
@ -100,13 +107,14 @@ class Parser:
ID_VAR_TOKENS = { ID_VAR_TOKENS = {
TokenType.VAR, TokenType.VAR,
TokenType.ALTER, TokenType.ALTER,
TokenType.ALWAYS,
TokenType.BEGIN, TokenType.BEGIN,
TokenType.BOTH,
TokenType.BUCKET, TokenType.BUCKET,
TokenType.CACHE, TokenType.CACHE,
TokenType.COLLATE, TokenType.COLLATE,
TokenType.COMMIT, TokenType.COMMIT,
TokenType.CONSTRAINT, TokenType.CONSTRAINT,
TokenType.CONVERT,
TokenType.DEFAULT, TokenType.DEFAULT,
TokenType.DELETE, TokenType.DELETE,
TokenType.ENGINE, TokenType.ENGINE,
@ -115,14 +123,19 @@ class Parser:
TokenType.FALSE, TokenType.FALSE,
TokenType.FIRST, TokenType.FIRST,
TokenType.FOLLOWING, TokenType.FOLLOWING,
TokenType.FOR,
TokenType.FORMAT, TokenType.FORMAT,
TokenType.FUNCTION, TokenType.FUNCTION,
TokenType.GENERATED,
TokenType.IDENTITY,
TokenType.IF, TokenType.IF,
TokenType.INDEX, TokenType.INDEX,
TokenType.ISNULL, TokenType.ISNULL,
TokenType.INTERVAL, TokenType.INTERVAL,
TokenType.LAZY, TokenType.LAZY,
TokenType.LEADING,
TokenType.LOCATION, TokenType.LOCATION,
TokenType.NATURAL,
TokenType.NEXT, TokenType.NEXT,
TokenType.ONLY, TokenType.ONLY,
TokenType.OPTIMIZE, TokenType.OPTIMIZE,
@ -141,6 +154,7 @@ class Parser:
TokenType.TABLE_FORMAT, TokenType.TABLE_FORMAT,
TokenType.TEMPORARY, TokenType.TEMPORARY,
TokenType.TOP, TokenType.TOP,
TokenType.TRAILING,
TokenType.TRUNCATE, TokenType.TRUNCATE,
TokenType.TRUE, TokenType.TRUE,
TokenType.UNBOUNDED, TokenType.UNBOUNDED,
@ -150,18 +164,15 @@ class Parser:
*TYPE_TOKENS, *TYPE_TOKENS,
} }
CASTS = { TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL}
TokenType.CAST,
TokenType.TRY_CAST, TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
}
FUNC_TOKENS = { FUNC_TOKENS = {
TokenType.CONVERT,
TokenType.CURRENT_DATE, TokenType.CURRENT_DATE,
TokenType.CURRENT_DATETIME, TokenType.CURRENT_DATETIME,
TokenType.CURRENT_TIMESTAMP, TokenType.CURRENT_TIMESTAMP,
TokenType.CURRENT_TIME, TokenType.CURRENT_TIME,
TokenType.EXTRACT,
TokenType.FILTER, TokenType.FILTER,
TokenType.FIRST, TokenType.FIRST,
TokenType.FORMAT, TokenType.FORMAT,
@ -178,7 +189,6 @@ class Parser:
TokenType.DATETIME, TokenType.DATETIME,
TokenType.TIMESTAMP, TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ, TokenType.TIMESTAMPTZ,
*CASTS,
*NESTED_TYPE_TOKENS, *NESTED_TYPE_TOKENS,
*SUBQUERY_PREDICATES, *SUBQUERY_PREDICATES,
} }
@ -215,6 +225,7 @@ class Parser:
FACTOR = { FACTOR = {
TokenType.DIV: exp.IntDiv, TokenType.DIV: exp.IntDiv,
TokenType.LR_ARROW: exp.Distance,
TokenType.SLASH: exp.Div, TokenType.SLASH: exp.Div,
TokenType.STAR: exp.Mul, TokenType.STAR: exp.Mul,
} }
@ -299,14 +310,13 @@ class Parser:
PRIMARY_PARSERS = { PRIMARY_PARSERS = {
TokenType.STRING: lambda _, token: exp.Literal.string(token.text), TokenType.STRING: lambda _, token: exp.Literal.string(token.text),
TokenType.NUMBER: lambda _, token: exp.Literal.number(token.text), TokenType.NUMBER: lambda _, token: exp.Literal.number(token.text),
TokenType.STAR: lambda self, _: exp.Star( TokenType.STAR: lambda self, _: exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}),
**{"except": self._parse_except(), "replace": self._parse_replace()}
),
TokenType.NULL: lambda *_: exp.Null(), TokenType.NULL: lambda *_: exp.Null(),
TokenType.TRUE: lambda *_: exp.Boolean(this=True), TokenType.TRUE: lambda *_: exp.Boolean(this=True),
TokenType.FALSE: lambda *_: exp.Boolean(this=False), TokenType.FALSE: lambda *_: exp.Boolean(this=False),
TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(), TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(),
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text), TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
TokenType.INTRODUCER: lambda self, token: self.expression( TokenType.INTRODUCER: lambda self, token: self.expression(
exp.Introducer, exp.Introducer,
this=token.text, this=token.text,
@ -319,13 +329,16 @@ class Parser:
TokenType.IN: lambda self, this: self._parse_in(this), TokenType.IN: lambda self, this: self._parse_in(this),
TokenType.IS: lambda self, this: self._parse_is(this), TokenType.IS: lambda self, this: self._parse_is(this),
TokenType.LIKE: lambda self, this: self._parse_escape( TokenType.LIKE: lambda self, this: self._parse_escape(
self.expression(exp.Like, this=this, expression=self._parse_type()) self.expression(exp.Like, this=this, expression=self._parse_bitwise())
), ),
TokenType.ILIKE: lambda self, this: self._parse_escape( TokenType.ILIKE: lambda self, this: self._parse_escape(
self.expression(exp.ILike, this=this, expression=self._parse_type()) self.expression(exp.ILike, this=this, expression=self._parse_bitwise())
), ),
TokenType.RLIKE: lambda self, this: self.expression( TokenType.RLIKE: lambda self, this: self.expression(
exp.RegexpLike, this=this, expression=self._parse_type() exp.RegexpLike, this=this, expression=self._parse_bitwise()
),
TokenType.SIMILAR_TO: lambda self, this: self.expression(
exp.SimilarTo, this=this, expression=self._parse_bitwise()
), ),
} }
@ -363,28 +376,21 @@ class Parser:
} }
FUNCTION_PARSERS = { FUNCTION_PARSERS = {
TokenType.CONVERT: lambda self, _: self._parse_convert(), "CONVERT": lambda self: self._parse_convert(),
TokenType.EXTRACT: lambda self, _: self._parse_extract(), "EXTRACT": lambda self: self._parse_extract(),
**{ "SUBSTRING": lambda self: self._parse_substring(),
token_type: lambda self, token_type: self._parse_cast( "TRIM": lambda self: self._parse_trim(),
self.STRICT_CAST and token_type == TokenType.CAST "CAST": lambda self: self._parse_cast(self.STRICT_CAST),
) "TRY_CAST": lambda self: self._parse_cast(False),
for token_type in CASTS
},
} }
QUERY_MODIFIER_PARSERS = { QUERY_MODIFIER_PARSERS = {
"laterals": lambda self: self._parse_laterals(),
"joins": lambda self: self._parse_joins(),
"where": lambda self: self._parse_where(), "where": lambda self: self._parse_where(),
"group": lambda self: self._parse_group(), "group": lambda self: self._parse_group(),
"having": lambda self: self._parse_having(), "having": lambda self: self._parse_having(),
"qualify": lambda self: self._parse_qualify(), "qualify": lambda self: self._parse_qualify(),
"window": lambda self: self._match(TokenType.WINDOW) "window": lambda self: self._match(TokenType.WINDOW) and self._parse_window(self._parse_id_var(), alias=True),
and self._parse_window(self._parse_id_var(), alias=True), "distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute),
"distribute": lambda self: self._parse_sort(
TokenType.DISTRIBUTE_BY, exp.Distribute
),
"sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort), "sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
"cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster), "cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
"order": lambda self: self._parse_order(), "order": lambda self: self._parse_order(),
@ -392,6 +398,8 @@ class Parser:
"offset": lambda self: self._parse_offset(), "offset": lambda self: self._parse_offset(),
} }
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX} CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX}
STRICT_CAST = True STRICT_CAST = True
@ -457,9 +465,7 @@ class Parser:
Returns Returns
the list of syntax trees (:class:`~sqlglot.expressions.Expression`). the list of syntax trees (:class:`~sqlglot.expressions.Expression`).
""" """
return self._parse( return self._parse(parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql)
parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql
)
def parse_into(self, expression_types, raw_tokens, sql=None): def parse_into(self, expression_types, raw_tokens, sql=None):
for expression_type in ensure_list(expression_types): for expression_type in ensure_list(expression_types):
@ -532,21 +538,13 @@ class Parser:
for k in expression.args: for k in expression.args:
if k not in expression.arg_types: if k not in expression.arg_types:
self.raise_error( self.raise_error(f"Unexpected keyword: '{k}' for {expression.__class__}")
f"Unexpected keyword: '{k}' for {expression.__class__}"
)
for k, mandatory in expression.arg_types.items(): for k, mandatory in expression.arg_types.items():
v = expression.args.get(k) v = expression.args.get(k)
if mandatory and (v is None or (isinstance(v, list) and not v)): if mandatory and (v is None or (isinstance(v, list) and not v)):
self.raise_error( self.raise_error(f"Required keyword: '{k}' missing for {expression.__class__}")
f"Required keyword: '{k}' missing for {expression.__class__}"
)
if ( if args and len(args) > len(expression.arg_types) and not expression.is_var_len_args:
args
and len(args) > len(expression.arg_types)
and not expression.is_var_len_args
):
self.raise_error( self.raise_error(
f"The number of provided arguments ({len(args)}) is greater than " f"The number of provided arguments ({len(args)}) is greater than "
f"the maximum number of supported arguments ({len(expression.arg_types)})" f"the maximum number of supported arguments ({len(expression.arg_types)})"
@ -594,11 +592,7 @@ class Parser:
) )
expression = self._parse_expression() expression = self._parse_expression()
expression = ( expression = self._parse_set_operations(expression) if expression else self._parse_select()
self._parse_set_operations(expression)
if expression
else self._parse_select()
)
self._parse_query_modifiers(expression) self._parse_query_modifiers(expression)
return expression return expression
@ -618,11 +612,7 @@ class Parser:
) )
def _parse_exists(self, not_=False): def _parse_exists(self, not_=False):
return ( return self._match(TokenType.IF) and (not not_ or self._match(TokenType.NOT)) and self._match(TokenType.EXISTS)
self._match(TokenType.IF)
and (not not_ or self._match(TokenType.NOT))
and self._match(TokenType.EXISTS)
)
def _parse_create(self): def _parse_create(self):
replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE) replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
@ -647,11 +637,9 @@ class Parser:
this = self._parse_index() this = self._parse_index()
elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW): elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW):
this = self._parse_table(schema=True) this = self._parse_table(schema=True)
properties = self._parse_properties( properties = self._parse_properties(this if isinstance(this, exp.Schema) else None)
this if isinstance(this, exp.Schema) else None
)
if self._match(TokenType.ALIAS): if self._match(TokenType.ALIAS):
expression = self._parse_select() expression = self._parse_select(nested=True)
return self.expression( return self.expression(
exp.Create, exp.Create,
@ -682,9 +670,7 @@ class Parser:
if schema and not isinstance(value, exp.Schema): if schema and not isinstance(value, exp.Schema):
columns = {v.name.upper() for v in value.expressions} columns = {v.name.upper() for v in value.expressions}
partitions = [ partitions = [
expression expression for expression in schema.expressions if expression.this.name.upper() in columns
for expression in schema.expressions
if expression.this.name.upper() in columns
] ]
schema.set( schema.set(
"expressions", "expressions",
@ -811,7 +797,7 @@ class Parser:
this=self._parse_table(schema=True), this=self._parse_table(schema=True),
exists=self._parse_exists(), exists=self._parse_exists(),
partition=self._parse_partition(), partition=self._parse_partition(),
expression=self._parse_select(), expression=self._parse_select(nested=True),
overwrite=overwrite, overwrite=overwrite,
) )
@ -829,8 +815,7 @@ class Parser:
exp.Update, exp.Update,
**{ **{
"this": self._parse_table(schema=True), "this": self._parse_table(schema=True),
"expressions": self._match(TokenType.SET) "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality),
and self._parse_csv(self._parse_equality),
"from": self._parse_from(), "from": self._parse_from(),
"where": self._parse_where(), "where": self._parse_where(),
}, },
@ -865,7 +850,7 @@ class Parser:
this=table, this=table,
lazy=lazy, lazy=lazy,
options=options, options=options,
expression=self._parse_select(), expression=self._parse_select(nested=True),
) )
def _parse_partition(self): def _parse_partition(self):
@ -894,9 +879,7 @@ class Parser:
self._match_r_paren() self._match_r_paren()
return self.expression(exp.Tuple, expressions=expressions) return self.expression(exp.Tuple, expressions=expressions)
def _parse_select(self, table=None): def _parse_select(self, nested=False, table=False):
index = self._index
if self._match(TokenType.SELECT): if self._match(TokenType.SELECT):
hint = self._parse_hint() hint = self._parse_hint()
all_ = self._match(TokenType.ALL) all_ = self._match(TokenType.ALL)
@ -912,9 +895,7 @@ class Parser:
self.raise_error("Cannot specify both ALL and DISTINCT after SELECT") self.raise_error("Cannot specify both ALL and DISTINCT after SELECT")
limit = self._parse_limit(top=True) limit = self._parse_limit(top=True)
expressions = self._parse_csv( expressions = self._parse_csv(lambda: self._parse_annotation(self._parse_expression()))
lambda: self._parse_annotation(self._parse_expression())
)
this = self.expression( this = self.expression(
exp.Select, exp.Select,
@ -960,19 +941,13 @@ class Parser:
) )
else: else:
self.raise_error(f"{this.key} does not support CTE") self.raise_error(f"{this.key} does not support CTE")
elif self._match(TokenType.L_PAREN): elif (table or nested) and self._match(TokenType.L_PAREN):
this = self._parse_table() if table else self._parse_select() this = self._parse_table() if table else self._parse_select(nested=True)
if this:
self._parse_query_modifiers(this) self._parse_query_modifiers(this)
self._match_r_paren() self._match_r_paren()
this = self._parse_subquery(this) this = self._parse_subquery(this)
else:
self._retreat(index)
elif self._match(TokenType.VALUES): elif self._match(TokenType.VALUES):
this = self.expression( this = self.expression(exp.Values, expressions=self._parse_csv(self._parse_value))
exp.Values, expressions=self._parse_csv(self._parse_value)
)
alias = self._parse_table_alias() alias = self._parse_table_alias()
if alias: if alias:
this = self.expression(exp.Subquery, this=this, alias=alias) this = self.expression(exp.Subquery, this=this, alias=alias)
@ -1001,7 +976,7 @@ class Parser:
def _parse_table_alias(self): def _parse_table_alias(self):
any_token = self._match(TokenType.ALIAS) any_token = self._match(TokenType.ALIAS)
alias = self._parse_id_var(any_token) alias = self._parse_id_var(any_token=any_token, tokens=self.TABLE_ALIAS_TOKENS)
columns = None columns = None
if self._match(TokenType.L_PAREN): if self._match(TokenType.L_PAREN):
@ -1021,9 +996,24 @@ class Parser:
return self.expression(exp.Subquery, this=this, alias=self._parse_table_alias()) return self.expression(exp.Subquery, this=this, alias=self._parse_table_alias())
def _parse_query_modifiers(self, this): def _parse_query_modifiers(self, this):
if not isinstance(this, (exp.Subquery, exp.Subqueryable)): if not isinstance(this, self.MODIFIABLES):
return return
table = isinstance(this, exp.Table)
while True:
lateral = self._parse_lateral()
join = self._parse_join()
comma = None if table else self._match(TokenType.COMMA)
if lateral:
this.append("laterals", lateral)
if join:
this.append("joins", join)
if comma:
this.args["from"].append("expressions", self._parse_table())
if not (lateral or join or comma):
break
for key, parser in self.QUERY_MODIFIER_PARSERS.items(): for key, parser in self.QUERY_MODIFIER_PARSERS.items():
expression = parser(self) expression = parser(self)
@ -1032,9 +1022,7 @@ class Parser:
def _parse_annotation(self, expression): def _parse_annotation(self, expression):
if self._match(TokenType.ANNOTATION): if self._match(TokenType.ANNOTATION):
return self.expression( return self.expression(exp.Annotation, this=self._prev.text, expression=expression)
exp.Annotation, this=self._prev.text, expression=expression
)
return expression return expression
@ -1052,16 +1040,16 @@ class Parser:
return self.expression(exp.From, expressions=self._parse_csv(self._parse_table)) return self.expression(exp.From, expressions=self._parse_csv(self._parse_table))
def _parse_laterals(self):
return self._parse_all(self._parse_lateral)
def _parse_lateral(self): def _parse_lateral(self):
if not self._match(TokenType.LATERAL): if not self._match(TokenType.LATERAL):
return None return None
if not self._match(TokenType.VIEW): subquery = self._parse_select(table=True)
self.raise_error("Expected VIEW after LATERAL")
if subquery:
return self.expression(exp.Lateral, this=subquery)
self._match(TokenType.VIEW)
outer = self._match(TokenType.OUTER) outer = self._match(TokenType.OUTER)
return self.expression( return self.expression(
@ -1071,31 +1059,27 @@ class Parser:
alias=self.expression( alias=self.expression(
exp.TableAlias, exp.TableAlias,
this=self._parse_id_var(any_token=False), this=self._parse_id_var(any_token=False),
columns=( columns=(self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else None),
self._parse_csv(self._parse_id_var)
if self._match(TokenType.ALIAS)
else None
),
), ),
) )
def _parse_joins(self):
return self._parse_all(self._parse_join)
def _parse_join_side_and_kind(self): def _parse_join_side_and_kind(self):
return ( return (
self._match(TokenType.NATURAL) and self._prev,
self._match_set(self.JOIN_SIDES) and self._prev, self._match_set(self.JOIN_SIDES) and self._prev,
self._match_set(self.JOIN_KINDS) and self._prev, self._match_set(self.JOIN_KINDS) and self._prev,
) )
def _parse_join(self): def _parse_join(self):
side, kind = self._parse_join_side_and_kind() natural, side, kind = self._parse_join_side_and_kind()
if not self._match(TokenType.JOIN): if not self._match(TokenType.JOIN):
return None return None
kwargs = {"this": self._parse_table()} kwargs = {"this": self._parse_table()}
if natural:
kwargs["natural"] = True
if side: if side:
kwargs["side"] = side.text kwargs["side"] = side.text
if kind: if kind:
@ -1120,6 +1104,11 @@ class Parser:
) )
def _parse_table(self, schema=False): def _parse_table(self, schema=False):
lateral = self._parse_lateral()
if lateral:
return lateral
unnest = self._parse_unnest() unnest = self._parse_unnest()
if unnest: if unnest:
@ -1172,9 +1161,7 @@ class Parser:
expressions = self._parse_csv(self._parse_column) expressions = self._parse_csv(self._parse_column)
self._match_r_paren() self._match_r_paren()
ordinality = bool( ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY))
self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY)
)
alias = self._parse_table_alias() alias = self._parse_table_alias()
@ -1280,17 +1267,13 @@ class Parser:
if not self._match(TokenType.ORDER_BY): if not self._match(TokenType.ORDER_BY):
return this return this
return self.expression( return self.expression(exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered))
exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)
)
def _parse_sort(self, token_type, exp_class): def _parse_sort(self, token_type, exp_class):
if not self._match(token_type): if not self._match(token_type):
return None return None
return self.expression( return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
exp_class, expressions=self._parse_csv(self._parse_ordered)
)
def _parse_ordered(self): def _parse_ordered(self):
this = self._parse_conjunction() this = self._parse_conjunction()
@ -1305,22 +1288,17 @@ class Parser:
if ( if (
not explicitly_null_ordered not explicitly_null_ordered
and ( and (
(asc and self.null_ordering == "nulls_are_small") (asc and self.null_ordering == "nulls_are_small") or (desc and self.null_ordering != "nulls_are_small")
or (desc and self.null_ordering != "nulls_are_small")
) )
and self.null_ordering != "nulls_are_last" and self.null_ordering != "nulls_are_last"
): ):
nulls_first = True nulls_first = True
return self.expression( return self.expression(exp.Ordered, this=this, desc=desc, nulls_first=nulls_first)
exp.Ordered, this=this, desc=desc, nulls_first=nulls_first
)
def _parse_limit(self, this=None, top=False): def _parse_limit(self, this=None, top=False):
if self._match(TokenType.TOP if top else TokenType.LIMIT): if self._match(TokenType.TOP if top else TokenType.LIMIT):
return self.expression( return self.expression(exp.Limit, this=this, expression=self._parse_number())
exp.Limit, this=this, expression=self._parse_number()
)
if self._match(TokenType.FETCH): if self._match(TokenType.FETCH):
direction = self._match_set((TokenType.FIRST, TokenType.NEXT)) direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
direction = self._prev.text if direction else "FIRST" direction = self._prev.text if direction else "FIRST"
@ -1354,7 +1332,7 @@ class Parser:
expression, expression,
this=this, this=this,
distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL), distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
expression=self._parse_select(), expression=self._parse_select(nested=True),
) )
def _parse_expression(self): def _parse_expression(self):
@ -1396,9 +1374,7 @@ class Parser:
this = self.expression(exp.In, this=this, unnest=unnest) this = self.expression(exp.In, this=this, unnest=unnest)
else: else:
self._match_l_paren() self._match_l_paren()
expressions = self._parse_csv( expressions = self._parse_csv(lambda: self._parse_select() or self._parse_expression())
lambda: self._parse_select() or self._parse_expression()
)
if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable): if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable):
this = self.expression(exp.In, this=this, query=expressions[0]) this = self.expression(exp.In, this=this, query=expressions[0])
@ -1430,13 +1406,9 @@ class Parser:
expression=self._parse_term(), expression=self._parse_term(),
) )
elif self._match_pair(TokenType.LT, TokenType.LT): elif self._match_pair(TokenType.LT, TokenType.LT):
this = self.expression( this = self.expression(exp.BitwiseLeftShift, this=this, expression=self._parse_term())
exp.BitwiseLeftShift, this=this, expression=self._parse_term()
)
elif self._match_pair(TokenType.GT, TokenType.GT): elif self._match_pair(TokenType.GT, TokenType.GT):
this = self.expression( this = self.expression(exp.BitwiseRightShift, this=this, expression=self._parse_term())
exp.BitwiseRightShift, this=this, expression=self._parse_term()
)
else: else:
break break
@ -1524,7 +1496,7 @@ class Parser:
self.raise_error("Expecting >") self.raise_error("Expecting >")
if type_token in self.TIMESTAMPS: if type_token in self.TIMESTAMPS:
tz = self._match(TokenType.WITH_TIME_ZONE) tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ
self._match(TokenType.WITHOUT_TIME_ZONE) self._match(TokenType.WITHOUT_TIME_ZONE)
if tz: if tz:
return exp.DataType( return exp.DataType(
@ -1594,16 +1566,14 @@ class Parser:
if query: if query:
expressions = [query] expressions = [query]
else: else:
expressions = self._parse_csv( expressions = self._parse_csv(lambda: self._parse_alias(self._parse_conjunction(), explicit=True))
lambda: self._parse_alias(self._parse_conjunction(), explicit=True)
)
this = list_get(expressions, 0) this = list_get(expressions, 0)
self._parse_query_modifiers(this) self._parse_query_modifiers(this)
self._match_r_paren() self._match_r_paren()
if isinstance(this, exp.Subqueryable): if isinstance(this, exp.Subqueryable):
return self._parse_subquery(this) return self._parse_set_operations(self._parse_subquery(this))
if len(expressions) > 1: if len(expressions) > 1:
return self.expression(exp.Tuple, expressions=expressions) return self.expression(exp.Tuple, expressions=expressions)
return self.expression(exp.Paren, this=this) return self.expression(exp.Paren, this=this)
@ -1611,11 +1581,7 @@ class Parser:
return None return None
def _parse_field(self, any_token=False): def _parse_field(self, any_token=False):
return ( return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token)
self._parse_primary()
or self._parse_function()
or self._parse_id_var(any_token)
)
def _parse_function(self): def _parse_function(self):
if not self._curr: if not self._curr:
@ -1628,21 +1594,22 @@ class Parser:
if not self._next or self._next.token_type != TokenType.L_PAREN: if not self._next or self._next.token_type != TokenType.L_PAREN:
if token_type in self.NO_PAREN_FUNCTIONS: if token_type in self.NO_PAREN_FUNCTIONS:
return self.expression( return self.expression(self._advance() or self.NO_PAREN_FUNCTIONS[token_type])
self._advance() or self.NO_PAREN_FUNCTIONS[token_type]
)
return None return None
if token_type not in self.FUNC_TOKENS: if token_type not in self.FUNC_TOKENS:
return None return None
if self._match_set(self.FUNCTION_PARSERS): this = self._curr.text
self._advance() upper = this.upper()
this = self.FUNCTION_PARSERS[token_type](self, token_type) self._advance(2)
parser = self.FUNCTION_PARSERS.get(upper)
if parser:
this = parser(self)
else: else:
subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type) subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type)
this = self._curr.text
self._advance(2)
if subquery_predicate and self._curr.token_type in ( if subquery_predicate and self._curr.token_type in (
TokenType.SELECT, TokenType.SELECT,
@ -1652,7 +1619,7 @@ class Parser:
self._match_r_paren() self._match_r_paren()
return this return this
function = self.FUNCTIONS.get(this.upper()) function = self.FUNCTIONS.get(upper)
args = self._parse_csv(self._parse_lambda) args = self._parse_csv(self._parse_lambda)
if function: if function:
@ -1700,10 +1667,7 @@ class Parser:
self._retreat(index) self._retreat(index)
return this return this
args = self._parse_csv( args = self._parse_csv(lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True)))
lambda: self._parse_constraint()
or self._parse_column_def(self._parse_field())
)
self._match_r_paren() self._match_r_paren()
return self.expression(exp.Schema, this=this, expressions=args) return self.expression(exp.Schema, this=this, expressions=args)
@ -1720,12 +1684,9 @@ class Parser:
break break
constraints.append(constraint) constraints.append(constraint)
return self.expression( return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints)
exp.ColumnDef, this=this, kind=kind, constraints=constraints
)
def _parse_column_constraint(self): def _parse_column_constraint(self):
kind = None
this = None this = None
if self._match(TokenType.CONSTRAINT): if self._match(TokenType.CONSTRAINT):
@ -1735,28 +1696,28 @@ class Parser:
kind = exp.AutoIncrementColumnConstraint() kind = exp.AutoIncrementColumnConstraint()
elif self._match(TokenType.CHECK): elif self._match(TokenType.CHECK):
self._match_l_paren() self._match_l_paren()
kind = self.expression( kind = self.expression(exp.CheckColumnConstraint, this=self._parse_conjunction())
exp.CheckColumnConstraint, this=self._parse_conjunction()
)
self._match_r_paren() self._match_r_paren()
elif self._match(TokenType.COLLATE): elif self._match(TokenType.COLLATE):
kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var()) kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var())
elif self._match(TokenType.DEFAULT): elif self._match(TokenType.DEFAULT):
kind = self.expression( kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_field())
exp.DefaultColumnConstraint, this=self._parse_field() elif self._match_pair(TokenType.NOT, TokenType.NULL):
)
elif self._match(TokenType.NOT) and self._match(TokenType.NULL):
kind = exp.NotNullColumnConstraint() kind = exp.NotNullColumnConstraint()
elif self._match(TokenType.SCHEMA_COMMENT): elif self._match(TokenType.SCHEMA_COMMENT):
kind = self.expression( kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string())
exp.CommentColumnConstraint, this=self._parse_string()
)
elif self._match(TokenType.PRIMARY_KEY): elif self._match(TokenType.PRIMARY_KEY):
kind = exp.PrimaryKeyColumnConstraint() kind = exp.PrimaryKeyColumnConstraint()
elif self._match(TokenType.UNIQUE): elif self._match(TokenType.UNIQUE):
kind = exp.UniqueColumnConstraint() kind = exp.UniqueColumnConstraint()
elif self._match(TokenType.GENERATED):
if kind is None: if self._match(TokenType.BY_DEFAULT):
kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=False)
else:
self._match(TokenType.ALWAYS)
kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
self._match_pair(TokenType.ALIAS, TokenType.IDENTITY)
else:
return None return None
return self.expression(exp.ColumnConstraint, this=this, kind=kind) return self.expression(exp.ColumnConstraint, this=this, kind=kind)
@ -1864,9 +1825,7 @@ class Parser:
if not self._match(TokenType.END): if not self._match(TokenType.END):
self.raise_error("Expected END after CASE", self._prev) self.raise_error("Expected END after CASE", self._prev)
return self._parse_window( return self._parse_window(self.expression(exp.Case, this=expression, ifs=ifs, default=default))
self.expression(exp.Case, this=expression, ifs=ifs, default=default)
)
def _parse_if(self): def _parse_if(self):
if self._match(TokenType.L_PAREN): if self._match(TokenType.L_PAREN):
@ -1889,7 +1848,7 @@ class Parser:
if not self._match(TokenType.FROM): if not self._match(TokenType.FROM):
self.raise_error("Expected FROM after EXTRACT", self._prev) self.raise_error("Expected FROM after EXTRACT", self._prev)
return self.expression(exp.Extract, this=this, expression=self._parse_type()) return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
def _parse_cast(self, strict): def _parse_cast(self, strict):
this = self._parse_conjunction() this = self._parse_conjunction()
@ -1917,12 +1876,54 @@ class Parser:
to = None to = None
return self.expression(exp.Cast, this=this, to=to) return self.expression(exp.Cast, this=this, to=to)
def _parse_substring(self):
# Postgres supports the form: substring(string [from int] [for int])
# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
args = self._parse_csv(self._parse_bitwise)
if self._match(TokenType.FROM):
args.append(self._parse_bitwise())
if self._match(TokenType.FOR):
args.append(self._parse_bitwise())
this = exp.Substring.from_arg_list(args)
self.validate_expression(this, args)
return this
def _parse_trim(self):
# https://www.w3resource.com/sql/character-functions/trim.php
# https://docs.oracle.com/javadb/10.8.3.0/ref/rreftrimfunc.html
position = None
collation = None
if self._match_set(self.TRIM_TYPES):
position = self._prev.text.upper()
expression = self._parse_term()
if self._match(TokenType.FROM):
this = self._parse_term()
else:
this = expression
expression = None
if self._match(TokenType.COLLATE):
collation = self._parse_term()
return self.expression(
exp.Trim,
this=this,
position=position,
expression=expression,
collation=collation,
)
def _parse_window(self, this, alias=False): def _parse_window(self, this, alias=False):
if self._match(TokenType.FILTER): if self._match(TokenType.FILTER):
self._match_l_paren() self._match_l_paren()
this = self.expression( this = self.expression(exp.Filter, this=this, expression=self._parse_where())
exp.Filter, this=this, expression=self._parse_where()
)
self._match_r_paren() self._match_r_paren()
if self._match(TokenType.WITHIN_GROUP): if self._match(TokenType.WITHIN_GROUP):
@ -1935,6 +1936,25 @@ class Parser:
self._match_r_paren() self._match_r_paren()
return this return this
# SQL spec defines an optional [ { IGNORE | RESPECT } NULLS ] OVER
# Some dialects choose to implement and some do not.
# https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html
# There is some code above in _parse_lambda that handles
# SELECT FIRST_VALUE(TABLE.COLUMN IGNORE|RESPECT NULLS) OVER ...
# The below changes handle
# SELECT FIRST_VALUE(TABLE.COLUMN) IGNORE|RESPECT NULLS OVER ...
# Oracle allows both formats
# (https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/img_text/first_value.html)
# and Snowflake chose to do the same for familiarity
# https://docs.snowflake.com/en/sql-reference/functions/first_value.html#usage-notes
if self._match(TokenType.IGNORE_NULLS):
this = self.expression(exp.IgnoreNulls, this=this)
elif self._match(TokenType.RESPECT_NULLS):
this = self.expression(exp.RespectNulls, this=this)
# bigquery select from window x AS (partition by ...) # bigquery select from window x AS (partition by ...)
if alias: if alias:
self._match(TokenType.ALIAS) self._match(TokenType.ALIAS)
@ -1992,13 +2012,9 @@ class Parser:
self._match(TokenType.BETWEEN) self._match(TokenType.BETWEEN)
return { return {
"value": ( "value": (self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text)
self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW))
and self._prev.text
)
or self._parse_bitwise(), or self._parse_bitwise(),
"side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text,
and self._prev.text,
} }
def _parse_alias(self, this, explicit=False): def _parse_alias(self, this, explicit=False):
@ -2023,22 +2039,16 @@ class Parser:
return this return this
def _parse_id_var(self, any_token=True): def _parse_id_var(self, any_token=True, tokens=None):
identifier = self._parse_identifier() identifier = self._parse_identifier()
if identifier: if identifier:
return identifier return identifier
if ( if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
any_token
and self._curr
and self._curr.token_type not in self.RESERVED_KEYWORDS
):
return self._advance() or exp.Identifier(this=self._prev.text, quoted=False) return self._advance() or exp.Identifier(this=self._prev.text, quoted=False)
return self._match_set(self.ID_VAR_TOKENS) and exp.Identifier( return self._match_set(tokens or self.ID_VAR_TOKENS) and exp.Identifier(this=self._prev.text, quoted=False)
this=self._prev.text, quoted=False
)
def _parse_string(self): def _parse_string(self):
if self._match(TokenType.STRING): if self._match(TokenType.STRING):
@ -2077,9 +2087,7 @@ class Parser:
def _parse_star(self): def _parse_star(self):
if self._match(TokenType.STAR): if self._match(TokenType.STAR):
return exp.Star( return exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()})
**{"except": self._parse_except(), "replace": self._parse_replace()}
)
return None return None
def _parse_placeholder(self): def _parse_placeholder(self):
@ -2117,15 +2125,10 @@ class Parser:
this = parse() this = parse()
while self._match_set(expressions): while self._match_set(expressions):
this = self.expression( this = self.expression(expressions[self._prev.token_type], this=this, expression=parse())
expressions[self._prev.token_type], this=this, expression=parse()
)
return this return this
def _parse_all(self, parse):
return list(iter(parse, None))
def _parse_wrapped_id_vars(self): def _parse_wrapped_id_vars(self):
self._match_l_paren() self._match_l_paren()
expressions = self._parse_csv(self._parse_id_var) expressions = self._parse_csv(self._parse_id_var)
@ -2156,10 +2159,7 @@ class Parser:
if not self._curr or not self._next: if not self._curr or not self._next:
return None return None
if ( if self._curr.token_type == token_type_a and self._next.token_type == token_type_b:
self._curr.token_type == token_type_a
and self._next.token_type == token_type_b
):
if advance: if advance:
self._advance(2) self._advance(2)
return True return True

View file

@ -72,9 +72,7 @@ class Step:
if from_: if from_:
from_ = from_.expressions from_ = from_.expressions
if len(from_) > 1: if len(from_) > 1:
raise UnsupportedError( raise UnsupportedError("Multi-from statements are unsupported. Run it through the optimizer")
"Multi-from statements are unsupported. Run it through the optimizer"
)
step = Scan.from_expression(from_[0], ctes) step = Scan.from_expression(from_[0], ctes)
else: else:
@ -104,9 +102,7 @@ class Step:
continue continue
if operand not in operands: if operand not in operands:
operands[operand] = f"_a_{next(sequence)}" operands[operand] = f"_a_{next(sequence)}"
operand.replace( operand.replace(exp.column(operands[operand], step.name, quoted=True))
exp.column(operands[operand], step.name, quoted=True)
)
else: else:
projections.append(e) projections.append(e)
@ -121,14 +117,9 @@ class Step:
aggregate = Aggregate() aggregate = Aggregate()
aggregate.source = step.name aggregate.source = step.name
aggregate.name = step.name aggregate.name = step.name
aggregate.operands = tuple( aggregate.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
alias(operand, alias_) for operand, alias_ in operands.items()
)
aggregate.aggregations = aggregations aggregate.aggregations = aggregations
aggregate.group = [ aggregate.group = [exp.column(e.alias_or_name, step.name, quoted=True) for e in group.expressions]
exp.column(e.alias_or_name, step.name, quoted=True)
for e in group.expressions
]
aggregate.add_dependency(step) aggregate.add_dependency(step)
step = aggregate step = aggregate
@ -212,9 +203,7 @@ class Scan(Step):
alias_ = expression.alias alias_ = expression.alias
if not alias_: if not alias_:
raise UnsupportedError( raise UnsupportedError("Tables/Subqueries must be aliased. Run it through the optimizer")
"Tables/Subqueries must be aliased. Run it through the optimizer"
)
if isinstance(expression, exp.Subquery): if isinstance(expression, exp.Subquery):
step = Step.from_expression(table, ctes) step = Step.from_expression(table, ctes)

View file

@ -38,6 +38,7 @@ class TokenType(AutoName):
DARROW = auto() DARROW = auto()
HASH_ARROW = auto() HASH_ARROW = auto()
DHASH_ARROW = auto() DHASH_ARROW = auto()
LR_ARROW = auto()
ANNOTATION = auto() ANNOTATION = auto()
DOLLAR = auto() DOLLAR = auto()
@ -53,6 +54,7 @@ class TokenType(AutoName):
TABLE = auto() TABLE = auto()
VAR = auto() VAR = auto()
BIT_STRING = auto() BIT_STRING = auto()
HEX_STRING = auto()
# types # types
BOOLEAN = auto() BOOLEAN = auto()
@ -78,10 +80,17 @@ class TokenType(AutoName):
UUID = auto() UUID = auto()
GEOGRAPHY = auto() GEOGRAPHY = auto()
NULLABLE = auto() NULLABLE = auto()
GEOMETRY = auto()
HLLSKETCH = auto()
SUPER = auto()
SERIAL = auto()
SMALLSERIAL = auto()
BIGSERIAL = auto()
# keywords # keywords
ADD_FILE = auto() ADD_FILE = auto()
ALIAS = auto() ALIAS = auto()
ALWAYS = auto()
ALL = auto() ALL = auto()
ALTER = auto() ALTER = auto()
ANALYZE = auto() ANALYZE = auto()
@ -92,11 +101,12 @@ class TokenType(AutoName):
AUTO_INCREMENT = auto() AUTO_INCREMENT = auto()
BEGIN = auto() BEGIN = auto()
BETWEEN = auto() BETWEEN = auto()
BOTH = auto()
BUCKET = auto() BUCKET = auto()
BY_DEFAULT = auto()
CACHE = auto() CACHE = auto()
CALL = auto() CALL = auto()
CASE = auto() CASE = auto()
CAST = auto()
CHARACTER_SET = auto() CHARACTER_SET = auto()
CHECK = auto() CHECK = auto()
CLUSTER_BY = auto() CLUSTER_BY = auto()
@ -104,7 +114,6 @@ class TokenType(AutoName):
COMMENT = auto() COMMENT = auto()
COMMIT = auto() COMMIT = auto()
CONSTRAINT = auto() CONSTRAINT = auto()
CONVERT = auto()
CREATE = auto() CREATE = auto()
CROSS = auto() CROSS = auto()
CUBE = auto() CUBE = auto()
@ -127,22 +136,24 @@ class TokenType(AutoName):
EXCEPT = auto() EXCEPT = auto()
EXISTS = auto() EXISTS = auto()
EXPLAIN = auto() EXPLAIN = auto()
EXTRACT = auto()
FALSE = auto() FALSE = auto()
FETCH = auto() FETCH = auto()
FILTER = auto() FILTER = auto()
FINAL = auto() FINAL = auto()
FIRST = auto() FIRST = auto()
FOLLOWING = auto() FOLLOWING = auto()
FOR = auto()
FOREIGN_KEY = auto() FOREIGN_KEY = auto()
FORMAT = auto() FORMAT = auto()
FULL = auto() FULL = auto()
FUNCTION = auto() FUNCTION = auto()
FROM = auto() FROM = auto()
GENERATED = auto()
GROUP_BY = auto() GROUP_BY = auto()
GROUPING_SETS = auto() GROUPING_SETS = auto()
HAVING = auto() HAVING = auto()
HINT = auto() HINT = auto()
IDENTITY = auto()
IF = auto() IF = auto()
IGNORE_NULLS = auto() IGNORE_NULLS = auto()
ILIKE = auto() ILIKE = auto()
@ -159,12 +170,14 @@ class TokenType(AutoName):
JOIN = auto() JOIN = auto()
LATERAL = auto() LATERAL = auto()
LAZY = auto() LAZY = auto()
LEADING = auto()
LEFT = auto() LEFT = auto()
LIKE = auto() LIKE = auto()
LIMIT = auto() LIMIT = auto()
LOCATION = auto() LOCATION = auto()
MAP = auto() MAP = auto()
MOD = auto() MOD = auto()
NATURAL = auto()
NEXT = auto() NEXT = auto()
NO_ACTION = auto() NO_ACTION = auto()
NULL = auto() NULL = auto()
@ -204,8 +217,10 @@ class TokenType(AutoName):
ROWS = auto() ROWS = auto()
SCHEMA_COMMENT = auto() SCHEMA_COMMENT = auto()
SELECT = auto() SELECT = auto()
SEPARATOR = auto()
SET = auto() SET = auto()
SHOW = auto() SHOW = auto()
SIMILAR_TO = auto()
SOME = auto() SOME = auto()
SORT_BY = auto() SORT_BY = auto()
STORED = auto() STORED = auto()
@ -213,12 +228,11 @@ class TokenType(AutoName):
TABLE_FORMAT = auto() TABLE_FORMAT = auto()
TABLE_SAMPLE = auto() TABLE_SAMPLE = auto()
TEMPORARY = auto() TEMPORARY = auto()
TIME = auto()
TOP = auto() TOP = auto()
THEN = auto() THEN = auto()
TRUE = auto() TRUE = auto()
TRAILING = auto()
TRUNCATE = auto() TRUNCATE = auto()
TRY_CAST = auto()
UNBOUNDED = auto() UNBOUNDED = auto()
UNCACHE = auto() UNCACHE = auto()
UNION = auto() UNION = auto()
@ -272,35 +286,32 @@ class _Tokenizer(type):
def __new__(cls, clsname, bases, attrs): def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs) klass = super().__new__(cls, clsname, bases, attrs)
klass.QUOTES = dict( klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES)
(quote, quote) if isinstance(quote, str) else (quote[0], quote[1]) klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS)
for quote in klass.QUOTES klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
) klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
klass._COMMENTS = dict(
klass.IDENTIFIERS = dict( (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) for comment in klass.COMMENTS
(identifier, identifier)
if isinstance(identifier, str)
else (identifier[0], identifier[1])
for identifier in klass.IDENTIFIERS
)
klass.COMMENTS = dict(
(comment, None) if isinstance(comment, str) else (comment[0], comment[1])
for comment in klass.COMMENTS
) )
klass.KEYWORD_TRIE = new_trie( klass.KEYWORD_TRIE = new_trie(
key.upper() key.upper()
for key, value in { for key, value in {
**klass.KEYWORDS, **klass.KEYWORDS,
**{comment: TokenType.COMMENT for comment in klass.COMMENTS}, **{comment: TokenType.COMMENT for comment in klass._COMMENTS},
**{quote: TokenType.QUOTE for quote in klass.QUOTES}, **{quote: TokenType.QUOTE for quote in klass._QUOTES},
**{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS},
**{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS},
}.items() }.items()
if " " in key or any(single in key for single in klass.SINGLE_TOKENS) if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
) )
return klass return klass
@staticmethod
def _delimeter_list_to_dict(list):
return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list)
class Tokenizer(metaclass=_Tokenizer): class Tokenizer(metaclass=_Tokenizer):
SINGLE_TOKENS = { SINGLE_TOKENS = {
@ -339,6 +350,10 @@ class Tokenizer(metaclass=_Tokenizer):
QUOTES = ["'"] QUOTES = ["'"]
BIT_STRINGS = []
HEX_STRINGS = []
IDENTIFIERS = ['"'] IDENTIFIERS = ['"']
ESCAPE = "'" ESCAPE = "'"
@ -357,6 +372,7 @@ class Tokenizer(metaclass=_Tokenizer):
"->>": TokenType.DARROW, "->>": TokenType.DARROW,
"#>": TokenType.HASH_ARROW, "#>": TokenType.HASH_ARROW,
"#>>": TokenType.DHASH_ARROW, "#>>": TokenType.DHASH_ARROW,
"<->": TokenType.LR_ARROW,
"ADD ARCHIVE": TokenType.ADD_FILE, "ADD ARCHIVE": TokenType.ADD_FILE,
"ADD ARCHIVES": TokenType.ADD_FILE, "ADD ARCHIVES": TokenType.ADD_FILE,
"ADD FILE": TokenType.ADD_FILE, "ADD FILE": TokenType.ADD_FILE,
@ -374,12 +390,12 @@ class Tokenizer(metaclass=_Tokenizer):
"AUTO_INCREMENT": TokenType.AUTO_INCREMENT, "AUTO_INCREMENT": TokenType.AUTO_INCREMENT,
"BEGIN": TokenType.BEGIN, "BEGIN": TokenType.BEGIN,
"BETWEEN": TokenType.BETWEEN, "BETWEEN": TokenType.BETWEEN,
"BOTH": TokenType.BOTH,
"BUCKET": TokenType.BUCKET, "BUCKET": TokenType.BUCKET,
"CALL": TokenType.CALL, "CALL": TokenType.CALL,
"CACHE": TokenType.CACHE, "CACHE": TokenType.CACHE,
"UNCACHE": TokenType.UNCACHE, "UNCACHE": TokenType.UNCACHE,
"CASE": TokenType.CASE, "CASE": TokenType.CASE,
"CAST": TokenType.CAST,
"CHARACTER SET": TokenType.CHARACTER_SET, "CHARACTER SET": TokenType.CHARACTER_SET,
"CHECK": TokenType.CHECK, "CHECK": TokenType.CHECK,
"CLUSTER BY": TokenType.CLUSTER_BY, "CLUSTER BY": TokenType.CLUSTER_BY,
@ -387,7 +403,6 @@ class Tokenizer(metaclass=_Tokenizer):
"COMMENT": TokenType.SCHEMA_COMMENT, "COMMENT": TokenType.SCHEMA_COMMENT,
"COMMIT": TokenType.COMMIT, "COMMIT": TokenType.COMMIT,
"CONSTRAINT": TokenType.CONSTRAINT, "CONSTRAINT": TokenType.CONSTRAINT,
"CONVERT": TokenType.CONVERT,
"CREATE": TokenType.CREATE, "CREATE": TokenType.CREATE,
"CROSS": TokenType.CROSS, "CROSS": TokenType.CROSS,
"CUBE": TokenType.CUBE, "CUBE": TokenType.CUBE,
@ -408,7 +423,6 @@ class Tokenizer(metaclass=_Tokenizer):
"EXCEPT": TokenType.EXCEPT, "EXCEPT": TokenType.EXCEPT,
"EXISTS": TokenType.EXISTS, "EXISTS": TokenType.EXISTS,
"EXPLAIN": TokenType.EXPLAIN, "EXPLAIN": TokenType.EXPLAIN,
"EXTRACT": TokenType.EXTRACT,
"FALSE": TokenType.FALSE, "FALSE": TokenType.FALSE,
"FETCH": TokenType.FETCH, "FETCH": TokenType.FETCH,
"FILTER": TokenType.FILTER, "FILTER": TokenType.FILTER,
@ -437,10 +451,12 @@ class Tokenizer(metaclass=_Tokenizer):
"JOIN": TokenType.JOIN, "JOIN": TokenType.JOIN,
"LATERAL": TokenType.LATERAL, "LATERAL": TokenType.LATERAL,
"LAZY": TokenType.LAZY, "LAZY": TokenType.LAZY,
"LEADING": TokenType.LEADING,
"LEFT": TokenType.LEFT, "LEFT": TokenType.LEFT,
"LIKE": TokenType.LIKE, "LIKE": TokenType.LIKE,
"LIMIT": TokenType.LIMIT, "LIMIT": TokenType.LIMIT,
"LOCATION": TokenType.LOCATION, "LOCATION": TokenType.LOCATION,
"NATURAL": TokenType.NATURAL,
"NEXT": TokenType.NEXT, "NEXT": TokenType.NEXT,
"NO ACTION": TokenType.NO_ACTION, "NO ACTION": TokenType.NO_ACTION,
"NOT": TokenType.NOT, "NOT": TokenType.NOT,
@ -490,8 +506,8 @@ class Tokenizer(metaclass=_Tokenizer):
"TEMPORARY": TokenType.TEMPORARY, "TEMPORARY": TokenType.TEMPORARY,
"THEN": TokenType.THEN, "THEN": TokenType.THEN,
"TRUE": TokenType.TRUE, "TRUE": TokenType.TRUE,
"TRAILING": TokenType.TRAILING,
"TRUNCATE": TokenType.TRUNCATE, "TRUNCATE": TokenType.TRUNCATE,
"TRY_CAST": TokenType.TRY_CAST,
"UNBOUNDED": TokenType.UNBOUNDED, "UNBOUNDED": TokenType.UNBOUNDED,
"UNION": TokenType.UNION, "UNION": TokenType.UNION,
"UNNEST": TokenType.UNNEST, "UNNEST": TokenType.UNNEST,
@ -626,14 +642,12 @@ class Tokenizer(metaclass=_Tokenizer):
break break
white_space = self.WHITE_SPACE.get(self._char) white_space = self.WHITE_SPACE.get(self._char)
identifier_end = self.IDENTIFIERS.get(self._char) identifier_end = self._IDENTIFIERS.get(self._char)
if white_space: if white_space:
if white_space == TokenType.BREAK: if white_space == TokenType.BREAK:
self._col = 1 self._col = 1
self._line += 1 self._line += 1
elif self._char == "0" and self._peek == "x":
self._scan_hex()
elif self._char.isdigit(): elif self._char.isdigit():
self._scan_number() self._scan_number()
elif identifier_end: elif identifier_end:
@ -666,9 +680,7 @@ class Tokenizer(metaclass=_Tokenizer):
text = self._text if text is None else text text = self._text if text is None else text
self.tokens.append(Token(token_type, text, self._line, self._col)) self.tokens.append(Token(token_type, text, self._line, self._col))
if token_type in self.COMMANDS and ( if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON):
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
):
self._start = self._current self._start = self._current
while not self._end and self._peek != ";": while not self._end and self._peek != ";":
self._advance() self._advance()
@ -725,6 +737,8 @@ class Tokenizer(metaclass=_Tokenizer):
if self._scan_string(word): if self._scan_string(word):
return return
if self._scan_numeric_string(word):
return
if self._scan_comment(word): if self._scan_comment(word):
return return
@ -732,10 +746,10 @@ class Tokenizer(metaclass=_Tokenizer):
self._add(self.KEYWORDS[word.upper()]) self._add(self.KEYWORDS[word.upper()])
def _scan_comment(self, comment_start): def _scan_comment(self, comment_start):
if comment_start not in self.COMMENTS: if comment_start not in self._COMMENTS:
return False return False
comment_end = self.COMMENTS[comment_start] comment_end = self._COMMENTS[comment_start]
if comment_end: if comment_end:
comment_end_size = len(comment_end) comment_end_size = len(comment_end)
@ -749,15 +763,18 @@ class Tokenizer(metaclass=_Tokenizer):
return True return True
def _scan_annotation(self): def _scan_annotation(self):
while ( while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK and self._peek != ",":
not self._end
and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK
and self._peek != ","
):
self._advance() self._advance()
self._add(TokenType.ANNOTATION, self._text[1:]) self._add(TokenType.ANNOTATION, self._text[1:])
def _scan_number(self): def _scan_number(self):
if self._char == "0":
peek = self._peek.upper()
if peek == "B":
return self._scan_bits()
elif peek == "X":
return self._scan_hex()
decimal = False decimal = False
scientific = 0 scientific = 0
@ -788,57 +805,71 @@ class Tokenizer(metaclass=_Tokenizer):
else: else:
return self._add(TokenType.NUMBER) return self._add(TokenType.NUMBER)
def _scan_bits(self):
self._advance()
value = self._extract_value()
try:
self._add(TokenType.BIT_STRING, f"{int(value, 2)}")
except ValueError:
self._add(TokenType.IDENTIFIER)
def _scan_hex(self): def _scan_hex(self):
self._advance() self._advance()
value = self._extract_value()
try:
self._add(TokenType.HEX_STRING, f"{int(value, 16)}")
except ValueError:
self._add(TokenType.IDENTIFIER)
def _extract_value(self):
while True: while True:
char = self._peek.strip() char = self._peek.strip()
if char and char not in self.SINGLE_TOKENS: if char and char not in self.SINGLE_TOKENS:
self._advance() self._advance()
else: else:
break break
try:
self._add(TokenType.BIT_STRING, f"{int(self._text, 16):b}") return self._text
except ValueError:
self._add(TokenType.IDENTIFIER)
def _scan_string(self, quote): def _scan_string(self, quote):
quote_end = self.QUOTES.get(quote) quote_end = self._QUOTES.get(quote)
if quote_end is None: if quote_end is None:
return False return False
text = ""
self._advance(len(quote)) self._advance(len(quote))
quote_end_size = len(quote_end) text = self._extract_string(quote_end)
while True:
if self._char == self.ESCAPE and self._peek == quote_end:
text += quote
self._advance(2)
else:
if self._chars(quote_end_size) == quote_end:
if quote_end_size > 1:
self._advance(quote_end_size - 1)
break
if self._end:
raise RuntimeError(
f"Missing {quote} from {self._line}:{self._start}"
)
text += self._char
self._advance()
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text
self._add(TokenType.STRING, text) self._add(TokenType.STRING, text)
return True return True
def _scan_numeric_string(self, string_start):
if string_start in self._HEX_STRINGS:
delimiters = self._HEX_STRINGS
token_type = TokenType.HEX_STRING
base = 16
elif string_start in self._BIT_STRINGS:
delimiters = self._BIT_STRINGS
token_type = TokenType.BIT_STRING
base = 2
else:
return False
self._advance(len(string_start))
string_end = delimiters.get(string_start)
text = self._extract_string(string_end)
try:
self._add(token_type, f"{int(text, base)}")
except ValueError:
raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}")
return True
def _scan_identifier(self, identifier_end): def _scan_identifier(self, identifier_end):
while self._peek != identifier_end: while self._peek != identifier_end:
if self._end: if self._end:
raise RuntimeError( raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}")
f"Missing {identifier_end} from {self._line}:{self._start}"
)
self._advance() self._advance()
self._advance() self._advance()
self._add(TokenType.IDENTIFIER, self._text[1:-1]) self._add(TokenType.IDENTIFIER, self._text[1:-1])
@ -851,3 +882,24 @@ class Tokenizer(metaclass=_Tokenizer):
else: else:
break break
self._add(self.KEYWORDS.get(self._text.upper(), TokenType.VAR)) self._add(self.KEYWORDS.get(self._text.upper(), TokenType.VAR))
def _extract_string(self, delimiter):
text = ""
delim_size = len(delimiter)
while True:
if self._char == self.ESCAPE and self._peek == delimiter:
text += delimiter
self._advance(2)
else:
if self._chars(delim_size) == delimiter:
if delim_size > 1:
self._advance(delim_size - 1)
break
if self._end:
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
text += self._char
self._advance()
return text

View file

@ -12,9 +12,7 @@ def unalias_group(expression):
""" """
if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
aliased_selects = { aliased_selects = {
e.alias: i e.alias: i for i, e in enumerate(expression.parent.expressions, start=1) if isinstance(e, exp.Alias)
for i, e in enumerate(expression.parent.expressions, start=1)
if isinstance(e, exp.Alias)
} }
expression = expression.copy() expression = expression.copy()

View file

@ -36,9 +36,7 @@ class Validator(unittest.TestCase):
for read_dialect, read_sql in (read or {}).items(): for read_dialect, read_sql in (read or {}).items():
with self.subTest(f"{read_dialect} -> {sql}"): with self.subTest(f"{read_dialect} -> {sql}"):
self.assertEqual( self.assertEqual(
parse_one(read_sql, read_dialect).sql( parse_one(read_sql, read_dialect).sql(self.dialect, unsupported_level=ErrorLevel.IGNORE),
self.dialect, unsupported_level=ErrorLevel.IGNORE
),
sql, sql,
) )
@ -46,9 +44,7 @@ class Validator(unittest.TestCase):
with self.subTest(f"{sql} -> {write_dialect}"): with self.subTest(f"{sql} -> {write_dialect}"):
if write_sql is UnsupportedError: if write_sql is UnsupportedError:
with self.assertRaises(UnsupportedError): with self.assertRaises(UnsupportedError):
expression.sql( expression.sql(write_dialect, unsupported_level=ErrorLevel.RAISE)
write_dialect, unsupported_level=ErrorLevel.RAISE
)
else: else:
self.assertEqual( self.assertEqual(
expression.sql( expression.sql(
@ -82,11 +78,19 @@ class TestDialect(Validator):
"oracle": "CAST(a AS CLOB)", "oracle": "CAST(a AS CLOB)",
"postgres": "CAST(a AS TEXT)", "postgres": "CAST(a AS TEXT)",
"presto": "CAST(a AS VARCHAR)", "presto": "CAST(a AS VARCHAR)",
"redshift": "CAST(a AS TEXT)",
"snowflake": "CAST(a AS TEXT)", "snowflake": "CAST(a AS TEXT)",
"spark": "CAST(a AS STRING)", "spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS STRING)", "starrocks": "CAST(a AS STRING)",
}, },
) )
self.validate_all(
"CAST(a AS DATETIME)",
write={
"postgres": "CAST(a AS TIMESTAMP)",
"sqlite": "CAST(a AS DATETIME)",
},
)
self.validate_all( self.validate_all(
"CAST(a AS STRING)", "CAST(a AS STRING)",
write={ write={
@ -97,6 +101,7 @@ class TestDialect(Validator):
"oracle": "CAST(a AS CLOB)", "oracle": "CAST(a AS CLOB)",
"postgres": "CAST(a AS TEXT)", "postgres": "CAST(a AS TEXT)",
"presto": "CAST(a AS VARCHAR)", "presto": "CAST(a AS VARCHAR)",
"redshift": "CAST(a AS TEXT)",
"snowflake": "CAST(a AS TEXT)", "snowflake": "CAST(a AS TEXT)",
"spark": "CAST(a AS STRING)", "spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS STRING)", "starrocks": "CAST(a AS STRING)",
@ -112,6 +117,7 @@ class TestDialect(Validator):
"oracle": "CAST(a AS VARCHAR2)", "oracle": "CAST(a AS VARCHAR2)",
"postgres": "CAST(a AS VARCHAR)", "postgres": "CAST(a AS VARCHAR)",
"presto": "CAST(a AS VARCHAR)", "presto": "CAST(a AS VARCHAR)",
"redshift": "CAST(a AS VARCHAR)",
"snowflake": "CAST(a AS VARCHAR)", "snowflake": "CAST(a AS VARCHAR)",
"spark": "CAST(a AS STRING)", "spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS VARCHAR)", "starrocks": "CAST(a AS VARCHAR)",
@ -127,6 +133,7 @@ class TestDialect(Validator):
"oracle": "CAST(a AS VARCHAR2(3))", "oracle": "CAST(a AS VARCHAR2(3))",
"postgres": "CAST(a AS VARCHAR(3))", "postgres": "CAST(a AS VARCHAR(3))",
"presto": "CAST(a AS VARCHAR(3))", "presto": "CAST(a AS VARCHAR(3))",
"redshift": "CAST(a AS VARCHAR(3))",
"snowflake": "CAST(a AS VARCHAR(3))", "snowflake": "CAST(a AS VARCHAR(3))",
"spark": "CAST(a AS VARCHAR(3))", "spark": "CAST(a AS VARCHAR(3))",
"starrocks": "CAST(a AS VARCHAR(3))", "starrocks": "CAST(a AS VARCHAR(3))",
@ -142,12 +149,26 @@ class TestDialect(Validator):
"oracle": "CAST(a AS NUMBER)", "oracle": "CAST(a AS NUMBER)",
"postgres": "CAST(a AS SMALLINT)", "postgres": "CAST(a AS SMALLINT)",
"presto": "CAST(a AS SMALLINT)", "presto": "CAST(a AS SMALLINT)",
"redshift": "CAST(a AS SMALLINT)",
"snowflake": "CAST(a AS SMALLINT)", "snowflake": "CAST(a AS SMALLINT)",
"spark": "CAST(a AS SHORT)", "spark": "CAST(a AS SHORT)",
"sqlite": "CAST(a AS INTEGER)", "sqlite": "CAST(a AS INTEGER)",
"starrocks": "CAST(a AS SMALLINT)", "starrocks": "CAST(a AS SMALLINT)",
}, },
) )
self.validate_all(
"TRY_CAST(a AS DOUBLE)",
read={
"postgres": "CAST(a AS DOUBLE PRECISION)",
"redshift": "CAST(a AS DOUBLE PRECISION)",
},
write={
"duckdb": "TRY_CAST(a AS DOUBLE)",
"postgres": "CAST(a AS DOUBLE PRECISION)",
"redshift": "CAST(a AS DOUBLE PRECISION)",
},
)
self.validate_all( self.validate_all(
"CAST(a AS DOUBLE)", "CAST(a AS DOUBLE)",
write={ write={
@ -159,16 +180,32 @@ class TestDialect(Validator):
"oracle": "CAST(a AS DOUBLE PRECISION)", "oracle": "CAST(a AS DOUBLE PRECISION)",
"postgres": "CAST(a AS DOUBLE PRECISION)", "postgres": "CAST(a AS DOUBLE PRECISION)",
"presto": "CAST(a AS DOUBLE)", "presto": "CAST(a AS DOUBLE)",
"redshift": "CAST(a AS DOUBLE PRECISION)",
"snowflake": "CAST(a AS DOUBLE)", "snowflake": "CAST(a AS DOUBLE)",
"spark": "CAST(a AS DOUBLE)", "spark": "CAST(a AS DOUBLE)",
"starrocks": "CAST(a AS DOUBLE)", "starrocks": "CAST(a AS DOUBLE)",
}, },
) )
self.validate_all( self.validate_all(
"CAST(a AS TIMESTAMP)", write={"starrocks": "CAST(a AS DATETIME)"} "CAST('1 DAY' AS INTERVAL)",
write={
"postgres": "CAST('1 DAY' AS INTERVAL)",
"redshift": "CAST('1 DAY' AS INTERVAL)",
},
) )
self.validate_all( self.validate_all(
"CAST(a AS TIMESTAMPTZ)", write={"starrocks": "CAST(a AS DATETIME)"} "CAST(a AS TIMESTAMP)",
write={
"starrocks": "CAST(a AS DATETIME)",
"redshift": "CAST(a AS TIMESTAMP)",
},
)
self.validate_all(
"CAST(a AS TIMESTAMPTZ)",
write={
"starrocks": "CAST(a AS DATETIME)",
"redshift": "CAST(a AS TIMESTAMPTZ)",
},
) )
self.validate_all("CAST(a AS TINYINT)", write={"oracle": "CAST(a AS NUMBER)"}) self.validate_all("CAST(a AS TINYINT)", write={"oracle": "CAST(a AS NUMBER)"})
self.validate_all("CAST(a AS SMALLINT)", write={"oracle": "CAST(a AS NUMBER)"}) self.validate_all("CAST(a AS SMALLINT)", write={"oracle": "CAST(a AS NUMBER)"})
@ -552,6 +589,7 @@ class TestDialect(Validator):
write={ write={
"bigquery": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", "bigquery": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"oracle": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST",
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
@ -566,6 +604,7 @@ class TestDialect(Validator):
"presto": "JSON_EXTRACT(x, 'y')", "presto": "JSON_EXTRACT(x, 'y')",
}, },
write={ write={
"oracle": "JSON_EXTRACT(x, 'y')",
"postgres": "x->'y'", "postgres": "x->'y'",
"presto": "JSON_EXTRACT(x, 'y')", "presto": "JSON_EXTRACT(x, 'y')",
}, },
@ -623,6 +662,37 @@ class TestDialect(Validator):
}, },
) )
# https://dev.mysql.com/doc/refman/8.0/en/join.html
# https://www.postgresql.org/docs/current/queries-table-expressions.html
def test_joined_tables(self):
self.validate_identity("SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)")
self.validate_identity("SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3)")
self.validate_identity("SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)")
self.validate_identity("SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)")
self.validate_all(
"SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)",
write={
"postgres": "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)",
"mysql": "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)",
},
)
self.validate_all(
"SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)",
write={
"postgres": "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)",
"mysql": "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)",
},
)
def test_lateral_subquery(self):
self.validate_identity(
"SELECT art FROM tbl1 INNER JOIN LATERAL (SELECT art FROM tbl2) AS tbl2 ON tbl1.art = tbl2.art"
)
self.validate_identity(
"SELECT * FROM tbl AS t LEFT JOIN LATERAL (SELECT * FROM b WHERE b.t_id = t.t_id) AS t ON TRUE"
)
def test_set_operators(self): def test_set_operators(self):
self.validate_all( self.validate_all(
"SELECT * FROM a UNION SELECT * FROM b", "SELECT * FROM a UNION SELECT * FROM b",
@ -731,6 +801,9 @@ class TestDialect(Validator):
) )
def test_operators(self): def test_operators(self):
self.validate_identity("some.column LIKE 'foo' || another.column || 'bar' || LOWER(x)")
self.validate_identity("some.column LIKE 'foo' + another.column + 'bar'")
self.validate_all( self.validate_all(
"x ILIKE '%y'", "x ILIKE '%y'",
read={ read={
@ -874,16 +947,8 @@ class TestDialect(Validator):
"spark": "FILTER(the_array, x -> x > 0)", "spark": "FILTER(the_array, x -> x > 0)",
}, },
) )
self.validate_all(
"SELECT a AS b FROM x GROUP BY b", def test_limit(self):
write={
"duckdb": "SELECT a AS b FROM x GROUP BY b",
"presto": "SELECT a AS b FROM x GROUP BY 1",
"hive": "SELECT a AS b FROM x GROUP BY 1",
"oracle": "SELECT a AS b FROM x GROUP BY 1",
"spark": "SELECT a AS b FROM x GROUP BY 1",
},
)
self.validate_all( self.validate_all(
"SELECT x FROM y LIMIT 10", "SELECT x FROM y LIMIT 10",
write={ write={
@ -915,6 +980,7 @@ class TestDialect(Validator):
read={ read={
"clickhouse": '`x` + "y"', "clickhouse": '`x` + "y"',
"sqlite": '`x` + "y"', "sqlite": '`x` + "y"',
"redshift": '"x" + "y"',
}, },
) )
self.validate_all( self.validate_all(
@ -977,5 +1043,36 @@ class TestDialect(Validator):
"oracle": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 CLOB, c2 CLOB(1024))", "oracle": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 CLOB, c2 CLOB(1024))",
"postgres": "CREATE TABLE t (b1 BYTEA, b2 BYTEA(1024), c1 TEXT, c2 TEXT(1024))", "postgres": "CREATE TABLE t (b1 BYTEA, b2 BYTEA(1024), c1 TEXT, c2 TEXT(1024))",
"sqlite": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))", "sqlite": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))",
"redshift": "CREATE TABLE t (b1 VARBYTE, b2 VARBYTE(1024), c1 TEXT, c2 TEXT(1024))",
},
)
def test_alias(self):
self.validate_all(
"SELECT a AS b FROM x GROUP BY b",
write={
"duckdb": "SELECT a AS b FROM x GROUP BY b",
"presto": "SELECT a AS b FROM x GROUP BY 1",
"hive": "SELECT a AS b FROM x GROUP BY 1",
"oracle": "SELECT a AS b FROM x GROUP BY 1",
"spark": "SELECT a AS b FROM x GROUP BY 1",
},
)
self.validate_all(
"SELECT y x FROM my_table t",
write={
"hive": "SELECT y AS x FROM my_table AS t",
"oracle": "SELECT y AS x FROM my_table t",
"postgres": "SELECT y AS x FROM my_table AS t",
"sqlite": "SELECT y AS x FROM my_table AS t",
},
)
self.validate_all(
"WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
write={
"hive": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
"oracle": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 t JOIN cte2 WHERE cte1.a = cte2.c",
"postgres": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
"sqlite": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
}, },
) )

View file

@ -341,6 +341,21 @@ class TestHive(Validator):
"spark": "PERCENTILE(x, 0.5)", "spark": "PERCENTILE(x, 0.5)",
}, },
) )
self.validate_all(
"PERCENTILE_APPROX(x, 0.5)",
read={
"hive": "PERCENTILE_APPROX(x, 0.5)",
"presto": "APPROX_PERCENTILE(x, 0.5)",
"duckdb": "APPROX_QUANTILE(x, 0.5)",
"spark": "PERCENTILE_APPROX(x, 0.5)",
},
write={
"hive": "PERCENTILE_APPROX(x, 0.5)",
"presto": "APPROX_PERCENTILE(x, 0.5)",
"duckdb": "APPROX_QUANTILE(x, 0.5)",
"spark": "PERCENTILE_APPROX(x, 0.5)",
},
)
self.validate_all( self.validate_all(
"APPROX_COUNT_DISTINCT(a)", "APPROX_COUNT_DISTINCT(a)",
write={ write={

View file

@ -15,6 +15,10 @@ class TestMySQL(Validator):
def test_identity(self): def test_identity(self):
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo") self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ')")
self.validate_identity("SELECT TRIM(TRAILING 'bla' FROM ' XXX ')")
self.validate_identity("SELECT TRIM(BOTH 'bla' FROM ' XXX ')")
self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')")
def test_introducers(self): def test_introducers(self):
self.validate_all( self.validate_all(
@ -27,12 +31,22 @@ class TestMySQL(Validator):
}, },
) )
def test_binary_literal(self): def test_hexadecimal_literal(self):
self.validate_all( self.validate_all(
"SELECT 0xCC", "SELECT 0xCC",
write={ write={
"mysql": "SELECT b'11001100'", "mysql": "SELECT x'CC'",
"spark": "SELECT X'11001100'", "sqlite": "SELECT x'CC'",
"spark": "SELECT X'CC'",
"trino": "SELECT X'CC'",
"bigquery": "SELECT 0xCC",
"oracle": "SELECT 204",
},
)
self.validate_all(
"SELECT X'1A'",
write={
"mysql": "SELECT x'1A'",
}, },
) )
self.validate_all( self.validate_all(
@ -41,10 +55,22 @@ class TestMySQL(Validator):
"mysql": "SELECT `0xz`", "mysql": "SELECT `0xz`",
}, },
) )
def test_bits_literal(self):
self.validate_all( self.validate_all(
"SELECT 0XCC", "SELECT 0b1011",
write={ write={
"mysql": "SELECT 0 AS XCC", "mysql": "SELECT b'1011'",
"postgres": "SELECT b'1011'",
"oracle": "SELECT 11",
},
)
self.validate_all(
"SELECT B'1011'",
write={
"mysql": "SELECT b'1011'",
"postgres": "SELECT b'1011'",
"oracle": "SELECT 11",
}, },
) )
@ -77,3 +103,19 @@ class TestMySQL(Validator):
"mysql": "SELECT 1", "mysql": "SELECT 1",
}, },
) )
def test_mysql(self):
self.validate_all(
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
write={
"mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR ',')",
"sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
},
)
self.validate_all(
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')",
write={
"mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')",
"sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC, '')",
},
)

View file

@ -8,9 +8,7 @@ class TestPostgres(Validator):
def test_ddl(self): def test_ddl(self):
self.validate_all( self.validate_all(
"CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)", "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)",
write={ write={"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"},
"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"
},
) )
self.validate_all( self.validate_all(
"CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)", "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)",
@ -42,11 +40,17 @@ class TestPostgres(Validator):
" CONSTRAINT valid_discount CHECK (price > discounted_price))" " CONSTRAINT valid_discount CHECK (price > discounted_price))"
}, },
) )
self.validate_all(
"CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)",
write={"postgres": "CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)"},
)
self.validate_all(
"CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)",
write={"postgres": "CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)"},
)
with self.assertRaises(ParseError): with self.assertRaises(ParseError):
transpile( transpile("CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres")
"CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres"
)
with self.assertRaises(ParseError): with self.assertRaises(ParseError):
transpile( transpile(
"CREATE TABLE products (price DECIMAL, CHECK price > 1)", "CREATE TABLE products (price DECIMAL, CHECK price > 1)",
@ -54,11 +58,16 @@ class TestPostgres(Validator):
) )
def test_postgres(self): def test_postgres(self):
self.validate_all( self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END")
"CREATE TABLE x (a INT SERIAL)", self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END")
read={"sqlite": "CREATE TABLE x (a INTEGER AUTOINCREMENT)"}, self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END")
write={"sqlite": "CREATE TABLE x (a INTEGER AUTOINCREMENT)"}, self.validate_identity('SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')')
) self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')")
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')")
self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))")
self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')")
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
self.validate_all( self.validate_all(
"CREATE TABLE x (a UUID, b BYTEA)", "CREATE TABLE x (a UUID, b BYTEA)",
write={ write={
@ -91,3 +100,65 @@ class TestPostgres(Validator):
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
}, },
) )
self.validate_all(
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END",
write={
"hive": "SELECT CASE WHEN SUBSTRING('abcdefg', 1, 2) IN ('ab') THEN 1 ELSE 0 END",
"spark": "SELECT CASE WHEN SUBSTRING('abcdefg', 1, 2) IN ('ab') THEN 1 ELSE 0 END",
},
)
self.validate_all(
"SELECT * FROM x WHERE SUBSTRING(col1 FROM 3 + LENGTH(col1) - 10 FOR 10) IN (col2)",
write={
"hive": "SELECT * FROM x WHERE SUBSTRING(col1, 3 + LENGTH(col1) - 10, 10) IN (col2)",
"spark": "SELECT * FROM x WHERE SUBSTRING(col1, 3 + LENGTH(col1) - 10, 10) IN (col2)",
},
)
self.validate_all(
"SELECT SUBSTRING(CAST(2022 AS CHAR(4)) || LPAD(CAST(3 AS CHAR(2)), 2, '0') FROM 3 FOR 4)",
read={
"postgres": "SELECT SUBSTRING(2022::CHAR(4) || LPAD(3::CHAR(2), 2, '0') FROM 3 FOR 4)",
},
)
self.validate_all(
"SELECT TRIM(BOTH ' XXX ')",
write={
"mysql": "SELECT TRIM(' XXX ')",
"postgres": "SELECT TRIM(' XXX ')",
"hive": "SELECT TRIM(' XXX ')",
},
)
self.validate_all(
"TRIM(LEADING FROM ' XXX ')",
write={
"mysql": "LTRIM(' XXX ')",
"postgres": "LTRIM(' XXX ')",
"hive": "LTRIM(' XXX ')",
"presto": "LTRIM(' XXX ')",
},
)
self.validate_all(
"TRIM(TRAILING FROM ' XXX ')",
write={
"mysql": "RTRIM(' XXX ')",
"postgres": "RTRIM(' XXX ')",
"hive": "RTRIM(' XXX ')",
"presto": "RTRIM(' XXX ')",
},
)
self.validate_all(
"SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss",
read={"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"},
)
self.validate_all(
"SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL",
read={
"postgres": "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL",
},
)
self.validate_all(
"SELECT p1.id, p2.id, v1, v2 FROM polygons AS p1, polygons AS p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id <> p2.id",
read={
"postgres": "SELECT p1.id, p2.id, v1, v2 FROM polygons p1, polygons p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id != p2.id",
},
)

View file

@ -0,0 +1,64 @@
from tests.dialects.test_dialect import Validator
class TestRedshift(Validator):
dialect = "redshift"
def test_redshift(self):
self.validate_all(
'create table "group" ("col" char(10))',
write={
"redshift": 'CREATE TABLE "group" ("col" CHAR(10))',
"mysql": "CREATE TABLE `group` (`col` CHAR(10))",
},
)
self.validate_all(
'create table if not exists city_slash_id("city/id" integer not null, state char(2) not null)',
write={
"redshift": 'CREATE TABLE IF NOT EXISTS city_slash_id ("city/id" INTEGER NOT NULL, state CHAR(2) NOT NULL)',
"presto": 'CREATE TABLE IF NOT EXISTS city_slash_id ("city/id" INTEGER NOT NULL, state CHAR(2) NOT NULL)',
},
)
self.validate_all(
"SELECT ST_AsEWKT(ST_GeomFromEWKT('SRID=4326;POINT(10 20)')::geography)",
write={
"redshift": "SELECT ST_ASEWKT(CAST(ST_GEOMFROMEWKT('SRID=4326;POINT(10 20)') AS GEOGRAPHY))",
"bigquery": "SELECT ST_ASEWKT(CAST(ST_GEOMFROMEWKT('SRID=4326;POINT(10 20)') AS GEOGRAPHY))",
},
)
self.validate_all(
"SELECT ST_AsEWKT(ST_GeogFromText('LINESTRING(110 40, 2 3, -10 80, -7 9)')::geometry)",
write={
"redshift": "SELECT ST_ASEWKT(CAST(ST_GEOGFROMTEXT('LINESTRING(110 40, 2 3, -10 80, -7 9)') AS GEOMETRY))",
},
)
self.validate_all(
"SELECT 'abc'::BINARY",
write={
"redshift": "SELECT CAST('abc' AS VARBYTE)",
},
)
self.validate_all(
"SELECT * FROM venue WHERE (venuecity, venuestate) IN (('Miami', 'FL'), ('Tampa', 'FL')) ORDER BY venueid",
write={
"redshift": "SELECT * FROM venue WHERE (venuecity, venuestate) IN (('Miami', 'FL'), ('Tampa', 'FL')) ORDER BY venueid",
},
)
self.validate_all(
'SELECT tablename, "column" FROM pg_table_def WHERE "column" LIKE \'%start\\_%\' LIMIT 5',
write={
"redshift": 'SELECT tablename, "column" FROM pg_table_def WHERE "column" LIKE \'%start\\\\_%\' LIMIT 5'
},
)
def test_identity(self):
self.validate_identity("CAST('bla' AS SUPER)")
self.validate_identity("CREATE TABLE real1 (realcol REAL)")
self.validate_identity("CAST('foo' AS HLLSKETCH)")
self.validate_identity("SELECT DATEADD(day, 1, 'today')")
self.validate_identity("'abc' SIMILAR TO '(b|c)%'")
self.validate_identity(
"SELECT caldate + INTERVAL '1 second' AS dateplus FROM date WHERE caldate = '12-31-2008'"
)
self.validate_identity("CREATE TABLE datetable (start_date DATE, end_date DATE)")
self.validate_identity("SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'")

View file

@ -143,3 +143,35 @@ class TestSnowflake(Validator):
"snowflake": r"SELECT 'a \' \\ \\t \\x21 z $ '", "snowflake": r"SELECT 'a \' \\ \\t \\x21 z $ '",
}, },
) )
def test_null_treatment(self):
self.validate_all(
r"SELECT FIRST_VALUE(TABLE1.COLUMN1) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1",
write={
"snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
},
)
self.validate_all(
r"SELECT FIRST_VALUE(TABLE1.COLUMN1 RESPECT NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1",
write={
"snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
},
)
self.validate_all(
r"SELECT FIRST_VALUE(TABLE1.COLUMN1) RESPECT NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1",
write={
"snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) RESPECT NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
},
)
self.validate_all(
r"SELECT FIRST_VALUE(TABLE1.COLUMN1 IGNORE NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1",
write={
"snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1 IGNORE NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
},
)
self.validate_all(
r"SELECT FIRST_VALUE(TABLE1.COLUMN1) IGNORE NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1",
write={
"snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) IGNORE NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
},
)

View file

@ -34,6 +34,7 @@ class TestSQLite(Validator):
write={ write={
"sqlite": "CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)", "sqlite": "CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)",
"mysql": "CREATE TABLE z (a INT UNIQUE PRIMARY KEY AUTO_INCREMENT)", "mysql": "CREATE TABLE z (a INT UNIQUE PRIMARY KEY AUTO_INCREMENT)",
"postgres": "CREATE TABLE z (a INT GENERATED BY DEFAULT AS IDENTITY NOT NULL UNIQUE PRIMARY KEY)",
}, },
) )
self.validate_all( self.validate_all(
@ -70,3 +71,20 @@ class TestSQLite(Validator):
"sqlite": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", "sqlite": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
}, },
) )
def test_hexadecimal_literal(self):
self.validate_all(
"SELECT 0XCC",
write={
"sqlite": "SELECT x'CC'",
"mysql": "SELECT x'CC'",
},
)
def test_window_null_treatment(self):
self.validate_all(
"SELECT FIRST_VALUE(Name) OVER (PARTITION BY AlbumId ORDER BY Bytes DESC) AS LargestTrack FROM tracks",
write={
"sqlite": "SELECT FIRST_VALUE(Name) OVER (PARTITION BY AlbumId ORDER BY Bytes DESC) AS LargestTrack FROM tracks"
},
)

View file

@ -318,6 +318,9 @@ SELECT 1 FROM a JOIN b ON a.foo = b.bar JOIN c ON a.foo = c.bar
SELECT 1 FROM a LEFT JOIN b ON a.foo = b.bar JOIN c ON a.foo = c.bar SELECT 1 FROM a LEFT JOIN b ON a.foo = b.bar JOIN c ON a.foo = c.bar
SELECT 1 FROM a LEFT INNER JOIN b ON a.foo = b.bar SELECT 1 FROM a LEFT INNER JOIN b ON a.foo = b.bar
SELECT 1 FROM a LEFT OUTER JOIN b ON a.foo = b.bar SELECT 1 FROM a LEFT OUTER JOIN b ON a.foo = b.bar
SELECT 1 FROM a NATURAL JOIN b
SELECT 1 FROM a NATURAL LEFT JOIN b
SELECT 1 FROM a NATURAL LEFT OUTER JOIN b
SELECT 1 FROM a OUTER JOIN b ON a.foo = b.bar SELECT 1 FROM a OUTER JOIN b ON a.foo = b.bar
SELECT 1 FROM a FULL JOIN b ON a.foo = b.bar SELECT 1 FROM a FULL JOIN b ON a.foo = b.bar
SELECT 1 UNION ALL SELECT 2 SELECT 1 UNION ALL SELECT 2
@ -329,6 +332,7 @@ SELECT 1 AS delete, 2 AS alter
SELECT * FROM (x) SELECT * FROM (x)
SELECT * FROM ((x)) SELECT * FROM ((x))
SELECT * FROM ((SELECT 1)) SELECT * FROM ((SELECT 1))
SELECT * FROM (x LATERAL VIEW EXPLODE(y) JOIN foo)
SELECT * FROM (SELECT 1) AS x SELECT * FROM (SELECT 1) AS x
SELECT * FROM (SELECT 1 UNION SELECT 2) AS x SELECT * FROM (SELECT 1 UNION SELECT 2) AS x
SELECT * FROM (SELECT 1 UNION ALL SELECT 2) AS x SELECT * FROM (SELECT 1 UNION ALL SELECT 2) AS x
@ -430,6 +434,7 @@ CREATE TEMPORARY VIEW x AS SELECT a FROM d
CREATE TEMPORARY VIEW IF NOT EXISTS x AS SELECT a FROM d CREATE TEMPORARY VIEW IF NOT EXISTS x AS SELECT a FROM d
CREATE TEMPORARY VIEW x AS WITH y AS (SELECT 1) SELECT * FROM y CREATE TEMPORARY VIEW x AS WITH y AS (SELECT 1) SELECT * FROM y
CREATE TABLE z (a INT, b VARCHAR, c VARCHAR(100), d DECIMAL(5, 3)) CREATE TABLE z (a INT, b VARCHAR, c VARCHAR(100), d DECIMAL(5, 3))
CREATE TABLE z (end INT)
CREATE TABLE z (a ARRAY<TEXT>, b MAP<TEXT, DOUBLE>, c DECIMAL(5, 3)) CREATE TABLE z (a ARRAY<TEXT>, b MAP<TEXT, DOUBLE>, c DECIMAL(5, 3))
CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECIMAL(5, 3)) CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECIMAL(5, 3))
CREATE TABLE z (a INT(11) DEFAULT UUID()) CREATE TABLE z (a INT(11) DEFAULT UUID())
@ -466,6 +471,7 @@ CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1
CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS WITH a AS (SELECT 1) SELECT a.* FROM a CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS WITH a AS (SELECT 1) SELECT a.* FROM a
CACHE LAZY TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a CACHE LAZY TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a
CACHE TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a CACHE TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a
CACHE TABLE x AS (SELECT 1 AS y)
CALL catalog.system.iceberg_procedure_name(named_arg_1 => 'arg_1', named_arg_2 => 'arg_2') CALL catalog.system.iceberg_procedure_name(named_arg_1 => 'arg_1', named_arg_2 => 'arg_2')
INSERT OVERWRITE TABLE a.b PARTITION(ds) SELECT x FROM y INSERT OVERWRITE TABLE a.b PARTITION(ds) SELECT x FROM y
INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD') SELECT x FROM y INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD') SELECT x FROM y
@ -512,3 +518,4 @@ SELECT ? AS ? FROM x WHERE b BETWEEN ? AND ? GROUP BY ?, 1 LIMIT ?
WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS foo, c.bar AS bar FROM c)) SELECT * FROM a WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS foo, c.bar AS bar FROM c)) SELECT * FROM a
WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a
SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z
SELECT ((SELECT 1) + 1)

View file

@ -0,0 +1,63 @@
-- Simple
SELECT a, b FROM (SELECT a, b FROM x);
SELECT x.a AS a, x.b AS b FROM x AS x;
-- Inner table alias is merged
SELECT a, b FROM (SELECT a, b FROM x AS q) AS r;
SELECT q.a AS a, q.b AS b FROM x AS q;
-- Double nesting
SELECT a, b FROM (SELECT a, b FROM (SELECT a, b FROM x));
SELECT x.a AS a, x.b AS b FROM x AS x;
-- WHERE clause is merged
SELECT a, SUM(b) FROM (SELECT a, b FROM x WHERE a > 1) GROUP BY a;
SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 GROUP BY x.a;
-- Outer query has join
SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1;
-- Join on derived table
SELECT a, c FROM x JOIN (SELECT b, c FROM y) AS y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
-- Inner query has a join
SELECT a, c FROM (SELECT a, c FROM x JOIN y ON x.b = y.b);
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
-- Inner query has conflicting name in outer query
SELECT a, c FROM (SELECT q.a, q.b FROM x AS q) AS x JOIN y AS q ON x.b = q.b;
SELECT q_2.a AS a, q.c AS c FROM x AS q_2 JOIN y AS q ON q_2.b = q.b;
-- Inner query has conflicting name in joined source
SELECT x.a, q.c FROM (SELECT a, x.b FROM x JOIN y AS q ON x.b = q.b) AS x JOIN y AS q ON x.b = q.b;
SELECT x.a AS a, q.c AS c FROM x AS x JOIN y AS q_2 ON x.b = q_2.b JOIN y AS q ON x.b = q.b;
-- Inner query has multiple conflicting names
SELECT x.a, q.c, r.c FROM (SELECT q.a, r.b FROM x AS q JOIN y AS r ON q.b = r.b) AS x JOIN y AS q ON x.b = q.b JOIN y AS r ON x.b = r.b;
SELECT q_2.a AS a, q.c AS c, r.c AS c FROM x AS q_2 JOIN y AS r_2 ON q_2.b = r_2.b JOIN y AS q ON r_2.b = q.b JOIN y AS r ON r_2.b = r.b;
-- Inner queries have conflicting names with each other
SELECT r.b FROM (SELECT b FROM x AS x) AS q JOIN (SELECT b FROM x) AS r ON q.b = r.b;
SELECT x_2.b AS b FROM x AS x JOIN x AS x_2 ON x.b = x_2.b;
-- WHERE clause in joined derived table is merged
SELECT x.a, y.c FROM x JOIN (SELECT b, c FROM y WHERE c > 1) AS y;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y WHERE y.c > 1;
-- WHERE clause in outer joined derived table is merged to ON clause
SELECT x.a, y.c FROM x LEFT JOIN (SELECT b, c FROM y WHERE c > 1) AS y;
SELECT x.a AS a, y.c AS c FROM x AS x LEFT JOIN y AS y ON y.c > 1;
-- Comma JOIN in outer query
SELECT x.a, y.c FROM (SELECT a FROM x) AS x, (SELECT c FROM y) AS y;
SELECT x.a AS a, y.c AS c FROM x AS x, y AS y;
-- Comma JOIN in inner query
SELECT x.a, x.c FROM (SELECT x.a, z.c FROM x, y AS z) AS x;
SELECT x.a AS a, z.c AS c FROM x AS x CROSS JOIN y AS z;
-- (Regression) Column in ORDER BY
SELECT * FROM (SELECT * FROM (SELECT * FROM x)) ORDER BY a LIMIT 1;
SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a LIMIT 1;

View file

@ -2,11 +2,7 @@ SELECT a, m FROM z LATERAL VIEW EXPLODE([1, 2]) q AS m;
SELECT SELECT
"z"."a" AS "a", "z"."a" AS "a",
"q"."m" AS "m" "q"."m" AS "m"
FROM ( FROM "z" AS "z"
SELECT
"z"."a" AS "a"
FROM "z" AS "z"
) AS "z"
LATERAL VIEW LATERAL VIEW
EXPLODE(ARRAY(1, 2)) q AS "m"; EXPLODE(ARRAY(1, 2)) q AS "m";
@ -91,41 +87,26 @@ FROM (
WHERE (TRUE AND TRUE OR 'a' = 'b') AND a > 1 WHERE (TRUE AND TRUE OR 'a' = 'b') AND a > 1
GROUP BY a; GROUP BY a;
SELECT SELECT
"d"."a" AS "a",
SUM("d"."b") AS "_col_1"
FROM (
SELECT
"x"."a" AS "a", "x"."a" AS "a",
"y"."b" AS "b" SUM("y"."b") AS "_col_1"
FROM ( FROM "x" AS "x"
SELECT LEFT JOIN (
"x"."a" AS "a"
FROM "x" AS "x"
WHERE
"x"."a" > 1
) AS "x"
LEFT JOIN (
SELECT SELECT
MAX("y"."b") AS "_col_0", MAX("y"."b") AS "_col_0",
"y"."a" AS "_u_1" "y"."a" AS "_u_1"
FROM "y" AS "y" FROM "y" AS "y"
GROUP BY GROUP BY
"y"."a" "y"."a"
) AS "_u_0" ) AS "_u_0"
ON "x"."a" = "_u_0"."_u_1" ON "x"."a" = "_u_0"."_u_1"
JOIN ( JOIN "y" AS "y"
SELECT
"y"."a" AS "a",
"y"."b" AS "b"
FROM "y" AS "y"
) AS "y"
ON "x"."a" = "y"."a" ON "x"."a" = "y"."a"
WHERE WHERE
"_u_0"."_col_0" >= 0 "_u_0"."_col_0" >= 0
AND "x"."a" > 1
AND NOT "_u_0"."_u_1" IS NULL AND NOT "_u_0"."_u_1" IS NULL
) AS "d"
GROUP BY GROUP BY
"d"."a"; "x"."a";
(SELECT a FROM x) LIMIT 1; (SELECT a FROM x) LIMIT 1;
( (

View file

@ -120,36 +120,16 @@ SELECT
"supplier"."s_address" AS "s_address", "supplier"."s_address" AS "s_address",
"supplier"."s_phone" AS "s_phone", "supplier"."s_phone" AS "s_phone",
"supplier"."s_comment" AS "s_comment" "supplier"."s_comment" AS "s_comment"
FROM ( FROM "part" AS "part"
SELECT
"part"."p_partkey" AS "p_partkey",
"part"."p_mfgr" AS "p_mfgr",
"part"."p_type" AS "p_type",
"part"."p_size" AS "p_size"
FROM "part" AS "part"
WHERE
"part"."p_size" = 15
AND "part"."p_type" LIKE '%BRASS'
) AS "part"
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
MIN("partsupp"."ps_supplycost") AS "_col_0", MIN("partsupp"."ps_supplycost") AS "_col_0",
"partsupp"."ps_partkey" AS "_u_1" "partsupp"."ps_partkey" AS "_u_1"
FROM "_e_0" AS "partsupp" FROM "_e_0" AS "partsupp"
CROSS JOIN "_e_1" AS "region" CROSS JOIN "_e_1" AS "region"
JOIN ( JOIN "nation" AS "nation"
SELECT
"nation"."n_nationkey" AS "n_nationkey",
"nation"."n_regionkey" AS "n_regionkey"
FROM "nation" AS "nation"
) AS "nation"
ON "nation"."n_regionkey" = "region"."r_regionkey" ON "nation"."n_regionkey" = "region"."r_regionkey"
JOIN ( JOIN "supplier" AS "supplier"
SELECT
"supplier"."s_suppkey" AS "s_suppkey",
"supplier"."s_nationkey" AS "s_nationkey"
FROM "supplier" AS "supplier"
) AS "supplier"
ON "supplier"."s_nationkey" = "nation"."n_nationkey" ON "supplier"."s_nationkey" = "nation"."n_nationkey"
AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey" AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey"
GROUP BY GROUP BY
@ -157,31 +137,17 @@ LEFT JOIN (
) AS "_u_0" ) AS "_u_0"
ON "part"."p_partkey" = "_u_0"."_u_1" ON "part"."p_partkey" = "_u_0"."_u_1"
CROSS JOIN "_e_1" AS "region" CROSS JOIN "_e_1" AS "region"
JOIN ( JOIN "nation" AS "nation"
SELECT
"nation"."n_nationkey" AS "n_nationkey",
"nation"."n_name" AS "n_name",
"nation"."n_regionkey" AS "n_regionkey"
FROM "nation" AS "nation"
) AS "nation"
ON "nation"."n_regionkey" = "region"."r_regionkey" ON "nation"."n_regionkey" = "region"."r_regionkey"
JOIN "_e_0" AS "partsupp" JOIN "_e_0" AS "partsupp"
ON "part"."p_partkey" = "partsupp"."ps_partkey" ON "part"."p_partkey" = "partsupp"."ps_partkey"
JOIN ( JOIN "supplier" AS "supplier"
SELECT
"supplier"."s_suppkey" AS "s_suppkey",
"supplier"."s_name" AS "s_name",
"supplier"."s_address" AS "s_address",
"supplier"."s_nationkey" AS "s_nationkey",
"supplier"."s_phone" AS "s_phone",
"supplier"."s_acctbal" AS "s_acctbal",
"supplier"."s_comment" AS "s_comment"
FROM "supplier" AS "supplier"
) AS "supplier"
ON "supplier"."s_nationkey" = "nation"."n_nationkey" ON "supplier"."s_nationkey" = "nation"."n_nationkey"
AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey" AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey"
WHERE WHERE
"partsupp"."ps_supplycost" = "_u_0"."_col_0" "part"."p_size" = 15
AND "part"."p_type" LIKE '%BRASS'
AND "partsupp"."ps_supplycost" = "_u_0"."_col_0"
AND NOT "_u_0"."_u_1" IS NULL AND NOT "_u_0"."_u_1" IS NULL
ORDER BY ORDER BY
"s_acctbal" DESC, "s_acctbal" DESC,
@ -224,36 +190,15 @@ SELECT
)) AS "revenue", )) AS "revenue",
CAST("orders"."o_orderdate" AS TEXT) AS "o_orderdate", CAST("orders"."o_orderdate" AS TEXT) AS "o_orderdate",
"orders"."o_shippriority" AS "o_shippriority" "orders"."o_shippriority" AS "o_shippriority"
FROM ( FROM "customer" AS "customer"
SELECT JOIN "orders" AS "orders"
"customer"."c_custkey" AS "c_custkey",
"customer"."c_mktsegment" AS "c_mktsegment"
FROM "customer" AS "customer"
WHERE
"customer"."c_mktsegment" = 'BUILDING'
) AS "customer"
JOIN (
SELECT
"orders"."o_orderkey" AS "o_orderkey",
"orders"."o_custkey" AS "o_custkey",
"orders"."o_orderdate" AS "o_orderdate",
"orders"."o_shippriority" AS "o_shippriority"
FROM "orders" AS "orders"
WHERE
"orders"."o_orderdate" < '1995-03-15'
) AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey" ON "customer"."c_custkey" = "orders"."o_custkey"
JOIN ( JOIN "lineitem" AS "lineitem"
SELECT
"lineitem"."l_orderkey" AS "l_orderkey",
"lineitem"."l_extendedprice" AS "l_extendedprice",
"lineitem"."l_discount" AS "l_discount",
"lineitem"."l_shipdate" AS "l_shipdate"
FROM "lineitem" AS "lineitem"
WHERE
"lineitem"."l_shipdate" > '1995-03-15'
) AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey" ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
WHERE
"customer"."c_mktsegment" = 'BUILDING'
AND "lineitem"."l_shipdate" > '1995-03-15'
AND "orders"."o_orderdate" < '1995-03-15'
GROUP BY GROUP BY
"lineitem"."l_orderkey", "lineitem"."l_orderkey",
"orders"."o_orderdate", "orders"."o_orderdate",
@ -342,57 +287,22 @@ SELECT
SUM("lineitem"."l_extendedprice" * ( SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount" 1 - "lineitem"."l_discount"
)) AS "revenue" )) AS "revenue"
FROM ( FROM "customer" AS "customer"
SELECT JOIN "orders" AS "orders"
"customer"."c_custkey" AS "c_custkey",
"customer"."c_nationkey" AS "c_nationkey"
FROM "customer" AS "customer"
) AS "customer"
JOIN (
SELECT
"orders"."o_orderkey" AS "o_orderkey",
"orders"."o_custkey" AS "o_custkey",
"orders"."o_orderdate" AS "o_orderdate"
FROM "orders" AS "orders"
WHERE
"orders"."o_orderdate" < CAST('1995-01-01' AS DATE)
AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE)
) AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey" ON "customer"."c_custkey" = "orders"."o_custkey"
CROSS JOIN ( CROSS JOIN "region" AS "region"
SELECT JOIN "nation" AS "nation"
"region"."r_regionkey" AS "r_regionkey",
"region"."r_name" AS "r_name"
FROM "region" AS "region"
WHERE
"region"."r_name" = 'ASIA'
) AS "region"
JOIN (
SELECT
"nation"."n_nationkey" AS "n_nationkey",
"nation"."n_name" AS "n_name",
"nation"."n_regionkey" AS "n_regionkey"
FROM "nation" AS "nation"
) AS "nation"
ON "nation"."n_regionkey" = "region"."r_regionkey" ON "nation"."n_regionkey" = "region"."r_regionkey"
JOIN ( JOIN "supplier" AS "supplier"
SELECT
"supplier"."s_suppkey" AS "s_suppkey",
"supplier"."s_nationkey" AS "s_nationkey"
FROM "supplier" AS "supplier"
) AS "supplier"
ON "customer"."c_nationkey" = "supplier"."s_nationkey" ON "customer"."c_nationkey" = "supplier"."s_nationkey"
AND "supplier"."s_nationkey" = "nation"."n_nationkey" AND "supplier"."s_nationkey" = "nation"."n_nationkey"
JOIN ( JOIN "lineitem" AS "lineitem"
SELECT
"lineitem"."l_orderkey" AS "l_orderkey",
"lineitem"."l_suppkey" AS "l_suppkey",
"lineitem"."l_extendedprice" AS "l_extendedprice",
"lineitem"."l_discount" AS "l_discount"
FROM "lineitem" AS "lineitem"
) AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey" ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
AND "lineitem"."l_suppkey" = "supplier"."s_suppkey" AND "lineitem"."l_suppkey" = "supplier"."s_suppkey"
WHERE
"orders"."o_orderdate" < CAST('1995-01-01' AS DATE)
AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE)
AND "region"."r_name" = 'ASIA'
GROUP BY GROUP BY
"nation"."n_name" "nation"."n_name"
ORDER BY ORDER BY
@ -471,53 +381,22 @@ WITH "_e_0" AS (
OR "nation"."n_name" = 'GERMANY' OR "nation"."n_name" = 'GERMANY'
) )
SELECT SELECT
"shipping"."supp_nation" AS "supp_nation",
"shipping"."cust_nation" AS "cust_nation",
"shipping"."l_year" AS "l_year",
SUM("shipping"."volume") AS "revenue"
FROM (
SELECT
"n1"."n_name" AS "supp_nation", "n1"."n_name" AS "supp_nation",
"n2"."n_name" AS "cust_nation", "n2"."n_name" AS "cust_nation",
EXTRACT(year FROM "lineitem"."l_shipdate") AS "l_year", EXTRACT(year FROM "lineitem"."l_shipdate") AS "l_year",
"lineitem"."l_extendedprice" * ( SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount" 1 - "lineitem"."l_discount"
) AS "volume" )) AS "revenue"
FROM ( FROM "supplier" AS "supplier"
SELECT JOIN "lineitem" AS "lineitem"
"supplier"."s_suppkey" AS "s_suppkey",
"supplier"."s_nationkey" AS "s_nationkey"
FROM "supplier" AS "supplier"
) AS "supplier"
JOIN (
SELECT
"lineitem"."l_orderkey" AS "l_orderkey",
"lineitem"."l_suppkey" AS "l_suppkey",
"lineitem"."l_extendedprice" AS "l_extendedprice",
"lineitem"."l_discount" AS "l_discount",
"lineitem"."l_shipdate" AS "l_shipdate"
FROM "lineitem" AS "lineitem"
WHERE
"lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
) AS "lineitem"
ON "supplier"."s_suppkey" = "lineitem"."l_suppkey" ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
JOIN ( JOIN "orders" AS "orders"
SELECT
"orders"."o_orderkey" AS "o_orderkey",
"orders"."o_custkey" AS "o_custkey"
FROM "orders" AS "orders"
) AS "orders"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey" ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
JOIN ( JOIN "customer" AS "customer"
SELECT
"customer"."c_custkey" AS "c_custkey",
"customer"."c_nationkey" AS "c_nationkey"
FROM "customer" AS "customer"
) AS "customer"
ON "customer"."c_custkey" = "orders"."o_custkey" ON "customer"."c_custkey" = "orders"."o_custkey"
JOIN "_e_0" AS "n1" JOIN "_e_0" AS "n1"
ON "supplier"."s_nationkey" = "n1"."n_nationkey" ON "supplier"."s_nationkey" = "n1"."n_nationkey"
JOIN "_e_0" AS "n2" JOIN "_e_0" AS "n2"
ON "customer"."c_nationkey" = "n2"."n_nationkey" ON "customer"."c_nationkey" = "n2"."n_nationkey"
AND ( AND (
"n1"."n_name" = 'FRANCE' "n1"."n_name" = 'FRANCE'
@ -527,11 +406,12 @@ FROM (
"n1"."n_name" = 'GERMANY' "n1"."n_name" = 'GERMANY'
OR "n2"."n_name" = 'GERMANY' OR "n2"."n_name" = 'GERMANY'
) )
) AS "shipping" WHERE
"lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
GROUP BY GROUP BY
"shipping"."supp_nation", "n1"."n_name",
"shipping"."cust_nation", "n2"."n_name",
"shipping"."l_year" EXTRACT(year FROM "lineitem"."l_shipdate")
ORDER BY ORDER BY
"supp_nation", "supp_nation",
"cust_nation", "cust_nation",
@ -578,87 +458,37 @@ group by
order by order by
o_year; o_year;
SELECT SELECT
"all_nations"."o_year" AS "o_year",
SUM(CASE
WHEN "all_nations"."nation" = 'BRAZIL'
THEN "all_nations"."volume"
ELSE 0
END) / SUM("all_nations"."volume") AS "mkt_share"
FROM (
SELECT
EXTRACT(year FROM "orders"."o_orderdate") AS "o_year", EXTRACT(year FROM "orders"."o_orderdate") AS "o_year",
"lineitem"."l_extendedprice" * ( SUM(CASE
WHEN "nation_2"."n_name" = 'BRAZIL'
THEN "lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount" 1 - "lineitem"."l_discount"
) AS "volume", )
"n2"."n_name" AS "nation" ELSE 0
FROM ( END) / SUM("lineitem"."l_extendedprice" * (
SELECT 1 - "lineitem"."l_discount"
"part"."p_partkey" AS "p_partkey", )) AS "mkt_share"
"part"."p_type" AS "p_type" FROM "part" AS "part"
FROM "part" AS "part" CROSS JOIN "region" AS "region"
WHERE JOIN "nation" AS "nation"
"part"."p_type" = 'ECONOMY ANODIZED STEEL' ON "nation"."n_regionkey" = "region"."r_regionkey"
) AS "part" JOIN "customer" AS "customer"
CROSS JOIN ( ON "customer"."c_nationkey" = "nation"."n_nationkey"
SELECT JOIN "orders" AS "orders"
"region"."r_regionkey" AS "r_regionkey",
"region"."r_name" AS "r_name"
FROM "region" AS "region"
WHERE
"region"."r_name" = 'AMERICA'
) AS "region"
JOIN (
SELECT
"nation"."n_nationkey" AS "n_nationkey",
"nation"."n_regionkey" AS "n_regionkey"
FROM "nation" AS "nation"
) AS "n1"
ON "n1"."n_regionkey" = "region"."r_regionkey"
JOIN (
SELECT
"customer"."c_custkey" AS "c_custkey",
"customer"."c_nationkey" AS "c_nationkey"
FROM "customer" AS "customer"
) AS "customer"
ON "customer"."c_nationkey" = "n1"."n_nationkey"
JOIN (
SELECT
"orders"."o_orderkey" AS "o_orderkey",
"orders"."o_custkey" AS "o_custkey",
"orders"."o_orderdate" AS "o_orderdate"
FROM "orders" AS "orders"
WHERE
"orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
) AS "orders"
ON "orders"."o_custkey" = "customer"."c_custkey" ON "orders"."o_custkey" = "customer"."c_custkey"
JOIN ( JOIN "lineitem" AS "lineitem"
SELECT
"lineitem"."l_orderkey" AS "l_orderkey",
"lineitem"."l_partkey" AS "l_partkey",
"lineitem"."l_suppkey" AS "l_suppkey",
"lineitem"."l_extendedprice" AS "l_extendedprice",
"lineitem"."l_discount" AS "l_discount"
FROM "lineitem" AS "lineitem"
) AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey" ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
AND "part"."p_partkey" = "lineitem"."l_partkey" AND "part"."p_partkey" = "lineitem"."l_partkey"
JOIN ( JOIN "supplier" AS "supplier"
SELECT
"supplier"."s_suppkey" AS "s_suppkey",
"supplier"."s_nationkey" AS "s_nationkey"
FROM "supplier" AS "supplier"
) AS "supplier"
ON "supplier"."s_suppkey" = "lineitem"."l_suppkey" ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
JOIN ( JOIN "nation" AS "nation_2"
SELECT ON "supplier"."s_nationkey" = "nation_2"."n_nationkey"
"nation"."n_nationkey" AS "n_nationkey", WHERE
"nation"."n_name" AS "n_name" "orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
FROM "nation" AS "nation" AND "part"."p_type" = 'ECONOMY ANODIZED STEEL'
) AS "n2" AND "region"."r_name" = 'AMERICA'
ON "supplier"."s_nationkey" = "n2"."n_nationkey"
) AS "all_nations"
GROUP BY GROUP BY
"all_nations"."o_year" EXTRACT(year FROM "orders"."o_orderdate")
ORDER BY ORDER BY
"o_year"; "o_year";
@ -698,69 +528,28 @@ order by
nation, nation,
o_year desc; o_year desc;
SELECT SELECT
"profit"."nation" AS "nation",
"profit"."o_year" AS "o_year",
SUM("profit"."amount") AS "sum_profit"
FROM (
SELECT
"nation"."n_name" AS "nation", "nation"."n_name" AS "nation",
EXTRACT(year FROM "orders"."o_orderdate") AS "o_year", EXTRACT(year FROM "orders"."o_orderdate") AS "o_year",
"lineitem"."l_extendedprice" * ( SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount" 1 - "lineitem"."l_discount"
) - "partsupp"."ps_supplycost" * "lineitem"."l_quantity" AS "amount" ) - "partsupp"."ps_supplycost" * "lineitem"."l_quantity") AS "sum_profit"
FROM ( FROM "part" AS "part"
SELECT JOIN "lineitem" AS "lineitem"
"part"."p_partkey" AS "p_partkey",
"part"."p_name" AS "p_name"
FROM "part" AS "part"
WHERE
"part"."p_name" LIKE '%green%'
) AS "part"
JOIN (
SELECT
"lineitem"."l_orderkey" AS "l_orderkey",
"lineitem"."l_partkey" AS "l_partkey",
"lineitem"."l_suppkey" AS "l_suppkey",
"lineitem"."l_quantity" AS "l_quantity",
"lineitem"."l_extendedprice" AS "l_extendedprice",
"lineitem"."l_discount" AS "l_discount"
FROM "lineitem" AS "lineitem"
) AS "lineitem"
ON "part"."p_partkey" = "lineitem"."l_partkey" ON "part"."p_partkey" = "lineitem"."l_partkey"
JOIN ( JOIN "supplier" AS "supplier"
SELECT
"supplier"."s_suppkey" AS "s_suppkey",
"supplier"."s_nationkey" AS "s_nationkey"
FROM "supplier" AS "supplier"
) AS "supplier"
ON "supplier"."s_suppkey" = "lineitem"."l_suppkey" ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
JOIN ( JOIN "partsupp" AS "partsupp"
SELECT
"partsupp"."ps_partkey" AS "ps_partkey",
"partsupp"."ps_suppkey" AS "ps_suppkey",
"partsupp"."ps_supplycost" AS "ps_supplycost"
FROM "partsupp" AS "partsupp"
) AS "partsupp"
ON "partsupp"."ps_partkey" = "lineitem"."l_partkey" ON "partsupp"."ps_partkey" = "lineitem"."l_partkey"
AND "partsupp"."ps_suppkey" = "lineitem"."l_suppkey" AND "partsupp"."ps_suppkey" = "lineitem"."l_suppkey"
JOIN ( JOIN "orders" AS "orders"
SELECT
"orders"."o_orderkey" AS "o_orderkey",
"orders"."o_orderdate" AS "o_orderdate"
FROM "orders" AS "orders"
) AS "orders"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey" ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
JOIN ( JOIN "nation" AS "nation"
SELECT
"nation"."n_nationkey" AS "n_nationkey",
"nation"."n_name" AS "n_name"
FROM "nation" AS "nation"
) AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey" ON "supplier"."s_nationkey" = "nation"."n_nationkey"
) AS "profit" WHERE
"part"."p_name" LIKE '%green%'
GROUP BY GROUP BY
"profit"."nation", "nation"."n_name",
"profit"."o_year" EXTRACT(year FROM "orders"."o_orderdate")
ORDER BY ORDER BY
"nation", "nation",
"o_year" DESC; "o_year" DESC;
@ -812,46 +601,17 @@ SELECT
"customer"."c_address" AS "c_address", "customer"."c_address" AS "c_address",
"customer"."c_phone" AS "c_phone", "customer"."c_phone" AS "c_phone",
"customer"."c_comment" AS "c_comment" "customer"."c_comment" AS "c_comment"
FROM ( FROM "customer" AS "customer"
SELECT JOIN "orders" AS "orders"
"customer"."c_custkey" AS "c_custkey",
"customer"."c_name" AS "c_name",
"customer"."c_address" AS "c_address",
"customer"."c_nationkey" AS "c_nationkey",
"customer"."c_phone" AS "c_phone",
"customer"."c_acctbal" AS "c_acctbal",
"customer"."c_comment" AS "c_comment"
FROM "customer" AS "customer"
) AS "customer"
JOIN (
SELECT
"orders"."o_orderkey" AS "o_orderkey",
"orders"."o_custkey" AS "o_custkey",
"orders"."o_orderdate" AS "o_orderdate"
FROM "orders" AS "orders"
WHERE
"orders"."o_orderdate" < CAST('1994-01-01' AS DATE)
AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE)
) AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey" ON "customer"."c_custkey" = "orders"."o_custkey"
JOIN ( JOIN "lineitem" AS "lineitem"
SELECT
"lineitem"."l_orderkey" AS "l_orderkey",
"lineitem"."l_extendedprice" AS "l_extendedprice",
"lineitem"."l_discount" AS "l_discount",
"lineitem"."l_returnflag" AS "l_returnflag"
FROM "lineitem" AS "lineitem"
WHERE
"lineitem"."l_returnflag" = 'R'
) AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey" ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
JOIN ( JOIN "nation" AS "nation"
SELECT
"nation"."n_nationkey" AS "n_nationkey",
"nation"."n_name" AS "n_name"
FROM "nation" AS "nation"
) AS "nation"
ON "customer"."c_nationkey" = "nation"."n_nationkey" ON "customer"."c_nationkey" = "nation"."n_nationkey"
WHERE
"lineitem"."l_returnflag" = 'R'
AND "orders"."o_orderdate" < CAST('1994-01-01' AS DATE)
AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE)
GROUP BY GROUP BY
"customer"."c_custkey", "customer"."c_custkey",
"customer"."c_name", "customer"."c_name",
@ -910,14 +670,7 @@ WITH "_e_0" AS (
SELECT SELECT
"partsupp"."ps_partkey" AS "ps_partkey", "partsupp"."ps_partkey" AS "ps_partkey",
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value" SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value"
FROM ( FROM "partsupp" AS "partsupp"
SELECT
"partsupp"."ps_partkey" AS "ps_partkey",
"partsupp"."ps_suppkey" AS "ps_suppkey",
"partsupp"."ps_availqty" AS "ps_availqty",
"partsupp"."ps_supplycost" AS "ps_supplycost"
FROM "partsupp" AS "partsupp"
) AS "partsupp"
JOIN "_e_0" AS "supplier" JOIN "_e_0" AS "supplier"
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
JOIN "_e_1" AS "nation" JOIN "_e_1" AS "nation"
@ -928,13 +681,7 @@ HAVING
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > ( SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > (
SELECT SELECT
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0" SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0"
FROM (
SELECT
"partsupp"."ps_suppkey" AS "ps_suppkey",
"partsupp"."ps_availqty" AS "ps_availqty",
"partsupp"."ps_supplycost" AS "ps_supplycost"
FROM "partsupp" AS "partsupp" FROM "partsupp" AS "partsupp"
) AS "partsupp"
JOIN "_e_0" AS "supplier" JOIN "_e_0" AS "supplier"
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
JOIN "_e_1" AS "nation" JOIN "_e_1" AS "nation"
@ -988,28 +735,15 @@ SELECT
THEN 1 THEN 1
ELSE 0 ELSE 0
END) AS "low_line_count" END) AS "low_line_count"
FROM ( FROM "orders" AS "orders"
SELECT JOIN "lineitem" AS "lineitem"
"orders"."o_orderkey" AS "o_orderkey", ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
"orders"."o_orderpriority" AS "o_orderpriority" WHERE
FROM "orders" AS "orders"
) AS "orders"
JOIN (
SELECT
"lineitem"."l_orderkey" AS "l_orderkey",
"lineitem"."l_shipdate" AS "l_shipdate",
"lineitem"."l_commitdate" AS "l_commitdate",
"lineitem"."l_receiptdate" AS "l_receiptdate",
"lineitem"."l_shipmode" AS "l_shipmode"
FROM "lineitem" AS "lineitem"
WHERE
"lineitem"."l_commitdate" < "lineitem"."l_receiptdate" "lineitem"."l_commitdate" < "lineitem"."l_receiptdate"
AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE) AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE)
AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE) AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE)
AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate" AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate"
AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP') AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP')
) AS "lineitem"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
GROUP BY GROUP BY
"lineitem"."l_shipmode" "lineitem"."l_shipmode"
ORDER BY ORDER BY
@ -1044,21 +778,10 @@ SELECT
FROM ( FROM (
SELECT SELECT
COUNT("orders"."o_orderkey") AS "c_count" COUNT("orders"."o_orderkey") AS "c_count"
FROM (
SELECT
"customer"."c_custkey" AS "c_custkey"
FROM "customer" AS "customer" FROM "customer" AS "customer"
) AS "customer" LEFT JOIN "orders" AS "orders"
LEFT JOIN (
SELECT
"orders"."o_orderkey" AS "o_orderkey",
"orders"."o_custkey" AS "o_custkey",
"orders"."o_comment" AS "o_comment"
FROM "orders" AS "orders"
WHERE
NOT "orders"."o_comment" LIKE '%special%requests%'
) AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey" ON "customer"."c_custkey" = "orders"."o_custkey"
AND NOT "orders"."o_comment" LIKE '%special%requests%'
GROUP BY GROUP BY
"customer"."c_custkey" "customer"."c_custkey"
) AS "c_orders" ) AS "c_orders"
@ -1094,24 +817,12 @@ SELECT
END) / SUM("lineitem"."l_extendedprice" * ( END) / SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount" 1 - "lineitem"."l_discount"
)) AS "promo_revenue" )) AS "promo_revenue"
FROM ( FROM "lineitem" AS "lineitem"
SELECT JOIN "part" AS "part"
"lineitem"."l_partkey" AS "l_partkey", ON "lineitem"."l_partkey" = "part"."p_partkey"
"lineitem"."l_extendedprice" AS "l_extendedprice", WHERE
"lineitem"."l_discount" AS "l_discount",
"lineitem"."l_shipdate" AS "l_shipdate"
FROM "lineitem" AS "lineitem"
WHERE
"lineitem"."l_shipdate" < CAST('1995-10-01' AS DATE) "lineitem"."l_shipdate" < CAST('1995-10-01' AS DATE)
AND "lineitem"."l_shipdate" >= CAST('1995-09-01' AS DATE) AND "lineitem"."l_shipdate" >= CAST('1995-09-01' AS DATE);
) AS "lineitem"
JOIN (
SELECT
"part"."p_partkey" AS "p_partkey",
"part"."p_type" AS "p_type"
FROM "part" AS "part"
) AS "part"
ON "lineitem"."l_partkey" = "part"."p_partkey";
-------------------------------------- --------------------------------------
-- TPC-H 15 -- TPC-H 15
@ -1165,14 +876,7 @@ SELECT
"supplier"."s_address" AS "s_address", "supplier"."s_address" AS "s_address",
"supplier"."s_phone" AS "s_phone", "supplier"."s_phone" AS "s_phone",
"revenue"."total_revenue" AS "total_revenue" "revenue"."total_revenue" AS "total_revenue"
FROM ( FROM "supplier" AS "supplier"
SELECT
"supplier"."s_suppkey" AS "s_suppkey",
"supplier"."s_name" AS "s_name",
"supplier"."s_address" AS "s_address",
"supplier"."s_phone" AS "s_phone"
FROM "supplier" AS "supplier"
) AS "supplier"
JOIN "revenue" JOIN "revenue"
ON "revenue"."total_revenue" = ( ON "revenue"."total_revenue" = (
SELECT SELECT
@ -1221,12 +925,7 @@ SELECT
"part"."p_type" AS "p_type", "part"."p_type" AS "p_type",
"part"."p_size" AS "p_size", "part"."p_size" AS "p_size",
COUNT(DISTINCT "partsupp"."ps_suppkey") AS "supplier_cnt" COUNT(DISTINCT "partsupp"."ps_suppkey") AS "supplier_cnt"
FROM ( FROM "partsupp" AS "partsupp"
SELECT
"partsupp"."ps_partkey" AS "ps_partkey",
"partsupp"."ps_suppkey" AS "ps_suppkey"
FROM "partsupp" AS "partsupp"
) AS "partsupp"
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
"supplier"."s_suppkey" AS "s_suppkey" "supplier"."s_suppkey" AS "s_suppkey"
@ -1237,21 +936,13 @@ LEFT JOIN (
"supplier"."s_suppkey" "supplier"."s_suppkey"
) AS "_u_0" ) AS "_u_0"
ON "partsupp"."ps_suppkey" = "_u_0"."s_suppkey" ON "partsupp"."ps_suppkey" = "_u_0"."s_suppkey"
JOIN ( JOIN "part" AS "part"
SELECT
"part"."p_partkey" AS "p_partkey",
"part"."p_brand" AS "p_brand",
"part"."p_type" AS "p_type",
"part"."p_size" AS "p_size"
FROM "part" AS "part"
WHERE
"part"."p_brand" <> 'Brand#45'
AND "part"."p_size" IN (49, 14, 23, 45, 19, 3, 36, 9)
AND NOT "part"."p_type" LIKE 'MEDIUM POLISHED%'
) AS "part"
ON "part"."p_partkey" = "partsupp"."ps_partkey" ON "part"."p_partkey" = "partsupp"."ps_partkey"
WHERE WHERE
"_u_0"."s_suppkey" IS NULL "_u_0"."s_suppkey" IS NULL
AND "part"."p_brand" <> 'Brand#45'
AND "part"."p_size" IN (49, 14, 23, 45, 19, 3, 36, 9)
AND NOT "part"."p_type" LIKE 'MEDIUM POLISHED%'
GROUP BY GROUP BY
"part"."p_brand", "part"."p_brand",
"part"."p_type", "part"."p_type",
@ -1284,23 +975,8 @@ where
); );
SELECT SELECT
SUM("lineitem"."l_extendedprice") / 7.0 AS "avg_yearly" SUM("lineitem"."l_extendedprice") / 7.0 AS "avg_yearly"
FROM ( FROM "lineitem" AS "lineitem"
SELECT JOIN "part" AS "part"
"lineitem"."l_partkey" AS "l_partkey",
"lineitem"."l_quantity" AS "l_quantity",
"lineitem"."l_extendedprice" AS "l_extendedprice"
FROM "lineitem" AS "lineitem"
) AS "lineitem"
JOIN (
SELECT
"part"."p_partkey" AS "p_partkey",
"part"."p_brand" AS "p_brand",
"part"."p_container" AS "p_container"
FROM "part" AS "part"
WHERE
"part"."p_brand" = 'Brand#23'
AND "part"."p_container" = 'MED BOX'
) AS "part"
ON "part"."p_partkey" = "lineitem"."l_partkey" ON "part"."p_partkey" = "lineitem"."l_partkey"
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
@ -1313,6 +989,8 @@ LEFT JOIN (
ON "_u_0"."_u_1" = "part"."p_partkey" ON "_u_0"."_u_1" = "part"."p_partkey"
WHERE WHERE
"lineitem"."l_quantity" < "_u_0"."_col_0" "lineitem"."l_quantity" < "_u_0"."_col_0"
AND "part"."p_brand" = 'Brand#23'
AND "part"."p_container" = 'MED BOX'
AND NOT "_u_0"."_u_1" IS NULL; AND NOT "_u_0"."_u_1" IS NULL;
-------------------------------------- --------------------------------------
@ -1359,20 +1037,8 @@ SELECT
"orders"."o_orderdate" AS "o_orderdate", "orders"."o_orderdate" AS "o_orderdate",
"orders"."o_totalprice" AS "o_totalprice", "orders"."o_totalprice" AS "o_totalprice",
SUM("lineitem"."l_quantity") AS "_col_5" SUM("lineitem"."l_quantity") AS "_col_5"
FROM ( FROM "customer" AS "customer"
SELECT JOIN "orders" AS "orders"
"customer"."c_custkey" AS "c_custkey",
"customer"."c_name" AS "c_name"
FROM "customer" AS "customer"
) AS "customer"
JOIN (
SELECT
"orders"."o_orderkey" AS "o_orderkey",
"orders"."o_custkey" AS "o_custkey",
"orders"."o_totalprice" AS "o_totalprice",
"orders"."o_orderdate" AS "o_orderdate"
FROM "orders" AS "orders"
) AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey" ON "customer"."c_custkey" = "orders"."o_custkey"
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
@ -1385,12 +1051,7 @@ LEFT JOIN (
SUM("lineitem"."l_quantity") > 300 SUM("lineitem"."l_quantity") > 300
) AS "_u_0" ) AS "_u_0"
ON "orders"."o_orderkey" = "_u_0"."l_orderkey" ON "orders"."o_orderkey" = "_u_0"."l_orderkey"
JOIN ( JOIN "lineitem" AS "lineitem"
SELECT
"lineitem"."l_orderkey" AS "l_orderkey",
"lineitem"."l_quantity" AS "l_quantity"
FROM "lineitem" AS "lineitem"
) AS "lineitem"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey" ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
WHERE WHERE
NOT "_u_0"."l_orderkey" IS NULL NOT "_u_0"."l_orderkey" IS NULL
@ -1447,24 +1108,8 @@ SELECT
SUM("lineitem"."l_extendedprice" * ( SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount" 1 - "lineitem"."l_discount"
)) AS "revenue" )) AS "revenue"
FROM ( FROM "lineitem" AS "lineitem"
SELECT JOIN "part" AS "part"
"lineitem"."l_partkey" AS "l_partkey",
"lineitem"."l_quantity" AS "l_quantity",
"lineitem"."l_extendedprice" AS "l_extendedprice",
"lineitem"."l_discount" AS "l_discount",
"lineitem"."l_shipinstruct" AS "l_shipinstruct",
"lineitem"."l_shipmode" AS "l_shipmode"
FROM "lineitem" AS "lineitem"
) AS "lineitem"
JOIN (
SELECT
"part"."p_partkey" AS "p_partkey",
"part"."p_brand" AS "p_brand",
"part"."p_size" AS "p_size",
"part"."p_container" AS "p_container"
FROM "part" AS "part"
) AS "part"
ON ( ON (
"part"."p_brand" = 'Brand#12' "part"."p_brand" = 'Brand#12'
AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG')
@ -1558,14 +1203,7 @@ order by
SELECT SELECT
"supplier"."s_name" AS "s_name", "supplier"."s_name" AS "s_name",
"supplier"."s_address" AS "s_address" "supplier"."s_address" AS "s_address"
FROM ( FROM "supplier" AS "supplier"
SELECT
"supplier"."s_suppkey" AS "s_suppkey",
"supplier"."s_name" AS "s_name",
"supplier"."s_address" AS "s_address",
"supplier"."s_nationkey" AS "s_nationkey"
FROM "supplier" AS "supplier"
) AS "supplier"
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
"partsupp"."ps_suppkey" AS "ps_suppkey" "partsupp"."ps_suppkey" AS "ps_suppkey"
@ -1604,17 +1242,11 @@ LEFT JOIN (
"partsupp"."ps_suppkey" "partsupp"."ps_suppkey"
) AS "_u_4" ) AS "_u_4"
ON "supplier"."s_suppkey" = "_u_4"."ps_suppkey" ON "supplier"."s_suppkey" = "_u_4"."ps_suppkey"
JOIN ( JOIN "nation" AS "nation"
SELECT
"nation"."n_nationkey" AS "n_nationkey",
"nation"."n_name" AS "n_name"
FROM "nation" AS "nation"
WHERE
"nation"."n_name" = 'CANADA'
) AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey" ON "supplier"."s_nationkey" = "nation"."n_nationkey"
WHERE WHERE
NOT "_u_4"."ps_suppkey" IS NULL "nation"."n_name" = 'CANADA'
AND NOT "_u_4"."ps_suppkey" IS NULL
ORDER BY ORDER BY
"s_name"; "s_name";
@ -1665,24 +1297,9 @@ limit
SELECT SELECT
"supplier"."s_name" AS "s_name", "supplier"."s_name" AS "s_name",
COUNT(*) AS "numwait" COUNT(*) AS "numwait"
FROM ( FROM "supplier" AS "supplier"
SELECT JOIN "lineitem" AS "lineitem"
"supplier"."s_suppkey" AS "s_suppkey", ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
"supplier"."s_name" AS "s_name",
"supplier"."s_nationkey" AS "s_nationkey"
FROM "supplier" AS "supplier"
) AS "supplier"
JOIN (
SELECT
"lineitem"."l_orderkey" AS "l_orderkey",
"lineitem"."l_suppkey" AS "l_suppkey",
"lineitem"."l_commitdate" AS "l_commitdate",
"lineitem"."l_receiptdate" AS "l_receiptdate"
FROM "lineitem" AS "lineitem"
WHERE
"lineitem"."l_receiptdate" > "lineitem"."l_commitdate"
) AS "l1"
ON "supplier"."s_suppkey" = "l1"."l_suppkey"
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
"l2"."l_orderkey" AS "l_orderkey", "l2"."l_orderkey" AS "l_orderkey",
@ -1691,7 +1308,7 @@ LEFT JOIN (
GROUP BY GROUP BY
"l2"."l_orderkey" "l2"."l_orderkey"
) AS "_u_0" ) AS "_u_0"
ON "_u_0"."l_orderkey" = "l1"."l_orderkey" ON "_u_0"."l_orderkey" = "lineitem"."l_orderkey"
LEFT JOIN ( LEFT JOIN (
SELECT SELECT
"l3"."l_orderkey" AS "l_orderkey", "l3"."l_orderkey" AS "l_orderkey",
@ -1702,31 +1319,20 @@ LEFT JOIN (
GROUP BY GROUP BY
"l3"."l_orderkey" "l3"."l_orderkey"
) AS "_u_2" ) AS "_u_2"
ON "_u_2"."l_orderkey" = "l1"."l_orderkey" ON "_u_2"."l_orderkey" = "lineitem"."l_orderkey"
JOIN ( JOIN "orders" AS "orders"
SELECT ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
"orders"."o_orderkey" AS "o_orderkey", JOIN "nation" AS "nation"
"orders"."o_orderstatus" AS "o_orderstatus"
FROM "orders" AS "orders"
WHERE
"orders"."o_orderstatus" = 'F'
) AS "orders"
ON "orders"."o_orderkey" = "l1"."l_orderkey"
JOIN (
SELECT
"nation"."n_nationkey" AS "n_nationkey",
"nation"."n_name" AS "n_name"
FROM "nation" AS "nation"
WHERE
"nation"."n_name" = 'SAUDI ARABIA'
) AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey" ON "supplier"."s_nationkey" = "nation"."n_nationkey"
WHERE WHERE
( (
"_u_2"."l_orderkey" IS NULL "_u_2"."l_orderkey" IS NULL
OR NOT ARRAY_ANY("_u_2"."_u_3", "_x" -> "_x" <> "l1"."l_suppkey") OR NOT ARRAY_ANY("_u_2"."_u_3", "_x" -> "_x" <> "lineitem"."l_suppkey")
) )
AND ARRAY_ANY("_u_0"."_u_1", "_x" -> "_x" <> "l1"."l_suppkey") AND "lineitem"."l_receiptdate" > "lineitem"."l_commitdate"
AND "nation"."n_name" = 'SAUDI ARABIA'
AND "orders"."o_orderstatus" = 'F'
AND ARRAY_ANY("_u_0"."_u_1", "_x" -> "_x" <> "lineitem"."l_suppkey")
AND NOT "_u_0"."l_orderkey" IS NULL AND NOT "_u_0"."l_orderkey" IS NULL
GROUP BY GROUP BY
"supplier"."s_name" "supplier"."s_name"
@ -1776,23 +1382,19 @@ group by
order by order by
cntrycode; cntrycode;
SELECT SELECT
"custsale"."cntrycode" AS "cntrycode",
COUNT(*) AS "numcust",
SUM("custsale"."c_acctbal") AS "totacctbal"
FROM (
SELECT
SUBSTRING("customer"."c_phone", 1, 2) AS "cntrycode", SUBSTRING("customer"."c_phone", 1, 2) AS "cntrycode",
"customer"."c_acctbal" AS "c_acctbal" COUNT(*) AS "numcust",
FROM "customer" AS "customer" SUM("customer"."c_acctbal") AS "totacctbal"
LEFT JOIN ( FROM "customer" AS "customer"
LEFT JOIN (
SELECT SELECT
"orders"."o_custkey" AS "_u_1" "orders"."o_custkey" AS "_u_1"
FROM "orders" AS "orders" FROM "orders" AS "orders"
GROUP BY GROUP BY
"orders"."o_custkey" "orders"."o_custkey"
) AS "_u_0" ) AS "_u_0"
ON "_u_0"."_u_1" = "customer"."c_custkey" ON "_u_0"."_u_1" = "customer"."c_custkey"
WHERE WHERE
"_u_0"."_u_1" IS NULL "_u_0"."_u_1" IS NULL
AND "customer"."c_acctbal" > ( AND "customer"."c_acctbal" > (
SELECT SELECT
@ -1803,8 +1405,7 @@ FROM (
AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17') AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
) )
AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17') AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
) AS "custsale"
GROUP BY GROUP BY
"custsale"."cntrycode" SUBSTRING("customer"."c_phone", 1, 2)
ORDER BY ORDER BY
"cntrycode"; "cntrycode";

View file

@ -5,9 +5,7 @@ FIXTURES_DIR = os.path.join(FILE_DIR, "fixtures")
def _filter_comments(s): def _filter_comments(s):
return "\n".join( return "\n".join([line for line in s.splitlines() if line and not line.startswith("--")])
[line for line in s.splitlines() if line and not line.startswith("--")]
)
def _extract_meta(sql): def _extract_meta(sql):
@ -23,9 +21,7 @@ def _extract_meta(sql):
def assert_logger_contains(message, logger, level="error"): def assert_logger_contains(message, logger, level="error"):
output = "\n".join( output = "\n".join(str(args[0][0]) for args in getattr(logger, level).call_args_list)
str(args[0][0]) for args in getattr(logger, level).call_args_list
)
assert message in output assert message in output

View file

@ -46,10 +46,7 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl WHERE FALSE", "SELECT x FROM tbl WHERE FALSE",
), ),
( (
lambda: select("x") lambda: select("x").from_("tbl").where("x > 0").where("x < 9", append=False),
.from_("tbl")
.where("x > 0")
.where("x < 9", append=False),
"SELECT x FROM tbl WHERE x < 9", "SELECT x FROM tbl WHERE x < 9",
), ),
( (
@ -61,10 +58,7 @@ class TestBuild(unittest.TestCase):
"SELECT x, y FROM tbl GROUP BY x, y", "SELECT x, y FROM tbl GROUP BY x, y",
), ),
( (
lambda: select("x", "y", "z", "a") lambda: select("x", "y", "z", "a").from_("tbl").group_by("x, y", "z").group_by("a"),
.from_("tbl")
.group_by("x, y", "z")
.group_by("a"),
"SELECT x, y, z, a FROM tbl GROUP BY x, y, z, a", "SELECT x, y, z, a FROM tbl GROUP BY x, y, z, a",
), ),
( (
@ -85,9 +79,7 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl JOIN tbl2 ON tbl.y = tbl2.y", "SELECT x FROM tbl JOIN tbl2 ON tbl.y = tbl2.y",
), ),
( (
lambda: select("x") lambda: select("x").from_("tbl").join("tbl2", on=["tbl.y = tbl2.y", "a = b"]),
.from_("tbl")
.join("tbl2", on=["tbl.y = tbl2.y", "a = b"]),
"SELECT x FROM tbl JOIN tbl2 ON tbl.y = tbl2.y AND a = b", "SELECT x FROM tbl JOIN tbl2 ON tbl.y = tbl2.y AND a = b",
), ),
( (
@ -95,21 +87,15 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl LEFT OUTER JOIN tbl2", "SELECT x FROM tbl LEFT OUTER JOIN tbl2",
), ),
( (
lambda: select("x") lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer"),
.from_("tbl")
.join(exp.Table(this="tbl2"), join_type="left outer"),
"SELECT x FROM tbl LEFT OUTER JOIN tbl2", "SELECT x FROM tbl LEFT OUTER JOIN tbl2",
), ),
( (
lambda: select("x") lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"),
.from_("tbl")
.join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"),
"SELECT x FROM tbl LEFT OUTER JOIN tbl2 AS foo", "SELECT x FROM tbl LEFT OUTER JOIN tbl2 AS foo",
), ),
( (
lambda: select("x") lambda: select("x").from_("tbl").join(select("y").from_("tbl2"), join_type="left outer"),
.from_("tbl")
.join(select("y").from_("tbl2"), join_type="left outer"),
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2)", "SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2)",
), ),
( (
@ -132,9 +118,7 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2) AS aliased", "SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2) AS aliased",
), ),
( (
lambda: select("x") lambda: select("x").from_("tbl").join(parse_one("left join x", into=exp.Join), on="a=b"),
.from_("tbl")
.join(parse_one("left join x", into=exp.Join), on="a=b"),
"SELECT x FROM tbl LEFT JOIN x ON a = b", "SELECT x FROM tbl LEFT JOIN x ON a = b",
), ),
( (
@ -142,9 +126,7 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl LEFT JOIN x ON a = b", "SELECT x FROM tbl LEFT JOIN x ON a = b",
), ),
( (
lambda: select("x") lambda: select("x").from_("tbl").join("select b from tbl2", on="a=b", join_type="left"),
.from_("tbl")
.join("select b from tbl2", on="a=b", join_type="left"),
"SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) ON a = b", "SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) ON a = b",
), ),
( (
@ -159,10 +141,7 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) AS aliased ON a = b", "SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) AS aliased ON a = b",
), ),
( (
lambda: select("x", "COUNT(y)") lambda: select("x", "COUNT(y)").from_("tbl").group_by("x").having("COUNT(y) > 0"),
.from_("tbl")
.group_by("x")
.having("COUNT(y) > 0"),
"SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 0", "SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 0",
), ),
( (
@ -190,24 +169,15 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl SORT BY x, y DESC", "SELECT x FROM tbl SORT BY x, y DESC",
), ),
( (
lambda: select("x", "y", "z", "a") lambda: select("x", "y", "z", "a").from_("tbl").order_by("x, y", "z").order_by("a"),
.from_("tbl")
.order_by("x, y", "z")
.order_by("a"),
"SELECT x, y, z, a FROM tbl ORDER BY x, y, z, a", "SELECT x, y, z, a FROM tbl ORDER BY x, y, z, a",
), ),
( (
lambda: select("x", "y", "z", "a") lambda: select("x", "y", "z", "a").from_("tbl").cluster_by("x, y", "z").cluster_by("a"),
.from_("tbl")
.cluster_by("x, y", "z")
.cluster_by("a"),
"SELECT x, y, z, a FROM tbl CLUSTER BY x, y, z, a", "SELECT x, y, z, a FROM tbl CLUSTER BY x, y, z, a",
), ),
( (
lambda: select("x", "y", "z", "a") lambda: select("x", "y", "z", "a").from_("tbl").sort_by("x, y", "z").sort_by("a"),
.from_("tbl")
.sort_by("x, y", "z")
.sort_by("a"),
"SELECT x, y, z, a FROM tbl SORT BY x, y, z, a", "SELECT x, y, z, a FROM tbl SORT BY x, y, z, a",
), ),
(lambda: select("x").from_("tbl").limit(10), "SELECT x FROM tbl LIMIT 10"), (lambda: select("x").from_("tbl").limit(10), "SELECT x FROM tbl LIMIT 10"),
@ -220,21 +190,15 @@ class TestBuild(unittest.TestCase):
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
), ),
( (
lambda: select("x") lambda: select("x").from_("tbl").with_("tbl", as_="SELECT x FROM tbl2", recursive=True),
.from_("tbl")
.with_("tbl", as_="SELECT x FROM tbl2", recursive=True),
"WITH RECURSIVE tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", "WITH RECURSIVE tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
), ),
( (
lambda: select("x") lambda: select("x").from_("tbl").with_("tbl", as_=select("x").from_("tbl2")),
.from_("tbl")
.with_("tbl", as_=select("x").from_("tbl2")),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
), ),
( (
lambda: select("x") lambda: select("x").from_("tbl").with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")),
.from_("tbl")
.with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")),
"WITH tbl(x, y) AS (SELECT x, y FROM tbl2) SELECT x FROM tbl", "WITH tbl(x, y) AS (SELECT x, y FROM tbl2) SELECT x FROM tbl",
), ),
( (
@ -245,72 +209,43 @@ class TestBuild(unittest.TestCase):
"WITH tbl AS (SELECT x FROM tbl2), tbl2 AS (SELECT x FROM tbl3) SELECT x FROM tbl", "WITH tbl AS (SELECT x FROM tbl2), tbl2 AS (SELECT x FROM tbl3) SELECT x FROM tbl",
), ),
( (
lambda: select("x") lambda: select("x").from_("tbl").with_("tbl", as_=select("x", "y").from_("tbl2")).select("y"),
.from_("tbl")
.with_("tbl", as_=select("x", "y").from_("tbl2"))
.select("y"),
"WITH tbl AS (SELECT x, y FROM tbl2) SELECT x, y FROM tbl", "WITH tbl AS (SELECT x, y FROM tbl2) SELECT x, y FROM tbl",
), ),
( (
lambda: select("x") lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl"),
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
), ),
( (
lambda: select("x") lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").group_by("x"),
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.group_by("x"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl GROUP BY x", "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl GROUP BY x",
), ),
( (
lambda: select("x") lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").order_by("x"),
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.order_by("x"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl ORDER BY x", "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl ORDER BY x",
), ),
( (
lambda: select("x") lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").limit(10),
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.limit(10),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl LIMIT 10", "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl LIMIT 10",
), ),
( (
lambda: select("x") lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").offset(10),
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.offset(10),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl OFFSET 10", "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl OFFSET 10",
), ),
( (
lambda: select("x") lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").join("tbl3"),
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.join("tbl3"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl JOIN tbl3", "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl JOIN tbl3",
), ),
( (
lambda: select("x") lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").distinct(),
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.distinct(),
"WITH tbl AS (SELECT x FROM tbl2) SELECT DISTINCT x FROM tbl", "WITH tbl AS (SELECT x FROM tbl2) SELECT DISTINCT x FROM tbl",
), ),
( (
lambda: select("x") lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").where("x > 10"),
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.where("x > 10"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl WHERE x > 10", "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl WHERE x > 10",
), ),
( (
lambda: select("x") lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").having("x > 20"),
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.having("x > 20"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl HAVING x > 20", "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl HAVING x > 20",
), ),
(lambda: select("x").from_("tbl").subquery(), "(SELECT x FROM tbl)"), (lambda: select("x").from_("tbl").subquery(), "(SELECT x FROM tbl)"),
@ -324,9 +259,7 @@ class TestBuild(unittest.TestCase):
), ),
(lambda: from_("tbl").select("x"), "SELECT x FROM tbl"), (lambda: from_("tbl").select("x"), "SELECT x FROM tbl"),
( (
lambda: parse_one("SELECT a FROM tbl") lambda: parse_one("SELECT a FROM tbl").assert_is(exp.Select).select("b"),
.assert_is(exp.Select)
.select("b"),
"SELECT a, b FROM tbl", "SELECT a, b FROM tbl",
), ),
( (
@ -368,15 +301,11 @@ class TestBuild(unittest.TestCase):
"SELECT * FROM x WHERE y = 1 AND z = 1", "SELECT * FROM x WHERE y = 1 AND z = 1",
), ),
( (
lambda: exp.subquery("select x from tbl", "foo") lambda: exp.subquery("select x from tbl", "foo").select("x").where("x > 0"),
.select("x")
.where("x > 0"),
"SELECT x FROM (SELECT x FROM tbl) AS foo WHERE x > 0", "SELECT x FROM (SELECT x FROM tbl) AS foo WHERE x > 0",
), ),
( (
lambda: exp.subquery( lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select("x"),
"select x from tbl UNION select x from bar", "unioned"
).select("x"),
"SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned", "SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned",
), ),
]: ]:

View file

@ -27,10 +27,7 @@ class TestExecutor(unittest.TestCase):
) )
cls.cache = {} cls.cache = {}
cls.sqls = [ cls.sqls = [(sql, expected) for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")]
(sql, expected)
for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")
]
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
@ -50,7 +47,8 @@ class TestExecutor(unittest.TestCase):
self.assertEqual(Python().generate(parse_one("'x '''")), r"'x \''") self.assertEqual(Python().generate(parse_one("'x '''")), r"'x \''")
def test_optimized_tpch(self): def test_optimized_tpch(self):
for sql, optimized in self.sqls[0:20]: for i, (sql, optimized) in enumerate(self.sqls[:20], start=1):
with self.subTest(f"{i}, {sql}"):
a = self.cached_execute(sql) a = self.cached_execute(sql)
b = self.conn.execute(optimized).fetchdf() b = self.conn.execute(optimized).fetchdf()
self.rename_anonymous(b, a) self.rename_anonymous(b, a)
@ -59,9 +57,7 @@ class TestExecutor(unittest.TestCase):
def test_execute_tpch(self): def test_execute_tpch(self):
def to_csv(expression): def to_csv(expression):
if isinstance(expression, exp.Table): if isinstance(expression, exp.Table):
return parse_one( return parse_one(f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}")
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}"
)
return expression return expression
for sql, _ in self.sqls[0:3]: for sql, _ in self.sqls[0:3]:

View file

@ -26,9 +26,7 @@ class TestExpressions(unittest.TestCase):
parse_one("ROW() OVER(Partition by y)"), parse_one("ROW() OVER(Partition by y)"),
parse_one("ROW() OVER (partition BY y)"), parse_one("ROW() OVER (partition BY y)"),
) )
self.assertEqual( self.assertEqual(parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)"))
parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)")
)
def test_find(self): def test_find(self):
expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y") expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y")
@ -87,9 +85,7 @@ class TestExpressions(unittest.TestCase):
self.assertIsNone(column.find_ancestor(exp.Join)) self.assertIsNone(column.find_ancestor(exp.Join))
def test_alias_or_name(self): def test_alias_or_name(self):
expression = parse_one( expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
"SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
)
self.assertEqual( self.assertEqual(
[e.alias_or_name for e in expression.expressions], [e.alias_or_name for e in expression.expressions],
["a", "B", "e", "*", "zz", "z"], ["a", "B", "e", "*", "zz", "z"],
@ -118,9 +114,7 @@ class TestExpressions(unittest.TestCase):
) )
def test_named_selects(self): def test_named_selects(self):
expression = parse_one( expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
"SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
)
self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"]) self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"])
expression = parse_one( expression = parse_one(
@ -196,15 +190,9 @@ class TestExpressions(unittest.TestCase):
def test_sql(self): def test_sql(self):
self.assertEqual(parse_one("x + y * 2").sql(), "x + y * 2") self.assertEqual(parse_one("x + y * 2").sql(), "x + y * 2")
self.assertEqual( self.assertEqual(parse_one('select "x"').sql(dialect="hive", pretty=True), "SELECT\n `x`")
parse_one('select "x"').sql(dialect="hive", pretty=True), "SELECT\n `x`" self.assertEqual(parse_one("X + y").sql(identify=True, normalize=True), '"x" + "y"')
) self.assertEqual(parse_one("SUM(X)").sql(identify=True, normalize=True), 'SUM("x")')
self.assertEqual(
parse_one("X + y").sql(identify=True, normalize=True), '"x" + "y"'
)
self.assertEqual(
parse_one("SUM(X)").sql(identify=True, normalize=True), 'SUM("x")'
)
def test_transform_with_arguments(self): def test_transform_with_arguments(self):
expression = parse_one("a") expression = parse_one("a")
@ -229,15 +217,11 @@ class TestExpressions(unittest.TestCase):
return node return node
actual_expression_1 = expression.transform(fun) actual_expression_1 = expression.transform(fun)
self.assertEqual( self.assertEqual(actual_expression_1.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)")
actual_expression_1.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)"
)
self.assertIsNot(actual_expression_1, expression) self.assertIsNot(actual_expression_1, expression)
actual_expression_2 = expression.transform(fun, copy=False) actual_expression_2 = expression.transform(fun, copy=False)
self.assertEqual( self.assertEqual(actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)")
actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)"
)
self.assertIs(actual_expression_2, expression) self.assertIs(actual_expression_2, expression)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -274,12 +258,8 @@ class TestExpressions(unittest.TestCase):
expression = parse_one("SELECT * FROM (SELECT * FROM x)") expression = parse_one("SELECT * FROM (SELECT * FROM x)")
self.assertEqual(len(list(expression.walk())), 9) self.assertEqual(len(list(expression.walk())), 9)
self.assertEqual(len(list(expression.walk(bfs=False))), 9) self.assertEqual(len(list(expression.walk(bfs=False))), 9)
self.assertTrue( self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk()))
all(isinstance(e, exp.Expression) for e, _, _ in expression.walk()) self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False)))
)
self.assertTrue(
all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False))
)
def test_functions(self): def test_functions(self):
self.assertIsInstance(parse_one("ABS(a)"), exp.Abs) self.assertIsInstance(parse_one("ABS(a)"), exp.Abs)
@ -303,9 +283,7 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(parse_one("IF(a, b, c)"), exp.If) self.assertIsInstance(parse_one("IF(a, b, c)"), exp.If)
self.assertIsInstance(parse_one("INITCAP(a)"), exp.Initcap) self.assertIsInstance(parse_one("INITCAP(a)"), exp.Initcap)
self.assertIsInstance(parse_one("JSON_EXTRACT(a, '$.name')"), exp.JSONExtract) self.assertIsInstance(parse_one("JSON_EXTRACT(a, '$.name')"), exp.JSONExtract)
self.assertIsInstance( self.assertIsInstance(parse_one("JSON_EXTRACT_SCALAR(a, '$.name')"), exp.JSONExtractScalar)
parse_one("JSON_EXTRACT_SCALAR(a, '$.name')"), exp.JSONExtractScalar
)
self.assertIsInstance(parse_one("LEAST(a, b)"), exp.Least) self.assertIsInstance(parse_one("LEAST(a, b)"), exp.Least)
self.assertIsInstance(parse_one("LN(a)"), exp.Ln) self.assertIsInstance(parse_one("LN(a)"), exp.Ln)
self.assertIsInstance(parse_one("LOG10(a)"), exp.Log10) self.assertIsInstance(parse_one("LOG10(a)"), exp.Log10)
@ -334,6 +312,7 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(parse_one("TIME_STR_TO_DATE(a)"), exp.TimeStrToDate) self.assertIsInstance(parse_one("TIME_STR_TO_DATE(a)"), exp.TimeStrToDate)
self.assertIsInstance(parse_one("TIME_STR_TO_TIME(a)"), exp.TimeStrToTime) self.assertIsInstance(parse_one("TIME_STR_TO_TIME(a)"), exp.TimeStrToTime)
self.assertIsInstance(parse_one("TIME_STR_TO_UNIX(a)"), exp.TimeStrToUnix) self.assertIsInstance(parse_one("TIME_STR_TO_UNIX(a)"), exp.TimeStrToUnix)
self.assertIsInstance(parse_one("TRIM(LEADING 'b' FROM 'bla')"), exp.Trim)
self.assertIsInstance(parse_one("TS_OR_DS_ADD(a, 1, 'day')"), exp.TsOrDsAdd) self.assertIsInstance(parse_one("TS_OR_DS_ADD(a, 1, 'day')"), exp.TsOrDsAdd)
self.assertIsInstance(parse_one("TS_OR_DS_TO_DATE(a)"), exp.TsOrDsToDate) self.assertIsInstance(parse_one("TS_OR_DS_TO_DATE(a)"), exp.TsOrDsToDate)
self.assertIsInstance(parse_one("TS_OR_DS_TO_DATE_STR(a)"), exp.Substring) self.assertIsInstance(parse_one("TS_OR_DS_TO_DATE_STR(a)"), exp.Substring)
@ -404,12 +383,8 @@ class TestExpressions(unittest.TestCase):
self.assertFalse(exp.to_identifier("x").quoted) self.assertFalse(exp.to_identifier("x").quoted)
def test_function_normalizer(self): def test_function_normalizer(self):
self.assertEqual( self.assertEqual(parse_one("HELLO()").sql(normalize_functions="lower"), "hello()")
parse_one("HELLO()").sql(normalize_functions="lower"), "hello()" self.assertEqual(parse_one("hello()").sql(normalize_functions="upper"), "HELLO()")
)
self.assertEqual(
parse_one("hello()").sql(normalize_functions="upper"), "HELLO()"
)
self.assertEqual(parse_one("heLLO()").sql(normalize_functions=None), "heLLO()") self.assertEqual(parse_one("heLLO()").sql(normalize_functions=None), "heLLO()")
self.assertEqual(parse_one("SUM(x)").sql(normalize_functions="lower"), "sum(x)") self.assertEqual(parse_one("SUM(x)").sql(normalize_functions="lower"), "sum(x)")
self.assertEqual(parse_one("sum(x)").sql(normalize_functions="upper"), "SUM(x)") self.assertEqual(parse_one("sum(x)").sql(normalize_functions="upper"), "SUM(x)")

View file

@ -31,9 +31,7 @@ class TestOptimizer(unittest.TestCase):
dialect = meta.get("dialect") dialect = meta.get("dialect")
with self.subTest(sql): with self.subTest(sql):
self.assertEqual( self.assertEqual(
func(parse_one(sql, read=dialect), **kwargs).sql( func(parse_one(sql, read=dialect), **kwargs).sql(pretty=pretty, dialect=dialect),
pretty=pretty, dialect=dialect
),
expected, expected,
) )
@ -86,9 +84,7 @@ class TestOptimizer(unittest.TestCase):
for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"): for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"):
with self.subTest(sql): with self.subTest(sql):
with self.assertRaises(OptimizeError): with self.assertRaises(OptimizeError):
optimizer.qualify_columns.qualify_columns( optimizer.qualify_columns.qualify_columns(parse_one(sql), schema=self.schema)
parse_one(sql), schema=self.schema
)
def test_quote_identities(self): def test_quote_identities(self):
self.check_file("quote_identities", optimizer.quote_identities.quote_identities) self.check_file("quote_identities", optimizer.quote_identities.quote_identities)
@ -100,9 +96,7 @@ class TestOptimizer(unittest.TestCase):
expression = optimizer.pushdown_projections.pushdown_projections(expression) expression = optimizer.pushdown_projections.pushdown_projections(expression)
return expression return expression
self.check_file( self.check_file("pushdown_projections", pushdown_projections, schema=self.schema)
"pushdown_projections", pushdown_projections, schema=self.schema
)
def test_simplify(self): def test_simplify(self):
self.check_file("simplify", optimizer.simplify.simplify) self.check_file("simplify", optimizer.simplify.simplify)
@ -115,9 +109,7 @@ class TestOptimizer(unittest.TestCase):
) )
def test_pushdown_predicates(self): def test_pushdown_predicates(self):
self.check_file( self.check_file("pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates)
"pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates
)
def test_expand_multi_table_selects(self): def test_expand_multi_table_selects(self):
self.check_file( self.check_file(
@ -138,10 +130,17 @@ class TestOptimizer(unittest.TestCase):
pretty=True, pretty=True,
) )
def test_merge_derived_tables(self):
def optimize(expression, **kwargs):
expression = optimizer.qualify_tables.qualify_tables(expression)
expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
expression = optimizer.merge_derived_tables.merge_derived_tables(expression)
return expression
self.check_file("merge_derived_tables", optimize, schema=self.schema)
def test_tpch(self): def test_tpch(self):
self.check_file( self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True)
"tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True
)
def test_schema(self): def test_schema(self):
schema = ensure_schema( schema = ensure_schema(
@ -262,9 +261,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(len(scopes), 5) self.assertEqual(len(scopes), 5)
self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x") self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x")
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y") self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
self.assertEqual( self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b"
)
self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y") self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y")
self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql()) self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql())

View file

@ -16,28 +16,23 @@ class TestParser(unittest.TestCase):
self.assertIsInstance(parse_one("array<int>", into=exp.DataType), exp.DataType) self.assertIsInstance(parse_one("array<int>", into=exp.DataType), exp.DataType)
def test_column(self): def test_column(self):
columns = parse_one("select a, ARRAY[1] b, case when 1 then 1 end").find_all( columns = parse_one("select a, ARRAY[1] b, case when 1 then 1 end").find_all(exp.Column)
exp.Column
)
assert len(list(columns)) == 1 assert len(list(columns)) == 1
self.assertIsNotNone(parse_one("date").find(exp.Column)) self.assertIsNotNone(parse_one("date").find(exp.Column))
def test_table(self): def test_table(self):
tables = [ tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)]
t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)
]
self.assertEqual(tables, ["a", "b.c", "d"]) self.assertEqual(tables, ["a", "b.c", "d"])
def test_select(self): def test_select(self):
self.assertIsNotNone( self.assertIsNotNone(parse_one("select 1 natural"))
parse_one("select * from (select 1) x order by x.y").args["order"] self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"])
) self.assertIsNotNone(parse_one("select * from x where a = (select 1) order by x.y").args["order"])
self.assertIsNotNone( self.assertEqual(len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1)
parse_one("select * from x where a = (select 1) order by x.y").args["order"]
)
self.assertEqual( self.assertEqual(
len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1 parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(),
"""SELECT * FROM x, z LATERAL VIEW EXPLODE(y) CROSS JOIN y""",
) )
def test_command(self): def test_command(self):
@ -72,12 +67,8 @@ class TestParser(unittest.TestCase):
) )
assert len(expressions) == 2 assert len(expressions) == 2
assert ( assert expressions[0].args["from"].expressions[0].args["this"].args["this"] == "a"
expressions[0].args["from"].expressions[0].args["this"].args["this"] == "a" assert expressions[1].args["from"].expressions[0].args["this"].args["this"] == "b"
)
assert (
expressions[1].args["from"].expressions[0].args["this"].args["this"] == "b"
)
def test_expression(self): def test_expression(self):
ignore = Parser(error_level=ErrorLevel.IGNORE) ignore = Parser(error_level=ErrorLevel.IGNORE)
@ -147,13 +138,9 @@ class TestParser(unittest.TestCase):
def test_pretty_config_override(self): def test_pretty_config_override(self):
self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x") self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x")
with patch("sqlglot.pretty", True): with patch("sqlglot.pretty", True):
self.assertEqual( self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT\n col\nFROM x")
parse_one("SELECT col FROM x").sql(), "SELECT\n col\nFROM x"
)
self.assertEqual( self.assertEqual(parse_one("SELECT col FROM x").sql(pretty=True), "SELECT\n col\nFROM x")
parse_one("SELECT col FROM x").sql(pretty=True), "SELECT\n col\nFROM x"
)
@patch("sqlglot.parser.logger") @patch("sqlglot.parser.logger")
def test_comment_error_n(self, logger): def test_comment_error_n(self, logger):

View file

@ -42,6 +42,20 @@ class TestTranspile(unittest.TestCase):
"SELECT * FROM x WHERE a = ANY (SELECT 1)", "SELECT * FROM x WHERE a = ANY (SELECT 1)",
) )
def test_leading_comma(self):
self.validate(
"SELECT FOO, BAR, BAZ",
"SELECT\n FOO\n , BAR\n , BAZ",
leading_comma=True,
pretty=True,
)
# without pretty, this should be a no-op
self.validate(
"SELECT FOO, BAR, BAZ",
"SELECT FOO, BAR, BAZ",
leading_comma=True,
)
def test_space(self): def test_space(self):
self.validate("SELECT MIN(3)>MIN(2)", "SELECT MIN(3) > MIN(2)") self.validate("SELECT MIN(3)>MIN(2)", "SELECT MIN(3) > MIN(2)")
self.validate("SELECT MIN(3)>=MIN(2)", "SELECT MIN(3) >= MIN(2)") self.validate("SELECT MIN(3)>=MIN(2)", "SELECT MIN(3) >= MIN(2)")
@ -108,6 +122,11 @@ class TestTranspile(unittest.TestCase):
"extract(month from '2021-01-31'::timestamp without time zone)", "extract(month from '2021-01-31'::timestamp without time zone)",
"EXTRACT(month FROM CAST('2021-01-31' AS TIMESTAMP))", "EXTRACT(month FROM CAST('2021-01-31' AS TIMESTAMP))",
) )
self.validate("extract(week from current_date + 2)", "EXTRACT(week FROM CURRENT_DATE + 2)")
self.validate(
"EXTRACT(minute FROM datetime1 - datetime2)",
"EXTRACT(minute FROM datetime1 - datetime2)",
)
def test_if(self): def test_if(self):
self.validate( self.validate(
@ -122,18 +141,14 @@ class TestTranspile(unittest.TestCase):
"SELECT IF a > 1 THEN b ELSE c END", "SELECT IF a > 1 THEN b ELSE c END",
"SELECT CASE WHEN a > 1 THEN b ELSE c END", "SELECT CASE WHEN a > 1 THEN b ELSE c END",
) )
self.validate( self.validate("SELECT IF(a > 1, 1) FROM foo", "SELECT CASE WHEN a > 1 THEN 1 END FROM foo")
"SELECT IF(a > 1, 1) FROM foo", "SELECT CASE WHEN a > 1 THEN 1 END FROM foo"
)
def test_ignore_nulls(self): def test_ignore_nulls(self):
self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)") self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)")
def test_time(self): def test_time(self):
self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)") self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)")
self.validate( self.validate("TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)")
"TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)"
)
self.validate( self.validate(
"TIMESTAMP(9) WITH TIME ZONE '2020-01-01'", "TIMESTAMP(9) WITH TIME ZONE '2020-01-01'",
"CAST('2020-01-01' AS TIMESTAMPTZ(9))", "CAST('2020-01-01' AS TIMESTAMPTZ(9))",
@ -159,9 +174,7 @@ class TestTranspile(unittest.TestCase):
self.validate("DATE '2020-01-01'", "CAST('2020-01-01' AS DATE)") self.validate("DATE '2020-01-01'", "CAST('2020-01-01' AS DATE)")
self.validate("'2020-01-01'::DATE", "CAST('2020-01-01' AS DATE)") self.validate("'2020-01-01'::DATE", "CAST('2020-01-01' AS DATE)")
self.validate("STR_TO_TIME('x', 'y')", "STRPTIME('x', 'y')", write="duckdb") self.validate("STR_TO_TIME('x', 'y')", "STRPTIME('x', 'y')", write="duckdb")
self.validate( self.validate("STR_TO_UNIX('x', 'y')", "EPOCH(STRPTIME('x', 'y'))", write="duckdb")
"STR_TO_UNIX('x', 'y')", "EPOCH(STRPTIME('x', 'y'))", write="duckdb"
)
self.validate("TIME_TO_STR(x, 'y')", "STRFTIME(x, 'y')", write="duckdb") self.validate("TIME_TO_STR(x, 'y')", "STRFTIME(x, 'y')", write="duckdb")
self.validate("TIME_TO_UNIX(x)", "EPOCH(x)", write="duckdb") self.validate("TIME_TO_UNIX(x)", "EPOCH(x)", write="duckdb")
self.validate( self.validate(
@ -209,12 +222,8 @@ class TestTranspile(unittest.TestCase):
self.validate("TIME_STR_TO_DATE(x)", "TIME_STR_TO_DATE(x)", write=None) self.validate("TIME_STR_TO_DATE(x)", "TIME_STR_TO_DATE(x)", write=None)
self.validate("TIME_STR_TO_DATE(x)", "TO_DATE(x)", write="hive") self.validate("TIME_STR_TO_DATE(x)", "TO_DATE(x)", write="hive")
self.validate( self.validate("UNIX_TO_STR(x, 'yyyy-MM-dd HH:mm:ss')", "FROM_UNIXTIME(x)", write="hive")
"UNIX_TO_STR(x, 'yyyy-MM-dd HH:mm:ss')", "FROM_UNIXTIME(x)", write="hive" self.validate("STR_TO_UNIX(x, 'yyyy-MM-dd HH:mm:ss')", "UNIX_TIMESTAMP(x)", write="hive")
)
self.validate(
"STR_TO_UNIX(x, 'yyyy-MM-dd HH:mm:ss')", "UNIX_TIMESTAMP(x)", write="hive"
)
self.validate("IF(x > 1, x + 1)", "IF(x > 1, x + 1)", write="presto") self.validate("IF(x > 1, x + 1)", "IF(x > 1, x + 1)", write="presto")
self.validate("IF(x > 1, 1 + 1)", "IF(x > 1, 1 + 1)", write="hive") self.validate("IF(x > 1, 1 + 1)", "IF(x > 1, 1 + 1)", write="hive")
self.validate("IF(x > 1, 1, 0)", "IF(x > 1, 1, 0)", write="hive") self.validate("IF(x > 1, 1, 0)", "IF(x > 1, 1, 0)", write="hive")
@ -232,9 +241,7 @@ class TestTranspile(unittest.TestCase):
) )
self.validate("STR_TO_TIME('x', 'y')", "DATE_PARSE('x', 'y')", write="presto") self.validate("STR_TO_TIME('x', 'y')", "DATE_PARSE('x', 'y')", write="presto")
self.validate( self.validate("STR_TO_UNIX('x', 'y')", "TO_UNIXTIME(DATE_PARSE('x', 'y'))", write="presto")
"STR_TO_UNIX('x', 'y')", "TO_UNIXTIME(DATE_PARSE('x', 'y'))", write="presto"
)
self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="presto") self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="presto")
self.validate("TIME_TO_UNIX(x)", "TO_UNIXTIME(x)", write="presto") self.validate("TIME_TO_UNIX(x)", "TO_UNIXTIME(x)", write="presto")
self.validate( self.validate(
@ -245,9 +252,7 @@ class TestTranspile(unittest.TestCase):
self.validate("UNIX_TO_TIME(123)", "FROM_UNIXTIME(123)", write="presto") self.validate("UNIX_TO_TIME(123)", "FROM_UNIXTIME(123)", write="presto")
self.validate("STR_TO_TIME('x', 'y')", "TO_TIMESTAMP('x', 'y')", write="spark") self.validate("STR_TO_TIME('x', 'y')", "TO_TIMESTAMP('x', 'y')", write="spark")
self.validate( self.validate("STR_TO_UNIX('x', 'y')", "UNIX_TIMESTAMP('x', 'y')", write="spark")
"STR_TO_UNIX('x', 'y')", "UNIX_TIMESTAMP('x', 'y')", write="spark"
)
self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="spark") self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="spark")
self.validate( self.validate(
@ -283,9 +288,7 @@ class TestTranspile(unittest.TestCase):
def test_partial(self): def test_partial(self):
for sql in load_sql_fixtures("partial.sql"): for sql in load_sql_fixtures("partial.sql"):
with self.subTest(sql): with self.subTest(sql):
self.assertEqual( self.assertEqual(transpile(sql, error_level=ErrorLevel.IGNORE)[0], sql.strip())
transpile(sql, error_level=ErrorLevel.IGNORE)[0], sql.strip()
)
def test_pretty(self): def test_pretty(self):
for _, sql, pretty in load_sql_fixture_pairs("pretty.sql"): for _, sql, pretty in load_sql_fixture_pairs("pretty.sql"):