Merging upstream version 6.1.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
3c6d649c90
commit
08ecea3adf
61 changed files with 1844 additions and 1555 deletions
11
CHANGELOG.md
Normal file
11
CHANGELOG.md
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
Changelog
|
||||||
|
=========
|
||||||
|
|
||||||
|
v6.1.0
|
||||||
|
------
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
- New: mysql group\_concat separator [49a4099](https://github.com/tobymao/sqlglot/commit/49a4099adc93780eeffef8204af36559eab50a9f)
|
||||||
|
|
||||||
|
- Improvement: Better nested select parsing [45603f](https://github.com/tobymao/sqlglot/commit/45603f14bf9146dc3f8b330b85a0e25b77630b9b)
|
2
LICENSE
2
LICENSE
|
@ -1,6 +1,6 @@
|
||||||
MIT License
|
MIT License
|
||||||
|
|
||||||
Copyright (c) 2021 Toby Mao
|
Copyright (c) 2022 Toby Mao
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
|
|
@ -8,5 +8,5 @@ python -m autoflake -i -r \
|
||||||
--remove-unused-variables \
|
--remove-unused-variables \
|
||||||
sqlglot/ tests/
|
sqlglot/ tests/
|
||||||
python -m isort --profile black sqlglot/ tests/
|
python -m isort --profile black sqlglot/ tests/
|
||||||
python -m black sqlglot/ tests/
|
python -m black --line-length 120 sqlglot/ tests/
|
||||||
python -m unittest
|
python -m unittest
|
||||||
|
|
|
@ -20,7 +20,7 @@ from sqlglot.generator import Generator
|
||||||
from sqlglot.parser import Parser
|
from sqlglot.parser import Parser
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
from sqlglot.tokens import Tokenizer, TokenType
|
||||||
|
|
||||||
__version__ = "6.0.4"
|
__version__ = "6.1.1"
|
||||||
|
|
||||||
pretty = False
|
pretty = False
|
||||||
|
|
||||||
|
|
|
@ -49,12 +49,7 @@ args = parser.parse_args()
|
||||||
error_level = sqlglot.ErrorLevel[args.error_level.upper()]
|
error_level = sqlglot.ErrorLevel[args.error_level.upper()]
|
||||||
|
|
||||||
if args.parse:
|
if args.parse:
|
||||||
sqls = [
|
sqls = [repr(expression) for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)]
|
||||||
repr(expression)
|
|
||||||
for expression in sqlglot.parse(
|
|
||||||
args.sql, read=args.read, error_level=error_level
|
|
||||||
)
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
sqls = sqlglot.transpile(
|
sqls = sqlglot.transpile(
|
||||||
args.sql,
|
args.sql,
|
||||||
|
|
|
@ -7,6 +7,7 @@ from sqlglot.dialects.mysql import MySQL
|
||||||
from sqlglot.dialects.oracle import Oracle
|
from sqlglot.dialects.oracle import Oracle
|
||||||
from sqlglot.dialects.postgres import Postgres
|
from sqlglot.dialects.postgres import Postgres
|
||||||
from sqlglot.dialects.presto import Presto
|
from sqlglot.dialects.presto import Presto
|
||||||
|
from sqlglot.dialects.redshift import Redshift
|
||||||
from sqlglot.dialects.snowflake import Snowflake
|
from sqlglot.dialects.snowflake import Snowflake
|
||||||
from sqlglot.dialects.spark import Spark
|
from sqlglot.dialects.spark import Spark
|
||||||
from sqlglot.dialects.sqlite import SQLite
|
from sqlglot.dialects.sqlite import SQLite
|
||||||
|
|
|
@ -44,6 +44,7 @@ class BigQuery(Dialect):
|
||||||
]
|
]
|
||||||
IDENTIFIERS = ["`"]
|
IDENTIFIERS = ["`"]
|
||||||
ESCAPE = "\\"
|
ESCAPE = "\\"
|
||||||
|
HEX_STRINGS = [("0x", ""), ("0X", "")]
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**Tokenizer.KEYWORDS,
|
**Tokenizer.KEYWORDS,
|
||||||
|
@ -120,9 +121,5 @@ class BigQuery(Dialect):
|
||||||
|
|
||||||
def intersect_op(self, expression):
|
def intersect_op(self, expression):
|
||||||
if not expression.args.get("distinct", False):
|
if not expression.args.get("distinct", False):
|
||||||
self.unsupported(
|
self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery")
|
||||||
"INTERSECT without DISTINCT is not supported in BigQuery"
|
return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
|
||||||
)
|
|
||||||
return (
|
|
||||||
f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
|
|
||||||
)
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ class Dialects(str, Enum):
|
||||||
ORACLE = "oracle"
|
ORACLE = "oracle"
|
||||||
POSTGRES = "postgres"
|
POSTGRES = "postgres"
|
||||||
PRESTO = "presto"
|
PRESTO = "presto"
|
||||||
|
REDSHIFT = "redshift"
|
||||||
SNOWFLAKE = "snowflake"
|
SNOWFLAKE = "snowflake"
|
||||||
SPARK = "spark"
|
SPARK = "spark"
|
||||||
SQLITE = "sqlite"
|
SQLITE = "sqlite"
|
||||||
|
@ -53,12 +54,19 @@ class _Dialect(type):
|
||||||
klass.generator_class = getattr(klass, "Generator", Generator)
|
klass.generator_class = getattr(klass, "Generator", Generator)
|
||||||
|
|
||||||
klass.tokenizer = klass.tokenizer_class()
|
klass.tokenizer = klass.tokenizer_class()
|
||||||
klass.quote_start, klass.quote_end = list(klass.tokenizer_class.QUOTES.items())[
|
klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
|
||||||
0
|
klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0]
|
||||||
]
|
|
||||||
klass.identifier_start, klass.identifier_end = list(
|
if klass.tokenizer_class._BIT_STRINGS and exp.BitString not in klass.generator_class.TRANSFORMS:
|
||||||
klass.tokenizer_class.IDENTIFIERS.items()
|
bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
|
||||||
)[0]
|
klass.generator_class.TRANSFORMS[
|
||||||
|
exp.BitString
|
||||||
|
] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
|
||||||
|
if klass.tokenizer_class._HEX_STRINGS and exp.HexString not in klass.generator_class.TRANSFORMS:
|
||||||
|
hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
|
||||||
|
klass.generator_class.TRANSFORMS[
|
||||||
|
exp.HexString
|
||||||
|
] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
|
||||||
|
|
||||||
return klass
|
return klass
|
||||||
|
|
||||||
|
@ -122,9 +130,7 @@ class Dialect(metaclass=_Dialect):
|
||||||
return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql)
|
return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql)
|
||||||
|
|
||||||
def parse_into(self, expression_type, sql, **opts):
|
def parse_into(self, expression_type, sql, **opts):
|
||||||
return self.parser(**opts).parse_into(
|
return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql)
|
||||||
expression_type, self.tokenizer.tokenize(sql), sql
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate(self, expression, **opts):
|
def generate(self, expression, **opts):
|
||||||
return self.generator(**opts).generate(expression)
|
return self.generator(**opts).generate(expression)
|
||||||
|
@ -164,9 +170,7 @@ class Dialect(metaclass=_Dialect):
|
||||||
|
|
||||||
|
|
||||||
def rename_func(name):
|
def rename_func(name):
|
||||||
return (
|
return lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})"
|
||||||
lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def approx_count_distinct_sql(self, expression):
|
def approx_count_distinct_sql(self, expression):
|
||||||
|
@ -260,8 +264,7 @@ def format_time_lambda(exp_class, dialect, default=None):
|
||||||
return exp_class(
|
return exp_class(
|
||||||
this=list_get(args, 0),
|
this=list_get(args, 0),
|
||||||
format=Dialect[dialect].format_time(
|
format=Dialect[dialect].format_time(
|
||||||
list_get(args, 1)
|
list_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
|
||||||
or (Dialect[dialect].time_format if default is True else default)
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -63,10 +63,7 @@ def _sort_array_reverse(args):
|
||||||
|
|
||||||
|
|
||||||
def _struct_pack_sql(self, expression):
|
def _struct_pack_sql(self, expression):
|
||||||
args = [
|
args = [self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) for e in expression.expressions]
|
||||||
self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e)
|
|
||||||
for e in expression.expressions
|
|
||||||
]
|
|
||||||
return f"STRUCT_PACK({', '.join(args)})"
|
return f"STRUCT_PACK({', '.join(args)})"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -109,9 +109,7 @@ def _unnest_to_explode_sql(self, expression):
|
||||||
alias=exp.TableAlias(this=alias.this, columns=[column]),
|
alias=exp.TableAlias(this=alias.this, columns=[column]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for expression, column in zip(
|
for expression, column in zip(unnest.expressions, alias.columns if alias else [])
|
||||||
unnest.expressions, alias.columns if alias else []
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return self.join_sql(expression)
|
return self.join_sql(expression)
|
||||||
|
|
||||||
|
@ -206,14 +204,11 @@ class Hive(Dialect):
|
||||||
substr=list_get(args, 0),
|
substr=list_get(args, 0),
|
||||||
position=list_get(args, 2),
|
position=list_get(args, 2),
|
||||||
),
|
),
|
||||||
"LOG": (
|
"LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)),
|
||||||
lambda args: exp.Log.from_arg_list(args)
|
|
||||||
if len(args) > 1
|
|
||||||
else exp.Ln.from_arg_list(args)
|
|
||||||
),
|
|
||||||
"MAP": _parse_map,
|
"MAP": _parse_map,
|
||||||
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
|
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
|
||||||
"PERCENTILE": exp.Quantile.from_arg_list,
|
"PERCENTILE": exp.Quantile.from_arg_list,
|
||||||
|
"PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list,
|
||||||
"COLLECT_SET": exp.SetAgg.from_arg_list,
|
"COLLECT_SET": exp.SetAgg.from_arg_list,
|
||||||
"SIZE": exp.ArraySize.from_arg_list,
|
"SIZE": exp.ArraySize.from_arg_list,
|
||||||
"SPLIT": exp.RegexpSplit.from_arg_list,
|
"SPLIT": exp.RegexpSplit.from_arg_list,
|
||||||
|
@ -262,6 +257,7 @@ class Hive(Dialect):
|
||||||
HiveMap: _map_sql,
|
HiveMap: _map_sql,
|
||||||
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e.args['value'])}",
|
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e.args['value'])}",
|
||||||
exp.Quantile: rename_func("PERCENTILE"),
|
exp.Quantile: rename_func("PERCENTILE"),
|
||||||
|
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
|
||||||
exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
|
exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
|
||||||
exp.RegexpSplit: rename_func("SPLIT"),
|
exp.RegexpSplit: rename_func("SPLIT"),
|
||||||
exp.SafeDivide: no_safe_divide_sql,
|
exp.SafeDivide: no_safe_divide_sql,
|
||||||
|
@ -296,8 +292,7 @@ class Hive(Dialect):
|
||||||
|
|
||||||
def datatype_sql(self, expression):
|
def datatype_sql(self, expression):
|
||||||
if (
|
if (
|
||||||
expression.this
|
expression.this in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR)
|
||||||
in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR)
|
|
||||||
and not expression.expressions
|
and not expression.expressions
|
||||||
):
|
):
|
||||||
expression = exp.DataType.build("text")
|
expression = exp.DataType.build("text")
|
||||||
|
|
|
@ -49,6 +49,21 @@ def _str_to_date_sql(self, expression):
|
||||||
return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
|
return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
|
||||||
|
|
||||||
|
|
||||||
|
def _trim_sql(self, expression):
|
||||||
|
target = self.sql(expression, "this")
|
||||||
|
trim_type = self.sql(expression, "position")
|
||||||
|
remove_chars = self.sql(expression, "expression")
|
||||||
|
|
||||||
|
# Use TRIM/LTRIM/RTRIM syntax if the expression isn't mysql-specific
|
||||||
|
if not remove_chars:
|
||||||
|
return self.trim_sql(expression)
|
||||||
|
|
||||||
|
trim_type = f"{trim_type} " if trim_type else ""
|
||||||
|
remove_chars = f"{remove_chars} " if remove_chars else ""
|
||||||
|
from_part = "FROM " if trim_type or remove_chars else ""
|
||||||
|
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
|
||||||
|
|
||||||
|
|
||||||
def _date_add(expression_class):
|
def _date_add(expression_class):
|
||||||
def func(args):
|
def func(args):
|
||||||
interval = list_get(args, 1)
|
interval = list_get(args, 1)
|
||||||
|
@ -88,9 +103,12 @@ class MySQL(Dialect):
|
||||||
QUOTES = ["'", '"']
|
QUOTES = ["'", '"']
|
||||||
COMMENTS = ["--", "#", ("/*", "*/")]
|
COMMENTS = ["--", "#", ("/*", "*/")]
|
||||||
IDENTIFIERS = ["`"]
|
IDENTIFIERS = ["`"]
|
||||||
|
BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")]
|
||||||
|
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")]
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**Tokenizer.KEYWORDS,
|
**Tokenizer.KEYWORDS,
|
||||||
|
"SEPARATOR": TokenType.SEPARATOR,
|
||||||
"_ARMSCII8": TokenType.INTRODUCER,
|
"_ARMSCII8": TokenType.INTRODUCER,
|
||||||
"_ASCII": TokenType.INTRODUCER,
|
"_ASCII": TokenType.INTRODUCER,
|
||||||
"_BIG5": TokenType.INTRODUCER,
|
"_BIG5": TokenType.INTRODUCER,
|
||||||
|
@ -145,6 +163,15 @@ class MySQL(Dialect):
|
||||||
"STR_TO_DATE": _str_to_date,
|
"STR_TO_DATE": _str_to_date,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FUNCTION_PARSERS = {
|
||||||
|
**Parser.FUNCTION_PARSERS,
|
||||||
|
"GROUP_CONCAT": lambda self: self.expression(
|
||||||
|
exp.GroupConcat,
|
||||||
|
this=self._parse_lambda(),
|
||||||
|
separator=self._match(TokenType.SEPARATOR) and self._parse_field(),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
class Generator(Generator):
|
class Generator(Generator):
|
||||||
NULL_ORDERING_SUPPORTED = False
|
NULL_ORDERING_SUPPORTED = False
|
||||||
|
|
||||||
|
@ -158,6 +185,8 @@ class MySQL(Dialect):
|
||||||
exp.DateAdd: _date_add_sql("ADD"),
|
exp.DateAdd: _date_add_sql("ADD"),
|
||||||
exp.DateSub: _date_add_sql("SUB"),
|
exp.DateSub: _date_add_sql("SUB"),
|
||||||
exp.DateTrunc: _date_trunc_sql,
|
exp.DateTrunc: _date_trunc_sql,
|
||||||
|
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
|
||||||
exp.StrToDate: _str_to_date_sql,
|
exp.StrToDate: _str_to_date_sql,
|
||||||
exp.StrToTime: _str_to_date_sql,
|
exp.StrToTime: _str_to_date_sql,
|
||||||
|
exp.Trim: _trim_sql,
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,6 +51,14 @@ class Oracle(Dialect):
|
||||||
sep="",
|
sep="",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def alias_sql(self, expression):
|
||||||
|
if isinstance(expression.this, exp.Table):
|
||||||
|
to_sql = self.sql(expression, "alias")
|
||||||
|
# oracle does not allow "AS" between table and alias
|
||||||
|
to_sql = f" {to_sql}" if to_sql else ""
|
||||||
|
return f"{self.sql(expression, 'this')}{to_sql}"
|
||||||
|
return super().alias_sql(expression)
|
||||||
|
|
||||||
def offset_sql(self, expression):
|
def offset_sql(self, expression):
|
||||||
return f"{super().offset_sql(expression)} ROWS"
|
return f"{super().offset_sql(expression)} ROWS"
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import (
|
||||||
from sqlglot.generator import Generator
|
from sqlglot.generator import Generator
|
||||||
from sqlglot.parser import Parser
|
from sqlglot.parser import Parser
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
from sqlglot.tokens import Tokenizer, TokenType
|
||||||
|
from sqlglot.transforms import delegate, preprocess
|
||||||
|
|
||||||
|
|
||||||
def _date_add_sql(kind):
|
def _date_add_sql(kind):
|
||||||
|
@ -32,11 +33,96 @@ def _date_add_sql(kind):
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def _lateral_sql(self, expression):
|
||||||
|
this = self.sql(expression, "this")
|
||||||
|
if isinstance(expression.this, exp.Subquery):
|
||||||
|
return f"LATERAL{self.sep()}{this}"
|
||||||
|
alias = expression.args["alias"]
|
||||||
|
table = alias.name
|
||||||
|
table = f" {table}" if table else table
|
||||||
|
columns = self.expressions(alias, key="columns", flat=True)
|
||||||
|
columns = f" AS {columns}" if columns else ""
|
||||||
|
return f"LATERAL{self.sep()}{this}{table}{columns}"
|
||||||
|
|
||||||
|
|
||||||
|
def _substring_sql(self, expression):
|
||||||
|
this = self.sql(expression, "this")
|
||||||
|
start = self.sql(expression, "start")
|
||||||
|
length = self.sql(expression, "length")
|
||||||
|
|
||||||
|
from_part = f" FROM {start}" if start else ""
|
||||||
|
for_part = f" FOR {length}" if length else ""
|
||||||
|
|
||||||
|
return f"SUBSTRING({this}{from_part}{for_part})"
|
||||||
|
|
||||||
|
|
||||||
|
def _trim_sql(self, expression):
|
||||||
|
target = self.sql(expression, "this")
|
||||||
|
trim_type = self.sql(expression, "position")
|
||||||
|
remove_chars = self.sql(expression, "expression")
|
||||||
|
collation = self.sql(expression, "collation")
|
||||||
|
|
||||||
|
# Use TRIM/LTRIM/RTRIM syntax if the expression isn't postgres-specific
|
||||||
|
if not remove_chars and not collation:
|
||||||
|
return self.trim_sql(expression)
|
||||||
|
|
||||||
|
trim_type = f"{trim_type} " if trim_type else ""
|
||||||
|
remove_chars = f"{remove_chars} " if remove_chars else ""
|
||||||
|
from_part = "FROM " if trim_type or remove_chars else ""
|
||||||
|
collation = f" COLLATE {collation}" if collation else ""
|
||||||
|
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
|
||||||
|
|
||||||
|
|
||||||
|
def _auto_increment_to_serial(expression):
|
||||||
|
auto = expression.find(exp.AutoIncrementColumnConstraint)
|
||||||
|
|
||||||
|
if auto:
|
||||||
|
expression = expression.copy()
|
||||||
|
expression.args["constraints"].remove(auto.parent)
|
||||||
|
kind = expression.args["kind"]
|
||||||
|
|
||||||
|
if kind.this == exp.DataType.Type.INT:
|
||||||
|
kind.replace(exp.DataType(this=exp.DataType.Type.SERIAL))
|
||||||
|
elif kind.this == exp.DataType.Type.SMALLINT:
|
||||||
|
kind.replace(exp.DataType(this=exp.DataType.Type.SMALLSERIAL))
|
||||||
|
elif kind.this == exp.DataType.Type.BIGINT:
|
||||||
|
kind.replace(exp.DataType(this=exp.DataType.Type.BIGSERIAL))
|
||||||
|
|
||||||
|
return expression
|
||||||
|
|
||||||
|
|
||||||
|
def _serial_to_generated(expression):
|
||||||
|
kind = expression.args["kind"]
|
||||||
|
|
||||||
|
if kind.this == exp.DataType.Type.SERIAL:
|
||||||
|
data_type = exp.DataType(this=exp.DataType.Type.INT)
|
||||||
|
elif kind.this == exp.DataType.Type.SMALLSERIAL:
|
||||||
|
data_type = exp.DataType(this=exp.DataType.Type.SMALLINT)
|
||||||
|
elif kind.this == exp.DataType.Type.BIGSERIAL:
|
||||||
|
data_type = exp.DataType(this=exp.DataType.Type.BIGINT)
|
||||||
|
else:
|
||||||
|
data_type = None
|
||||||
|
|
||||||
|
if data_type:
|
||||||
|
expression = expression.copy()
|
||||||
|
expression.args["kind"].replace(data_type)
|
||||||
|
constraints = expression.args["constraints"]
|
||||||
|
generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False))
|
||||||
|
notnull = exp.ColumnConstraint(kind=exp.NotNullColumnConstraint())
|
||||||
|
if notnull not in constraints:
|
||||||
|
constraints.insert(0, notnull)
|
||||||
|
if generated not in constraints:
|
||||||
|
constraints.insert(0, generated)
|
||||||
|
|
||||||
|
return expression
|
||||||
|
|
||||||
|
|
||||||
class Postgres(Dialect):
|
class Postgres(Dialect):
|
||||||
null_ordering = "nulls_are_large"
|
null_ordering = "nulls_are_large"
|
||||||
time_format = "'YYYY-MM-DD HH24:MI:SS'"
|
time_format = "'YYYY-MM-DD HH24:MI:SS'"
|
||||||
time_mapping = {
|
time_mapping = {
|
||||||
"AM": "%p", # AM or PM
|
"AM": "%p",
|
||||||
|
"PM": "%p",
|
||||||
"D": "%w", # 1-based day of week
|
"D": "%w", # 1-based day of week
|
||||||
"DD": "%d", # day of month
|
"DD": "%d", # day of month
|
||||||
"DDD": "%j", # zero padded day of year
|
"DDD": "%j", # zero padded day of year
|
||||||
|
@ -65,14 +151,25 @@ class Postgres(Dialect):
|
||||||
}
|
}
|
||||||
|
|
||||||
class Tokenizer(Tokenizer):
|
class Tokenizer(Tokenizer):
|
||||||
|
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
|
||||||
|
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**Tokenizer.KEYWORDS,
|
**Tokenizer.KEYWORDS,
|
||||||
"SERIAL": TokenType.AUTO_INCREMENT,
|
"ALWAYS": TokenType.ALWAYS,
|
||||||
|
"BY DEFAULT": TokenType.BY_DEFAULT,
|
||||||
|
"IDENTITY": TokenType.IDENTITY,
|
||||||
|
"FOR": TokenType.FOR,
|
||||||
|
"GENERATED": TokenType.GENERATED,
|
||||||
|
"DOUBLE PRECISION": TokenType.DOUBLE,
|
||||||
|
"BIGSERIAL": TokenType.BIGSERIAL,
|
||||||
|
"SERIAL": TokenType.SERIAL,
|
||||||
|
"SMALLSERIAL": TokenType.SMALLSERIAL,
|
||||||
"UUID": TokenType.UUID,
|
"UUID": TokenType.UUID,
|
||||||
}
|
}
|
||||||
|
|
||||||
class Parser(Parser):
|
class Parser(Parser):
|
||||||
STRICT_CAST = False
|
STRICT_CAST = False
|
||||||
|
|
||||||
FUNCTIONS = {
|
FUNCTIONS = {
|
||||||
**Parser.FUNCTIONS,
|
**Parser.FUNCTIONS,
|
||||||
"TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"),
|
"TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"),
|
||||||
|
@ -86,14 +183,18 @@ class Postgres(Dialect):
|
||||||
exp.DataType.Type.FLOAT: "REAL",
|
exp.DataType.Type.FLOAT: "REAL",
|
||||||
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
|
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
|
||||||
exp.DataType.Type.BINARY: "BYTEA",
|
exp.DataType.Type.BINARY: "BYTEA",
|
||||||
}
|
exp.DataType.Type.DATETIME: "TIMESTAMP",
|
||||||
|
|
||||||
TOKEN_MAPPING = {
|
|
||||||
TokenType.AUTO_INCREMENT: "SERIAL",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**Generator.TRANSFORMS,
|
**Generator.TRANSFORMS,
|
||||||
|
exp.ColumnDef: preprocess(
|
||||||
|
[
|
||||||
|
_auto_increment_to_serial,
|
||||||
|
_serial_to_generated,
|
||||||
|
],
|
||||||
|
delegate("columndef_sql"),
|
||||||
|
),
|
||||||
exp.JSONExtract: arrow_json_extract_sql,
|
exp.JSONExtract: arrow_json_extract_sql,
|
||||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||||
exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}",
|
exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}",
|
||||||
|
@ -102,8 +203,11 @@ class Postgres(Dialect):
|
||||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||||
exp.DateAdd: _date_add_sql("+"),
|
exp.DateAdd: _date_add_sql("+"),
|
||||||
exp.DateSub: _date_add_sql("-"),
|
exp.DateSub: _date_add_sql("-"),
|
||||||
|
exp.Lateral: _lateral_sql,
|
||||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||||
|
exp.Substring: _substring_sql,
|
||||||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||||
exp.TableSample: no_tablesample_sql,
|
exp.TableSample: no_tablesample_sql,
|
||||||
|
exp.Trim: _trim_sql,
|
||||||
exp.TryCast: no_trycast_sql,
|
exp.TryCast: no_trycast_sql,
|
||||||
}
|
}
|
||||||
|
|
|
@ -96,9 +96,7 @@ def _ts_or_ds_to_date_sql(self, expression):
|
||||||
time_format = self.format_time(expression)
|
time_format = self.format_time(expression)
|
||||||
if time_format and time_format not in (Presto.time_format, Presto.date_format):
|
if time_format and time_format not in (Presto.time_format, Presto.date_format):
|
||||||
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
|
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
|
||||||
return (
|
return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
|
||||||
f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _ts_or_ds_add_sql(self, expression):
|
def _ts_or_ds_add_sql(self, expression):
|
||||||
|
@ -141,6 +139,7 @@ class Presto(Dialect):
|
||||||
"FROM_UNIXTIME": exp.UnixToTime.from_arg_list,
|
"FROM_UNIXTIME": exp.UnixToTime.from_arg_list,
|
||||||
"STRPOS": exp.StrPosition.from_arg_list,
|
"STRPOS": exp.StrPosition.from_arg_list,
|
||||||
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
|
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
|
||||||
|
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||||
}
|
}
|
||||||
|
|
||||||
class Generator(Generator):
|
class Generator(Generator):
|
||||||
|
@ -193,6 +192,7 @@ class Presto(Dialect):
|
||||||
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
|
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
|
||||||
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY = {self.sql(e.args['value'])}",
|
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY = {self.sql(e.args['value'])}",
|
||||||
exp.Quantile: _quantile_sql,
|
exp.Quantile: _quantile_sql,
|
||||||
|
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
|
||||||
exp.SafeDivide: no_safe_divide_sql,
|
exp.SafeDivide: no_safe_divide_sql,
|
||||||
exp.Schema: _schema_sql,
|
exp.Schema: _schema_sql,
|
||||||
exp.SortArray: _no_sort_array,
|
exp.SortArray: _no_sort_array,
|
||||||
|
|
34
sqlglot/dialects/redshift.py
Normal file
34
sqlglot/dialects/redshift.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
from sqlglot import exp
|
||||||
|
from sqlglot.dialects.postgres import Postgres
|
||||||
|
from sqlglot.tokens import TokenType
|
||||||
|
|
||||||
|
|
||||||
|
class Redshift(Postgres):
|
||||||
|
time_format = "'YYYY-MM-DD HH:MI:SS'"
|
||||||
|
time_mapping = {
|
||||||
|
**Postgres.time_mapping,
|
||||||
|
"MON": "%b",
|
||||||
|
"HH": "%H",
|
||||||
|
}
|
||||||
|
|
||||||
|
class Tokenizer(Postgres.Tokenizer):
|
||||||
|
ESCAPE = "\\"
|
||||||
|
|
||||||
|
KEYWORDS = {
|
||||||
|
**Postgres.Tokenizer.KEYWORDS,
|
||||||
|
"GEOMETRY": TokenType.GEOMETRY,
|
||||||
|
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||||
|
"HLLSKETCH": TokenType.HLLSKETCH,
|
||||||
|
"SUPER": TokenType.SUPER,
|
||||||
|
"TIME": TokenType.TIMESTAMP,
|
||||||
|
"TIMETZ": TokenType.TIMESTAMPTZ,
|
||||||
|
"VARBYTE": TokenType.BINARY,
|
||||||
|
"SIMILAR TO": TokenType.SIMILAR_TO,
|
||||||
|
}
|
||||||
|
|
||||||
|
class Generator(Postgres.Generator):
|
||||||
|
TYPE_MAPPING = {
|
||||||
|
**Postgres.Generator.TYPE_MAPPING,
|
||||||
|
exp.DataType.Type.BINARY: "VARBYTE",
|
||||||
|
exp.DataType.Type.INT: "INTEGER",
|
||||||
|
}
|
|
@ -23,9 +23,7 @@ def _snowflake_to_timestamp(args):
|
||||||
|
|
||||||
# case: <numeric_expr> [ , <scale> ]
|
# case: <numeric_expr> [ , <scale> ]
|
||||||
if second_arg.name not in ["0", "3", "9"]:
|
if second_arg.name not in ["0", "3", "9"]:
|
||||||
raise ValueError(
|
raise ValueError(f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9")
|
||||||
f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9"
|
|
||||||
)
|
|
||||||
|
|
||||||
if second_arg.name == "0":
|
if second_arg.name == "0":
|
||||||
timescale = exp.UnixToTime.SECONDS
|
timescale = exp.UnixToTime.SECONDS
|
||||||
|
|
|
@ -65,12 +65,11 @@ class Spark(Hive):
|
||||||
this=list_get(args, 0),
|
this=list_get(args, 0),
|
||||||
start=exp.Sub(
|
start=exp.Sub(
|
||||||
this=exp.Length(this=list_get(args, 0)),
|
this=exp.Length(this=list_get(args, 0)),
|
||||||
expression=exp.Add(
|
expression=exp.Add(this=list_get(args, 1), expression=exp.Literal.number(1)),
|
||||||
this=list_get(args, 1), expression=exp.Literal.number(1)
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
length=list_get(args, 1),
|
length=list_get(args, 1),
|
||||||
),
|
),
|
||||||
|
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||||
}
|
}
|
||||||
|
|
||||||
class Generator(Hive.Generator):
|
class Generator(Hive.Generator):
|
||||||
|
@ -82,11 +81,7 @@ class Spark(Hive):
|
||||||
}
|
}
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**{
|
**{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort}},
|
||||||
k: v
|
|
||||||
for k, v in Hive.Generator.TRANSFORMS.items()
|
|
||||||
if k not in {exp.ArraySort}
|
|
||||||
},
|
|
||||||
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||||
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
|
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
|
||||||
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
|
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
|
||||||
|
@ -102,5 +97,5 @@ class Spark(Hive):
|
||||||
HiveMap: _map_sql,
|
HiveMap: _map_sql,
|
||||||
}
|
}
|
||||||
|
|
||||||
def bitstring_sql(self, expression):
|
class Tokenizer(Hive.Tokenizer):
|
||||||
return f"X'{self.sql(expression, 'this')}'"
|
HEX_STRINGS = [("X'", "'")]
|
||||||
|
|
|
@ -16,6 +16,7 @@ from sqlglot.tokens import Tokenizer, TokenType
|
||||||
class SQLite(Dialect):
|
class SQLite(Dialect):
|
||||||
class Tokenizer(Tokenizer):
|
class Tokenizer(Tokenizer):
|
||||||
IDENTIFIERS = ['"', ("[", "]"), "`"]
|
IDENTIFIERS = ['"', ("[", "]"), "`"]
|
||||||
|
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**Tokenizer.KEYWORDS,
|
**Tokenizer.KEYWORDS,
|
||||||
|
|
|
@ -8,3 +8,6 @@ class Trino(Presto):
|
||||||
**Presto.Generator.TRANSFORMS,
|
**Presto.Generator.TRANSFORMS,
|
||||||
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class Tokenizer(Presto.Tokenizer):
|
||||||
|
HEX_STRINGS = [("X'", "'")]
|
||||||
|
|
|
@ -115,13 +115,8 @@ class ChangeDistiller:
|
||||||
for kept_source_node_id, kept_target_node_id in matching_set:
|
for kept_source_node_id, kept_target_node_id in matching_set:
|
||||||
source_node = self._source_index[kept_source_node_id]
|
source_node = self._source_index[kept_source_node_id]
|
||||||
target_node = self._target_index[kept_target_node_id]
|
target_node = self._target_index[kept_target_node_id]
|
||||||
if (
|
if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node:
|
||||||
not isinstance(source_node, LEAF_EXPRESSION_TYPES)
|
edit_script.extend(self._generate_move_edits(source_node, target_node, matching_set))
|
||||||
or source_node == target_node
|
|
||||||
):
|
|
||||||
edit_script.extend(
|
|
||||||
self._generate_move_edits(source_node, target_node, matching_set)
|
|
||||||
)
|
|
||||||
edit_script.append(Keep(source_node, target_node))
|
edit_script.append(Keep(source_node, target_node))
|
||||||
else:
|
else:
|
||||||
edit_script.append(Update(source_node, target_node))
|
edit_script.append(Update(source_node, target_node))
|
||||||
|
@ -132,9 +127,7 @@ class ChangeDistiller:
|
||||||
source_args = [id(e) for e in _expression_only_args(source)]
|
source_args = [id(e) for e in _expression_only_args(source)]
|
||||||
target_args = [id(e) for e in _expression_only_args(target)]
|
target_args = [id(e) for e in _expression_only_args(target)]
|
||||||
|
|
||||||
args_lcs = set(
|
args_lcs = set(_lcs(source_args, target_args, lambda l, r: (l, r) in matching_set))
|
||||||
_lcs(source_args, target_args, lambda l, r: (l, r) in matching_set)
|
|
||||||
)
|
|
||||||
|
|
||||||
move_edits = []
|
move_edits = []
|
||||||
for a in source_args:
|
for a in source_args:
|
||||||
|
@ -148,14 +141,10 @@ class ChangeDistiller:
|
||||||
matching_set = leaves_matching_set.copy()
|
matching_set = leaves_matching_set.copy()
|
||||||
|
|
||||||
ordered_unmatched_source_nodes = {
|
ordered_unmatched_source_nodes = {
|
||||||
id(n[0]): None
|
id(n[0]): None for n in self._source.bfs() if id(n[0]) in self._unmatched_source_nodes
|
||||||
for n in self._source.bfs()
|
|
||||||
if id(n[0]) in self._unmatched_source_nodes
|
|
||||||
}
|
}
|
||||||
ordered_unmatched_target_nodes = {
|
ordered_unmatched_target_nodes = {
|
||||||
id(n[0]): None
|
id(n[0]): None for n in self._target.bfs() if id(n[0]) in self._unmatched_target_nodes
|
||||||
for n in self._target.bfs()
|
|
||||||
if id(n[0]) in self._unmatched_target_nodes
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for source_node_id in ordered_unmatched_source_nodes:
|
for source_node_id in ordered_unmatched_source_nodes:
|
||||||
|
@ -169,18 +158,13 @@ class ChangeDistiller:
|
||||||
max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids))
|
max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids))
|
||||||
if max_leaves_num:
|
if max_leaves_num:
|
||||||
common_leaves_num = sum(
|
common_leaves_num = sum(
|
||||||
1 if s in source_leaf_ids and t in target_leaf_ids else 0
|
1 if s in source_leaf_ids and t in target_leaf_ids else 0 for s, t in leaves_matching_set
|
||||||
for s, t in leaves_matching_set
|
|
||||||
)
|
)
|
||||||
leaf_similarity_score = common_leaves_num / max_leaves_num
|
leaf_similarity_score = common_leaves_num / max_leaves_num
|
||||||
else:
|
else:
|
||||||
leaf_similarity_score = 0.0
|
leaf_similarity_score = 0.0
|
||||||
|
|
||||||
adjusted_t = (
|
adjusted_t = self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4
|
||||||
self.t
|
|
||||||
if min(len(source_leaf_ids), len(target_leaf_ids)) > 4
|
|
||||||
else 0.4
|
|
||||||
)
|
|
||||||
|
|
||||||
if leaf_similarity_score >= 0.8 or (
|
if leaf_similarity_score >= 0.8 or (
|
||||||
leaf_similarity_score >= adjusted_t
|
leaf_similarity_score >= adjusted_t
|
||||||
|
@ -217,10 +201,7 @@ class ChangeDistiller:
|
||||||
matching_set = set()
|
matching_set = set()
|
||||||
while candidate_matchings:
|
while candidate_matchings:
|
||||||
_, _, source_leaf, target_leaf = heappop(candidate_matchings)
|
_, _, source_leaf, target_leaf = heappop(candidate_matchings)
|
||||||
if (
|
if id(source_leaf) in self._unmatched_source_nodes and id(target_leaf) in self._unmatched_target_nodes:
|
||||||
id(source_leaf) in self._unmatched_source_nodes
|
|
||||||
and id(target_leaf) in self._unmatched_target_nodes
|
|
||||||
):
|
|
||||||
matching_set.add((id(source_leaf), id(target_leaf)))
|
matching_set.add((id(source_leaf), id(target_leaf)))
|
||||||
self._unmatched_source_nodes.remove(id(source_leaf))
|
self._unmatched_source_nodes.remove(id(source_leaf))
|
||||||
self._unmatched_target_nodes.remove(id(target_leaf))
|
self._unmatched_target_nodes.remove(id(target_leaf))
|
||||||
|
|
|
@ -3,11 +3,17 @@ import time
|
||||||
|
|
||||||
from sqlglot import parse_one
|
from sqlglot import parse_one
|
||||||
from sqlglot.executor.python import PythonExecutor
|
from sqlglot.executor.python import PythonExecutor
|
||||||
from sqlglot.optimizer import optimize
|
from sqlglot.optimizer import RULES, optimize
|
||||||
|
from sqlglot.optimizer.merge_derived_tables import merge_derived_tables
|
||||||
from sqlglot.planner import Plan
|
from sqlglot.planner import Plan
|
||||||
|
|
||||||
logger = logging.getLogger("sqlglot")
|
logger = logging.getLogger("sqlglot")
|
||||||
|
|
||||||
|
OPTIMIZER_RULES = list(RULES)
|
||||||
|
|
||||||
|
# The executor needs isolated table selects
|
||||||
|
OPTIMIZER_RULES.remove(merge_derived_tables)
|
||||||
|
|
||||||
|
|
||||||
def execute(sql, schema, read=None):
|
def execute(sql, schema, read=None):
|
||||||
"""
|
"""
|
||||||
|
@ -28,7 +34,7 @@ def execute(sql, schema, read=None):
|
||||||
"""
|
"""
|
||||||
expression = parse_one(sql, read=read)
|
expression = parse_one(sql, read=read)
|
||||||
now = time.time()
|
now = time.time()
|
||||||
expression = optimize(expression, schema)
|
expression = optimize(expression, schema, rules=OPTIMIZER_RULES)
|
||||||
logger.debug("Optimization finished: %f", time.time() - now)
|
logger.debug("Optimization finished: %f", time.time() - now)
|
||||||
logger.debug("Optimized SQL: %s", expression.sql(pretty=True))
|
logger.debug("Optimized SQL: %s", expression.sql(pretty=True))
|
||||||
plan = Plan(expression)
|
plan = Plan(expression)
|
||||||
|
|
|
@ -19,9 +19,7 @@ class Context:
|
||||||
env (Optional[dict]): dictionary of functions within the execution context
|
env (Optional[dict]): dictionary of functions within the execution context
|
||||||
"""
|
"""
|
||||||
self.tables = tables
|
self.tables = tables
|
||||||
self.range_readers = {
|
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
|
||||||
name: table.range_reader for name, table in self.tables.items()
|
|
||||||
}
|
|
||||||
self.row_readers = {name: table.reader for name, table in tables.items()}
|
self.row_readers = {name: table.reader for name, table in tables.items()}
|
||||||
self.env = {**(env or {}), "scope": self.row_readers}
|
self.env = {**(env or {}), "scope": self.row_readers}
|
||||||
|
|
||||||
|
|
|
@ -26,11 +26,7 @@ class PythonExecutor:
|
||||||
while queue:
|
while queue:
|
||||||
node = queue.pop()
|
node = queue.pop()
|
||||||
context = self.context(
|
context = self.context(
|
||||||
{
|
{name: table for dep in node.dependencies for name, table in contexts[dep].tables.items()}
|
||||||
name: table
|
|
||||||
for dep in node.dependencies
|
|
||||||
for name, table in contexts[dep].tables.items()
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
running.add(node)
|
running.add(node)
|
||||||
|
|
||||||
|
@ -151,9 +147,7 @@ class PythonExecutor:
|
||||||
return self.context({name: table for name in ctx.tables})
|
return self.context({name: table for name in ctx.tables})
|
||||||
|
|
||||||
for name, join in step.joins.items():
|
for name, join in step.joins.items():
|
||||||
join_context = self.context(
|
join_context = self.context({**join_context.tables, name: context.tables[name]})
|
||||||
{**join_context.tables, name: context.tables[name]}
|
|
||||||
)
|
|
||||||
|
|
||||||
if join.get("source_key"):
|
if join.get("source_key"):
|
||||||
table = self.hash_join(join, source, name, join_context)
|
table = self.hash_join(join, source, name, join_context)
|
||||||
|
@ -247,9 +241,7 @@ class PythonExecutor:
|
||||||
|
|
||||||
if step.operands:
|
if step.operands:
|
||||||
source_table = context.tables[source]
|
source_table = context.tables[source]
|
||||||
operand_table = Table(
|
operand_table = Table(source_table.columns + self.table(step.operands).columns)
|
||||||
source_table.columns + self.table(step.operands).columns
|
|
||||||
)
|
|
||||||
|
|
||||||
for reader, ctx in context:
|
for reader, ctx in context:
|
||||||
operand_table.append(reader.row + ctx.eval_tuple(operands))
|
operand_table.append(reader.row + ctx.eval_tuple(operands))
|
||||||
|
|
|
@ -37,10 +37,7 @@ class Table:
|
||||||
break
|
break
|
||||||
|
|
||||||
lines.append(
|
lines.append(
|
||||||
" ".join(
|
" ".join(str(row[column]).rjust(widths[column])[0 : widths[column]] for column in self.columns)
|
||||||
str(row[column]).rjust(widths[column])[0 : widths[column]]
|
|
||||||
for column in self.columns
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
|
@ -47,10 +47,7 @@ class Expression(metaclass=_Expression):
|
||||||
return hash(
|
return hash(
|
||||||
(
|
(
|
||||||
self.key,
|
self.key,
|
||||||
tuple(
|
tuple((k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()),
|
||||||
(k, tuple(v) if isinstance(v, list) else v)
|
|
||||||
for k, v in _norm_args(self).items()
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -116,9 +113,22 @@ class Expression(metaclass=_Expression):
|
||||||
item.parent = parent
|
item.parent = parent
|
||||||
return new
|
return new
|
||||||
|
|
||||||
|
def append(self, arg_key, value):
|
||||||
|
"""
|
||||||
|
Appends value to arg_key if it's a list or sets it as a new list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arg_key (str): name of the list expression arg
|
||||||
|
value (Any): value to append to the list
|
||||||
|
"""
|
||||||
|
if not isinstance(self.args.get(arg_key), list):
|
||||||
|
self.args[arg_key] = []
|
||||||
|
self.args[arg_key].append(value)
|
||||||
|
self._set_parent(arg_key, value)
|
||||||
|
|
||||||
def set(self, arg_key, value):
|
def set(self, arg_key, value):
|
||||||
"""
|
"""
|
||||||
Sets `arg` to `value`.
|
Sets `arg_key` to `value`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
arg_key (str): name of the expression arg
|
arg_key (str): name of the expression arg
|
||||||
|
@ -267,6 +277,14 @@ class Expression(metaclass=_Expression):
|
||||||
expression = expression.this
|
expression = expression.this
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
|
def unalias(self):
|
||||||
|
"""
|
||||||
|
Returns the inner expression if this is an Alias.
|
||||||
|
"""
|
||||||
|
if isinstance(self, Alias):
|
||||||
|
return self.this
|
||||||
|
return self
|
||||||
|
|
||||||
def unnest_operands(self):
|
def unnest_operands(self):
|
||||||
"""
|
"""
|
||||||
Returns unnested operands as a tuple.
|
Returns unnested operands as a tuple.
|
||||||
|
@ -279,9 +297,7 @@ class Expression(metaclass=_Expression):
|
||||||
|
|
||||||
A AND B AND C -> [A, B, C]
|
A AND B AND C -> [A, B, C]
|
||||||
"""
|
"""
|
||||||
for node, _, _ in self.dfs(
|
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not isinstance(n, self.__class__)):
|
||||||
prune=lambda n, p, *_: p and not isinstance(n, self.__class__)
|
|
||||||
):
|
|
||||||
if not isinstance(node, self.__class__):
|
if not isinstance(node, self.__class__):
|
||||||
yield node.unnest() if unnest else node
|
yield node.unnest() if unnest else node
|
||||||
|
|
||||||
|
@ -314,9 +330,7 @@ class Expression(metaclass=_Expression):
|
||||||
|
|
||||||
args = {
|
args = {
|
||||||
k: ", ".join(
|
k: ", ".join(
|
||||||
v.to_s(hide_missing=hide_missing, level=level + 1)
|
v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
|
||||||
if hasattr(v, "to_s")
|
|
||||||
else str(v)
|
|
||||||
for v in ensure_list(vs)
|
for v in ensure_list(vs)
|
||||||
if v is not None
|
if v is not None
|
||||||
)
|
)
|
||||||
|
@ -354,9 +368,7 @@ class Expression(metaclass=_Expression):
|
||||||
new_node.parent = node.parent
|
new_node.parent = node.parent
|
||||||
return new_node
|
return new_node
|
||||||
|
|
||||||
replace_children(
|
replace_children(new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs))
|
||||||
new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs)
|
|
||||||
)
|
|
||||||
return new_node
|
return new_node
|
||||||
|
|
||||||
def replace(self, expression):
|
def replace(self, expression):
|
||||||
|
@ -546,6 +558,10 @@ class BitString(Condition):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HexString(Condition):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Column(Condition):
|
class Column(Condition):
|
||||||
arg_types = {"this": True, "table": False}
|
arg_types = {"this": True, "table": False}
|
||||||
|
|
||||||
|
@ -566,35 +582,44 @@ class ColumnConstraint(Expression):
|
||||||
arg_types = {"this": False, "kind": True}
|
arg_types = {"this": False, "kind": True}
|
||||||
|
|
||||||
|
|
||||||
class AutoIncrementColumnConstraint(Expression):
|
class ColumnConstraintKind(Expression):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class CheckColumnConstraint(Expression):
|
class AutoIncrementColumnConstraint(ColumnConstraintKind):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class CollateColumnConstraint(Expression):
|
class CheckColumnConstraint(ColumnConstraintKind):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class CommentColumnConstraint(Expression):
|
class CollateColumnConstraint(ColumnConstraintKind):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DefaultColumnConstraint(Expression):
|
class CommentColumnConstraint(ColumnConstraintKind):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class NotNullColumnConstraint(Expression):
|
class DefaultColumnConstraint(ColumnConstraintKind):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PrimaryKeyColumnConstraint(Expression):
|
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
|
||||||
|
# this: True -> ALWAYS, this: False -> BY DEFAULT
|
||||||
|
arg_types = {"this": True, "expression": False}
|
||||||
|
|
||||||
|
|
||||||
|
class NotNullColumnConstraint(ColumnConstraintKind):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UniqueColumnConstraint(Expression):
|
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UniqueColumnConstraint(ColumnConstraintKind):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -651,9 +676,7 @@ class Identifier(Expression):
|
||||||
return bool(self.args.get("quoted"))
|
return bool(self.args.get("quoted"))
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(
|
return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(other.this)
|
||||||
other.this
|
|
||||||
)
|
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash((self.key, self.this.lower()))
|
return hash((self.key, self.this.lower()))
|
||||||
|
@ -709,9 +732,7 @@ class Literal(Condition):
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return (
|
return (
|
||||||
isinstance(other, Literal)
|
isinstance(other, Literal) and self.this == other.this and self.args["is_string"] == other.args["is_string"]
|
||||||
and self.this == other.this
|
|
||||||
and self.args["is_string"] == other.args["is_string"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
|
@ -733,6 +754,7 @@ class Join(Expression):
|
||||||
"side": False,
|
"side": False,
|
||||||
"kind": False,
|
"kind": False,
|
||||||
"using": False,
|
"using": False,
|
||||||
|
"natural": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -743,6 +765,10 @@ class Join(Expression):
|
||||||
def side(self):
|
def side(self):
|
||||||
return self.text("side").upper()
|
return self.text("side").upper()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def alias_or_name(self):
|
||||||
|
return self.this.alias_or_name
|
||||||
|
|
||||||
def on(self, *expressions, append=True, dialect=None, copy=True, **opts):
|
def on(self, *expressions, append=True, dialect=None, copy=True, **opts):
|
||||||
"""
|
"""
|
||||||
Append to or set the ON expressions.
|
Append to or set the ON expressions.
|
||||||
|
@ -873,10 +899,6 @@ class Reference(Expression):
|
||||||
arg_types = {"this": True, "expressions": True}
|
arg_types = {"this": True, "expressions": True}
|
||||||
|
|
||||||
|
|
||||||
class Table(Expression):
|
|
||||||
arg_types = {"this": True, "db": False, "catalog": False}
|
|
||||||
|
|
||||||
|
|
||||||
class Tuple(Expression):
|
class Tuple(Expression):
|
||||||
arg_types = {"expressions": False}
|
arg_types = {"expressions": False}
|
||||||
|
|
||||||
|
@ -986,6 +1008,16 @@ QUERY_MODIFIERS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Table(Expression):
|
||||||
|
arg_types = {
|
||||||
|
"this": True,
|
||||||
|
"db": False,
|
||||||
|
"catalog": False,
|
||||||
|
"laterals": False,
|
||||||
|
"joins": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Union(Subqueryable, Expression):
|
class Union(Subqueryable, Expression):
|
||||||
arg_types = {
|
arg_types = {
|
||||||
"with": False,
|
"with": False,
|
||||||
|
@ -1396,7 +1428,9 @@ class Select(Subqueryable, Expression):
|
||||||
join.this.replace(join.this.subquery())
|
join.this.replace(join.this.subquery())
|
||||||
|
|
||||||
if join_type:
|
if join_type:
|
||||||
side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args)
|
natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args)
|
||||||
|
if natural:
|
||||||
|
join.set("natural", True)
|
||||||
if side:
|
if side:
|
||||||
join.set("side", side.text)
|
join.set("side", side.text)
|
||||||
if kind:
|
if kind:
|
||||||
|
@ -1529,10 +1563,7 @@ class Select(Subqueryable, Expression):
|
||||||
properties_expression = None
|
properties_expression = None
|
||||||
if properties:
|
if properties:
|
||||||
properties_str = " ".join(
|
properties_str = " ".join(
|
||||||
[
|
[f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in properties.items()]
|
||||||
f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}"
|
|
||||||
for k, v in properties.items()
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
properties_expression = maybe_parse(
|
properties_expression = maybe_parse(
|
||||||
properties_str,
|
properties_str,
|
||||||
|
@ -1654,6 +1685,7 @@ class DataType(Expression):
|
||||||
DECIMAL = auto()
|
DECIMAL = auto()
|
||||||
BOOLEAN = auto()
|
BOOLEAN = auto()
|
||||||
JSON = auto()
|
JSON = auto()
|
||||||
|
INTERVAL = auto()
|
||||||
TIMESTAMP = auto()
|
TIMESTAMP = auto()
|
||||||
TIMESTAMPTZ = auto()
|
TIMESTAMPTZ = auto()
|
||||||
DATE = auto()
|
DATE = auto()
|
||||||
|
@ -1662,15 +1694,19 @@ class DataType(Expression):
|
||||||
MAP = auto()
|
MAP = auto()
|
||||||
UUID = auto()
|
UUID = auto()
|
||||||
GEOGRAPHY = auto()
|
GEOGRAPHY = auto()
|
||||||
|
GEOMETRY = auto()
|
||||||
STRUCT = auto()
|
STRUCT = auto()
|
||||||
NULLABLE = auto()
|
NULLABLE = auto()
|
||||||
|
HLLSKETCH = auto()
|
||||||
|
SUPER = auto()
|
||||||
|
SERIAL = auto()
|
||||||
|
SMALLSERIAL = auto()
|
||||||
|
BIGSERIAL = auto()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(cls, dtype, **kwargs):
|
def build(cls, dtype, **kwargs):
|
||||||
return DataType(
|
return DataType(
|
||||||
this=dtype
|
this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()],
|
||||||
if isinstance(dtype, DataType.Type)
|
|
||||||
else DataType.Type[dtype.upper()],
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1798,6 +1834,14 @@ class Like(Binary, Predicate):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SimilarTo(Binary, Predicate):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Distance(Binary):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LT(Binary, Predicate):
|
class LT(Binary, Predicate):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -1899,6 +1943,10 @@ class IgnoreNulls(Expression):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RespectNulls(Expression):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Functions
|
# Functions
|
||||||
class Func(Condition):
|
class Func(Condition):
|
||||||
"""
|
"""
|
||||||
|
@ -1924,9 +1972,7 @@ class Func(Condition):
|
||||||
|
|
||||||
all_arg_keys = list(cls.arg_types)
|
all_arg_keys = list(cls.arg_types)
|
||||||
# If this function supports variable length argument treat the last argument as such.
|
# If this function supports variable length argument treat the last argument as such.
|
||||||
non_var_len_arg_keys = (
|
non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
|
||||||
all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
|
|
||||||
)
|
|
||||||
|
|
||||||
args_dict = {}
|
args_dict = {}
|
||||||
arg_idx = 0
|
arg_idx = 0
|
||||||
|
@ -1944,9 +1990,7 @@ class Func(Condition):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sql_names(cls):
|
def sql_names(cls):
|
||||||
if cls is Func:
|
if cls is Func:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("SQL name is only supported by concrete function implementations")
|
||||||
"SQL name is only supported by concrete function implementations"
|
|
||||||
)
|
|
||||||
if not hasattr(cls, "_sql_names"):
|
if not hasattr(cls, "_sql_names"):
|
||||||
cls._sql_names = [camel_to_snake_case(cls.__name__)]
|
cls._sql_names = [camel_to_snake_case(cls.__name__)]
|
||||||
return cls._sql_names
|
return cls._sql_names
|
||||||
|
@ -2178,6 +2222,10 @@ class Greatest(Func):
|
||||||
is_var_len_args = True
|
is_var_len_args = True
|
||||||
|
|
||||||
|
|
||||||
|
class GroupConcat(Func):
|
||||||
|
arg_types = {"this": True, "separator": False}
|
||||||
|
|
||||||
|
|
||||||
class If(Func):
|
class If(Func):
|
||||||
arg_types = {"this": True, "true": True, "false": False}
|
arg_types = {"this": True, "true": True, "false": False}
|
||||||
|
|
||||||
|
@ -2274,6 +2322,10 @@ class Quantile(AggFunc):
|
||||||
arg_types = {"this": True, "quantile": True}
|
arg_types = {"this": True, "quantile": True}
|
||||||
|
|
||||||
|
|
||||||
|
class ApproxQuantile(Quantile):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Reduce(Func):
|
class Reduce(Func):
|
||||||
arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
|
arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
|
||||||
|
|
||||||
|
@ -2306,8 +2358,10 @@ class Split(Func):
|
||||||
arg_types = {"this": True, "expression": True}
|
arg_types = {"this": True, "expression": True}
|
||||||
|
|
||||||
|
|
||||||
|
# Start may be omitted in the case of postgres
|
||||||
|
# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
|
||||||
class Substring(Func):
|
class Substring(Func):
|
||||||
arg_types = {"this": True, "start": True, "length": False}
|
arg_types = {"this": True, "start": False, "length": False}
|
||||||
|
|
||||||
|
|
||||||
class StrPosition(Func):
|
class StrPosition(Func):
|
||||||
|
@ -2379,6 +2433,15 @@ class TimeStrToUnix(Func):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Trim(Func):
|
||||||
|
arg_types = {
|
||||||
|
"this": True,
|
||||||
|
"position": False,
|
||||||
|
"expression": False,
|
||||||
|
"collation": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class TsOrDsAdd(Func, TimeUnit):
|
class TsOrDsAdd(Func, TimeUnit):
|
||||||
arg_types = {"this": True, "expression": True, "unit": False}
|
arg_types = {"this": True, "expression": True, "unit": False}
|
||||||
|
|
||||||
|
@ -2455,9 +2518,7 @@ def _all_functions():
|
||||||
obj
|
obj
|
||||||
for _, obj in inspect.getmembers(
|
for _, obj in inspect.getmembers(
|
||||||
sys.modules[__name__],
|
sys.modules[__name__],
|
||||||
lambda obj: inspect.isclass(obj)
|
lambda obj: inspect.isclass(obj) and issubclass(obj, Func) and obj not in (AggFunc, Anonymous, Func),
|
||||||
and issubclass(obj, Func)
|
|
||||||
and obj not in (AggFunc, Anonymous, Func),
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -2633,9 +2694,7 @@ def _apply_conjunction_builder(
|
||||||
|
|
||||||
|
|
||||||
def _combine(expressions, operator, dialect=None, **opts):
|
def _combine(expressions, operator, dialect=None, **opts):
|
||||||
expressions = [
|
expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions]
|
||||||
condition(expression, dialect=dialect, **opts) for expression in expressions
|
|
||||||
]
|
|
||||||
this = expressions[0]
|
this = expressions[0]
|
||||||
if expressions[1:]:
|
if expressions[1:]:
|
||||||
this = _wrap_operator(this)
|
this = _wrap_operator(this)
|
||||||
|
@ -2809,9 +2868,7 @@ def to_identifier(alias, quoted=None):
|
||||||
quoted = not re.match(SAFE_IDENTIFIER_RE, alias)
|
quoted = not re.match(SAFE_IDENTIFIER_RE, alias)
|
||||||
identifier = Identifier(this=alias, quoted=quoted)
|
identifier = Identifier(this=alias, quoted=quoted)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Alias needs to be a string or an Identifier, got: {alias.__class__}")
|
||||||
f"Alias needs to be a string or an Identifier, got: {alias.__class__}"
|
|
||||||
)
|
|
||||||
return identifier
|
return identifier
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -41,6 +41,8 @@ class Generator:
|
||||||
max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError.
|
max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError.
|
||||||
This is only relevant if unsupported_level is ErrorLevel.RAISE.
|
This is only relevant if unsupported_level is ErrorLevel.RAISE.
|
||||||
Default: 3
|
Default: 3
|
||||||
|
leading_comma (bool): if the the comma is leading or trailing in select statements
|
||||||
|
Default: False
|
||||||
"""
|
"""
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
|
@ -108,6 +110,7 @@ class Generator:
|
||||||
"_indent",
|
"_indent",
|
||||||
"_replace_backslash",
|
"_replace_backslash",
|
||||||
"_escaped_quote_end",
|
"_escaped_quote_end",
|
||||||
|
"_leading_comma",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -131,6 +134,7 @@ class Generator:
|
||||||
unsupported_level=ErrorLevel.WARN,
|
unsupported_level=ErrorLevel.WARN,
|
||||||
null_ordering=None,
|
null_ordering=None,
|
||||||
max_unsupported=3,
|
max_unsupported=3,
|
||||||
|
leading_comma=False,
|
||||||
):
|
):
|
||||||
import sqlglot
|
import sqlglot
|
||||||
|
|
||||||
|
@ -157,6 +161,7 @@ class Generator:
|
||||||
self._indent = indent
|
self._indent = indent
|
||||||
self._replace_backslash = self.escape == "\\"
|
self._replace_backslash = self.escape == "\\"
|
||||||
self._escaped_quote_end = self.escape + self.quote_end
|
self._escaped_quote_end = self.escape + self.quote_end
|
||||||
|
self._leading_comma = leading_comma
|
||||||
|
|
||||||
def generate(self, expression):
|
def generate(self, expression):
|
||||||
"""
|
"""
|
||||||
|
@ -178,9 +183,7 @@ class Generator:
|
||||||
for msg in self.unsupported_messages:
|
for msg in self.unsupported_messages:
|
||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
|
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
|
||||||
raise UnsupportedError(
|
raise UnsupportedError(concat_errors(self.unsupported_messages, self.max_unsupported))
|
||||||
concat_errors(self.unsupported_messages, self.max_unsupported)
|
|
||||||
)
|
|
||||||
|
|
||||||
return sql
|
return sql
|
||||||
|
|
||||||
|
@ -197,9 +200,7 @@ class Generator:
|
||||||
|
|
||||||
def wrap(self, expression):
|
def wrap(self, expression):
|
||||||
this_sql = self.indent(
|
this_sql = self.indent(
|
||||||
self.sql(expression)
|
self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) else self.sql(expression, "this"),
|
||||||
if isinstance(expression, (exp.Select, exp.Union))
|
|
||||||
else self.sql(expression, "this"),
|
|
||||||
level=1,
|
level=1,
|
||||||
pad=0,
|
pad=0,
|
||||||
)
|
)
|
||||||
|
@ -251,9 +252,7 @@ class Generator:
|
||||||
return transform
|
return transform
|
||||||
|
|
||||||
if not isinstance(expression, exp.Expression):
|
if not isinstance(expression, exp.Expression):
|
||||||
raise ValueError(
|
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
|
||||||
f"Expected an Expression. Received {type(expression)}: {expression}"
|
|
||||||
)
|
|
||||||
|
|
||||||
exp_handler_name = f"{expression.key}_sql"
|
exp_handler_name = f"{expression.key}_sql"
|
||||||
if hasattr(self, exp_handler_name):
|
if hasattr(self, exp_handler_name):
|
||||||
|
@ -276,11 +275,7 @@ class Generator:
|
||||||
lazy = " LAZY" if expression.args.get("lazy") else ""
|
lazy = " LAZY" if expression.args.get("lazy") else ""
|
||||||
table = self.sql(expression, "this")
|
table = self.sql(expression, "this")
|
||||||
options = expression.args.get("options")
|
options = expression.args.get("options")
|
||||||
options = (
|
options = f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})" if options else ""
|
||||||
f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})"
|
|
||||||
if options
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
sql = self.sql(expression, "expression")
|
sql = self.sql(expression, "expression")
|
||||||
sql = f" AS{self.sep()}{sql}" if sql else ""
|
sql = f" AS{self.sep()}{sql}" if sql else ""
|
||||||
sql = f"CACHE{lazy} TABLE {table}{options}{sql}"
|
sql = f"CACHE{lazy} TABLE {table}{options}{sql}"
|
||||||
|
@ -306,9 +301,7 @@ class Generator:
|
||||||
def columndef_sql(self, expression):
|
def columndef_sql(self, expression):
|
||||||
column = self.sql(expression, "this")
|
column = self.sql(expression, "this")
|
||||||
kind = self.sql(expression, "kind")
|
kind = self.sql(expression, "kind")
|
||||||
constraints = self.expressions(
|
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
|
||||||
expression, key="constraints", sep=" ", flat=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if not constraints:
|
if not constraints:
|
||||||
return f"{column} {kind}"
|
return f"{column} {kind}"
|
||||||
|
@ -338,6 +331,9 @@ class Generator:
|
||||||
default = self.sql(expression, "this")
|
default = self.sql(expression, "this")
|
||||||
return f"DEFAULT {default}"
|
return f"DEFAULT {default}"
|
||||||
|
|
||||||
|
def generatedasidentitycolumnconstraint_sql(self, expression):
|
||||||
|
return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY"
|
||||||
|
|
||||||
def notnullcolumnconstraint_sql(self, _):
|
def notnullcolumnconstraint_sql(self, _):
|
||||||
return "NOT NULL"
|
return "NOT NULL"
|
||||||
|
|
||||||
|
@ -384,7 +380,10 @@ class Generator:
|
||||||
return f"{alias}{columns}"
|
return f"{alias}{columns}"
|
||||||
|
|
||||||
def bitstring_sql(self, expression):
|
def bitstring_sql(self, expression):
|
||||||
return f"b'{self.sql(expression, 'this')}'"
|
return self.sql(expression, "this")
|
||||||
|
|
||||||
|
def hexstring_sql(self, expression):
|
||||||
|
return self.sql(expression, "this")
|
||||||
|
|
||||||
def datatype_sql(self, expression):
|
def datatype_sql(self, expression):
|
||||||
type_value = expression.this
|
type_value = expression.this
|
||||||
|
@ -452,10 +451,7 @@ class Generator:
|
||||||
|
|
||||||
def partition_sql(self, expression):
|
def partition_sql(self, expression):
|
||||||
keys = csv(
|
keys = csv(
|
||||||
*[
|
*[f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] for k, v in expression.args.get("this")]
|
||||||
f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"]
|
|
||||||
for k, v in expression.args.get("this")
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
return f"PARTITION({keys})"
|
return f"PARTITION({keys})"
|
||||||
|
|
||||||
|
@ -470,9 +466,9 @@ class Generator:
|
||||||
elif p_class in self.WITH_PROPERTIES:
|
elif p_class in self.WITH_PROPERTIES:
|
||||||
with_properties.append(p)
|
with_properties.append(p)
|
||||||
|
|
||||||
return self.root_properties(
|
return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties(
|
||||||
exp.Properties(expressions=root_properties)
|
exp.Properties(expressions=with_properties)
|
||||||
) + self.with_properties(exp.Properties(expressions=with_properties))
|
)
|
||||||
|
|
||||||
def root_properties(self, properties):
|
def root_properties(self, properties):
|
||||||
if properties.expressions:
|
if properties.expressions:
|
||||||
|
@ -508,11 +504,7 @@ class Generator:
|
||||||
kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO"
|
kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO"
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
exists = " IF EXISTS " if expression.args.get("exists") else " "
|
exists = " IF EXISTS " if expression.args.get("exists") else " "
|
||||||
partition_sql = (
|
partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else ""
|
||||||
self.sql(expression, "partition")
|
|
||||||
if expression.args.get("partition")
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
expression_sql = self.sql(expression, "expression")
|
expression_sql = self.sql(expression, "expression")
|
||||||
sep = self.sep() if partition_sql else ""
|
sep = self.sep() if partition_sql else ""
|
||||||
sql = f"INSERT {kind} {this}{exists}{partition_sql}{sep}{expression_sql}"
|
sql = f"INSERT {kind} {this}{exists}{partition_sql}{sep}{expression_sql}"
|
||||||
|
@ -531,7 +523,7 @@ class Generator:
|
||||||
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
|
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
|
||||||
|
|
||||||
def table_sql(self, expression):
|
def table_sql(self, expression):
|
||||||
return ".".join(
|
table = ".".join(
|
||||||
part
|
part
|
||||||
for part in [
|
for part in [
|
||||||
self.sql(expression, "catalog"),
|
self.sql(expression, "catalog"),
|
||||||
|
@ -541,6 +533,10 @@ class Generator:
|
||||||
if part
|
if part
|
||||||
)
|
)
|
||||||
|
|
||||||
|
laterals = self.expressions(expression, key="laterals", sep="")
|
||||||
|
joins = self.expressions(expression, key="joins", sep="")
|
||||||
|
return f"{table}{laterals}{joins}"
|
||||||
|
|
||||||
def tablesample_sql(self, expression):
|
def tablesample_sql(self, expression):
|
||||||
if self.alias_post_tablesample and isinstance(expression.this, exp.Alias):
|
if self.alias_post_tablesample and isinstance(expression.this, exp.Alias):
|
||||||
this = self.sql(expression.this, "this")
|
this = self.sql(expression.this, "this")
|
||||||
|
@ -586,11 +582,7 @@ class Generator:
|
||||||
def group_sql(self, expression):
|
def group_sql(self, expression):
|
||||||
group_by = self.op_expressions("GROUP BY", expression)
|
group_by = self.op_expressions("GROUP BY", expression)
|
||||||
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
|
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
|
||||||
grouping_sets = (
|
grouping_sets = f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
|
||||||
f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}"
|
|
||||||
if grouping_sets
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
cube = self.expressions(expression, key="cube", indent=False)
|
cube = self.expressions(expression, key="cube", indent=False)
|
||||||
cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
|
cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
|
||||||
rollup = self.expressions(expression, key="rollup", indent=False)
|
rollup = self.expressions(expression, key="rollup", indent=False)
|
||||||
|
@ -603,7 +595,16 @@ class Generator:
|
||||||
|
|
||||||
def join_sql(self, expression):
|
def join_sql(self, expression):
|
||||||
op_sql = self.seg(
|
op_sql = self.seg(
|
||||||
" ".join(op for op in (expression.side, expression.kind, "JOIN") if op)
|
" ".join(
|
||||||
|
op
|
||||||
|
for op in (
|
||||||
|
"NATURAL" if expression.args.get("natural") else None,
|
||||||
|
expression.side,
|
||||||
|
expression.kind,
|
||||||
|
"JOIN",
|
||||||
|
)
|
||||||
|
if op
|
||||||
|
)
|
||||||
)
|
)
|
||||||
on_sql = self.sql(expression, "on")
|
on_sql = self.sql(expression, "on")
|
||||||
using = expression.args.get("using")
|
using = expression.args.get("using")
|
||||||
|
@ -630,9 +631,9 @@ class Generator:
|
||||||
|
|
||||||
def lateral_sql(self, expression):
|
def lateral_sql(self, expression):
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
op_sql = self.seg(
|
if isinstance(expression.this, exp.Subquery):
|
||||||
f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}"
|
return f"LATERAL{self.sep()}{this}"
|
||||||
)
|
op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}")
|
||||||
alias = expression.args["alias"]
|
alias = expression.args["alias"]
|
||||||
table = alias.name
|
table = alias.name
|
||||||
table = f" {table}" if table else table
|
table = f" {table}" if table else table
|
||||||
|
@ -688,21 +689,13 @@ class Generator:
|
||||||
|
|
||||||
sort_order = " DESC" if desc else ""
|
sort_order = " DESC" if desc else ""
|
||||||
nulls_sort_change = ""
|
nulls_sort_change = ""
|
||||||
if nulls_first and (
|
if nulls_first and ((asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last):
|
||||||
(asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last
|
|
||||||
):
|
|
||||||
nulls_sort_change = " NULLS FIRST"
|
nulls_sort_change = " NULLS FIRST"
|
||||||
elif (
|
elif nulls_last and ((asc and nulls_are_small) or (desc and nulls_are_large)) and not nulls_are_last:
|
||||||
nulls_last
|
|
||||||
and ((asc and nulls_are_small) or (desc and nulls_are_large))
|
|
||||||
and not nulls_are_last
|
|
||||||
):
|
|
||||||
nulls_sort_change = " NULLS LAST"
|
nulls_sort_change = " NULLS LAST"
|
||||||
|
|
||||||
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
|
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
|
||||||
self.unsupported(
|
self.unsupported("Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect")
|
||||||
"Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect"
|
|
||||||
)
|
|
||||||
nulls_sort_change = ""
|
nulls_sort_change = ""
|
||||||
|
|
||||||
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
|
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
|
||||||
|
@ -798,14 +791,20 @@ class Generator:
|
||||||
|
|
||||||
def window_sql(self, expression):
|
def window_sql(self, expression):
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
|
|
||||||
partition = self.expressions(expression, key="partition_by", flat=True)
|
partition = self.expressions(expression, key="partition_by", flat=True)
|
||||||
partition = f"PARTITION BY {partition}" if partition else ""
|
partition = f"PARTITION BY {partition}" if partition else ""
|
||||||
|
|
||||||
order = expression.args.get("order")
|
order = expression.args.get("order")
|
||||||
order_sql = self.order_sql(order, flat=True) if order else ""
|
order_sql = self.order_sql(order, flat=True) if order else ""
|
||||||
|
|
||||||
partition_sql = partition + " " if partition and order else partition
|
partition_sql = partition + " " if partition and order else partition
|
||||||
|
|
||||||
spec = expression.args.get("spec")
|
spec = expression.args.get("spec")
|
||||||
spec_sql = " " + self.window_spec_sql(spec) if spec else ""
|
spec_sql = " " + self.window_spec_sql(spec) if spec else ""
|
||||||
|
|
||||||
alias = self.sql(expression, "alias")
|
alias = self.sql(expression, "alias")
|
||||||
|
|
||||||
if expression.arg_key == "window":
|
if expression.arg_key == "window":
|
||||||
this = this = f"{self.seg('WINDOW')} {this} AS"
|
this = this = f"{self.seg('WINDOW')} {this} AS"
|
||||||
else:
|
else:
|
||||||
|
@ -818,13 +817,8 @@ class Generator:
|
||||||
|
|
||||||
def window_spec_sql(self, expression):
|
def window_spec_sql(self, expression):
|
||||||
kind = self.sql(expression, "kind")
|
kind = self.sql(expression, "kind")
|
||||||
start = csv(
|
start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
|
||||||
self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" "
|
end = csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") or "CURRENT ROW"
|
||||||
)
|
|
||||||
end = (
|
|
||||||
csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ")
|
|
||||||
or "CURRENT ROW"
|
|
||||||
)
|
|
||||||
return f"{kind} BETWEEN {start} AND {end}"
|
return f"{kind} BETWEEN {start} AND {end}"
|
||||||
|
|
||||||
def withingroup_sql(self, expression):
|
def withingroup_sql(self, expression):
|
||||||
|
@ -879,6 +873,17 @@ class Generator:
|
||||||
expression_sql = self.sql(expression, "expression")
|
expression_sql = self.sql(expression, "expression")
|
||||||
return f"EXTRACT({this} FROM {expression_sql})"
|
return f"EXTRACT({this} FROM {expression_sql})"
|
||||||
|
|
||||||
|
def trim_sql(self, expression):
|
||||||
|
target = self.sql(expression, "this")
|
||||||
|
trim_type = self.sql(expression, "position")
|
||||||
|
|
||||||
|
if trim_type == "LEADING":
|
||||||
|
return f"LTRIM({target})"
|
||||||
|
elif trim_type == "TRAILING":
|
||||||
|
return f"RTRIM({target})"
|
||||||
|
else:
|
||||||
|
return f"TRIM({target})"
|
||||||
|
|
||||||
def check_sql(self, expression):
|
def check_sql(self, expression):
|
||||||
this = self.sql(expression, key="this")
|
this = self.sql(expression, key="this")
|
||||||
return f"CHECK ({this})"
|
return f"CHECK ({this})"
|
||||||
|
@ -898,9 +903,7 @@ class Generator:
|
||||||
return f"UNIQUE ({columns})"
|
return f"UNIQUE ({columns})"
|
||||||
|
|
||||||
def if_sql(self, expression):
|
def if_sql(self, expression):
|
||||||
return self.case_sql(
|
return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
|
||||||
exp.Case(ifs=[expression], default=expression.args.get("false"))
|
|
||||||
)
|
|
||||||
|
|
||||||
def in_sql(self, expression):
|
def in_sql(self, expression):
|
||||||
query = expression.args.get("query")
|
query = expression.args.get("query")
|
||||||
|
@ -917,7 +920,9 @@ class Generator:
|
||||||
return f"(SELECT {self.sql(unnest)})"
|
return f"(SELECT {self.sql(unnest)})"
|
||||||
|
|
||||||
def interval_sql(self, expression):
|
def interval_sql(self, expression):
|
||||||
return f"INTERVAL {self.sql(expression, 'this')} {self.sql(expression, 'unit')}"
|
unit = self.sql(expression, "unit")
|
||||||
|
unit = f" {unit}" if unit else ""
|
||||||
|
return f"INTERVAL {self.sql(expression, 'this')}{unit}"
|
||||||
|
|
||||||
def reference_sql(self, expression):
|
def reference_sql(self, expression):
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
|
@ -925,9 +930,7 @@ class Generator:
|
||||||
return f"REFERENCES {this}({expressions})"
|
return f"REFERENCES {this}({expressions})"
|
||||||
|
|
||||||
def anonymous_sql(self, expression):
|
def anonymous_sql(self, expression):
|
||||||
args = self.indent(
|
args = self.indent(self.expressions(expression, flat=True), skip_first=True, skip_last=True)
|
||||||
self.expressions(expression, flat=True), skip_first=True, skip_last=True
|
|
||||||
)
|
|
||||||
return f"{self.normalize_func(self.sql(expression, 'this'))}({args})"
|
return f"{self.normalize_func(self.sql(expression, 'this'))}({args})"
|
||||||
|
|
||||||
def paren_sql(self, expression):
|
def paren_sql(self, expression):
|
||||||
|
@ -1006,6 +1009,9 @@ class Generator:
|
||||||
def ignorenulls_sql(self, expression):
|
def ignorenulls_sql(self, expression):
|
||||||
return f"{self.sql(expression, 'this')} IGNORE NULLS"
|
return f"{self.sql(expression, 'this')} IGNORE NULLS"
|
||||||
|
|
||||||
|
def respectnulls_sql(self, expression):
|
||||||
|
return f"{self.sql(expression, 'this')} RESPECT NULLS"
|
||||||
|
|
||||||
def intdiv_sql(self, expression):
|
def intdiv_sql(self, expression):
|
||||||
return self.sql(
|
return self.sql(
|
||||||
exp.Cast(
|
exp.Cast(
|
||||||
|
@ -1023,6 +1029,9 @@ class Generator:
|
||||||
def div_sql(self, expression):
|
def div_sql(self, expression):
|
||||||
return self.binary(expression, "/")
|
return self.binary(expression, "/")
|
||||||
|
|
||||||
|
def distance_sql(self, expression):
|
||||||
|
return self.binary(expression, "<->")
|
||||||
|
|
||||||
def dot_sql(self, expression):
|
def dot_sql(self, expression):
|
||||||
return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}"
|
return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}"
|
||||||
|
|
||||||
|
@ -1047,6 +1056,9 @@ class Generator:
|
||||||
def like_sql(self, expression):
|
def like_sql(self, expression):
|
||||||
return self.binary(expression, "LIKE")
|
return self.binary(expression, "LIKE")
|
||||||
|
|
||||||
|
def similarto_sql(self, expression):
|
||||||
|
return self.binary(expression, "SIMILAR TO")
|
||||||
|
|
||||||
def lt_sql(self, expression):
|
def lt_sql(self, expression):
|
||||||
return self.binary(expression, "<")
|
return self.binary(expression, "<")
|
||||||
|
|
||||||
|
@ -1069,14 +1081,10 @@ class Generator:
|
||||||
return self.binary(expression, "-")
|
return self.binary(expression, "-")
|
||||||
|
|
||||||
def trycast_sql(self, expression):
|
def trycast_sql(self, expression):
|
||||||
return (
|
return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
|
||||||
f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
|
|
||||||
)
|
|
||||||
|
|
||||||
def binary(self, expression, op):
|
def binary(self, expression, op):
|
||||||
return (
|
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
|
||||||
f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def function_fallback_sql(self, expression):
|
def function_fallback_sql(self, expression):
|
||||||
args = []
|
args = []
|
||||||
|
@ -1089,9 +1097,7 @@ class Generator:
|
||||||
return f"{self.normalize_func(expression.sql_name())}({args_str})"
|
return f"{self.normalize_func(expression.sql_name())}({args_str})"
|
||||||
|
|
||||||
def format_time(self, expression):
|
def format_time(self, expression):
|
||||||
return format_time(
|
return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie)
|
||||||
self.sql(expression, "format"), self.time_mapping, self.time_trie
|
|
||||||
)
|
|
||||||
|
|
||||||
def expressions(self, expression, key=None, flat=False, indent=True, sep=", "):
|
def expressions(self, expression, key=None, flat=False, indent=True, sep=", "):
|
||||||
expressions = expression.args.get(key or "expressions")
|
expressions = expression.args.get(key or "expressions")
|
||||||
|
@ -1102,7 +1108,14 @@ class Generator:
|
||||||
if flat:
|
if flat:
|
||||||
return sep.join(self.sql(e) for e in expressions)
|
return sep.join(self.sql(e) for e in expressions)
|
||||||
|
|
||||||
expressions = self.sep(sep).join(self.sql(e) for e in expressions)
|
sql = (self.sql(e) for e in expressions)
|
||||||
|
# the only time leading_comma changes the output is if pretty print is enabled
|
||||||
|
if self._leading_comma and self.pretty:
|
||||||
|
pad = " " * self.pad
|
||||||
|
expressions = "\n".join(f"{sep}{s}" if i > 0 else f"{pad}{s}" for i, s in enumerate(sql))
|
||||||
|
else:
|
||||||
|
expressions = self.sep(sep).join(sql)
|
||||||
|
|
||||||
if indent:
|
if indent:
|
||||||
return self.indent(expressions, skip_first=False)
|
return self.indent(expressions, skip_first=False)
|
||||||
return expressions
|
return expressions
|
||||||
|
@ -1116,9 +1129,7 @@ class Generator:
|
||||||
def set_operation(self, expression, op):
|
def set_operation(self, expression, op):
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
op = self.seg(op)
|
op = self.seg(op)
|
||||||
return self.query_modifiers(
|
return self.query_modifiers(expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}")
|
||||||
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def token_sql(self, token_type):
|
def token_sql(self, token_type):
|
||||||
return self.TOKEN_MAPPING.get(token_type, token_type.name)
|
return self.TOKEN_MAPPING.get(token_type, token_type.name)
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
from sqlglot.optimizer.optimizer import optimize
|
from sqlglot.optimizer.optimizer import RULES, optimize
|
||||||
from sqlglot.optimizer.schema import Schema
|
from sqlglot.optimizer.schema import Schema
|
||||||
|
|
|
@ -13,9 +13,7 @@ def isolate_table_selects(expression):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not isinstance(source.parent, exp.Alias):
|
if not isinstance(source.parent, exp.Alias):
|
||||||
raise OptimizeError(
|
raise OptimizeError("Tables require an alias. Run qualify_tables optimization.")
|
||||||
"Tables require an alias. Run qualify_tables optimization."
|
|
||||||
)
|
|
||||||
|
|
||||||
parent = source.parent
|
parent = source.parent
|
||||||
|
|
||||||
|
|
232
sqlglot/optimizer/merge_derived_tables.py
Normal file
232
sqlglot/optimizer/merge_derived_tables.py
Normal file
|
@ -0,0 +1,232 @@
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from sqlglot import expressions as exp
|
||||||
|
from sqlglot.optimizer.scope import traverse_scope
|
||||||
|
from sqlglot.optimizer.simplify import simplify
|
||||||
|
|
||||||
|
|
||||||
|
def merge_derived_tables(expression):
|
||||||
|
"""
|
||||||
|
Rewrite sqlglot AST to merge derived tables into the outer query.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> import sqlglot
|
||||||
|
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x)")
|
||||||
|
>>> merge_derived_tables(expression).sql()
|
||||||
|
'SELECT x.a FROM x'
|
||||||
|
|
||||||
|
Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression (sqlglot.Expression): expression to optimize
|
||||||
|
Returns:
|
||||||
|
sqlglot.Expression: optimized expression
|
||||||
|
"""
|
||||||
|
for outer_scope in traverse_scope(expression):
|
||||||
|
for subquery in outer_scope.derived_tables:
|
||||||
|
inner_select = subquery.unnest()
|
||||||
|
if (
|
||||||
|
isinstance(outer_scope.expression, exp.Select)
|
||||||
|
and isinstance(inner_select, exp.Select)
|
||||||
|
and _mergeable(inner_select)
|
||||||
|
):
|
||||||
|
alias = subquery.alias_or_name
|
||||||
|
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
|
||||||
|
inner_scope = outer_scope.sources[alias]
|
||||||
|
|
||||||
|
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||||
|
_merge_from(outer_scope, inner_scope, subquery)
|
||||||
|
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||||
|
_merge_expressions(outer_scope, inner_scope, alias)
|
||||||
|
_merge_where(outer_scope, inner_scope, from_or_join)
|
||||||
|
_merge_order(outer_scope, inner_scope)
|
||||||
|
return expression
|
||||||
|
|
||||||
|
|
||||||
|
# If a derived table has these Select args, it can't be merged
|
||||||
|
UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
|
||||||
|
"expressions",
|
||||||
|
"from",
|
||||||
|
"joins",
|
||||||
|
"where",
|
||||||
|
"order",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _mergeable(inner_select):
|
||||||
|
"""
|
||||||
|
Return True if `inner_select` can be merged into outer query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inner_select (exp.Select)
|
||||||
|
Returns:
|
||||||
|
bool: True if can be merged
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
isinstance(inner_select, exp.Select)
|
||||||
|
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
|
||||||
|
and inner_select.args.get("from")
|
||||||
|
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rename_inner_sources(outer_scope, inner_scope, alias):
|
||||||
|
"""
|
||||||
|
Renames any sources in the inner query that conflict with names in the outer query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outer_scope (sqlglot.optimizer.scope.Scope)
|
||||||
|
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||||
|
alias (str)
|
||||||
|
"""
|
||||||
|
taken = set(outer_scope.selected_sources)
|
||||||
|
conflicts = taken.intersection(set(inner_scope.selected_sources))
|
||||||
|
conflicts = conflicts - {alias}
|
||||||
|
|
||||||
|
for conflict in conflicts:
|
||||||
|
new_name = _find_new_name(taken, conflict)
|
||||||
|
|
||||||
|
source, _ = inner_scope.selected_sources[conflict]
|
||||||
|
new_alias = exp.to_identifier(new_name)
|
||||||
|
|
||||||
|
if isinstance(source, exp.Subquery):
|
||||||
|
source.set("alias", exp.TableAlias(this=new_alias))
|
||||||
|
elif isinstance(source, exp.Table) and isinstance(source.parent, exp.Alias):
|
||||||
|
source.parent.set("alias", new_alias)
|
||||||
|
elif isinstance(source, exp.Table):
|
||||||
|
source.replace(exp.alias_(source.copy(), new_alias))
|
||||||
|
|
||||||
|
for column in inner_scope.source_columns(conflict):
|
||||||
|
column.set("table", exp.to_identifier(new_name))
|
||||||
|
|
||||||
|
inner_scope.rename_source(conflict, new_name)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_new_name(taken, base):
|
||||||
|
"""
|
||||||
|
Searches for a new source name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
taken (set[str]): set of taken names
|
||||||
|
base (str): base name to alter
|
||||||
|
"""
|
||||||
|
i = 2
|
||||||
|
new = f"{base}_{i}"
|
||||||
|
while new in taken:
|
||||||
|
i += 1
|
||||||
|
new = f"{base}_{i}"
|
||||||
|
return new
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_from(outer_scope, inner_scope, subquery):
|
||||||
|
"""
|
||||||
|
Merge FROM clause of inner query into outer query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outer_scope (sqlglot.optimizer.scope.Scope)
|
||||||
|
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||||
|
subquery (exp.Subquery)
|
||||||
|
"""
|
||||||
|
new_subquery = inner_scope.expression.args.get("from").expressions[0]
|
||||||
|
subquery.replace(new_subquery)
|
||||||
|
outer_scope.remove_source(subquery.alias_or_name)
|
||||||
|
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_joins(outer_scope, inner_scope, from_or_join):
|
||||||
|
"""
|
||||||
|
Merge JOIN clauses of inner query into outer query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outer_scope (sqlglot.optimizer.scope.Scope)
|
||||||
|
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||||
|
from_or_join (exp.From|exp.Join)
|
||||||
|
"""
|
||||||
|
|
||||||
|
new_joins = []
|
||||||
|
comma_joins = inner_scope.expression.args.get("from").expressions[1:]
|
||||||
|
for subquery in comma_joins:
|
||||||
|
new_joins.append(exp.Join(this=subquery, kind="CROSS"))
|
||||||
|
outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name])
|
||||||
|
|
||||||
|
joins = inner_scope.expression.args.get("joins") or []
|
||||||
|
for join in joins:
|
||||||
|
new_joins.append(join)
|
||||||
|
outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name])
|
||||||
|
|
||||||
|
if new_joins:
|
||||||
|
outer_joins = outer_scope.expression.args.get("joins", [])
|
||||||
|
|
||||||
|
# Maintain the join order
|
||||||
|
if isinstance(from_or_join, exp.From):
|
||||||
|
position = 0
|
||||||
|
else:
|
||||||
|
position = outer_joins.index(from_or_join) + 1
|
||||||
|
outer_joins[position:position] = new_joins
|
||||||
|
|
||||||
|
outer_scope.expression.set("joins", outer_joins)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_expressions(outer_scope, inner_scope, alias):
|
||||||
|
"""
|
||||||
|
Merge projections of inner query into outer query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outer_scope (sqlglot.optimizer.scope.Scope)
|
||||||
|
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||||
|
alias (str)
|
||||||
|
"""
|
||||||
|
# Collect all columns that for the alias of the inner query
|
||||||
|
outer_columns = defaultdict(list)
|
||||||
|
for column in outer_scope.columns:
|
||||||
|
if column.table == alias:
|
||||||
|
outer_columns[column.name].append(column)
|
||||||
|
|
||||||
|
# Replace columns with the projection expression in the inner query
|
||||||
|
for expression in inner_scope.expression.expressions:
|
||||||
|
projection_name = expression.alias_or_name
|
||||||
|
if not projection_name:
|
||||||
|
continue
|
||||||
|
columns_to_replace = outer_columns.get(projection_name, [])
|
||||||
|
for column in columns_to_replace:
|
||||||
|
column.replace(expression.unalias())
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_where(outer_scope, inner_scope, from_or_join):
|
||||||
|
"""
|
||||||
|
Merge WHERE clause of inner query into outer query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outer_scope (sqlglot.optimizer.scope.Scope)
|
||||||
|
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||||
|
from_or_join (exp.From|exp.Join)
|
||||||
|
"""
|
||||||
|
where = inner_scope.expression.args.get("where")
|
||||||
|
if not where or not where.this:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(from_or_join, exp.Join) and from_or_join.side:
|
||||||
|
# Merge predicates from an outer join to the ON clause
|
||||||
|
from_or_join.on(where.this, copy=False)
|
||||||
|
from_or_join.set("on", simplify(from_or_join.args.get("on")))
|
||||||
|
else:
|
||||||
|
outer_scope.expression.where(where.this, copy=False)
|
||||||
|
outer_scope.expression.set("where", simplify(outer_scope.expression.args.get("where")))
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_order(outer_scope, inner_scope):
|
||||||
|
"""
|
||||||
|
Merge ORDER clause of inner query into outer query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outer_scope (sqlglot.optimizer.scope.Scope)
|
||||||
|
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"])
|
||||||
|
or len(outer_scope.selected_sources) != 1
|
||||||
|
or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
|
|
@ -22,18 +22,14 @@ def normalize(expression, dnf=False, max_distance=128):
|
||||||
"""
|
"""
|
||||||
expression = simplify(expression)
|
expression = simplify(expression)
|
||||||
|
|
||||||
expression = while_changing(
|
expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance))
|
||||||
expression, lambda e: distributive_law(e, dnf, max_distance)
|
|
||||||
)
|
|
||||||
return simplify(expression)
|
return simplify(expression)
|
||||||
|
|
||||||
|
|
||||||
def normalized(expression, dnf=False):
|
def normalized(expression, dnf=False):
|
||||||
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
|
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
|
||||||
|
|
||||||
return not any(
|
return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root))
|
||||||
connector.find_ancestor(ancestor) for connector in expression.find_all(root)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def normalization_distance(expression, dnf=False):
|
def normalization_distance(expression, dnf=False):
|
||||||
|
@ -54,9 +50,7 @@ def normalization_distance(expression, dnf=False):
|
||||||
Returns:
|
Returns:
|
||||||
int: difference
|
int: difference
|
||||||
"""
|
"""
|
||||||
return sum(_predicate_lengths(expression, dnf)) - (
|
return sum(_predicate_lengths(expression, dnf)) - (len(list(expression.find_all(exp.Connector))) + 1)
|
||||||
len(list(expression.find_all(exp.Connector))) + 1
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _predicate_lengths(expression, dnf):
|
def _predicate_lengths(expression, dnf):
|
||||||
|
@ -73,11 +67,7 @@ def _predicate_lengths(expression, dnf):
|
||||||
left, right = expression.args.values()
|
left, right = expression.args.values()
|
||||||
|
|
||||||
if isinstance(expression, exp.And if dnf else exp.Or):
|
if isinstance(expression, exp.And if dnf else exp.Or):
|
||||||
x = [
|
x = [a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)]
|
||||||
a + b
|
|
||||||
for a in _predicate_lengths(left, dnf)
|
|
||||||
for b in _predicate_lengths(right, dnf)
|
|
||||||
]
|
|
||||||
return x
|
return x
|
||||||
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
|
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
|
||||||
|
|
||||||
|
@ -102,9 +92,7 @@ def distributive_law(expression, dnf, max_distance):
|
||||||
to_func = exp.and_ if to_exp == exp.And else exp.or_
|
to_func = exp.and_ if to_exp == exp.And else exp.or_
|
||||||
|
|
||||||
if isinstance(a, to_exp) and isinstance(b, to_exp):
|
if isinstance(a, to_exp) and isinstance(b, to_exp):
|
||||||
if len(tuple(a.find_all(exp.Connector))) > len(
|
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
|
||||||
tuple(b.find_all(exp.Connector))
|
|
||||||
):
|
|
||||||
return _distribute(a, b, from_func, to_func)
|
return _distribute(a, b, from_func, to_func)
|
||||||
return _distribute(b, a, from_func, to_func)
|
return _distribute(b, a, from_func, to_func)
|
||||||
if isinstance(a, to_exp):
|
if isinstance(a, to_exp):
|
||||||
|
|
|
@ -68,8 +68,4 @@ def normalize(expression):
|
||||||
|
|
||||||
|
|
||||||
def other_table_names(join, exclude):
|
def other_table_names(join, exclude):
|
||||||
return [
|
return [name for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) if name != exclude]
|
||||||
name
|
|
||||||
for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
|
|
||||||
if name != exclude
|
|
||||||
]
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
||||||
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
|
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
|
||||||
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
|
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
|
||||||
|
from sqlglot.optimizer.merge_derived_tables import merge_derived_tables
|
||||||
from sqlglot.optimizer.normalize import normalize
|
from sqlglot.optimizer.normalize import normalize
|
||||||
from sqlglot.optimizer.optimize_joins import optimize_joins
|
from sqlglot.optimizer.optimize_joins import optimize_joins
|
||||||
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
|
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
|
||||||
|
@ -10,8 +11,23 @@ from sqlglot.optimizer.qualify_tables import qualify_tables
|
||||||
from sqlglot.optimizer.quote_identities import quote_identities
|
from sqlglot.optimizer.quote_identities import quote_identities
|
||||||
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
|
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
|
||||||
|
|
||||||
|
RULES = (
|
||||||
|
qualify_tables,
|
||||||
|
isolate_table_selects,
|
||||||
|
qualify_columns,
|
||||||
|
pushdown_projections,
|
||||||
|
normalize,
|
||||||
|
unnest_subqueries,
|
||||||
|
expand_multi_table_selects,
|
||||||
|
pushdown_predicates,
|
||||||
|
optimize_joins,
|
||||||
|
eliminate_subqueries,
|
||||||
|
merge_derived_tables,
|
||||||
|
quote_identities,
|
||||||
|
)
|
||||||
|
|
||||||
def optimize(expression, schema=None, db=None, catalog=None):
|
|
||||||
|
def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwargs):
|
||||||
"""
|
"""
|
||||||
Rewrite a sqlglot AST into an optimized form.
|
Rewrite a sqlglot AST into an optimized form.
|
||||||
|
|
||||||
|
@ -25,19 +41,18 @@ def optimize(expression, schema=None, db=None, catalog=None):
|
||||||
3. {catalog: {db: {table: {col: type}}}}
|
3. {catalog: {db: {table: {col: type}}}}
|
||||||
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
|
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
|
||||||
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
|
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
|
||||||
|
rules (list): sequence of optimizer rules to use
|
||||||
|
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
|
||||||
Returns:
|
Returns:
|
||||||
sqlglot.Expression: optimized expression
|
sqlglot.Expression: optimized expression
|
||||||
"""
|
"""
|
||||||
|
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
|
||||||
expression = expression.copy()
|
expression = expression.copy()
|
||||||
expression = qualify_tables(expression, db=db, catalog=catalog)
|
for rule in rules:
|
||||||
expression = isolate_table_selects(expression)
|
|
||||||
expression = qualify_columns(expression, schema)
|
# Find any additional rule parameters, beyond `expression`
|
||||||
expression = pushdown_projections(expression)
|
rule_params = rule.__code__.co_varnames
|
||||||
expression = normalize(expression)
|
rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs}
|
||||||
expression = unnest_subqueries(expression)
|
|
||||||
expression = expand_multi_table_selects(expression)
|
expression = rule(expression, **rule_kwargs)
|
||||||
expression = pushdown_predicates(expression)
|
|
||||||
expression = optimize_joins(expression)
|
|
||||||
expression = eliminate_subqueries(expression)
|
|
||||||
expression = quote_identities(expression)
|
|
||||||
return expression
|
return expression
|
||||||
|
|
|
@ -42,11 +42,7 @@ def pushdown(condition, sources):
|
||||||
condition = condition.replace(simplify(condition))
|
condition = condition.replace(simplify(condition))
|
||||||
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
|
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
|
||||||
|
|
||||||
predicates = list(
|
predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition])
|
||||||
condition.flatten()
|
|
||||||
if isinstance(condition, exp.And if cnf_like else exp.Or)
|
|
||||||
else [condition]
|
|
||||||
)
|
|
||||||
|
|
||||||
if cnf_like:
|
if cnf_like:
|
||||||
pushdown_cnf(predicates, sources)
|
pushdown_cnf(predicates, sources)
|
||||||
|
@ -105,17 +101,11 @@ def pushdown_dnf(predicates, scope):
|
||||||
for column in predicate.find_all(exp.Column):
|
for column in predicate.find_all(exp.Column):
|
||||||
if column.table == table:
|
if column.table == table:
|
||||||
condition = column.find_ancestor(exp.Condition)
|
condition = column.find_ancestor(exp.Condition)
|
||||||
predicate_condition = (
|
predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition
|
||||||
exp.and_(predicate_condition, condition)
|
|
||||||
if predicate_condition
|
|
||||||
else condition
|
|
||||||
)
|
|
||||||
|
|
||||||
if predicate_condition:
|
if predicate_condition:
|
||||||
conditions[table] = (
|
conditions[table] = (
|
||||||
exp.or_(conditions[table], predicate_condition)
|
exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition
|
||||||
if table in conditions
|
|
||||||
else predicate_condition
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, node in nodes.items():
|
for name, node in nodes.items():
|
||||||
|
@ -133,9 +123,7 @@ def pushdown_dnf(predicates, scope):
|
||||||
def nodes_for_predicate(predicate, sources):
|
def nodes_for_predicate(predicate, sources):
|
||||||
nodes = {}
|
nodes = {}
|
||||||
tables = exp.column_table_names(predicate)
|
tables = exp.column_table_names(predicate)
|
||||||
where_condition = isinstance(
|
where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
|
||||||
predicate.find_ancestor(exp.Join, exp.Where), exp.Where
|
|
||||||
)
|
|
||||||
|
|
||||||
for table in tables:
|
for table in tables:
|
||||||
node, source = sources.get(table) or (None, None)
|
node, source = sources.get(table) or (None, None)
|
||||||
|
|
|
@ -226,9 +226,7 @@ def _expand_stars(scope, resolver):
|
||||||
tables = list(scope.selected_sources)
|
tables = list(scope.selected_sources)
|
||||||
_add_except_columns(expression, tables, except_columns)
|
_add_except_columns(expression, tables, except_columns)
|
||||||
_add_replace_columns(expression, tables, replace_columns)
|
_add_replace_columns(expression, tables, replace_columns)
|
||||||
elif isinstance(expression, exp.Column) and isinstance(
|
elif isinstance(expression, exp.Column) and isinstance(expression.this, exp.Star):
|
||||||
expression.this, exp.Star
|
|
||||||
):
|
|
||||||
tables = [expression.table]
|
tables = [expression.table]
|
||||||
_add_except_columns(expression.this, tables, except_columns)
|
_add_except_columns(expression.this, tables, except_columns)
|
||||||
_add_replace_columns(expression.this, tables, replace_columns)
|
_add_replace_columns(expression.this, tables, replace_columns)
|
||||||
|
@ -245,9 +243,7 @@ def _expand_stars(scope, resolver):
|
||||||
if name not in except_columns.get(table_id, set()):
|
if name not in except_columns.get(table_id, set()):
|
||||||
alias_ = replace_columns.get(table_id, {}).get(name, name)
|
alias_ = replace_columns.get(table_id, {}).get(name, name)
|
||||||
column = exp.column(name, table)
|
column = exp.column(name, table)
|
||||||
new_selections.append(
|
new_selections.append(alias(column, alias_) if alias_ != name else column)
|
||||||
alias(column, alias_) if alias_ != name else column
|
|
||||||
)
|
|
||||||
|
|
||||||
scope.expression.set("expressions", new_selections)
|
scope.expression.set("expressions", new_selections)
|
||||||
|
|
||||||
|
@ -280,9 +276,7 @@ def _qualify_outputs(scope):
|
||||||
"""Ensure all output columns are aliased"""
|
"""Ensure all output columns are aliased"""
|
||||||
new_selections = []
|
new_selections = []
|
||||||
|
|
||||||
for i, (selection, aliased_column) in enumerate(
|
for i, (selection, aliased_column) in enumerate(itertools.zip_longest(scope.selects, scope.outer_column_list)):
|
||||||
itertools.zip_longest(scope.selects, scope.outer_column_list)
|
|
||||||
):
|
|
||||||
if isinstance(selection, exp.Column):
|
if isinstance(selection, exp.Column):
|
||||||
# convoluted setter because a simple selection.replace(alias) would require a copy
|
# convoluted setter because a simple selection.replace(alias) would require a copy
|
||||||
alias_ = alias(exp.column(""), alias=selection.name)
|
alias_ = alias(exp.column(""), alias=selection.name)
|
||||||
|
@ -302,11 +296,7 @@ def _qualify_outputs(scope):
|
||||||
|
|
||||||
|
|
||||||
def _check_unknown_tables(scope):
|
def _check_unknown_tables(scope):
|
||||||
if (
|
if scope.external_columns and not scope.is_unnest and not scope.is_correlated_subquery:
|
||||||
scope.external_columns
|
|
||||||
and not scope.is_unnest
|
|
||||||
and not scope.is_correlated_subquery
|
|
||||||
):
|
|
||||||
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
|
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -334,20 +324,14 @@ class _Resolver:
|
||||||
(str) table name
|
(str) table name
|
||||||
"""
|
"""
|
||||||
if self._unambiguous_columns is None:
|
if self._unambiguous_columns is None:
|
||||||
self._unambiguous_columns = self._get_unambiguous_columns(
|
self._unambiguous_columns = self._get_unambiguous_columns(self._get_all_source_columns())
|
||||||
self._get_all_source_columns()
|
|
||||||
)
|
|
||||||
return self._unambiguous_columns.get(column_name)
|
return self._unambiguous_columns.get(column_name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def all_columns(self):
|
def all_columns(self):
|
||||||
"""All available columns of all sources in this scope"""
|
"""All available columns of all sources in this scope"""
|
||||||
if self._all_columns is None:
|
if self._all_columns is None:
|
||||||
self._all_columns = set(
|
self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns)
|
||||||
column
|
|
||||||
for columns in self._get_all_source_columns().values()
|
|
||||||
for column in columns
|
|
||||||
)
|
|
||||||
return self._all_columns
|
return self._all_columns
|
||||||
|
|
||||||
def get_source_columns(self, name):
|
def get_source_columns(self, name):
|
||||||
|
@ -369,9 +353,7 @@ class _Resolver:
|
||||||
|
|
||||||
def _get_all_source_columns(self):
|
def _get_all_source_columns(self):
|
||||||
if self._source_columns is None:
|
if self._source_columns is None:
|
||||||
self._source_columns = {
|
self._source_columns = {k: self.get_source_columns(k) for k in self.scope.selected_sources}
|
||||||
k: self.get_source_columns(k) for k in self.scope.selected_sources
|
|
||||||
}
|
|
||||||
return self._source_columns
|
return self._source_columns
|
||||||
|
|
||||||
def _get_unambiguous_columns(self, source_columns):
|
def _get_unambiguous_columns(self, source_columns):
|
||||||
|
@ -389,9 +371,7 @@ class _Resolver:
|
||||||
source_columns = list(source_columns.items())
|
source_columns = list(source_columns.items())
|
||||||
|
|
||||||
first_table, first_columns = source_columns[0]
|
first_table, first_columns = source_columns[0]
|
||||||
unambiguous_columns = {
|
unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
|
||||||
col: first_table for col in self._find_unique_columns(first_columns)
|
|
||||||
}
|
|
||||||
all_columns = set(unambiguous_columns)
|
all_columns = set(unambiguous_columns)
|
||||||
|
|
||||||
for table, columns in source_columns[1:]:
|
for table, columns in source_columns[1:]:
|
||||||
|
|
|
@ -27,9 +27,7 @@ def qualify_tables(expression, db=None, catalog=None):
|
||||||
for derived_table in scope.ctes + scope.derived_tables:
|
for derived_table in scope.ctes + scope.derived_tables:
|
||||||
if not derived_table.args.get("alias"):
|
if not derived_table.args.get("alias"):
|
||||||
alias_ = f"_q_{next(sequence)}"
|
alias_ = f"_q_{next(sequence)}"
|
||||||
derived_table.set(
|
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
|
||||||
"alias", exp.TableAlias(this=exp.to_identifier(alias_))
|
|
||||||
)
|
|
||||||
scope.rename_source(None, alias_)
|
scope.rename_source(None, alias_)
|
||||||
|
|
||||||
for source in scope.sources.values():
|
for source in scope.sources.values():
|
||||||
|
|
|
@ -57,9 +57,7 @@ class MappingSchema(Schema):
|
||||||
|
|
||||||
for forbidden in self.forbidden_args:
|
for forbidden in self.forbidden_args:
|
||||||
if table.text(forbidden):
|
if table.text(forbidden):
|
||||||
raise ValueError(
|
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
|
||||||
f"Schema doesn't support {forbidden}. Received: {table.sql()}"
|
|
||||||
)
|
|
||||||
return list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
|
return list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -104,9 +104,7 @@ class Scope:
|
||||||
elif isinstance(node, exp.CTE):
|
elif isinstance(node, exp.CTE):
|
||||||
self._ctes.append(node)
|
self._ctes.append(node)
|
||||||
prune = True
|
prune = True
|
||||||
elif isinstance(node, exp.Subquery) and isinstance(
|
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
|
||||||
parent, (exp.From, exp.Join)
|
|
||||||
):
|
|
||||||
self._derived_tables.append(node)
|
self._derived_tables.append(node)
|
||||||
prune = True
|
prune = True
|
||||||
elif isinstance(node, exp.Subqueryable):
|
elif isinstance(node, exp.Subqueryable):
|
||||||
|
@ -195,20 +193,14 @@ class Scope:
|
||||||
self._ensure_collected()
|
self._ensure_collected()
|
||||||
columns = self._raw_columns
|
columns = self._raw_columns
|
||||||
|
|
||||||
external_columns = [
|
external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns]
|
||||||
column
|
|
||||||
for scope in self.subquery_scopes
|
|
||||||
for column in scope.external_columns
|
|
||||||
]
|
|
||||||
|
|
||||||
named_outputs = {e.alias_or_name for e in self.expression.expressions}
|
named_outputs = {e.alias_or_name for e in self.expression.expressions}
|
||||||
|
|
||||||
self._columns = [
|
self._columns = [
|
||||||
c
|
c
|
||||||
for c in columns + external_columns
|
for c in columns + external_columns
|
||||||
if not (
|
if not (c.find_ancestor(exp.Qualify, exp.Order) and not c.table and c.name in named_outputs)
|
||||||
c.find_ancestor(exp.Qualify, exp.Order) and c.name in named_outputs
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
return self._columns
|
return self._columns
|
||||||
|
|
||||||
|
@ -229,9 +221,7 @@ class Scope:
|
||||||
for table in self.tables:
|
for table in self.tables:
|
||||||
referenced_names.append(
|
referenced_names.append(
|
||||||
(
|
(
|
||||||
table.parent.alias
|
table.parent.alias if isinstance(table.parent, exp.Alias) else table.name,
|
||||||
if isinstance(table.parent, exp.Alias)
|
|
||||||
else table.name,
|
|
||||||
table,
|
table,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -274,9 +264,7 @@ class Scope:
|
||||||
sources in the current scope.
|
sources in the current scope.
|
||||||
"""
|
"""
|
||||||
if self._external_columns is None:
|
if self._external_columns is None:
|
||||||
self._external_columns = [
|
self._external_columns = [c for c in self.columns if c.table not in self.selected_sources]
|
||||||
c for c in self.columns if c.table not in self.selected_sources
|
|
||||||
]
|
|
||||||
return self._external_columns
|
return self._external_columns
|
||||||
|
|
||||||
def source_columns(self, source_name):
|
def source_columns(self, source_name):
|
||||||
|
@ -310,6 +298,16 @@ class Scope:
|
||||||
columns = self.sources.pop(old_name or "", [])
|
columns = self.sources.pop(old_name or "", [])
|
||||||
self.sources[new_name] = columns
|
self.sources[new_name] = columns
|
||||||
|
|
||||||
|
def add_source(self, name, source):
|
||||||
|
"""Add a source to this scope"""
|
||||||
|
self.sources[name] = source
|
||||||
|
self.clear_cache()
|
||||||
|
|
||||||
|
def remove_source(self, name):
|
||||||
|
"""Remove a source from this scope"""
|
||||||
|
self.sources.pop(name, None)
|
||||||
|
self.clear_cache()
|
||||||
|
|
||||||
|
|
||||||
def traverse_scope(expression):
|
def traverse_scope(expression):
|
||||||
"""
|
"""
|
||||||
|
@ -334,7 +332,7 @@ def traverse_scope(expression):
|
||||||
Args:
|
Args:
|
||||||
expression (exp.Expression): expression to traverse
|
expression (exp.Expression): expression to traverse
|
||||||
Returns:
|
Returns:
|
||||||
List[Scope]: scope instances
|
list[Scope]: scope instances
|
||||||
"""
|
"""
|
||||||
return list(_traverse_scope(Scope(expression)))
|
return list(_traverse_scope(Scope(expression)))
|
||||||
|
|
||||||
|
@ -356,9 +354,7 @@ def _traverse_scope(scope):
|
||||||
def _traverse_select(scope):
|
def _traverse_select(scope):
|
||||||
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
|
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
|
||||||
yield from _traverse_subqueries(scope)
|
yield from _traverse_subqueries(scope)
|
||||||
yield from _traverse_derived_tables(
|
yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
|
||||||
scope.derived_tables, scope, ScopeType.DERIVED_TABLE
|
|
||||||
)
|
|
||||||
_add_table_sources(scope)
|
_add_table_sources(scope)
|
||||||
|
|
||||||
|
|
||||||
|
@ -367,15 +363,11 @@ def _traverse_union(scope):
|
||||||
|
|
||||||
# The last scope to be yield should be the top most scope
|
# The last scope to be yield should be the top most scope
|
||||||
left = None
|
left = None
|
||||||
for left in _traverse_scope(
|
for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)):
|
||||||
scope.branch(scope.expression.left, scope_type=ScopeType.UNION)
|
|
||||||
):
|
|
||||||
yield left
|
yield left
|
||||||
|
|
||||||
right = None
|
right = None
|
||||||
for right in _traverse_scope(
|
for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
|
||||||
scope.branch(scope.expression.right, scope_type=ScopeType.UNION)
|
|
||||||
):
|
|
||||||
yield right
|
yield right
|
||||||
|
|
||||||
scope.union = (left, right)
|
scope.union = (left, right)
|
||||||
|
@ -387,14 +379,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
|
||||||
for derived_table in derived_tables:
|
for derived_table in derived_tables:
|
||||||
for child_scope in _traverse_scope(
|
for child_scope in _traverse_scope(
|
||||||
scope.branch(
|
scope.branch(
|
||||||
derived_table
|
derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
|
||||||
if isinstance(derived_table, (exp.Unnest, exp.Lateral))
|
|
||||||
else derived_table.this,
|
|
||||||
add_sources=sources if scope_type == ScopeType.CTE else None,
|
add_sources=sources if scope_type == ScopeType.CTE else None,
|
||||||
outer_column_list=derived_table.alias_column_names,
|
outer_column_list=derived_table.alias_column_names,
|
||||||
scope_type=ScopeType.UNNEST
|
scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type,
|
||||||
if isinstance(derived_table, exp.Unnest)
|
|
||||||
else scope_type,
|
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
yield child_scope
|
yield child_scope
|
||||||
|
@ -430,9 +418,7 @@ def _add_table_sources(scope):
|
||||||
def _traverse_subqueries(scope):
|
def _traverse_subqueries(scope):
|
||||||
for subquery in scope.subqueries:
|
for subquery in scope.subqueries:
|
||||||
top = None
|
top = None
|
||||||
for child_scope in _traverse_scope(
|
for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
|
||||||
scope.branch(subquery, scope_type=ScopeType.SUBQUERY)
|
|
||||||
):
|
|
||||||
yield child_scope
|
yield child_scope
|
||||||
top = child_scope
|
top = child_scope
|
||||||
scope.subquery_scopes.append(top)
|
scope.subquery_scopes.append(top)
|
||||||
|
|
|
@ -188,9 +188,7 @@ def absorb_and_eliminate(expression):
|
||||||
aa.replace(exp.TRUE if kind == exp.And else exp.FALSE)
|
aa.replace(exp.TRUE if kind == exp.And else exp.FALSE)
|
||||||
elif is_complement(b, ab):
|
elif is_complement(b, ab):
|
||||||
ab.replace(exp.TRUE if kind == exp.And else exp.FALSE)
|
ab.replace(exp.TRUE if kind == exp.And else exp.FALSE)
|
||||||
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(
|
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
|
||||||
a.flatten()
|
|
||||||
):
|
|
||||||
a.replace(exp.FALSE if kind == exp.And else exp.TRUE)
|
a.replace(exp.FALSE if kind == exp.And else exp.TRUE)
|
||||||
elif isinstance(b, kind):
|
elif isinstance(b, kind):
|
||||||
# eliminate
|
# eliminate
|
||||||
|
@ -227,9 +225,7 @@ def simplify_literals(expression):
|
||||||
operands.append(a)
|
operands.append(a)
|
||||||
|
|
||||||
if len(operands) < size:
|
if len(operands) < size:
|
||||||
return functools.reduce(
|
return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
|
||||||
lambda a, b: expression.__class__(this=a, expression=b), operands
|
|
||||||
)
|
|
||||||
elif isinstance(expression, exp.Neg):
|
elif isinstance(expression, exp.Neg):
|
||||||
this = expression.this
|
this = expression.this
|
||||||
if this.is_number:
|
if this.is_number:
|
||||||
|
|
|
@ -89,11 +89,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(predicate, exp.Binary):
|
if isinstance(predicate, exp.Binary):
|
||||||
key = (
|
key = predicate.right if any(node is column for node, *_ in predicate.left.walk()) else predicate.left
|
||||||
predicate.right
|
|
||||||
if any(node is column for node, *_ in predicate.left.walk())
|
|
||||||
else predicate.left
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -124,9 +120,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
||||||
# if the value of the subquery is not an agg or a key, we need to collect it into an array
|
# if the value of the subquery is not an agg or a key, we need to collect it into an array
|
||||||
# so that it can be grouped
|
# so that it can be grouped
|
||||||
if not value.find(exp.AggFunc) and value.this not in group_by:
|
if not value.find(exp.AggFunc) and value.this not in group_by:
|
||||||
select.select(
|
select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False)
|
||||||
f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# exists queries should not have any selects as it only checks if there are any rows
|
# exists queries should not have any selects as it only checks if there are any rows
|
||||||
# all selects will be added by the optimizer and only used for join keys
|
# all selects will be added by the optimizer and only used for join keys
|
||||||
|
@ -151,16 +145,12 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
||||||
else:
|
else:
|
||||||
parent_predicate = _replace(parent_predicate, "TRUE")
|
parent_predicate = _replace(parent_predicate, "TRUE")
|
||||||
elif isinstance(parent_predicate, exp.All):
|
elif isinstance(parent_predicate, exp.All):
|
||||||
parent_predicate = _replace(
|
parent_predicate = _replace(parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})")
|
||||||
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
|
|
||||||
)
|
|
||||||
elif isinstance(parent_predicate, exp.Any):
|
elif isinstance(parent_predicate, exp.Any):
|
||||||
if value.this in group_by:
|
if value.this in group_by:
|
||||||
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
|
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
|
||||||
else:
|
else:
|
||||||
parent_predicate = _replace(
|
parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})")
|
||||||
parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})"
|
|
||||||
)
|
|
||||||
elif isinstance(parent_predicate, exp.In):
|
elif isinstance(parent_predicate, exp.In):
|
||||||
if value.this in group_by:
|
if value.this in group_by:
|
||||||
parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
|
parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
|
||||||
|
@ -178,9 +168,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
||||||
|
|
||||||
if key in group_by:
|
if key in group_by:
|
||||||
key.replace(nested)
|
key.replace(nested)
|
||||||
parent_predicate = _replace(
|
parent_predicate = _replace(parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)")
|
||||||
parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
|
|
||||||
)
|
|
||||||
elif isinstance(predicate, exp.EQ):
|
elif isinstance(predicate, exp.EQ):
|
||||||
parent_predicate = _replace(
|
parent_predicate = _replace(
|
||||||
parent_predicate,
|
parent_predicate,
|
||||||
|
|
|
@ -78,6 +78,7 @@ class Parser:
|
||||||
TokenType.TEXT,
|
TokenType.TEXT,
|
||||||
TokenType.BINARY,
|
TokenType.BINARY,
|
||||||
TokenType.JSON,
|
TokenType.JSON,
|
||||||
|
TokenType.INTERVAL,
|
||||||
TokenType.TIMESTAMP,
|
TokenType.TIMESTAMP,
|
||||||
TokenType.TIMESTAMPTZ,
|
TokenType.TIMESTAMPTZ,
|
||||||
TokenType.DATETIME,
|
TokenType.DATETIME,
|
||||||
|
@ -85,6 +86,12 @@ class Parser:
|
||||||
TokenType.DECIMAL,
|
TokenType.DECIMAL,
|
||||||
TokenType.UUID,
|
TokenType.UUID,
|
||||||
TokenType.GEOGRAPHY,
|
TokenType.GEOGRAPHY,
|
||||||
|
TokenType.GEOMETRY,
|
||||||
|
TokenType.HLLSKETCH,
|
||||||
|
TokenType.SUPER,
|
||||||
|
TokenType.SERIAL,
|
||||||
|
TokenType.SMALLSERIAL,
|
||||||
|
TokenType.BIGSERIAL,
|
||||||
*NESTED_TYPE_TOKENS,
|
*NESTED_TYPE_TOKENS,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,13 +107,14 @@ class Parser:
|
||||||
ID_VAR_TOKENS = {
|
ID_VAR_TOKENS = {
|
||||||
TokenType.VAR,
|
TokenType.VAR,
|
||||||
TokenType.ALTER,
|
TokenType.ALTER,
|
||||||
|
TokenType.ALWAYS,
|
||||||
TokenType.BEGIN,
|
TokenType.BEGIN,
|
||||||
|
TokenType.BOTH,
|
||||||
TokenType.BUCKET,
|
TokenType.BUCKET,
|
||||||
TokenType.CACHE,
|
TokenType.CACHE,
|
||||||
TokenType.COLLATE,
|
TokenType.COLLATE,
|
||||||
TokenType.COMMIT,
|
TokenType.COMMIT,
|
||||||
TokenType.CONSTRAINT,
|
TokenType.CONSTRAINT,
|
||||||
TokenType.CONVERT,
|
|
||||||
TokenType.DEFAULT,
|
TokenType.DEFAULT,
|
||||||
TokenType.DELETE,
|
TokenType.DELETE,
|
||||||
TokenType.ENGINE,
|
TokenType.ENGINE,
|
||||||
|
@ -115,14 +123,19 @@ class Parser:
|
||||||
TokenType.FALSE,
|
TokenType.FALSE,
|
||||||
TokenType.FIRST,
|
TokenType.FIRST,
|
||||||
TokenType.FOLLOWING,
|
TokenType.FOLLOWING,
|
||||||
|
TokenType.FOR,
|
||||||
TokenType.FORMAT,
|
TokenType.FORMAT,
|
||||||
TokenType.FUNCTION,
|
TokenType.FUNCTION,
|
||||||
|
TokenType.GENERATED,
|
||||||
|
TokenType.IDENTITY,
|
||||||
TokenType.IF,
|
TokenType.IF,
|
||||||
TokenType.INDEX,
|
TokenType.INDEX,
|
||||||
TokenType.ISNULL,
|
TokenType.ISNULL,
|
||||||
TokenType.INTERVAL,
|
TokenType.INTERVAL,
|
||||||
TokenType.LAZY,
|
TokenType.LAZY,
|
||||||
|
TokenType.LEADING,
|
||||||
TokenType.LOCATION,
|
TokenType.LOCATION,
|
||||||
|
TokenType.NATURAL,
|
||||||
TokenType.NEXT,
|
TokenType.NEXT,
|
||||||
TokenType.ONLY,
|
TokenType.ONLY,
|
||||||
TokenType.OPTIMIZE,
|
TokenType.OPTIMIZE,
|
||||||
|
@ -141,6 +154,7 @@ class Parser:
|
||||||
TokenType.TABLE_FORMAT,
|
TokenType.TABLE_FORMAT,
|
||||||
TokenType.TEMPORARY,
|
TokenType.TEMPORARY,
|
||||||
TokenType.TOP,
|
TokenType.TOP,
|
||||||
|
TokenType.TRAILING,
|
||||||
TokenType.TRUNCATE,
|
TokenType.TRUNCATE,
|
||||||
TokenType.TRUE,
|
TokenType.TRUE,
|
||||||
TokenType.UNBOUNDED,
|
TokenType.UNBOUNDED,
|
||||||
|
@ -150,18 +164,15 @@ class Parser:
|
||||||
*TYPE_TOKENS,
|
*TYPE_TOKENS,
|
||||||
}
|
}
|
||||||
|
|
||||||
CASTS = {
|
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL}
|
||||||
TokenType.CAST,
|
|
||||||
TokenType.TRY_CAST,
|
TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
|
||||||
}
|
|
||||||
|
|
||||||
FUNC_TOKENS = {
|
FUNC_TOKENS = {
|
||||||
TokenType.CONVERT,
|
|
||||||
TokenType.CURRENT_DATE,
|
TokenType.CURRENT_DATE,
|
||||||
TokenType.CURRENT_DATETIME,
|
TokenType.CURRENT_DATETIME,
|
||||||
TokenType.CURRENT_TIMESTAMP,
|
TokenType.CURRENT_TIMESTAMP,
|
||||||
TokenType.CURRENT_TIME,
|
TokenType.CURRENT_TIME,
|
||||||
TokenType.EXTRACT,
|
|
||||||
TokenType.FILTER,
|
TokenType.FILTER,
|
||||||
TokenType.FIRST,
|
TokenType.FIRST,
|
||||||
TokenType.FORMAT,
|
TokenType.FORMAT,
|
||||||
|
@ -178,7 +189,6 @@ class Parser:
|
||||||
TokenType.DATETIME,
|
TokenType.DATETIME,
|
||||||
TokenType.TIMESTAMP,
|
TokenType.TIMESTAMP,
|
||||||
TokenType.TIMESTAMPTZ,
|
TokenType.TIMESTAMPTZ,
|
||||||
*CASTS,
|
|
||||||
*NESTED_TYPE_TOKENS,
|
*NESTED_TYPE_TOKENS,
|
||||||
*SUBQUERY_PREDICATES,
|
*SUBQUERY_PREDICATES,
|
||||||
}
|
}
|
||||||
|
@ -215,6 +225,7 @@ class Parser:
|
||||||
|
|
||||||
FACTOR = {
|
FACTOR = {
|
||||||
TokenType.DIV: exp.IntDiv,
|
TokenType.DIV: exp.IntDiv,
|
||||||
|
TokenType.LR_ARROW: exp.Distance,
|
||||||
TokenType.SLASH: exp.Div,
|
TokenType.SLASH: exp.Div,
|
||||||
TokenType.STAR: exp.Mul,
|
TokenType.STAR: exp.Mul,
|
||||||
}
|
}
|
||||||
|
@ -299,14 +310,13 @@ class Parser:
|
||||||
PRIMARY_PARSERS = {
|
PRIMARY_PARSERS = {
|
||||||
TokenType.STRING: lambda _, token: exp.Literal.string(token.text),
|
TokenType.STRING: lambda _, token: exp.Literal.string(token.text),
|
||||||
TokenType.NUMBER: lambda _, token: exp.Literal.number(token.text),
|
TokenType.NUMBER: lambda _, token: exp.Literal.number(token.text),
|
||||||
TokenType.STAR: lambda self, _: exp.Star(
|
TokenType.STAR: lambda self, _: exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}),
|
||||||
**{"except": self._parse_except(), "replace": self._parse_replace()}
|
|
||||||
),
|
|
||||||
TokenType.NULL: lambda *_: exp.Null(),
|
TokenType.NULL: lambda *_: exp.Null(),
|
||||||
TokenType.TRUE: lambda *_: exp.Boolean(this=True),
|
TokenType.TRUE: lambda *_: exp.Boolean(this=True),
|
||||||
TokenType.FALSE: lambda *_: exp.Boolean(this=False),
|
TokenType.FALSE: lambda *_: exp.Boolean(this=False),
|
||||||
TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(),
|
TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(),
|
||||||
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
|
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
|
||||||
|
TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
|
||||||
TokenType.INTRODUCER: lambda self, token: self.expression(
|
TokenType.INTRODUCER: lambda self, token: self.expression(
|
||||||
exp.Introducer,
|
exp.Introducer,
|
||||||
this=token.text,
|
this=token.text,
|
||||||
|
@ -319,13 +329,16 @@ class Parser:
|
||||||
TokenType.IN: lambda self, this: self._parse_in(this),
|
TokenType.IN: lambda self, this: self._parse_in(this),
|
||||||
TokenType.IS: lambda self, this: self._parse_is(this),
|
TokenType.IS: lambda self, this: self._parse_is(this),
|
||||||
TokenType.LIKE: lambda self, this: self._parse_escape(
|
TokenType.LIKE: lambda self, this: self._parse_escape(
|
||||||
self.expression(exp.Like, this=this, expression=self._parse_type())
|
self.expression(exp.Like, this=this, expression=self._parse_bitwise())
|
||||||
),
|
),
|
||||||
TokenType.ILIKE: lambda self, this: self._parse_escape(
|
TokenType.ILIKE: lambda self, this: self._parse_escape(
|
||||||
self.expression(exp.ILike, this=this, expression=self._parse_type())
|
self.expression(exp.ILike, this=this, expression=self._parse_bitwise())
|
||||||
),
|
),
|
||||||
TokenType.RLIKE: lambda self, this: self.expression(
|
TokenType.RLIKE: lambda self, this: self.expression(
|
||||||
exp.RegexpLike, this=this, expression=self._parse_type()
|
exp.RegexpLike, this=this, expression=self._parse_bitwise()
|
||||||
|
),
|
||||||
|
TokenType.SIMILAR_TO: lambda self, this: self.expression(
|
||||||
|
exp.SimilarTo, this=this, expression=self._parse_bitwise()
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -363,28 +376,21 @@ class Parser:
|
||||||
}
|
}
|
||||||
|
|
||||||
FUNCTION_PARSERS = {
|
FUNCTION_PARSERS = {
|
||||||
TokenType.CONVERT: lambda self, _: self._parse_convert(),
|
"CONVERT": lambda self: self._parse_convert(),
|
||||||
TokenType.EXTRACT: lambda self, _: self._parse_extract(),
|
"EXTRACT": lambda self: self._parse_extract(),
|
||||||
**{
|
"SUBSTRING": lambda self: self._parse_substring(),
|
||||||
token_type: lambda self, token_type: self._parse_cast(
|
"TRIM": lambda self: self._parse_trim(),
|
||||||
self.STRICT_CAST and token_type == TokenType.CAST
|
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
|
||||||
)
|
"TRY_CAST": lambda self: self._parse_cast(False),
|
||||||
for token_type in CASTS
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
QUERY_MODIFIER_PARSERS = {
|
QUERY_MODIFIER_PARSERS = {
|
||||||
"laterals": lambda self: self._parse_laterals(),
|
|
||||||
"joins": lambda self: self._parse_joins(),
|
|
||||||
"where": lambda self: self._parse_where(),
|
"where": lambda self: self._parse_where(),
|
||||||
"group": lambda self: self._parse_group(),
|
"group": lambda self: self._parse_group(),
|
||||||
"having": lambda self: self._parse_having(),
|
"having": lambda self: self._parse_having(),
|
||||||
"qualify": lambda self: self._parse_qualify(),
|
"qualify": lambda self: self._parse_qualify(),
|
||||||
"window": lambda self: self._match(TokenType.WINDOW)
|
"window": lambda self: self._match(TokenType.WINDOW) and self._parse_window(self._parse_id_var(), alias=True),
|
||||||
and self._parse_window(self._parse_id_var(), alias=True),
|
"distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute),
|
||||||
"distribute": lambda self: self._parse_sort(
|
|
||||||
TokenType.DISTRIBUTE_BY, exp.Distribute
|
|
||||||
),
|
|
||||||
"sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
|
"sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
|
||||||
"cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
|
"cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
|
||||||
"order": lambda self: self._parse_order(),
|
"order": lambda self: self._parse_order(),
|
||||||
|
@ -392,6 +398,8 @@ class Parser:
|
||||||
"offset": lambda self: self._parse_offset(),
|
"offset": lambda self: self._parse_offset(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
|
||||||
|
|
||||||
CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX}
|
CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX}
|
||||||
|
|
||||||
STRICT_CAST = True
|
STRICT_CAST = True
|
||||||
|
@ -457,9 +465,7 @@ class Parser:
|
||||||
Returns
|
Returns
|
||||||
the list of syntax trees (:class:`~sqlglot.expressions.Expression`).
|
the list of syntax trees (:class:`~sqlglot.expressions.Expression`).
|
||||||
"""
|
"""
|
||||||
return self._parse(
|
return self._parse(parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql)
|
||||||
parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse_into(self, expression_types, raw_tokens, sql=None):
|
def parse_into(self, expression_types, raw_tokens, sql=None):
|
||||||
for expression_type in ensure_list(expression_types):
|
for expression_type in ensure_list(expression_types):
|
||||||
|
@ -532,21 +538,13 @@ class Parser:
|
||||||
|
|
||||||
for k in expression.args:
|
for k in expression.args:
|
||||||
if k not in expression.arg_types:
|
if k not in expression.arg_types:
|
||||||
self.raise_error(
|
self.raise_error(f"Unexpected keyword: '{k}' for {expression.__class__}")
|
||||||
f"Unexpected keyword: '{k}' for {expression.__class__}"
|
|
||||||
)
|
|
||||||
for k, mandatory in expression.arg_types.items():
|
for k, mandatory in expression.arg_types.items():
|
||||||
v = expression.args.get(k)
|
v = expression.args.get(k)
|
||||||
if mandatory and (v is None or (isinstance(v, list) and not v)):
|
if mandatory and (v is None or (isinstance(v, list) and not v)):
|
||||||
self.raise_error(
|
self.raise_error(f"Required keyword: '{k}' missing for {expression.__class__}")
|
||||||
f"Required keyword: '{k}' missing for {expression.__class__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if args and len(args) > len(expression.arg_types) and not expression.is_var_len_args:
|
||||||
args
|
|
||||||
and len(args) > len(expression.arg_types)
|
|
||||||
and not expression.is_var_len_args
|
|
||||||
):
|
|
||||||
self.raise_error(
|
self.raise_error(
|
||||||
f"The number of provided arguments ({len(args)}) is greater than "
|
f"The number of provided arguments ({len(args)}) is greater than "
|
||||||
f"the maximum number of supported arguments ({len(expression.arg_types)})"
|
f"the maximum number of supported arguments ({len(expression.arg_types)})"
|
||||||
|
@ -594,11 +592,7 @@ class Parser:
|
||||||
)
|
)
|
||||||
|
|
||||||
expression = self._parse_expression()
|
expression = self._parse_expression()
|
||||||
expression = (
|
expression = self._parse_set_operations(expression) if expression else self._parse_select()
|
||||||
self._parse_set_operations(expression)
|
|
||||||
if expression
|
|
||||||
else self._parse_select()
|
|
||||||
)
|
|
||||||
self._parse_query_modifiers(expression)
|
self._parse_query_modifiers(expression)
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
|
@ -618,11 +612,7 @@ class Parser:
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_exists(self, not_=False):
|
def _parse_exists(self, not_=False):
|
||||||
return (
|
return self._match(TokenType.IF) and (not not_ or self._match(TokenType.NOT)) and self._match(TokenType.EXISTS)
|
||||||
self._match(TokenType.IF)
|
|
||||||
and (not not_ or self._match(TokenType.NOT))
|
|
||||||
and self._match(TokenType.EXISTS)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_create(self):
|
def _parse_create(self):
|
||||||
replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
|
replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
|
||||||
|
@ -647,11 +637,9 @@ class Parser:
|
||||||
this = self._parse_index()
|
this = self._parse_index()
|
||||||
elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW):
|
elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW):
|
||||||
this = self._parse_table(schema=True)
|
this = self._parse_table(schema=True)
|
||||||
properties = self._parse_properties(
|
properties = self._parse_properties(this if isinstance(this, exp.Schema) else None)
|
||||||
this if isinstance(this, exp.Schema) else None
|
|
||||||
)
|
|
||||||
if self._match(TokenType.ALIAS):
|
if self._match(TokenType.ALIAS):
|
||||||
expression = self._parse_select()
|
expression = self._parse_select(nested=True)
|
||||||
|
|
||||||
return self.expression(
|
return self.expression(
|
||||||
exp.Create,
|
exp.Create,
|
||||||
|
@ -682,9 +670,7 @@ class Parser:
|
||||||
if schema and not isinstance(value, exp.Schema):
|
if schema and not isinstance(value, exp.Schema):
|
||||||
columns = {v.name.upper() for v in value.expressions}
|
columns = {v.name.upper() for v in value.expressions}
|
||||||
partitions = [
|
partitions = [
|
||||||
expression
|
expression for expression in schema.expressions if expression.this.name.upper() in columns
|
||||||
for expression in schema.expressions
|
|
||||||
if expression.this.name.upper() in columns
|
|
||||||
]
|
]
|
||||||
schema.set(
|
schema.set(
|
||||||
"expressions",
|
"expressions",
|
||||||
|
@ -811,7 +797,7 @@ class Parser:
|
||||||
this=self._parse_table(schema=True),
|
this=self._parse_table(schema=True),
|
||||||
exists=self._parse_exists(),
|
exists=self._parse_exists(),
|
||||||
partition=self._parse_partition(),
|
partition=self._parse_partition(),
|
||||||
expression=self._parse_select(),
|
expression=self._parse_select(nested=True),
|
||||||
overwrite=overwrite,
|
overwrite=overwrite,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -829,8 +815,7 @@ class Parser:
|
||||||
exp.Update,
|
exp.Update,
|
||||||
**{
|
**{
|
||||||
"this": self._parse_table(schema=True),
|
"this": self._parse_table(schema=True),
|
||||||
"expressions": self._match(TokenType.SET)
|
"expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality),
|
||||||
and self._parse_csv(self._parse_equality),
|
|
||||||
"from": self._parse_from(),
|
"from": self._parse_from(),
|
||||||
"where": self._parse_where(),
|
"where": self._parse_where(),
|
||||||
},
|
},
|
||||||
|
@ -865,7 +850,7 @@ class Parser:
|
||||||
this=table,
|
this=table,
|
||||||
lazy=lazy,
|
lazy=lazy,
|
||||||
options=options,
|
options=options,
|
||||||
expression=self._parse_select(),
|
expression=self._parse_select(nested=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_partition(self):
|
def _parse_partition(self):
|
||||||
|
@ -894,9 +879,7 @@ class Parser:
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
return self.expression(exp.Tuple, expressions=expressions)
|
return self.expression(exp.Tuple, expressions=expressions)
|
||||||
|
|
||||||
def _parse_select(self, table=None):
|
def _parse_select(self, nested=False, table=False):
|
||||||
index = self._index
|
|
||||||
|
|
||||||
if self._match(TokenType.SELECT):
|
if self._match(TokenType.SELECT):
|
||||||
hint = self._parse_hint()
|
hint = self._parse_hint()
|
||||||
all_ = self._match(TokenType.ALL)
|
all_ = self._match(TokenType.ALL)
|
||||||
|
@ -912,9 +895,7 @@ class Parser:
|
||||||
self.raise_error("Cannot specify both ALL and DISTINCT after SELECT")
|
self.raise_error("Cannot specify both ALL and DISTINCT after SELECT")
|
||||||
|
|
||||||
limit = self._parse_limit(top=True)
|
limit = self._parse_limit(top=True)
|
||||||
expressions = self._parse_csv(
|
expressions = self._parse_csv(lambda: self._parse_annotation(self._parse_expression()))
|
||||||
lambda: self._parse_annotation(self._parse_expression())
|
|
||||||
)
|
|
||||||
|
|
||||||
this = self.expression(
|
this = self.expression(
|
||||||
exp.Select,
|
exp.Select,
|
||||||
|
@ -960,19 +941,13 @@ class Parser:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.raise_error(f"{this.key} does not support CTE")
|
self.raise_error(f"{this.key} does not support CTE")
|
||||||
elif self._match(TokenType.L_PAREN):
|
elif (table or nested) and self._match(TokenType.L_PAREN):
|
||||||
this = self._parse_table() if table else self._parse_select()
|
this = self._parse_table() if table else self._parse_select(nested=True)
|
||||||
|
|
||||||
if this:
|
|
||||||
self._parse_query_modifiers(this)
|
self._parse_query_modifiers(this)
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
this = self._parse_subquery(this)
|
this = self._parse_subquery(this)
|
||||||
else:
|
|
||||||
self._retreat(index)
|
|
||||||
elif self._match(TokenType.VALUES):
|
elif self._match(TokenType.VALUES):
|
||||||
this = self.expression(
|
this = self.expression(exp.Values, expressions=self._parse_csv(self._parse_value))
|
||||||
exp.Values, expressions=self._parse_csv(self._parse_value)
|
|
||||||
)
|
|
||||||
alias = self._parse_table_alias()
|
alias = self._parse_table_alias()
|
||||||
if alias:
|
if alias:
|
||||||
this = self.expression(exp.Subquery, this=this, alias=alias)
|
this = self.expression(exp.Subquery, this=this, alias=alias)
|
||||||
|
@ -1001,7 +976,7 @@ class Parser:
|
||||||
|
|
||||||
def _parse_table_alias(self):
|
def _parse_table_alias(self):
|
||||||
any_token = self._match(TokenType.ALIAS)
|
any_token = self._match(TokenType.ALIAS)
|
||||||
alias = self._parse_id_var(any_token)
|
alias = self._parse_id_var(any_token=any_token, tokens=self.TABLE_ALIAS_TOKENS)
|
||||||
columns = None
|
columns = None
|
||||||
|
|
||||||
if self._match(TokenType.L_PAREN):
|
if self._match(TokenType.L_PAREN):
|
||||||
|
@ -1021,9 +996,24 @@ class Parser:
|
||||||
return self.expression(exp.Subquery, this=this, alias=self._parse_table_alias())
|
return self.expression(exp.Subquery, this=this, alias=self._parse_table_alias())
|
||||||
|
|
||||||
def _parse_query_modifiers(self, this):
|
def _parse_query_modifiers(self, this):
|
||||||
if not isinstance(this, (exp.Subquery, exp.Subqueryable)):
|
if not isinstance(this, self.MODIFIABLES):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
table = isinstance(this, exp.Table)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
lateral = self._parse_lateral()
|
||||||
|
join = self._parse_join()
|
||||||
|
comma = None if table else self._match(TokenType.COMMA)
|
||||||
|
if lateral:
|
||||||
|
this.append("laterals", lateral)
|
||||||
|
if join:
|
||||||
|
this.append("joins", join)
|
||||||
|
if comma:
|
||||||
|
this.args["from"].append("expressions", self._parse_table())
|
||||||
|
if not (lateral or join or comma):
|
||||||
|
break
|
||||||
|
|
||||||
for key, parser in self.QUERY_MODIFIER_PARSERS.items():
|
for key, parser in self.QUERY_MODIFIER_PARSERS.items():
|
||||||
expression = parser(self)
|
expression = parser(self)
|
||||||
|
|
||||||
|
@ -1032,9 +1022,7 @@ class Parser:
|
||||||
|
|
||||||
def _parse_annotation(self, expression):
|
def _parse_annotation(self, expression):
|
||||||
if self._match(TokenType.ANNOTATION):
|
if self._match(TokenType.ANNOTATION):
|
||||||
return self.expression(
|
return self.expression(exp.Annotation, this=self._prev.text, expression=expression)
|
||||||
exp.Annotation, this=self._prev.text, expression=expression
|
|
||||||
)
|
|
||||||
|
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
|
@ -1052,16 +1040,16 @@ class Parser:
|
||||||
|
|
||||||
return self.expression(exp.From, expressions=self._parse_csv(self._parse_table))
|
return self.expression(exp.From, expressions=self._parse_csv(self._parse_table))
|
||||||
|
|
||||||
def _parse_laterals(self):
|
|
||||||
return self._parse_all(self._parse_lateral)
|
|
||||||
|
|
||||||
def _parse_lateral(self):
|
def _parse_lateral(self):
|
||||||
if not self._match(TokenType.LATERAL):
|
if not self._match(TokenType.LATERAL):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not self._match(TokenType.VIEW):
|
subquery = self._parse_select(table=True)
|
||||||
self.raise_error("Expected VIEW after LATERAL")
|
|
||||||
|
|
||||||
|
if subquery:
|
||||||
|
return self.expression(exp.Lateral, this=subquery)
|
||||||
|
|
||||||
|
self._match(TokenType.VIEW)
|
||||||
outer = self._match(TokenType.OUTER)
|
outer = self._match(TokenType.OUTER)
|
||||||
|
|
||||||
return self.expression(
|
return self.expression(
|
||||||
|
@ -1071,31 +1059,27 @@ class Parser:
|
||||||
alias=self.expression(
|
alias=self.expression(
|
||||||
exp.TableAlias,
|
exp.TableAlias,
|
||||||
this=self._parse_id_var(any_token=False),
|
this=self._parse_id_var(any_token=False),
|
||||||
columns=(
|
columns=(self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else None),
|
||||||
self._parse_csv(self._parse_id_var)
|
|
||||||
if self._match(TokenType.ALIAS)
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_joins(self):
|
|
||||||
return self._parse_all(self._parse_join)
|
|
||||||
|
|
||||||
def _parse_join_side_and_kind(self):
|
def _parse_join_side_and_kind(self):
|
||||||
return (
|
return (
|
||||||
|
self._match(TokenType.NATURAL) and self._prev,
|
||||||
self._match_set(self.JOIN_SIDES) and self._prev,
|
self._match_set(self.JOIN_SIDES) and self._prev,
|
||||||
self._match_set(self.JOIN_KINDS) and self._prev,
|
self._match_set(self.JOIN_KINDS) and self._prev,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_join(self):
|
def _parse_join(self):
|
||||||
side, kind = self._parse_join_side_and_kind()
|
natural, side, kind = self._parse_join_side_and_kind()
|
||||||
|
|
||||||
if not self._match(TokenType.JOIN):
|
if not self._match(TokenType.JOIN):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
kwargs = {"this": self._parse_table()}
|
kwargs = {"this": self._parse_table()}
|
||||||
|
|
||||||
|
if natural:
|
||||||
|
kwargs["natural"] = True
|
||||||
if side:
|
if side:
|
||||||
kwargs["side"] = side.text
|
kwargs["side"] = side.text
|
||||||
if kind:
|
if kind:
|
||||||
|
@ -1120,6 +1104,11 @@ class Parser:
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_table(self, schema=False):
|
def _parse_table(self, schema=False):
|
||||||
|
lateral = self._parse_lateral()
|
||||||
|
|
||||||
|
if lateral:
|
||||||
|
return lateral
|
||||||
|
|
||||||
unnest = self._parse_unnest()
|
unnest = self._parse_unnest()
|
||||||
|
|
||||||
if unnest:
|
if unnest:
|
||||||
|
@ -1172,9 +1161,7 @@ class Parser:
|
||||||
expressions = self._parse_csv(self._parse_column)
|
expressions = self._parse_csv(self._parse_column)
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
|
|
||||||
ordinality = bool(
|
ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY))
|
||||||
self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY)
|
|
||||||
)
|
|
||||||
|
|
||||||
alias = self._parse_table_alias()
|
alias = self._parse_table_alias()
|
||||||
|
|
||||||
|
@ -1280,17 +1267,13 @@ class Parser:
|
||||||
if not self._match(TokenType.ORDER_BY):
|
if not self._match(TokenType.ORDER_BY):
|
||||||
return this
|
return this
|
||||||
|
|
||||||
return self.expression(
|
return self.expression(exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered))
|
||||||
exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_sort(self, token_type, exp_class):
|
def _parse_sort(self, token_type, exp_class):
|
||||||
if not self._match(token_type):
|
if not self._match(token_type):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self.expression(
|
return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
|
||||||
exp_class, expressions=self._parse_csv(self._parse_ordered)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_ordered(self):
|
def _parse_ordered(self):
|
||||||
this = self._parse_conjunction()
|
this = self._parse_conjunction()
|
||||||
|
@ -1305,22 +1288,17 @@ class Parser:
|
||||||
if (
|
if (
|
||||||
not explicitly_null_ordered
|
not explicitly_null_ordered
|
||||||
and (
|
and (
|
||||||
(asc and self.null_ordering == "nulls_are_small")
|
(asc and self.null_ordering == "nulls_are_small") or (desc and self.null_ordering != "nulls_are_small")
|
||||||
or (desc and self.null_ordering != "nulls_are_small")
|
|
||||||
)
|
)
|
||||||
and self.null_ordering != "nulls_are_last"
|
and self.null_ordering != "nulls_are_last"
|
||||||
):
|
):
|
||||||
nulls_first = True
|
nulls_first = True
|
||||||
|
|
||||||
return self.expression(
|
return self.expression(exp.Ordered, this=this, desc=desc, nulls_first=nulls_first)
|
||||||
exp.Ordered, this=this, desc=desc, nulls_first=nulls_first
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_limit(self, this=None, top=False):
|
def _parse_limit(self, this=None, top=False):
|
||||||
if self._match(TokenType.TOP if top else TokenType.LIMIT):
|
if self._match(TokenType.TOP if top else TokenType.LIMIT):
|
||||||
return self.expression(
|
return self.expression(exp.Limit, this=this, expression=self._parse_number())
|
||||||
exp.Limit, this=this, expression=self._parse_number()
|
|
||||||
)
|
|
||||||
if self._match(TokenType.FETCH):
|
if self._match(TokenType.FETCH):
|
||||||
direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
|
direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
|
||||||
direction = self._prev.text if direction else "FIRST"
|
direction = self._prev.text if direction else "FIRST"
|
||||||
|
@ -1354,7 +1332,7 @@ class Parser:
|
||||||
expression,
|
expression,
|
||||||
this=this,
|
this=this,
|
||||||
distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
|
distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
|
||||||
expression=self._parse_select(),
|
expression=self._parse_select(nested=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_expression(self):
|
def _parse_expression(self):
|
||||||
|
@ -1396,9 +1374,7 @@ class Parser:
|
||||||
this = self.expression(exp.In, this=this, unnest=unnest)
|
this = self.expression(exp.In, this=this, unnest=unnest)
|
||||||
else:
|
else:
|
||||||
self._match_l_paren()
|
self._match_l_paren()
|
||||||
expressions = self._parse_csv(
|
expressions = self._parse_csv(lambda: self._parse_select() or self._parse_expression())
|
||||||
lambda: self._parse_select() or self._parse_expression()
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable):
|
if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable):
|
||||||
this = self.expression(exp.In, this=this, query=expressions[0])
|
this = self.expression(exp.In, this=this, query=expressions[0])
|
||||||
|
@ -1430,13 +1406,9 @@ class Parser:
|
||||||
expression=self._parse_term(),
|
expression=self._parse_term(),
|
||||||
)
|
)
|
||||||
elif self._match_pair(TokenType.LT, TokenType.LT):
|
elif self._match_pair(TokenType.LT, TokenType.LT):
|
||||||
this = self.expression(
|
this = self.expression(exp.BitwiseLeftShift, this=this, expression=self._parse_term())
|
||||||
exp.BitwiseLeftShift, this=this, expression=self._parse_term()
|
|
||||||
)
|
|
||||||
elif self._match_pair(TokenType.GT, TokenType.GT):
|
elif self._match_pair(TokenType.GT, TokenType.GT):
|
||||||
this = self.expression(
|
this = self.expression(exp.BitwiseRightShift, this=this, expression=self._parse_term())
|
||||||
exp.BitwiseRightShift, this=this, expression=self._parse_term()
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -1524,7 +1496,7 @@ class Parser:
|
||||||
self.raise_error("Expecting >")
|
self.raise_error("Expecting >")
|
||||||
|
|
||||||
if type_token in self.TIMESTAMPS:
|
if type_token in self.TIMESTAMPS:
|
||||||
tz = self._match(TokenType.WITH_TIME_ZONE)
|
tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ
|
||||||
self._match(TokenType.WITHOUT_TIME_ZONE)
|
self._match(TokenType.WITHOUT_TIME_ZONE)
|
||||||
if tz:
|
if tz:
|
||||||
return exp.DataType(
|
return exp.DataType(
|
||||||
|
@ -1594,16 +1566,14 @@ class Parser:
|
||||||
if query:
|
if query:
|
||||||
expressions = [query]
|
expressions = [query]
|
||||||
else:
|
else:
|
||||||
expressions = self._parse_csv(
|
expressions = self._parse_csv(lambda: self._parse_alias(self._parse_conjunction(), explicit=True))
|
||||||
lambda: self._parse_alias(self._parse_conjunction(), explicit=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
this = list_get(expressions, 0)
|
this = list_get(expressions, 0)
|
||||||
self._parse_query_modifiers(this)
|
self._parse_query_modifiers(this)
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
|
|
||||||
if isinstance(this, exp.Subqueryable):
|
if isinstance(this, exp.Subqueryable):
|
||||||
return self._parse_subquery(this)
|
return self._parse_set_operations(self._parse_subquery(this))
|
||||||
if len(expressions) > 1:
|
if len(expressions) > 1:
|
||||||
return self.expression(exp.Tuple, expressions=expressions)
|
return self.expression(exp.Tuple, expressions=expressions)
|
||||||
return self.expression(exp.Paren, this=this)
|
return self.expression(exp.Paren, this=this)
|
||||||
|
@ -1611,11 +1581,7 @@ class Parser:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_field(self, any_token=False):
|
def _parse_field(self, any_token=False):
|
||||||
return (
|
return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token)
|
||||||
self._parse_primary()
|
|
||||||
or self._parse_function()
|
|
||||||
or self._parse_id_var(any_token)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_function(self):
|
def _parse_function(self):
|
||||||
if not self._curr:
|
if not self._curr:
|
||||||
|
@ -1628,21 +1594,22 @@ class Parser:
|
||||||
|
|
||||||
if not self._next or self._next.token_type != TokenType.L_PAREN:
|
if not self._next or self._next.token_type != TokenType.L_PAREN:
|
||||||
if token_type in self.NO_PAREN_FUNCTIONS:
|
if token_type in self.NO_PAREN_FUNCTIONS:
|
||||||
return self.expression(
|
return self.expression(self._advance() or self.NO_PAREN_FUNCTIONS[token_type])
|
||||||
self._advance() or self.NO_PAREN_FUNCTIONS[token_type]
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if token_type not in self.FUNC_TOKENS:
|
if token_type not in self.FUNC_TOKENS:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self._match_set(self.FUNCTION_PARSERS):
|
this = self._curr.text
|
||||||
self._advance()
|
upper = this.upper()
|
||||||
this = self.FUNCTION_PARSERS[token_type](self, token_type)
|
self._advance(2)
|
||||||
|
|
||||||
|
parser = self.FUNCTION_PARSERS.get(upper)
|
||||||
|
|
||||||
|
if parser:
|
||||||
|
this = parser(self)
|
||||||
else:
|
else:
|
||||||
subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type)
|
subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type)
|
||||||
this = self._curr.text
|
|
||||||
self._advance(2)
|
|
||||||
|
|
||||||
if subquery_predicate and self._curr.token_type in (
|
if subquery_predicate and self._curr.token_type in (
|
||||||
TokenType.SELECT,
|
TokenType.SELECT,
|
||||||
|
@ -1652,7 +1619,7 @@ class Parser:
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
return this
|
return this
|
||||||
|
|
||||||
function = self.FUNCTIONS.get(this.upper())
|
function = self.FUNCTIONS.get(upper)
|
||||||
args = self._parse_csv(self._parse_lambda)
|
args = self._parse_csv(self._parse_lambda)
|
||||||
|
|
||||||
if function:
|
if function:
|
||||||
|
@ -1700,10 +1667,7 @@ class Parser:
|
||||||
self._retreat(index)
|
self._retreat(index)
|
||||||
return this
|
return this
|
||||||
|
|
||||||
args = self._parse_csv(
|
args = self._parse_csv(lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True)))
|
||||||
lambda: self._parse_constraint()
|
|
||||||
or self._parse_column_def(self._parse_field())
|
|
||||||
)
|
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
return self.expression(exp.Schema, this=this, expressions=args)
|
return self.expression(exp.Schema, this=this, expressions=args)
|
||||||
|
|
||||||
|
@ -1720,12 +1684,9 @@ class Parser:
|
||||||
break
|
break
|
||||||
constraints.append(constraint)
|
constraints.append(constraint)
|
||||||
|
|
||||||
return self.expression(
|
return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints)
|
||||||
exp.ColumnDef, this=this, kind=kind, constraints=constraints
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_column_constraint(self):
|
def _parse_column_constraint(self):
|
||||||
kind = None
|
|
||||||
this = None
|
this = None
|
||||||
|
|
||||||
if self._match(TokenType.CONSTRAINT):
|
if self._match(TokenType.CONSTRAINT):
|
||||||
|
@ -1735,28 +1696,28 @@ class Parser:
|
||||||
kind = exp.AutoIncrementColumnConstraint()
|
kind = exp.AutoIncrementColumnConstraint()
|
||||||
elif self._match(TokenType.CHECK):
|
elif self._match(TokenType.CHECK):
|
||||||
self._match_l_paren()
|
self._match_l_paren()
|
||||||
kind = self.expression(
|
kind = self.expression(exp.CheckColumnConstraint, this=self._parse_conjunction())
|
||||||
exp.CheckColumnConstraint, this=self._parse_conjunction()
|
|
||||||
)
|
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
elif self._match(TokenType.COLLATE):
|
elif self._match(TokenType.COLLATE):
|
||||||
kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var())
|
kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var())
|
||||||
elif self._match(TokenType.DEFAULT):
|
elif self._match(TokenType.DEFAULT):
|
||||||
kind = self.expression(
|
kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_field())
|
||||||
exp.DefaultColumnConstraint, this=self._parse_field()
|
elif self._match_pair(TokenType.NOT, TokenType.NULL):
|
||||||
)
|
|
||||||
elif self._match(TokenType.NOT) and self._match(TokenType.NULL):
|
|
||||||
kind = exp.NotNullColumnConstraint()
|
kind = exp.NotNullColumnConstraint()
|
||||||
elif self._match(TokenType.SCHEMA_COMMENT):
|
elif self._match(TokenType.SCHEMA_COMMENT):
|
||||||
kind = self.expression(
|
kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string())
|
||||||
exp.CommentColumnConstraint, this=self._parse_string()
|
|
||||||
)
|
|
||||||
elif self._match(TokenType.PRIMARY_KEY):
|
elif self._match(TokenType.PRIMARY_KEY):
|
||||||
kind = exp.PrimaryKeyColumnConstraint()
|
kind = exp.PrimaryKeyColumnConstraint()
|
||||||
elif self._match(TokenType.UNIQUE):
|
elif self._match(TokenType.UNIQUE):
|
||||||
kind = exp.UniqueColumnConstraint()
|
kind = exp.UniqueColumnConstraint()
|
||||||
|
elif self._match(TokenType.GENERATED):
|
||||||
if kind is None:
|
if self._match(TokenType.BY_DEFAULT):
|
||||||
|
kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=False)
|
||||||
|
else:
|
||||||
|
self._match(TokenType.ALWAYS)
|
||||||
|
kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
|
||||||
|
self._match_pair(TokenType.ALIAS, TokenType.IDENTITY)
|
||||||
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self.expression(exp.ColumnConstraint, this=this, kind=kind)
|
return self.expression(exp.ColumnConstraint, this=this, kind=kind)
|
||||||
|
@ -1864,9 +1825,7 @@ class Parser:
|
||||||
if not self._match(TokenType.END):
|
if not self._match(TokenType.END):
|
||||||
self.raise_error("Expected END after CASE", self._prev)
|
self.raise_error("Expected END after CASE", self._prev)
|
||||||
|
|
||||||
return self._parse_window(
|
return self._parse_window(self.expression(exp.Case, this=expression, ifs=ifs, default=default))
|
||||||
self.expression(exp.Case, this=expression, ifs=ifs, default=default)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_if(self):
|
def _parse_if(self):
|
||||||
if self._match(TokenType.L_PAREN):
|
if self._match(TokenType.L_PAREN):
|
||||||
|
@ -1889,7 +1848,7 @@ class Parser:
|
||||||
if not self._match(TokenType.FROM):
|
if not self._match(TokenType.FROM):
|
||||||
self.raise_error("Expected FROM after EXTRACT", self._prev)
|
self.raise_error("Expected FROM after EXTRACT", self._prev)
|
||||||
|
|
||||||
return self.expression(exp.Extract, this=this, expression=self._parse_type())
|
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
|
||||||
|
|
||||||
def _parse_cast(self, strict):
|
def _parse_cast(self, strict):
|
||||||
this = self._parse_conjunction()
|
this = self._parse_conjunction()
|
||||||
|
@ -1917,12 +1876,54 @@ class Parser:
|
||||||
to = None
|
to = None
|
||||||
return self.expression(exp.Cast, this=this, to=to)
|
return self.expression(exp.Cast, this=this, to=to)
|
||||||
|
|
||||||
|
def _parse_substring(self):
|
||||||
|
# Postgres supports the form: substring(string [from int] [for int])
|
||||||
|
# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
|
||||||
|
|
||||||
|
args = self._parse_csv(self._parse_bitwise)
|
||||||
|
|
||||||
|
if self._match(TokenType.FROM):
|
||||||
|
args.append(self._parse_bitwise())
|
||||||
|
if self._match(TokenType.FOR):
|
||||||
|
args.append(self._parse_bitwise())
|
||||||
|
|
||||||
|
this = exp.Substring.from_arg_list(args)
|
||||||
|
self.validate_expression(this, args)
|
||||||
|
|
||||||
|
return this
|
||||||
|
|
||||||
|
def _parse_trim(self):
|
||||||
|
# https://www.w3resource.com/sql/character-functions/trim.php
|
||||||
|
# https://docs.oracle.com/javadb/10.8.3.0/ref/rreftrimfunc.html
|
||||||
|
|
||||||
|
position = None
|
||||||
|
collation = None
|
||||||
|
|
||||||
|
if self._match_set(self.TRIM_TYPES):
|
||||||
|
position = self._prev.text.upper()
|
||||||
|
|
||||||
|
expression = self._parse_term()
|
||||||
|
if self._match(TokenType.FROM):
|
||||||
|
this = self._parse_term()
|
||||||
|
else:
|
||||||
|
this = expression
|
||||||
|
expression = None
|
||||||
|
|
||||||
|
if self._match(TokenType.COLLATE):
|
||||||
|
collation = self._parse_term()
|
||||||
|
|
||||||
|
return self.expression(
|
||||||
|
exp.Trim,
|
||||||
|
this=this,
|
||||||
|
position=position,
|
||||||
|
expression=expression,
|
||||||
|
collation=collation,
|
||||||
|
)
|
||||||
|
|
||||||
def _parse_window(self, this, alias=False):
|
def _parse_window(self, this, alias=False):
|
||||||
if self._match(TokenType.FILTER):
|
if self._match(TokenType.FILTER):
|
||||||
self._match_l_paren()
|
self._match_l_paren()
|
||||||
this = self.expression(
|
this = self.expression(exp.Filter, this=this, expression=self._parse_where())
|
||||||
exp.Filter, this=this, expression=self._parse_where()
|
|
||||||
)
|
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
|
|
||||||
if self._match(TokenType.WITHIN_GROUP):
|
if self._match(TokenType.WITHIN_GROUP):
|
||||||
|
@ -1935,6 +1936,25 @@ class Parser:
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
return this
|
return this
|
||||||
|
|
||||||
|
# SQL spec defines an optional [ { IGNORE | RESPECT } NULLS ] OVER
|
||||||
|
# Some dialects choose to implement and some do not.
|
||||||
|
# https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html
|
||||||
|
|
||||||
|
# There is some code above in _parse_lambda that handles
|
||||||
|
# SELECT FIRST_VALUE(TABLE.COLUMN IGNORE|RESPECT NULLS) OVER ...
|
||||||
|
|
||||||
|
# The below changes handle
|
||||||
|
# SELECT FIRST_VALUE(TABLE.COLUMN) IGNORE|RESPECT NULLS OVER ...
|
||||||
|
|
||||||
|
# Oracle allows both formats
|
||||||
|
# (https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/img_text/first_value.html)
|
||||||
|
# and Snowflake chose to do the same for familiarity
|
||||||
|
# https://docs.snowflake.com/en/sql-reference/functions/first_value.html#usage-notes
|
||||||
|
if self._match(TokenType.IGNORE_NULLS):
|
||||||
|
this = self.expression(exp.IgnoreNulls, this=this)
|
||||||
|
elif self._match(TokenType.RESPECT_NULLS):
|
||||||
|
this = self.expression(exp.RespectNulls, this=this)
|
||||||
|
|
||||||
# bigquery select from window x AS (partition by ...)
|
# bigquery select from window x AS (partition by ...)
|
||||||
if alias:
|
if alias:
|
||||||
self._match(TokenType.ALIAS)
|
self._match(TokenType.ALIAS)
|
||||||
|
@ -1992,13 +2012,9 @@ class Parser:
|
||||||
self._match(TokenType.BETWEEN)
|
self._match(TokenType.BETWEEN)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"value": (
|
"value": (self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text)
|
||||||
self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW))
|
|
||||||
and self._prev.text
|
|
||||||
)
|
|
||||||
or self._parse_bitwise(),
|
or self._parse_bitwise(),
|
||||||
"side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING))
|
"side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text,
|
||||||
and self._prev.text,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def _parse_alias(self, this, explicit=False):
|
def _parse_alias(self, this, explicit=False):
|
||||||
|
@ -2023,22 +2039,16 @@ class Parser:
|
||||||
|
|
||||||
return this
|
return this
|
||||||
|
|
||||||
def _parse_id_var(self, any_token=True):
|
def _parse_id_var(self, any_token=True, tokens=None):
|
||||||
identifier = self._parse_identifier()
|
identifier = self._parse_identifier()
|
||||||
|
|
||||||
if identifier:
|
if identifier:
|
||||||
return identifier
|
return identifier
|
||||||
|
|
||||||
if (
|
if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
|
||||||
any_token
|
|
||||||
and self._curr
|
|
||||||
and self._curr.token_type not in self.RESERVED_KEYWORDS
|
|
||||||
):
|
|
||||||
return self._advance() or exp.Identifier(this=self._prev.text, quoted=False)
|
return self._advance() or exp.Identifier(this=self._prev.text, quoted=False)
|
||||||
|
|
||||||
return self._match_set(self.ID_VAR_TOKENS) and exp.Identifier(
|
return self._match_set(tokens or self.ID_VAR_TOKENS) and exp.Identifier(this=self._prev.text, quoted=False)
|
||||||
this=self._prev.text, quoted=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_string(self):
|
def _parse_string(self):
|
||||||
if self._match(TokenType.STRING):
|
if self._match(TokenType.STRING):
|
||||||
|
@ -2077,9 +2087,7 @@ class Parser:
|
||||||
|
|
||||||
def _parse_star(self):
|
def _parse_star(self):
|
||||||
if self._match(TokenType.STAR):
|
if self._match(TokenType.STAR):
|
||||||
return exp.Star(
|
return exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()})
|
||||||
**{"except": self._parse_except(), "replace": self._parse_replace()}
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_placeholder(self):
|
def _parse_placeholder(self):
|
||||||
|
@ -2117,15 +2125,10 @@ class Parser:
|
||||||
this = parse()
|
this = parse()
|
||||||
|
|
||||||
while self._match_set(expressions):
|
while self._match_set(expressions):
|
||||||
this = self.expression(
|
this = self.expression(expressions[self._prev.token_type], this=this, expression=parse())
|
||||||
expressions[self._prev.token_type], this=this, expression=parse()
|
|
||||||
)
|
|
||||||
|
|
||||||
return this
|
return this
|
||||||
|
|
||||||
def _parse_all(self, parse):
|
|
||||||
return list(iter(parse, None))
|
|
||||||
|
|
||||||
def _parse_wrapped_id_vars(self):
|
def _parse_wrapped_id_vars(self):
|
||||||
self._match_l_paren()
|
self._match_l_paren()
|
||||||
expressions = self._parse_csv(self._parse_id_var)
|
expressions = self._parse_csv(self._parse_id_var)
|
||||||
|
@ -2156,10 +2159,7 @@ class Parser:
|
||||||
if not self._curr or not self._next:
|
if not self._curr or not self._next:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if (
|
if self._curr.token_type == token_type_a and self._next.token_type == token_type_b:
|
||||||
self._curr.token_type == token_type_a
|
|
||||||
and self._next.token_type == token_type_b
|
|
||||||
):
|
|
||||||
if advance:
|
if advance:
|
||||||
self._advance(2)
|
self._advance(2)
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -72,9 +72,7 @@ class Step:
|
||||||
if from_:
|
if from_:
|
||||||
from_ = from_.expressions
|
from_ = from_.expressions
|
||||||
if len(from_) > 1:
|
if len(from_) > 1:
|
||||||
raise UnsupportedError(
|
raise UnsupportedError("Multi-from statements are unsupported. Run it through the optimizer")
|
||||||
"Multi-from statements are unsupported. Run it through the optimizer"
|
|
||||||
)
|
|
||||||
|
|
||||||
step = Scan.from_expression(from_[0], ctes)
|
step = Scan.from_expression(from_[0], ctes)
|
||||||
else:
|
else:
|
||||||
|
@ -104,9 +102,7 @@ class Step:
|
||||||
continue
|
continue
|
||||||
if operand not in operands:
|
if operand not in operands:
|
||||||
operands[operand] = f"_a_{next(sequence)}"
|
operands[operand] = f"_a_{next(sequence)}"
|
||||||
operand.replace(
|
operand.replace(exp.column(operands[operand], step.name, quoted=True))
|
||||||
exp.column(operands[operand], step.name, quoted=True)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
projections.append(e)
|
projections.append(e)
|
||||||
|
|
||||||
|
@ -121,14 +117,9 @@ class Step:
|
||||||
aggregate = Aggregate()
|
aggregate = Aggregate()
|
||||||
aggregate.source = step.name
|
aggregate.source = step.name
|
||||||
aggregate.name = step.name
|
aggregate.name = step.name
|
||||||
aggregate.operands = tuple(
|
aggregate.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
|
||||||
alias(operand, alias_) for operand, alias_ in operands.items()
|
|
||||||
)
|
|
||||||
aggregate.aggregations = aggregations
|
aggregate.aggregations = aggregations
|
||||||
aggregate.group = [
|
aggregate.group = [exp.column(e.alias_or_name, step.name, quoted=True) for e in group.expressions]
|
||||||
exp.column(e.alias_or_name, step.name, quoted=True)
|
|
||||||
for e in group.expressions
|
|
||||||
]
|
|
||||||
aggregate.add_dependency(step)
|
aggregate.add_dependency(step)
|
||||||
step = aggregate
|
step = aggregate
|
||||||
|
|
||||||
|
@ -212,9 +203,7 @@ class Scan(Step):
|
||||||
alias_ = expression.alias
|
alias_ = expression.alias
|
||||||
|
|
||||||
if not alias_:
|
if not alias_:
|
||||||
raise UnsupportedError(
|
raise UnsupportedError("Tables/Subqueries must be aliased. Run it through the optimizer")
|
||||||
"Tables/Subqueries must be aliased. Run it through the optimizer"
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(expression, exp.Subquery):
|
if isinstance(expression, exp.Subquery):
|
||||||
step = Step.from_expression(table, ctes)
|
step = Step.from_expression(table, ctes)
|
||||||
|
|
|
@ -38,6 +38,7 @@ class TokenType(AutoName):
|
||||||
DARROW = auto()
|
DARROW = auto()
|
||||||
HASH_ARROW = auto()
|
HASH_ARROW = auto()
|
||||||
DHASH_ARROW = auto()
|
DHASH_ARROW = auto()
|
||||||
|
LR_ARROW = auto()
|
||||||
ANNOTATION = auto()
|
ANNOTATION = auto()
|
||||||
DOLLAR = auto()
|
DOLLAR = auto()
|
||||||
|
|
||||||
|
@ -53,6 +54,7 @@ class TokenType(AutoName):
|
||||||
TABLE = auto()
|
TABLE = auto()
|
||||||
VAR = auto()
|
VAR = auto()
|
||||||
BIT_STRING = auto()
|
BIT_STRING = auto()
|
||||||
|
HEX_STRING = auto()
|
||||||
|
|
||||||
# types
|
# types
|
||||||
BOOLEAN = auto()
|
BOOLEAN = auto()
|
||||||
|
@ -78,10 +80,17 @@ class TokenType(AutoName):
|
||||||
UUID = auto()
|
UUID = auto()
|
||||||
GEOGRAPHY = auto()
|
GEOGRAPHY = auto()
|
||||||
NULLABLE = auto()
|
NULLABLE = auto()
|
||||||
|
GEOMETRY = auto()
|
||||||
|
HLLSKETCH = auto()
|
||||||
|
SUPER = auto()
|
||||||
|
SERIAL = auto()
|
||||||
|
SMALLSERIAL = auto()
|
||||||
|
BIGSERIAL = auto()
|
||||||
|
|
||||||
# keywords
|
# keywords
|
||||||
ADD_FILE = auto()
|
ADD_FILE = auto()
|
||||||
ALIAS = auto()
|
ALIAS = auto()
|
||||||
|
ALWAYS = auto()
|
||||||
ALL = auto()
|
ALL = auto()
|
||||||
ALTER = auto()
|
ALTER = auto()
|
||||||
ANALYZE = auto()
|
ANALYZE = auto()
|
||||||
|
@ -92,11 +101,12 @@ class TokenType(AutoName):
|
||||||
AUTO_INCREMENT = auto()
|
AUTO_INCREMENT = auto()
|
||||||
BEGIN = auto()
|
BEGIN = auto()
|
||||||
BETWEEN = auto()
|
BETWEEN = auto()
|
||||||
|
BOTH = auto()
|
||||||
BUCKET = auto()
|
BUCKET = auto()
|
||||||
|
BY_DEFAULT = auto()
|
||||||
CACHE = auto()
|
CACHE = auto()
|
||||||
CALL = auto()
|
CALL = auto()
|
||||||
CASE = auto()
|
CASE = auto()
|
||||||
CAST = auto()
|
|
||||||
CHARACTER_SET = auto()
|
CHARACTER_SET = auto()
|
||||||
CHECK = auto()
|
CHECK = auto()
|
||||||
CLUSTER_BY = auto()
|
CLUSTER_BY = auto()
|
||||||
|
@ -104,7 +114,6 @@ class TokenType(AutoName):
|
||||||
COMMENT = auto()
|
COMMENT = auto()
|
||||||
COMMIT = auto()
|
COMMIT = auto()
|
||||||
CONSTRAINT = auto()
|
CONSTRAINT = auto()
|
||||||
CONVERT = auto()
|
|
||||||
CREATE = auto()
|
CREATE = auto()
|
||||||
CROSS = auto()
|
CROSS = auto()
|
||||||
CUBE = auto()
|
CUBE = auto()
|
||||||
|
@ -127,22 +136,24 @@ class TokenType(AutoName):
|
||||||
EXCEPT = auto()
|
EXCEPT = auto()
|
||||||
EXISTS = auto()
|
EXISTS = auto()
|
||||||
EXPLAIN = auto()
|
EXPLAIN = auto()
|
||||||
EXTRACT = auto()
|
|
||||||
FALSE = auto()
|
FALSE = auto()
|
||||||
FETCH = auto()
|
FETCH = auto()
|
||||||
FILTER = auto()
|
FILTER = auto()
|
||||||
FINAL = auto()
|
FINAL = auto()
|
||||||
FIRST = auto()
|
FIRST = auto()
|
||||||
FOLLOWING = auto()
|
FOLLOWING = auto()
|
||||||
|
FOR = auto()
|
||||||
FOREIGN_KEY = auto()
|
FOREIGN_KEY = auto()
|
||||||
FORMAT = auto()
|
FORMAT = auto()
|
||||||
FULL = auto()
|
FULL = auto()
|
||||||
FUNCTION = auto()
|
FUNCTION = auto()
|
||||||
FROM = auto()
|
FROM = auto()
|
||||||
|
GENERATED = auto()
|
||||||
GROUP_BY = auto()
|
GROUP_BY = auto()
|
||||||
GROUPING_SETS = auto()
|
GROUPING_SETS = auto()
|
||||||
HAVING = auto()
|
HAVING = auto()
|
||||||
HINT = auto()
|
HINT = auto()
|
||||||
|
IDENTITY = auto()
|
||||||
IF = auto()
|
IF = auto()
|
||||||
IGNORE_NULLS = auto()
|
IGNORE_NULLS = auto()
|
||||||
ILIKE = auto()
|
ILIKE = auto()
|
||||||
|
@ -159,12 +170,14 @@ class TokenType(AutoName):
|
||||||
JOIN = auto()
|
JOIN = auto()
|
||||||
LATERAL = auto()
|
LATERAL = auto()
|
||||||
LAZY = auto()
|
LAZY = auto()
|
||||||
|
LEADING = auto()
|
||||||
LEFT = auto()
|
LEFT = auto()
|
||||||
LIKE = auto()
|
LIKE = auto()
|
||||||
LIMIT = auto()
|
LIMIT = auto()
|
||||||
LOCATION = auto()
|
LOCATION = auto()
|
||||||
MAP = auto()
|
MAP = auto()
|
||||||
MOD = auto()
|
MOD = auto()
|
||||||
|
NATURAL = auto()
|
||||||
NEXT = auto()
|
NEXT = auto()
|
||||||
NO_ACTION = auto()
|
NO_ACTION = auto()
|
||||||
NULL = auto()
|
NULL = auto()
|
||||||
|
@ -204,8 +217,10 @@ class TokenType(AutoName):
|
||||||
ROWS = auto()
|
ROWS = auto()
|
||||||
SCHEMA_COMMENT = auto()
|
SCHEMA_COMMENT = auto()
|
||||||
SELECT = auto()
|
SELECT = auto()
|
||||||
|
SEPARATOR = auto()
|
||||||
SET = auto()
|
SET = auto()
|
||||||
SHOW = auto()
|
SHOW = auto()
|
||||||
|
SIMILAR_TO = auto()
|
||||||
SOME = auto()
|
SOME = auto()
|
||||||
SORT_BY = auto()
|
SORT_BY = auto()
|
||||||
STORED = auto()
|
STORED = auto()
|
||||||
|
@ -213,12 +228,11 @@ class TokenType(AutoName):
|
||||||
TABLE_FORMAT = auto()
|
TABLE_FORMAT = auto()
|
||||||
TABLE_SAMPLE = auto()
|
TABLE_SAMPLE = auto()
|
||||||
TEMPORARY = auto()
|
TEMPORARY = auto()
|
||||||
TIME = auto()
|
|
||||||
TOP = auto()
|
TOP = auto()
|
||||||
THEN = auto()
|
THEN = auto()
|
||||||
TRUE = auto()
|
TRUE = auto()
|
||||||
|
TRAILING = auto()
|
||||||
TRUNCATE = auto()
|
TRUNCATE = auto()
|
||||||
TRY_CAST = auto()
|
|
||||||
UNBOUNDED = auto()
|
UNBOUNDED = auto()
|
||||||
UNCACHE = auto()
|
UNCACHE = auto()
|
||||||
UNION = auto()
|
UNION = auto()
|
||||||
|
@ -272,35 +286,32 @@ class _Tokenizer(type):
|
||||||
def __new__(cls, clsname, bases, attrs):
|
def __new__(cls, clsname, bases, attrs):
|
||||||
klass = super().__new__(cls, clsname, bases, attrs)
|
klass = super().__new__(cls, clsname, bases, attrs)
|
||||||
|
|
||||||
klass.QUOTES = dict(
|
klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES)
|
||||||
(quote, quote) if isinstance(quote, str) else (quote[0], quote[1])
|
klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS)
|
||||||
for quote in klass.QUOTES
|
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
|
||||||
)
|
klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
|
||||||
|
klass._COMMENTS = dict(
|
||||||
klass.IDENTIFIERS = dict(
|
(comment, None) if isinstance(comment, str) else (comment[0], comment[1]) for comment in klass.COMMENTS
|
||||||
(identifier, identifier)
|
|
||||||
if isinstance(identifier, str)
|
|
||||||
else (identifier[0], identifier[1])
|
|
||||||
for identifier in klass.IDENTIFIERS
|
|
||||||
)
|
|
||||||
|
|
||||||
klass.COMMENTS = dict(
|
|
||||||
(comment, None) if isinstance(comment, str) else (comment[0], comment[1])
|
|
||||||
for comment in klass.COMMENTS
|
|
||||||
)
|
)
|
||||||
|
|
||||||
klass.KEYWORD_TRIE = new_trie(
|
klass.KEYWORD_TRIE = new_trie(
|
||||||
key.upper()
|
key.upper()
|
||||||
for key, value in {
|
for key, value in {
|
||||||
**klass.KEYWORDS,
|
**klass.KEYWORDS,
|
||||||
**{comment: TokenType.COMMENT for comment in klass.COMMENTS},
|
**{comment: TokenType.COMMENT for comment in klass._COMMENTS},
|
||||||
**{quote: TokenType.QUOTE for quote in klass.QUOTES},
|
**{quote: TokenType.QUOTE for quote in klass._QUOTES},
|
||||||
|
**{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS},
|
||||||
|
**{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS},
|
||||||
}.items()
|
}.items()
|
||||||
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
|
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
|
||||||
)
|
)
|
||||||
|
|
||||||
return klass
|
return klass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _delimeter_list_to_dict(list):
|
||||||
|
return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list)
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer(metaclass=_Tokenizer):
|
class Tokenizer(metaclass=_Tokenizer):
|
||||||
SINGLE_TOKENS = {
|
SINGLE_TOKENS = {
|
||||||
|
@ -339,6 +350,10 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
|
|
||||||
QUOTES = ["'"]
|
QUOTES = ["'"]
|
||||||
|
|
||||||
|
BIT_STRINGS = []
|
||||||
|
|
||||||
|
HEX_STRINGS = []
|
||||||
|
|
||||||
IDENTIFIERS = ['"']
|
IDENTIFIERS = ['"']
|
||||||
|
|
||||||
ESCAPE = "'"
|
ESCAPE = "'"
|
||||||
|
@ -357,6 +372,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"->>": TokenType.DARROW,
|
"->>": TokenType.DARROW,
|
||||||
"#>": TokenType.HASH_ARROW,
|
"#>": TokenType.HASH_ARROW,
|
||||||
"#>>": TokenType.DHASH_ARROW,
|
"#>>": TokenType.DHASH_ARROW,
|
||||||
|
"<->": TokenType.LR_ARROW,
|
||||||
"ADD ARCHIVE": TokenType.ADD_FILE,
|
"ADD ARCHIVE": TokenType.ADD_FILE,
|
||||||
"ADD ARCHIVES": TokenType.ADD_FILE,
|
"ADD ARCHIVES": TokenType.ADD_FILE,
|
||||||
"ADD FILE": TokenType.ADD_FILE,
|
"ADD FILE": TokenType.ADD_FILE,
|
||||||
|
@ -374,12 +390,12 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"AUTO_INCREMENT": TokenType.AUTO_INCREMENT,
|
"AUTO_INCREMENT": TokenType.AUTO_INCREMENT,
|
||||||
"BEGIN": TokenType.BEGIN,
|
"BEGIN": TokenType.BEGIN,
|
||||||
"BETWEEN": TokenType.BETWEEN,
|
"BETWEEN": TokenType.BETWEEN,
|
||||||
|
"BOTH": TokenType.BOTH,
|
||||||
"BUCKET": TokenType.BUCKET,
|
"BUCKET": TokenType.BUCKET,
|
||||||
"CALL": TokenType.CALL,
|
"CALL": TokenType.CALL,
|
||||||
"CACHE": TokenType.CACHE,
|
"CACHE": TokenType.CACHE,
|
||||||
"UNCACHE": TokenType.UNCACHE,
|
"UNCACHE": TokenType.UNCACHE,
|
||||||
"CASE": TokenType.CASE,
|
"CASE": TokenType.CASE,
|
||||||
"CAST": TokenType.CAST,
|
|
||||||
"CHARACTER SET": TokenType.CHARACTER_SET,
|
"CHARACTER SET": TokenType.CHARACTER_SET,
|
||||||
"CHECK": TokenType.CHECK,
|
"CHECK": TokenType.CHECK,
|
||||||
"CLUSTER BY": TokenType.CLUSTER_BY,
|
"CLUSTER BY": TokenType.CLUSTER_BY,
|
||||||
|
@ -387,7 +403,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"COMMENT": TokenType.SCHEMA_COMMENT,
|
"COMMENT": TokenType.SCHEMA_COMMENT,
|
||||||
"COMMIT": TokenType.COMMIT,
|
"COMMIT": TokenType.COMMIT,
|
||||||
"CONSTRAINT": TokenType.CONSTRAINT,
|
"CONSTRAINT": TokenType.CONSTRAINT,
|
||||||
"CONVERT": TokenType.CONVERT,
|
|
||||||
"CREATE": TokenType.CREATE,
|
"CREATE": TokenType.CREATE,
|
||||||
"CROSS": TokenType.CROSS,
|
"CROSS": TokenType.CROSS,
|
||||||
"CUBE": TokenType.CUBE,
|
"CUBE": TokenType.CUBE,
|
||||||
|
@ -408,7 +423,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"EXCEPT": TokenType.EXCEPT,
|
"EXCEPT": TokenType.EXCEPT,
|
||||||
"EXISTS": TokenType.EXISTS,
|
"EXISTS": TokenType.EXISTS,
|
||||||
"EXPLAIN": TokenType.EXPLAIN,
|
"EXPLAIN": TokenType.EXPLAIN,
|
||||||
"EXTRACT": TokenType.EXTRACT,
|
|
||||||
"FALSE": TokenType.FALSE,
|
"FALSE": TokenType.FALSE,
|
||||||
"FETCH": TokenType.FETCH,
|
"FETCH": TokenType.FETCH,
|
||||||
"FILTER": TokenType.FILTER,
|
"FILTER": TokenType.FILTER,
|
||||||
|
@ -437,10 +451,12 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"JOIN": TokenType.JOIN,
|
"JOIN": TokenType.JOIN,
|
||||||
"LATERAL": TokenType.LATERAL,
|
"LATERAL": TokenType.LATERAL,
|
||||||
"LAZY": TokenType.LAZY,
|
"LAZY": TokenType.LAZY,
|
||||||
|
"LEADING": TokenType.LEADING,
|
||||||
"LEFT": TokenType.LEFT,
|
"LEFT": TokenType.LEFT,
|
||||||
"LIKE": TokenType.LIKE,
|
"LIKE": TokenType.LIKE,
|
||||||
"LIMIT": TokenType.LIMIT,
|
"LIMIT": TokenType.LIMIT,
|
||||||
"LOCATION": TokenType.LOCATION,
|
"LOCATION": TokenType.LOCATION,
|
||||||
|
"NATURAL": TokenType.NATURAL,
|
||||||
"NEXT": TokenType.NEXT,
|
"NEXT": TokenType.NEXT,
|
||||||
"NO ACTION": TokenType.NO_ACTION,
|
"NO ACTION": TokenType.NO_ACTION,
|
||||||
"NOT": TokenType.NOT,
|
"NOT": TokenType.NOT,
|
||||||
|
@ -490,8 +506,8 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"TEMPORARY": TokenType.TEMPORARY,
|
"TEMPORARY": TokenType.TEMPORARY,
|
||||||
"THEN": TokenType.THEN,
|
"THEN": TokenType.THEN,
|
||||||
"TRUE": TokenType.TRUE,
|
"TRUE": TokenType.TRUE,
|
||||||
|
"TRAILING": TokenType.TRAILING,
|
||||||
"TRUNCATE": TokenType.TRUNCATE,
|
"TRUNCATE": TokenType.TRUNCATE,
|
||||||
"TRY_CAST": TokenType.TRY_CAST,
|
|
||||||
"UNBOUNDED": TokenType.UNBOUNDED,
|
"UNBOUNDED": TokenType.UNBOUNDED,
|
||||||
"UNION": TokenType.UNION,
|
"UNION": TokenType.UNION,
|
||||||
"UNNEST": TokenType.UNNEST,
|
"UNNEST": TokenType.UNNEST,
|
||||||
|
@ -626,14 +642,12 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
break
|
break
|
||||||
|
|
||||||
white_space = self.WHITE_SPACE.get(self._char)
|
white_space = self.WHITE_SPACE.get(self._char)
|
||||||
identifier_end = self.IDENTIFIERS.get(self._char)
|
identifier_end = self._IDENTIFIERS.get(self._char)
|
||||||
|
|
||||||
if white_space:
|
if white_space:
|
||||||
if white_space == TokenType.BREAK:
|
if white_space == TokenType.BREAK:
|
||||||
self._col = 1
|
self._col = 1
|
||||||
self._line += 1
|
self._line += 1
|
||||||
elif self._char == "0" and self._peek == "x":
|
|
||||||
self._scan_hex()
|
|
||||||
elif self._char.isdigit():
|
elif self._char.isdigit():
|
||||||
self._scan_number()
|
self._scan_number()
|
||||||
elif identifier_end:
|
elif identifier_end:
|
||||||
|
@ -666,9 +680,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
text = self._text if text is None else text
|
text = self._text if text is None else text
|
||||||
self.tokens.append(Token(token_type, text, self._line, self._col))
|
self.tokens.append(Token(token_type, text, self._line, self._col))
|
||||||
|
|
||||||
if token_type in self.COMMANDS and (
|
if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON):
|
||||||
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
|
|
||||||
):
|
|
||||||
self._start = self._current
|
self._start = self._current
|
||||||
while not self._end and self._peek != ";":
|
while not self._end and self._peek != ";":
|
||||||
self._advance()
|
self._advance()
|
||||||
|
@ -725,6 +737,8 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
|
|
||||||
if self._scan_string(word):
|
if self._scan_string(word):
|
||||||
return
|
return
|
||||||
|
if self._scan_numeric_string(word):
|
||||||
|
return
|
||||||
if self._scan_comment(word):
|
if self._scan_comment(word):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -732,10 +746,10 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
self._add(self.KEYWORDS[word.upper()])
|
self._add(self.KEYWORDS[word.upper()])
|
||||||
|
|
||||||
def _scan_comment(self, comment_start):
|
def _scan_comment(self, comment_start):
|
||||||
if comment_start not in self.COMMENTS:
|
if comment_start not in self._COMMENTS:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
comment_end = self.COMMENTS[comment_start]
|
comment_end = self._COMMENTS[comment_start]
|
||||||
|
|
||||||
if comment_end:
|
if comment_end:
|
||||||
comment_end_size = len(comment_end)
|
comment_end_size = len(comment_end)
|
||||||
|
@ -749,15 +763,18 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _scan_annotation(self):
|
def _scan_annotation(self):
|
||||||
while (
|
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK and self._peek != ",":
|
||||||
not self._end
|
|
||||||
and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK
|
|
||||||
and self._peek != ","
|
|
||||||
):
|
|
||||||
self._advance()
|
self._advance()
|
||||||
self._add(TokenType.ANNOTATION, self._text[1:])
|
self._add(TokenType.ANNOTATION, self._text[1:])
|
||||||
|
|
||||||
def _scan_number(self):
|
def _scan_number(self):
|
||||||
|
if self._char == "0":
|
||||||
|
peek = self._peek.upper()
|
||||||
|
if peek == "B":
|
||||||
|
return self._scan_bits()
|
||||||
|
elif peek == "X":
|
||||||
|
return self._scan_hex()
|
||||||
|
|
||||||
decimal = False
|
decimal = False
|
||||||
scientific = 0
|
scientific = 0
|
||||||
|
|
||||||
|
@ -788,57 +805,71 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
else:
|
else:
|
||||||
return self._add(TokenType.NUMBER)
|
return self._add(TokenType.NUMBER)
|
||||||
|
|
||||||
|
def _scan_bits(self):
|
||||||
|
self._advance()
|
||||||
|
value = self._extract_value()
|
||||||
|
try:
|
||||||
|
self._add(TokenType.BIT_STRING, f"{int(value, 2)}")
|
||||||
|
except ValueError:
|
||||||
|
self._add(TokenType.IDENTIFIER)
|
||||||
|
|
||||||
def _scan_hex(self):
|
def _scan_hex(self):
|
||||||
self._advance()
|
self._advance()
|
||||||
|
value = self._extract_value()
|
||||||
|
try:
|
||||||
|
self._add(TokenType.HEX_STRING, f"{int(value, 16)}")
|
||||||
|
except ValueError:
|
||||||
|
self._add(TokenType.IDENTIFIER)
|
||||||
|
|
||||||
|
def _extract_value(self):
|
||||||
while True:
|
while True:
|
||||||
char = self._peek.strip()
|
char = self._peek.strip()
|
||||||
if char and char not in self.SINGLE_TOKENS:
|
if char and char not in self.SINGLE_TOKENS:
|
||||||
self._advance()
|
self._advance()
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
try:
|
|
||||||
self._add(TokenType.BIT_STRING, f"{int(self._text, 16):b}")
|
return self._text
|
||||||
except ValueError:
|
|
||||||
self._add(TokenType.IDENTIFIER)
|
|
||||||
|
|
||||||
def _scan_string(self, quote):
|
def _scan_string(self, quote):
|
||||||
quote_end = self.QUOTES.get(quote)
|
quote_end = self._QUOTES.get(quote)
|
||||||
if quote_end is None:
|
if quote_end is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
text = ""
|
|
||||||
self._advance(len(quote))
|
self._advance(len(quote))
|
||||||
quote_end_size = len(quote_end)
|
text = self._extract_string(quote_end)
|
||||||
|
|
||||||
while True:
|
|
||||||
if self._char == self.ESCAPE and self._peek == quote_end:
|
|
||||||
text += quote
|
|
||||||
self._advance(2)
|
|
||||||
else:
|
|
||||||
if self._chars(quote_end_size) == quote_end:
|
|
||||||
if quote_end_size > 1:
|
|
||||||
self._advance(quote_end_size - 1)
|
|
||||||
break
|
|
||||||
|
|
||||||
if self._end:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Missing {quote} from {self._line}:{self._start}"
|
|
||||||
)
|
|
||||||
text += self._char
|
|
||||||
self._advance()
|
|
||||||
|
|
||||||
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
|
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
|
||||||
text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text
|
text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text
|
||||||
self._add(TokenType.STRING, text)
|
self._add(TokenType.STRING, text)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _scan_numeric_string(self, string_start):
|
||||||
|
if string_start in self._HEX_STRINGS:
|
||||||
|
delimiters = self._HEX_STRINGS
|
||||||
|
token_type = TokenType.HEX_STRING
|
||||||
|
base = 16
|
||||||
|
elif string_start in self._BIT_STRINGS:
|
||||||
|
delimiters = self._BIT_STRINGS
|
||||||
|
token_type = TokenType.BIT_STRING
|
||||||
|
base = 2
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._advance(len(string_start))
|
||||||
|
string_end = delimiters.get(string_start)
|
||||||
|
text = self._extract_string(string_end)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._add(token_type, f"{int(text, base)}")
|
||||||
|
except ValueError:
|
||||||
|
raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}")
|
||||||
|
return True
|
||||||
|
|
||||||
def _scan_identifier(self, identifier_end):
|
def _scan_identifier(self, identifier_end):
|
||||||
while self._peek != identifier_end:
|
while self._peek != identifier_end:
|
||||||
if self._end:
|
if self._end:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}")
|
||||||
f"Missing {identifier_end} from {self._line}:{self._start}"
|
|
||||||
)
|
|
||||||
self._advance()
|
self._advance()
|
||||||
self._advance()
|
self._advance()
|
||||||
self._add(TokenType.IDENTIFIER, self._text[1:-1])
|
self._add(TokenType.IDENTIFIER, self._text[1:-1])
|
||||||
|
@ -851,3 +882,24 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
self._add(self.KEYWORDS.get(self._text.upper(), TokenType.VAR))
|
self._add(self.KEYWORDS.get(self._text.upper(), TokenType.VAR))
|
||||||
|
|
||||||
|
def _extract_string(self, delimiter):
|
||||||
|
text = ""
|
||||||
|
delim_size = len(delimiter)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if self._char == self.ESCAPE and self._peek == delimiter:
|
||||||
|
text += delimiter
|
||||||
|
self._advance(2)
|
||||||
|
else:
|
||||||
|
if self._chars(delim_size) == delimiter:
|
||||||
|
if delim_size > 1:
|
||||||
|
self._advance(delim_size - 1)
|
||||||
|
break
|
||||||
|
|
||||||
|
if self._end:
|
||||||
|
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
|
||||||
|
text += self._char
|
||||||
|
self._advance()
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
|
@ -12,9 +12,7 @@ def unalias_group(expression):
|
||||||
"""
|
"""
|
||||||
if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
|
if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
|
||||||
aliased_selects = {
|
aliased_selects = {
|
||||||
e.alias: i
|
e.alias: i for i, e in enumerate(expression.parent.expressions, start=1) if isinstance(e, exp.Alias)
|
||||||
for i, e in enumerate(expression.parent.expressions, start=1)
|
|
||||||
if isinstance(e, exp.Alias)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
expression = expression.copy()
|
expression = expression.copy()
|
||||||
|
|
|
@ -36,9 +36,7 @@ class Validator(unittest.TestCase):
|
||||||
for read_dialect, read_sql in (read or {}).items():
|
for read_dialect, read_sql in (read or {}).items():
|
||||||
with self.subTest(f"{read_dialect} -> {sql}"):
|
with self.subTest(f"{read_dialect} -> {sql}"):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
parse_one(read_sql, read_dialect).sql(
|
parse_one(read_sql, read_dialect).sql(self.dialect, unsupported_level=ErrorLevel.IGNORE),
|
||||||
self.dialect, unsupported_level=ErrorLevel.IGNORE
|
|
||||||
),
|
|
||||||
sql,
|
sql,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -46,9 +44,7 @@ class Validator(unittest.TestCase):
|
||||||
with self.subTest(f"{sql} -> {write_dialect}"):
|
with self.subTest(f"{sql} -> {write_dialect}"):
|
||||||
if write_sql is UnsupportedError:
|
if write_sql is UnsupportedError:
|
||||||
with self.assertRaises(UnsupportedError):
|
with self.assertRaises(UnsupportedError):
|
||||||
expression.sql(
|
expression.sql(write_dialect, unsupported_level=ErrorLevel.RAISE)
|
||||||
write_dialect, unsupported_level=ErrorLevel.RAISE
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
expression.sql(
|
expression.sql(
|
||||||
|
@ -82,11 +78,19 @@ class TestDialect(Validator):
|
||||||
"oracle": "CAST(a AS CLOB)",
|
"oracle": "CAST(a AS CLOB)",
|
||||||
"postgres": "CAST(a AS TEXT)",
|
"postgres": "CAST(a AS TEXT)",
|
||||||
"presto": "CAST(a AS VARCHAR)",
|
"presto": "CAST(a AS VARCHAR)",
|
||||||
|
"redshift": "CAST(a AS TEXT)",
|
||||||
"snowflake": "CAST(a AS TEXT)",
|
"snowflake": "CAST(a AS TEXT)",
|
||||||
"spark": "CAST(a AS STRING)",
|
"spark": "CAST(a AS STRING)",
|
||||||
"starrocks": "CAST(a AS STRING)",
|
"starrocks": "CAST(a AS STRING)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"CAST(a AS DATETIME)",
|
||||||
|
write={
|
||||||
|
"postgres": "CAST(a AS TIMESTAMP)",
|
||||||
|
"sqlite": "CAST(a AS DATETIME)",
|
||||||
|
},
|
||||||
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CAST(a AS STRING)",
|
"CAST(a AS STRING)",
|
||||||
write={
|
write={
|
||||||
|
@ -97,6 +101,7 @@ class TestDialect(Validator):
|
||||||
"oracle": "CAST(a AS CLOB)",
|
"oracle": "CAST(a AS CLOB)",
|
||||||
"postgres": "CAST(a AS TEXT)",
|
"postgres": "CAST(a AS TEXT)",
|
||||||
"presto": "CAST(a AS VARCHAR)",
|
"presto": "CAST(a AS VARCHAR)",
|
||||||
|
"redshift": "CAST(a AS TEXT)",
|
||||||
"snowflake": "CAST(a AS TEXT)",
|
"snowflake": "CAST(a AS TEXT)",
|
||||||
"spark": "CAST(a AS STRING)",
|
"spark": "CAST(a AS STRING)",
|
||||||
"starrocks": "CAST(a AS STRING)",
|
"starrocks": "CAST(a AS STRING)",
|
||||||
|
@ -112,6 +117,7 @@ class TestDialect(Validator):
|
||||||
"oracle": "CAST(a AS VARCHAR2)",
|
"oracle": "CAST(a AS VARCHAR2)",
|
||||||
"postgres": "CAST(a AS VARCHAR)",
|
"postgres": "CAST(a AS VARCHAR)",
|
||||||
"presto": "CAST(a AS VARCHAR)",
|
"presto": "CAST(a AS VARCHAR)",
|
||||||
|
"redshift": "CAST(a AS VARCHAR)",
|
||||||
"snowflake": "CAST(a AS VARCHAR)",
|
"snowflake": "CAST(a AS VARCHAR)",
|
||||||
"spark": "CAST(a AS STRING)",
|
"spark": "CAST(a AS STRING)",
|
||||||
"starrocks": "CAST(a AS VARCHAR)",
|
"starrocks": "CAST(a AS VARCHAR)",
|
||||||
|
@ -127,6 +133,7 @@ class TestDialect(Validator):
|
||||||
"oracle": "CAST(a AS VARCHAR2(3))",
|
"oracle": "CAST(a AS VARCHAR2(3))",
|
||||||
"postgres": "CAST(a AS VARCHAR(3))",
|
"postgres": "CAST(a AS VARCHAR(3))",
|
||||||
"presto": "CAST(a AS VARCHAR(3))",
|
"presto": "CAST(a AS VARCHAR(3))",
|
||||||
|
"redshift": "CAST(a AS VARCHAR(3))",
|
||||||
"snowflake": "CAST(a AS VARCHAR(3))",
|
"snowflake": "CAST(a AS VARCHAR(3))",
|
||||||
"spark": "CAST(a AS VARCHAR(3))",
|
"spark": "CAST(a AS VARCHAR(3))",
|
||||||
"starrocks": "CAST(a AS VARCHAR(3))",
|
"starrocks": "CAST(a AS VARCHAR(3))",
|
||||||
|
@ -142,12 +149,26 @@ class TestDialect(Validator):
|
||||||
"oracle": "CAST(a AS NUMBER)",
|
"oracle": "CAST(a AS NUMBER)",
|
||||||
"postgres": "CAST(a AS SMALLINT)",
|
"postgres": "CAST(a AS SMALLINT)",
|
||||||
"presto": "CAST(a AS SMALLINT)",
|
"presto": "CAST(a AS SMALLINT)",
|
||||||
|
"redshift": "CAST(a AS SMALLINT)",
|
||||||
"snowflake": "CAST(a AS SMALLINT)",
|
"snowflake": "CAST(a AS SMALLINT)",
|
||||||
"spark": "CAST(a AS SHORT)",
|
"spark": "CAST(a AS SHORT)",
|
||||||
"sqlite": "CAST(a AS INTEGER)",
|
"sqlite": "CAST(a AS INTEGER)",
|
||||||
"starrocks": "CAST(a AS SMALLINT)",
|
"starrocks": "CAST(a AS SMALLINT)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"TRY_CAST(a AS DOUBLE)",
|
||||||
|
read={
|
||||||
|
"postgres": "CAST(a AS DOUBLE PRECISION)",
|
||||||
|
"redshift": "CAST(a AS DOUBLE PRECISION)",
|
||||||
|
},
|
||||||
|
write={
|
||||||
|
"duckdb": "TRY_CAST(a AS DOUBLE)",
|
||||||
|
"postgres": "CAST(a AS DOUBLE PRECISION)",
|
||||||
|
"redshift": "CAST(a AS DOUBLE PRECISION)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CAST(a AS DOUBLE)",
|
"CAST(a AS DOUBLE)",
|
||||||
write={
|
write={
|
||||||
|
@ -159,16 +180,32 @@ class TestDialect(Validator):
|
||||||
"oracle": "CAST(a AS DOUBLE PRECISION)",
|
"oracle": "CAST(a AS DOUBLE PRECISION)",
|
||||||
"postgres": "CAST(a AS DOUBLE PRECISION)",
|
"postgres": "CAST(a AS DOUBLE PRECISION)",
|
||||||
"presto": "CAST(a AS DOUBLE)",
|
"presto": "CAST(a AS DOUBLE)",
|
||||||
|
"redshift": "CAST(a AS DOUBLE PRECISION)",
|
||||||
"snowflake": "CAST(a AS DOUBLE)",
|
"snowflake": "CAST(a AS DOUBLE)",
|
||||||
"spark": "CAST(a AS DOUBLE)",
|
"spark": "CAST(a AS DOUBLE)",
|
||||||
"starrocks": "CAST(a AS DOUBLE)",
|
"starrocks": "CAST(a AS DOUBLE)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CAST(a AS TIMESTAMP)", write={"starrocks": "CAST(a AS DATETIME)"}
|
"CAST('1 DAY' AS INTERVAL)",
|
||||||
|
write={
|
||||||
|
"postgres": "CAST('1 DAY' AS INTERVAL)",
|
||||||
|
"redshift": "CAST('1 DAY' AS INTERVAL)",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CAST(a AS TIMESTAMPTZ)", write={"starrocks": "CAST(a AS DATETIME)"}
|
"CAST(a AS TIMESTAMP)",
|
||||||
|
write={
|
||||||
|
"starrocks": "CAST(a AS DATETIME)",
|
||||||
|
"redshift": "CAST(a AS TIMESTAMP)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"CAST(a AS TIMESTAMPTZ)",
|
||||||
|
write={
|
||||||
|
"starrocks": "CAST(a AS DATETIME)",
|
||||||
|
"redshift": "CAST(a AS TIMESTAMPTZ)",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
self.validate_all("CAST(a AS TINYINT)", write={"oracle": "CAST(a AS NUMBER)"})
|
self.validate_all("CAST(a AS TINYINT)", write={"oracle": "CAST(a AS NUMBER)"})
|
||||||
self.validate_all("CAST(a AS SMALLINT)", write={"oracle": "CAST(a AS NUMBER)"})
|
self.validate_all("CAST(a AS SMALLINT)", write={"oracle": "CAST(a AS NUMBER)"})
|
||||||
|
@ -552,6 +589,7 @@ class TestDialect(Validator):
|
||||||
write={
|
write={
|
||||||
"bigquery": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
|
"bigquery": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
|
||||||
"duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
|
"duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
|
||||||
|
"oracle": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
|
||||||
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST",
|
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST",
|
||||||
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
|
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
|
||||||
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
|
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
|
||||||
|
@ -566,6 +604,7 @@ class TestDialect(Validator):
|
||||||
"presto": "JSON_EXTRACT(x, 'y')",
|
"presto": "JSON_EXTRACT(x, 'y')",
|
||||||
},
|
},
|
||||||
write={
|
write={
|
||||||
|
"oracle": "JSON_EXTRACT(x, 'y')",
|
||||||
"postgres": "x->'y'",
|
"postgres": "x->'y'",
|
||||||
"presto": "JSON_EXTRACT(x, 'y')",
|
"presto": "JSON_EXTRACT(x, 'y')",
|
||||||
},
|
},
|
||||||
|
@ -623,6 +662,37 @@ class TestDialect(Validator):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# https://dev.mysql.com/doc/refman/8.0/en/join.html
|
||||||
|
# https://www.postgresql.org/docs/current/queries-table-expressions.html
|
||||||
|
def test_joined_tables(self):
|
||||||
|
self.validate_identity("SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)")
|
||||||
|
self.validate_identity("SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3)")
|
||||||
|
self.validate_identity("SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)")
|
||||||
|
self.validate_identity("SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)")
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)",
|
||||||
|
write={
|
||||||
|
"postgres": "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)",
|
||||||
|
"mysql": "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)",
|
||||||
|
write={
|
||||||
|
"postgres": "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)",
|
||||||
|
"mysql": "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_lateral_subquery(self):
|
||||||
|
self.validate_identity(
|
||||||
|
"SELECT art FROM tbl1 INNER JOIN LATERAL (SELECT art FROM tbl2) AS tbl2 ON tbl1.art = tbl2.art"
|
||||||
|
)
|
||||||
|
self.validate_identity(
|
||||||
|
"SELECT * FROM tbl AS t LEFT JOIN LATERAL (SELECT * FROM b WHERE b.t_id = t.t_id) AS t ON TRUE"
|
||||||
|
)
|
||||||
|
|
||||||
def test_set_operators(self):
|
def test_set_operators(self):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT * FROM a UNION SELECT * FROM b",
|
"SELECT * FROM a UNION SELECT * FROM b",
|
||||||
|
@ -731,6 +801,9 @@ class TestDialect(Validator):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_operators(self):
|
def test_operators(self):
|
||||||
|
self.validate_identity("some.column LIKE 'foo' || another.column || 'bar' || LOWER(x)")
|
||||||
|
self.validate_identity("some.column LIKE 'foo' + another.column + 'bar'")
|
||||||
|
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"x ILIKE '%y'",
|
"x ILIKE '%y'",
|
||||||
read={
|
read={
|
||||||
|
@ -874,16 +947,8 @@ class TestDialect(Validator):
|
||||||
"spark": "FILTER(the_array, x -> x > 0)",
|
"spark": "FILTER(the_array, x -> x > 0)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
|
||||||
"SELECT a AS b FROM x GROUP BY b",
|
def test_limit(self):
|
||||||
write={
|
|
||||||
"duckdb": "SELECT a AS b FROM x GROUP BY b",
|
|
||||||
"presto": "SELECT a AS b FROM x GROUP BY 1",
|
|
||||||
"hive": "SELECT a AS b FROM x GROUP BY 1",
|
|
||||||
"oracle": "SELECT a AS b FROM x GROUP BY 1",
|
|
||||||
"spark": "SELECT a AS b FROM x GROUP BY 1",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT x FROM y LIMIT 10",
|
"SELECT x FROM y LIMIT 10",
|
||||||
write={
|
write={
|
||||||
|
@ -915,6 +980,7 @@ class TestDialect(Validator):
|
||||||
read={
|
read={
|
||||||
"clickhouse": '`x` + "y"',
|
"clickhouse": '`x` + "y"',
|
||||||
"sqlite": '`x` + "y"',
|
"sqlite": '`x` + "y"',
|
||||||
|
"redshift": '"x" + "y"',
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -977,5 +1043,36 @@ class TestDialect(Validator):
|
||||||
"oracle": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 CLOB, c2 CLOB(1024))",
|
"oracle": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 CLOB, c2 CLOB(1024))",
|
||||||
"postgres": "CREATE TABLE t (b1 BYTEA, b2 BYTEA(1024), c1 TEXT, c2 TEXT(1024))",
|
"postgres": "CREATE TABLE t (b1 BYTEA, b2 BYTEA(1024), c1 TEXT, c2 TEXT(1024))",
|
||||||
"sqlite": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))",
|
"sqlite": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))",
|
||||||
|
"redshift": "CREATE TABLE t (b1 VARBYTE, b2 VARBYTE(1024), c1 TEXT, c2 TEXT(1024))",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_alias(self):
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT a AS b FROM x GROUP BY b",
|
||||||
|
write={
|
||||||
|
"duckdb": "SELECT a AS b FROM x GROUP BY b",
|
||||||
|
"presto": "SELECT a AS b FROM x GROUP BY 1",
|
||||||
|
"hive": "SELECT a AS b FROM x GROUP BY 1",
|
||||||
|
"oracle": "SELECT a AS b FROM x GROUP BY 1",
|
||||||
|
"spark": "SELECT a AS b FROM x GROUP BY 1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT y x FROM my_table t",
|
||||||
|
write={
|
||||||
|
"hive": "SELECT y AS x FROM my_table AS t",
|
||||||
|
"oracle": "SELECT y AS x FROM my_table t",
|
||||||
|
"postgres": "SELECT y AS x FROM my_table AS t",
|
||||||
|
"sqlite": "SELECT y AS x FROM my_table AS t",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
|
||||||
|
write={
|
||||||
|
"hive": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
|
||||||
|
"oracle": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 t JOIN cte2 WHERE cte1.a = cte2.c",
|
||||||
|
"postgres": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
|
||||||
|
"sqlite": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
@ -341,6 +341,21 @@ class TestHive(Validator):
|
||||||
"spark": "PERCENTILE(x, 0.5)",
|
"spark": "PERCENTILE(x, 0.5)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"PERCENTILE_APPROX(x, 0.5)",
|
||||||
|
read={
|
||||||
|
"hive": "PERCENTILE_APPROX(x, 0.5)",
|
||||||
|
"presto": "APPROX_PERCENTILE(x, 0.5)",
|
||||||
|
"duckdb": "APPROX_QUANTILE(x, 0.5)",
|
||||||
|
"spark": "PERCENTILE_APPROX(x, 0.5)",
|
||||||
|
},
|
||||||
|
write={
|
||||||
|
"hive": "PERCENTILE_APPROX(x, 0.5)",
|
||||||
|
"presto": "APPROX_PERCENTILE(x, 0.5)",
|
||||||
|
"duckdb": "APPROX_QUANTILE(x, 0.5)",
|
||||||
|
"spark": "PERCENTILE_APPROX(x, 0.5)",
|
||||||
|
},
|
||||||
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"APPROX_COUNT_DISTINCT(a)",
|
"APPROX_COUNT_DISTINCT(a)",
|
||||||
write={
|
write={
|
||||||
|
|
|
@ -15,6 +15,10 @@ class TestMySQL(Validator):
|
||||||
|
|
||||||
def test_identity(self):
|
def test_identity(self):
|
||||||
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")
|
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")
|
||||||
|
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ')")
|
||||||
|
self.validate_identity("SELECT TRIM(TRAILING 'bla' FROM ' XXX ')")
|
||||||
|
self.validate_identity("SELECT TRIM(BOTH 'bla' FROM ' XXX ')")
|
||||||
|
self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')")
|
||||||
|
|
||||||
def test_introducers(self):
|
def test_introducers(self):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -27,12 +31,22 @@ class TestMySQL(Validator):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_binary_literal(self):
|
def test_hexadecimal_literal(self):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT 0xCC",
|
"SELECT 0xCC",
|
||||||
write={
|
write={
|
||||||
"mysql": "SELECT b'11001100'",
|
"mysql": "SELECT x'CC'",
|
||||||
"spark": "SELECT X'11001100'",
|
"sqlite": "SELECT x'CC'",
|
||||||
|
"spark": "SELECT X'CC'",
|
||||||
|
"trino": "SELECT X'CC'",
|
||||||
|
"bigquery": "SELECT 0xCC",
|
||||||
|
"oracle": "SELECT 204",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT X'1A'",
|
||||||
|
write={
|
||||||
|
"mysql": "SELECT x'1A'",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -41,10 +55,22 @@ class TestMySQL(Validator):
|
||||||
"mysql": "SELECT `0xz`",
|
"mysql": "SELECT `0xz`",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_bits_literal(self):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT 0XCC",
|
"SELECT 0b1011",
|
||||||
write={
|
write={
|
||||||
"mysql": "SELECT 0 AS XCC",
|
"mysql": "SELECT b'1011'",
|
||||||
|
"postgres": "SELECT b'1011'",
|
||||||
|
"oracle": "SELECT 11",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT B'1011'",
|
||||||
|
write={
|
||||||
|
"mysql": "SELECT b'1011'",
|
||||||
|
"postgres": "SELECT b'1011'",
|
||||||
|
"oracle": "SELECT 11",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -77,3 +103,19 @@ class TestMySQL(Validator):
|
||||||
"mysql": "SELECT 1",
|
"mysql": "SELECT 1",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_mysql(self):
|
||||||
|
self.validate_all(
|
||||||
|
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
|
||||||
|
write={
|
||||||
|
"mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR ',')",
|
||||||
|
"sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')",
|
||||||
|
write={
|
||||||
|
"mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')",
|
||||||
|
"sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC, '')",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
|
@ -8,9 +8,7 @@ class TestPostgres(Validator):
|
||||||
def test_ddl(self):
|
def test_ddl(self):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)",
|
"CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)",
|
||||||
write={
|
write={"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"},
|
||||||
"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)",
|
"CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)",
|
||||||
|
@ -42,11 +40,17 @@ class TestPostgres(Validator):
|
||||||
" CONSTRAINT valid_discount CHECK (price > discounted_price))"
|
" CONSTRAINT valid_discount CHECK (price > discounted_price))"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)",
|
||||||
|
write={"postgres": "CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)"},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)",
|
||||||
|
write={"postgres": "CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)"},
|
||||||
|
)
|
||||||
|
|
||||||
with self.assertRaises(ParseError):
|
with self.assertRaises(ParseError):
|
||||||
transpile(
|
transpile("CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres")
|
||||||
"CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres"
|
|
||||||
)
|
|
||||||
with self.assertRaises(ParseError):
|
with self.assertRaises(ParseError):
|
||||||
transpile(
|
transpile(
|
||||||
"CREATE TABLE products (price DECIMAL, CHECK price > 1)",
|
"CREATE TABLE products (price DECIMAL, CHECK price > 1)",
|
||||||
|
@ -54,11 +58,16 @@ class TestPostgres(Validator):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_postgres(self):
|
def test_postgres(self):
|
||||||
self.validate_all(
|
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END")
|
||||||
"CREATE TABLE x (a INT SERIAL)",
|
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END")
|
||||||
read={"sqlite": "CREATE TABLE x (a INTEGER AUTOINCREMENT)"},
|
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END")
|
||||||
write={"sqlite": "CREATE TABLE x (a INTEGER AUTOINCREMENT)"},
|
self.validate_identity('SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')')
|
||||||
)
|
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')")
|
||||||
|
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')")
|
||||||
|
self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))")
|
||||||
|
self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')")
|
||||||
|
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
|
||||||
|
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CREATE TABLE x (a UUID, b BYTEA)",
|
"CREATE TABLE x (a UUID, b BYTEA)",
|
||||||
write={
|
write={
|
||||||
|
@ -91,3 +100,65 @@ class TestPostgres(Validator):
|
||||||
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
|
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END",
|
||||||
|
write={
|
||||||
|
"hive": "SELECT CASE WHEN SUBSTRING('abcdefg', 1, 2) IN ('ab') THEN 1 ELSE 0 END",
|
||||||
|
"spark": "SELECT CASE WHEN SUBSTRING('abcdefg', 1, 2) IN ('ab') THEN 1 ELSE 0 END",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT * FROM x WHERE SUBSTRING(col1 FROM 3 + LENGTH(col1) - 10 FOR 10) IN (col2)",
|
||||||
|
write={
|
||||||
|
"hive": "SELECT * FROM x WHERE SUBSTRING(col1, 3 + LENGTH(col1) - 10, 10) IN (col2)",
|
||||||
|
"spark": "SELECT * FROM x WHERE SUBSTRING(col1, 3 + LENGTH(col1) - 10, 10) IN (col2)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT SUBSTRING(CAST(2022 AS CHAR(4)) || LPAD(CAST(3 AS CHAR(2)), 2, '0') FROM 3 FOR 4)",
|
||||||
|
read={
|
||||||
|
"postgres": "SELECT SUBSTRING(2022::CHAR(4) || LPAD(3::CHAR(2), 2, '0') FROM 3 FOR 4)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT TRIM(BOTH ' XXX ')",
|
||||||
|
write={
|
||||||
|
"mysql": "SELECT TRIM(' XXX ')",
|
||||||
|
"postgres": "SELECT TRIM(' XXX ')",
|
||||||
|
"hive": "SELECT TRIM(' XXX ')",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"TRIM(LEADING FROM ' XXX ')",
|
||||||
|
write={
|
||||||
|
"mysql": "LTRIM(' XXX ')",
|
||||||
|
"postgres": "LTRIM(' XXX ')",
|
||||||
|
"hive": "LTRIM(' XXX ')",
|
||||||
|
"presto": "LTRIM(' XXX ')",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"TRIM(TRAILING FROM ' XXX ')",
|
||||||
|
write={
|
||||||
|
"mysql": "RTRIM(' XXX ')",
|
||||||
|
"postgres": "RTRIM(' XXX ')",
|
||||||
|
"hive": "RTRIM(' XXX ')",
|
||||||
|
"presto": "RTRIM(' XXX ')",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss",
|
||||||
|
read={"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL",
|
||||||
|
read={
|
||||||
|
"postgres": "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT p1.id, p2.id, v1, v2 FROM polygons AS p1, polygons AS p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id <> p2.id",
|
||||||
|
read={
|
||||||
|
"postgres": "SELECT p1.id, p2.id, v1, v2 FROM polygons p1, polygons p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id != p2.id",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
64
tests/dialects/test_redshift.py
Normal file
64
tests/dialects/test_redshift.py
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
from tests.dialects.test_dialect import Validator
|
||||||
|
|
||||||
|
|
||||||
|
class TestRedshift(Validator):
|
||||||
|
dialect = "redshift"
|
||||||
|
|
||||||
|
def test_redshift(self):
|
||||||
|
self.validate_all(
|
||||||
|
'create table "group" ("col" char(10))',
|
||||||
|
write={
|
||||||
|
"redshift": 'CREATE TABLE "group" ("col" CHAR(10))',
|
||||||
|
"mysql": "CREATE TABLE `group` (`col` CHAR(10))",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
'create table if not exists city_slash_id("city/id" integer not null, state char(2) not null)',
|
||||||
|
write={
|
||||||
|
"redshift": 'CREATE TABLE IF NOT EXISTS city_slash_id ("city/id" INTEGER NOT NULL, state CHAR(2) NOT NULL)',
|
||||||
|
"presto": 'CREATE TABLE IF NOT EXISTS city_slash_id ("city/id" INTEGER NOT NULL, state CHAR(2) NOT NULL)',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT ST_AsEWKT(ST_GeomFromEWKT('SRID=4326;POINT(10 20)')::geography)",
|
||||||
|
write={
|
||||||
|
"redshift": "SELECT ST_ASEWKT(CAST(ST_GEOMFROMEWKT('SRID=4326;POINT(10 20)') AS GEOGRAPHY))",
|
||||||
|
"bigquery": "SELECT ST_ASEWKT(CAST(ST_GEOMFROMEWKT('SRID=4326;POINT(10 20)') AS GEOGRAPHY))",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT ST_AsEWKT(ST_GeogFromText('LINESTRING(110 40, 2 3, -10 80, -7 9)')::geometry)",
|
||||||
|
write={
|
||||||
|
"redshift": "SELECT ST_ASEWKT(CAST(ST_GEOGFROMTEXT('LINESTRING(110 40, 2 3, -10 80, -7 9)') AS GEOMETRY))",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT 'abc'::BINARY",
|
||||||
|
write={
|
||||||
|
"redshift": "SELECT CAST('abc' AS VARBYTE)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT * FROM venue WHERE (venuecity, venuestate) IN (('Miami', 'FL'), ('Tampa', 'FL')) ORDER BY venueid",
|
||||||
|
write={
|
||||||
|
"redshift": "SELECT * FROM venue WHERE (venuecity, venuestate) IN (('Miami', 'FL'), ('Tampa', 'FL')) ORDER BY venueid",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
'SELECT tablename, "column" FROM pg_table_def WHERE "column" LIKE \'%start\\_%\' LIMIT 5',
|
||||||
|
write={
|
||||||
|
"redshift": 'SELECT tablename, "column" FROM pg_table_def WHERE "column" LIKE \'%start\\\\_%\' LIMIT 5'
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_identity(self):
|
||||||
|
self.validate_identity("CAST('bla' AS SUPER)")
|
||||||
|
self.validate_identity("CREATE TABLE real1 (realcol REAL)")
|
||||||
|
self.validate_identity("CAST('foo' AS HLLSKETCH)")
|
||||||
|
self.validate_identity("SELECT DATEADD(day, 1, 'today')")
|
||||||
|
self.validate_identity("'abc' SIMILAR TO '(b|c)%'")
|
||||||
|
self.validate_identity(
|
||||||
|
"SELECT caldate + INTERVAL '1 second' AS dateplus FROM date WHERE caldate = '12-31-2008'"
|
||||||
|
)
|
||||||
|
self.validate_identity("CREATE TABLE datetable (start_date DATE, end_date DATE)")
|
||||||
|
self.validate_identity("SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'")
|
|
@ -143,3 +143,35 @@ class TestSnowflake(Validator):
|
||||||
"snowflake": r"SELECT 'a \' \\ \\t \\x21 z $ '",
|
"snowflake": r"SELECT 'a \' \\ \\t \\x21 z $ '",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_null_treatment(self):
|
||||||
|
self.validate_all(
|
||||||
|
r"SELECT FIRST_VALUE(TABLE1.COLUMN1) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1",
|
||||||
|
write={
|
||||||
|
"snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
r"SELECT FIRST_VALUE(TABLE1.COLUMN1 RESPECT NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1",
|
||||||
|
write={
|
||||||
|
"snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
r"SELECT FIRST_VALUE(TABLE1.COLUMN1) RESPECT NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1",
|
||||||
|
write={
|
||||||
|
"snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) RESPECT NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
r"SELECT FIRST_VALUE(TABLE1.COLUMN1 IGNORE NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1",
|
||||||
|
write={
|
||||||
|
"snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1 IGNORE NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
r"SELECT FIRST_VALUE(TABLE1.COLUMN1) IGNORE NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1",
|
||||||
|
write={
|
||||||
|
"snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) IGNORE NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
|
@ -34,6 +34,7 @@ class TestSQLite(Validator):
|
||||||
write={
|
write={
|
||||||
"sqlite": "CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)",
|
"sqlite": "CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)",
|
||||||
"mysql": "CREATE TABLE z (a INT UNIQUE PRIMARY KEY AUTO_INCREMENT)",
|
"mysql": "CREATE TABLE z (a INT UNIQUE PRIMARY KEY AUTO_INCREMENT)",
|
||||||
|
"postgres": "CREATE TABLE z (a INT GENERATED BY DEFAULT AS IDENTITY NOT NULL UNIQUE PRIMARY KEY)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -70,3 +71,20 @@ class TestSQLite(Validator):
|
||||||
"sqlite": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
|
"sqlite": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_hexadecimal_literal(self):
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT 0XCC",
|
||||||
|
write={
|
||||||
|
"sqlite": "SELECT x'CC'",
|
||||||
|
"mysql": "SELECT x'CC'",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_window_null_treatment(self):
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT FIRST_VALUE(Name) OVER (PARTITION BY AlbumId ORDER BY Bytes DESC) AS LargestTrack FROM tracks",
|
||||||
|
write={
|
||||||
|
"sqlite": "SELECT FIRST_VALUE(Name) OVER (PARTITION BY AlbumId ORDER BY Bytes DESC) AS LargestTrack FROM tracks"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
7
tests/fixtures/identity.sql
vendored
7
tests/fixtures/identity.sql
vendored
|
@ -318,6 +318,9 @@ SELECT 1 FROM a JOIN b ON a.foo = b.bar JOIN c ON a.foo = c.bar
|
||||||
SELECT 1 FROM a LEFT JOIN b ON a.foo = b.bar JOIN c ON a.foo = c.bar
|
SELECT 1 FROM a LEFT JOIN b ON a.foo = b.bar JOIN c ON a.foo = c.bar
|
||||||
SELECT 1 FROM a LEFT INNER JOIN b ON a.foo = b.bar
|
SELECT 1 FROM a LEFT INNER JOIN b ON a.foo = b.bar
|
||||||
SELECT 1 FROM a LEFT OUTER JOIN b ON a.foo = b.bar
|
SELECT 1 FROM a LEFT OUTER JOIN b ON a.foo = b.bar
|
||||||
|
SELECT 1 FROM a NATURAL JOIN b
|
||||||
|
SELECT 1 FROM a NATURAL LEFT JOIN b
|
||||||
|
SELECT 1 FROM a NATURAL LEFT OUTER JOIN b
|
||||||
SELECT 1 FROM a OUTER JOIN b ON a.foo = b.bar
|
SELECT 1 FROM a OUTER JOIN b ON a.foo = b.bar
|
||||||
SELECT 1 FROM a FULL JOIN b ON a.foo = b.bar
|
SELECT 1 FROM a FULL JOIN b ON a.foo = b.bar
|
||||||
SELECT 1 UNION ALL SELECT 2
|
SELECT 1 UNION ALL SELECT 2
|
||||||
|
@ -329,6 +332,7 @@ SELECT 1 AS delete, 2 AS alter
|
||||||
SELECT * FROM (x)
|
SELECT * FROM (x)
|
||||||
SELECT * FROM ((x))
|
SELECT * FROM ((x))
|
||||||
SELECT * FROM ((SELECT 1))
|
SELECT * FROM ((SELECT 1))
|
||||||
|
SELECT * FROM (x LATERAL VIEW EXPLODE(y) JOIN foo)
|
||||||
SELECT * FROM (SELECT 1) AS x
|
SELECT * FROM (SELECT 1) AS x
|
||||||
SELECT * FROM (SELECT 1 UNION SELECT 2) AS x
|
SELECT * FROM (SELECT 1 UNION SELECT 2) AS x
|
||||||
SELECT * FROM (SELECT 1 UNION ALL SELECT 2) AS x
|
SELECT * FROM (SELECT 1 UNION ALL SELECT 2) AS x
|
||||||
|
@ -430,6 +434,7 @@ CREATE TEMPORARY VIEW x AS SELECT a FROM d
|
||||||
CREATE TEMPORARY VIEW IF NOT EXISTS x AS SELECT a FROM d
|
CREATE TEMPORARY VIEW IF NOT EXISTS x AS SELECT a FROM d
|
||||||
CREATE TEMPORARY VIEW x AS WITH y AS (SELECT 1) SELECT * FROM y
|
CREATE TEMPORARY VIEW x AS WITH y AS (SELECT 1) SELECT * FROM y
|
||||||
CREATE TABLE z (a INT, b VARCHAR, c VARCHAR(100), d DECIMAL(5, 3))
|
CREATE TABLE z (a INT, b VARCHAR, c VARCHAR(100), d DECIMAL(5, 3))
|
||||||
|
CREATE TABLE z (end INT)
|
||||||
CREATE TABLE z (a ARRAY<TEXT>, b MAP<TEXT, DOUBLE>, c DECIMAL(5, 3))
|
CREATE TABLE z (a ARRAY<TEXT>, b MAP<TEXT, DOUBLE>, c DECIMAL(5, 3))
|
||||||
CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECIMAL(5, 3))
|
CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECIMAL(5, 3))
|
||||||
CREATE TABLE z (a INT(11) DEFAULT UUID())
|
CREATE TABLE z (a INT(11) DEFAULT UUID())
|
||||||
|
@ -466,6 +471,7 @@ CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1
|
||||||
CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS WITH a AS (SELECT 1) SELECT a.* FROM a
|
CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS WITH a AS (SELECT 1) SELECT a.* FROM a
|
||||||
CACHE LAZY TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a
|
CACHE LAZY TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a
|
||||||
CACHE TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a
|
CACHE TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a
|
||||||
|
CACHE TABLE x AS (SELECT 1 AS y)
|
||||||
CALL catalog.system.iceberg_procedure_name(named_arg_1 => 'arg_1', named_arg_2 => 'arg_2')
|
CALL catalog.system.iceberg_procedure_name(named_arg_1 => 'arg_1', named_arg_2 => 'arg_2')
|
||||||
INSERT OVERWRITE TABLE a.b PARTITION(ds) SELECT x FROM y
|
INSERT OVERWRITE TABLE a.b PARTITION(ds) SELECT x FROM y
|
||||||
INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD') SELECT x FROM y
|
INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD') SELECT x FROM y
|
||||||
|
@ -512,3 +518,4 @@ SELECT ? AS ? FROM x WHERE b BETWEEN ? AND ? GROUP BY ?, 1 LIMIT ?
|
||||||
WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS foo, c.bar AS bar FROM c)) SELECT * FROM a
|
WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS foo, c.bar AS bar FROM c)) SELECT * FROM a
|
||||||
WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a
|
WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a
|
||||||
SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z
|
SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z
|
||||||
|
SELECT ((SELECT 1) + 1)
|
||||||
|
|
63
tests/fixtures/optimizer/merge_derived_tables.sql
vendored
Normal file
63
tests/fixtures/optimizer/merge_derived_tables.sql
vendored
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
-- Simple
|
||||||
|
SELECT a, b FROM (SELECT a, b FROM x);
|
||||||
|
SELECT x.a AS a, x.b AS b FROM x AS x;
|
||||||
|
|
||||||
|
-- Inner table alias is merged
|
||||||
|
SELECT a, b FROM (SELECT a, b FROM x AS q) AS r;
|
||||||
|
SELECT q.a AS a, q.b AS b FROM x AS q;
|
||||||
|
|
||||||
|
-- Double nesting
|
||||||
|
SELECT a, b FROM (SELECT a, b FROM (SELECT a, b FROM x));
|
||||||
|
SELECT x.a AS a, x.b AS b FROM x AS x;
|
||||||
|
|
||||||
|
-- WHERE clause is merged
|
||||||
|
SELECT a, SUM(b) FROM (SELECT a, b FROM x WHERE a > 1) GROUP BY a;
|
||||||
|
SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 GROUP BY x.a;
|
||||||
|
|
||||||
|
-- Outer query has join
|
||||||
|
SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b;
|
||||||
|
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1;
|
||||||
|
|
||||||
|
-- Join on derived table
|
||||||
|
SELECT a, c FROM x JOIN (SELECT b, c FROM y) AS y ON x.b = y.b;
|
||||||
|
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
|
||||||
|
|
||||||
|
-- Inner query has a join
|
||||||
|
SELECT a, c FROM (SELECT a, c FROM x JOIN y ON x.b = y.b);
|
||||||
|
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
|
||||||
|
|
||||||
|
-- Inner query has conflicting name in outer query
|
||||||
|
SELECT a, c FROM (SELECT q.a, q.b FROM x AS q) AS x JOIN y AS q ON x.b = q.b;
|
||||||
|
SELECT q_2.a AS a, q.c AS c FROM x AS q_2 JOIN y AS q ON q_2.b = q.b;
|
||||||
|
|
||||||
|
-- Inner query has conflicting name in joined source
|
||||||
|
SELECT x.a, q.c FROM (SELECT a, x.b FROM x JOIN y AS q ON x.b = q.b) AS x JOIN y AS q ON x.b = q.b;
|
||||||
|
SELECT x.a AS a, q.c AS c FROM x AS x JOIN y AS q_2 ON x.b = q_2.b JOIN y AS q ON x.b = q.b;
|
||||||
|
|
||||||
|
-- Inner query has multiple conflicting names
|
||||||
|
SELECT x.a, q.c, r.c FROM (SELECT q.a, r.b FROM x AS q JOIN y AS r ON q.b = r.b) AS x JOIN y AS q ON x.b = q.b JOIN y AS r ON x.b = r.b;
|
||||||
|
SELECT q_2.a AS a, q.c AS c, r.c AS c FROM x AS q_2 JOIN y AS r_2 ON q_2.b = r_2.b JOIN y AS q ON r_2.b = q.b JOIN y AS r ON r_2.b = r.b;
|
||||||
|
|
||||||
|
-- Inner queries have conflicting names with each other
|
||||||
|
SELECT r.b FROM (SELECT b FROM x AS x) AS q JOIN (SELECT b FROM x) AS r ON q.b = r.b;
|
||||||
|
SELECT x_2.b AS b FROM x AS x JOIN x AS x_2 ON x.b = x_2.b;
|
||||||
|
|
||||||
|
-- WHERE clause in joined derived table is merged
|
||||||
|
SELECT x.a, y.c FROM x JOIN (SELECT b, c FROM y WHERE c > 1) AS y;
|
||||||
|
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y WHERE y.c > 1;
|
||||||
|
|
||||||
|
-- WHERE clause in outer joined derived table is merged to ON clause
|
||||||
|
SELECT x.a, y.c FROM x LEFT JOIN (SELECT b, c FROM y WHERE c > 1) AS y;
|
||||||
|
SELECT x.a AS a, y.c AS c FROM x AS x LEFT JOIN y AS y ON y.c > 1;
|
||||||
|
|
||||||
|
-- Comma JOIN in outer query
|
||||||
|
SELECT x.a, y.c FROM (SELECT a FROM x) AS x, (SELECT c FROM y) AS y;
|
||||||
|
SELECT x.a AS a, y.c AS c FROM x AS x, y AS y;
|
||||||
|
|
||||||
|
-- Comma JOIN in inner query
|
||||||
|
SELECT x.a, x.c FROM (SELECT x.a, z.c FROM x, y AS z) AS x;
|
||||||
|
SELECT x.a AS a, z.c AS c FROM x AS x CROSS JOIN y AS z;
|
||||||
|
|
||||||
|
-- (Regression) Column in ORDER BY
|
||||||
|
SELECT * FROM (SELECT * FROM (SELECT * FROM x)) ORDER BY a LIMIT 1;
|
||||||
|
SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a LIMIT 1;
|
37
tests/fixtures/optimizer/optimizer.sql
vendored
37
tests/fixtures/optimizer/optimizer.sql
vendored
|
@ -2,11 +2,7 @@ SELECT a, m FROM z LATERAL VIEW EXPLODE([1, 2]) q AS m;
|
||||||
SELECT
|
SELECT
|
||||||
"z"."a" AS "a",
|
"z"."a" AS "a",
|
||||||
"q"."m" AS "m"
|
"q"."m" AS "m"
|
||||||
FROM (
|
FROM "z" AS "z"
|
||||||
SELECT
|
|
||||||
"z"."a" AS "a"
|
|
||||||
FROM "z" AS "z"
|
|
||||||
) AS "z"
|
|
||||||
LATERAL VIEW
|
LATERAL VIEW
|
||||||
EXPLODE(ARRAY(1, 2)) q AS "m";
|
EXPLODE(ARRAY(1, 2)) q AS "m";
|
||||||
|
|
||||||
|
@ -91,41 +87,26 @@ FROM (
|
||||||
WHERE (TRUE AND TRUE OR 'a' = 'b') AND a > 1
|
WHERE (TRUE AND TRUE OR 'a' = 'b') AND a > 1
|
||||||
GROUP BY a;
|
GROUP BY a;
|
||||||
SELECT
|
SELECT
|
||||||
"d"."a" AS "a",
|
|
||||||
SUM("d"."b") AS "_col_1"
|
|
||||||
FROM (
|
|
||||||
SELECT
|
|
||||||
"x"."a" AS "a",
|
"x"."a" AS "a",
|
||||||
"y"."b" AS "b"
|
SUM("y"."b") AS "_col_1"
|
||||||
FROM (
|
FROM "x" AS "x"
|
||||||
SELECT
|
LEFT JOIN (
|
||||||
"x"."a" AS "a"
|
|
||||||
FROM "x" AS "x"
|
|
||||||
WHERE
|
|
||||||
"x"."a" > 1
|
|
||||||
) AS "x"
|
|
||||||
LEFT JOIN (
|
|
||||||
SELECT
|
SELECT
|
||||||
MAX("y"."b") AS "_col_0",
|
MAX("y"."b") AS "_col_0",
|
||||||
"y"."a" AS "_u_1"
|
"y"."a" AS "_u_1"
|
||||||
FROM "y" AS "y"
|
FROM "y" AS "y"
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"y"."a"
|
"y"."a"
|
||||||
) AS "_u_0"
|
) AS "_u_0"
|
||||||
ON "x"."a" = "_u_0"."_u_1"
|
ON "x"."a" = "_u_0"."_u_1"
|
||||||
JOIN (
|
JOIN "y" AS "y"
|
||||||
SELECT
|
|
||||||
"y"."a" AS "a",
|
|
||||||
"y"."b" AS "b"
|
|
||||||
FROM "y" AS "y"
|
|
||||||
) AS "y"
|
|
||||||
ON "x"."a" = "y"."a"
|
ON "x"."a" = "y"."a"
|
||||||
WHERE
|
WHERE
|
||||||
"_u_0"."_col_0" >= 0
|
"_u_0"."_col_0" >= 0
|
||||||
|
AND "x"."a" > 1
|
||||||
AND NOT "_u_0"."_u_1" IS NULL
|
AND NOT "_u_0"."_u_1" IS NULL
|
||||||
) AS "d"
|
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"d"."a";
|
"x"."a";
|
||||||
|
|
||||||
(SELECT a FROM x) LIMIT 1;
|
(SELECT a FROM x) LIMIT 1;
|
||||||
(
|
(
|
||||||
|
|
665
tests/fixtures/optimizer/tpc-h/tpc-h.sql
vendored
665
tests/fixtures/optimizer/tpc-h/tpc-h.sql
vendored
|
@ -120,36 +120,16 @@ SELECT
|
||||||
"supplier"."s_address" AS "s_address",
|
"supplier"."s_address" AS "s_address",
|
||||||
"supplier"."s_phone" AS "s_phone",
|
"supplier"."s_phone" AS "s_phone",
|
||||||
"supplier"."s_comment" AS "s_comment"
|
"supplier"."s_comment" AS "s_comment"
|
||||||
FROM (
|
FROM "part" AS "part"
|
||||||
SELECT
|
|
||||||
"part"."p_partkey" AS "p_partkey",
|
|
||||||
"part"."p_mfgr" AS "p_mfgr",
|
|
||||||
"part"."p_type" AS "p_type",
|
|
||||||
"part"."p_size" AS "p_size"
|
|
||||||
FROM "part" AS "part"
|
|
||||||
WHERE
|
|
||||||
"part"."p_size" = 15
|
|
||||||
AND "part"."p_type" LIKE '%BRASS'
|
|
||||||
) AS "part"
|
|
||||||
LEFT JOIN (
|
LEFT JOIN (
|
||||||
SELECT
|
SELECT
|
||||||
MIN("partsupp"."ps_supplycost") AS "_col_0",
|
MIN("partsupp"."ps_supplycost") AS "_col_0",
|
||||||
"partsupp"."ps_partkey" AS "_u_1"
|
"partsupp"."ps_partkey" AS "_u_1"
|
||||||
FROM "_e_0" AS "partsupp"
|
FROM "_e_0" AS "partsupp"
|
||||||
CROSS JOIN "_e_1" AS "region"
|
CROSS JOIN "_e_1" AS "region"
|
||||||
JOIN (
|
JOIN "nation" AS "nation"
|
||||||
SELECT
|
|
||||||
"nation"."n_nationkey" AS "n_nationkey",
|
|
||||||
"nation"."n_regionkey" AS "n_regionkey"
|
|
||||||
FROM "nation" AS "nation"
|
|
||||||
) AS "nation"
|
|
||||||
ON "nation"."n_regionkey" = "region"."r_regionkey"
|
ON "nation"."n_regionkey" = "region"."r_regionkey"
|
||||||
JOIN (
|
JOIN "supplier" AS "supplier"
|
||||||
SELECT
|
|
||||||
"supplier"."s_suppkey" AS "s_suppkey",
|
|
||||||
"supplier"."s_nationkey" AS "s_nationkey"
|
|
||||||
FROM "supplier" AS "supplier"
|
|
||||||
) AS "supplier"
|
|
||||||
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
|
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
|
||||||
AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey"
|
AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey"
|
||||||
GROUP BY
|
GROUP BY
|
||||||
|
@ -157,31 +137,17 @@ LEFT JOIN (
|
||||||
) AS "_u_0"
|
) AS "_u_0"
|
||||||
ON "part"."p_partkey" = "_u_0"."_u_1"
|
ON "part"."p_partkey" = "_u_0"."_u_1"
|
||||||
CROSS JOIN "_e_1" AS "region"
|
CROSS JOIN "_e_1" AS "region"
|
||||||
JOIN (
|
JOIN "nation" AS "nation"
|
||||||
SELECT
|
|
||||||
"nation"."n_nationkey" AS "n_nationkey",
|
|
||||||
"nation"."n_name" AS "n_name",
|
|
||||||
"nation"."n_regionkey" AS "n_regionkey"
|
|
||||||
FROM "nation" AS "nation"
|
|
||||||
) AS "nation"
|
|
||||||
ON "nation"."n_regionkey" = "region"."r_regionkey"
|
ON "nation"."n_regionkey" = "region"."r_regionkey"
|
||||||
JOIN "_e_0" AS "partsupp"
|
JOIN "_e_0" AS "partsupp"
|
||||||
ON "part"."p_partkey" = "partsupp"."ps_partkey"
|
ON "part"."p_partkey" = "partsupp"."ps_partkey"
|
||||||
JOIN (
|
JOIN "supplier" AS "supplier"
|
||||||
SELECT
|
|
||||||
"supplier"."s_suppkey" AS "s_suppkey",
|
|
||||||
"supplier"."s_name" AS "s_name",
|
|
||||||
"supplier"."s_address" AS "s_address",
|
|
||||||
"supplier"."s_nationkey" AS "s_nationkey",
|
|
||||||
"supplier"."s_phone" AS "s_phone",
|
|
||||||
"supplier"."s_acctbal" AS "s_acctbal",
|
|
||||||
"supplier"."s_comment" AS "s_comment"
|
|
||||||
FROM "supplier" AS "supplier"
|
|
||||||
) AS "supplier"
|
|
||||||
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
|
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
|
||||||
AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey"
|
AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey"
|
||||||
WHERE
|
WHERE
|
||||||
"partsupp"."ps_supplycost" = "_u_0"."_col_0"
|
"part"."p_size" = 15
|
||||||
|
AND "part"."p_type" LIKE '%BRASS'
|
||||||
|
AND "partsupp"."ps_supplycost" = "_u_0"."_col_0"
|
||||||
AND NOT "_u_0"."_u_1" IS NULL
|
AND NOT "_u_0"."_u_1" IS NULL
|
||||||
ORDER BY
|
ORDER BY
|
||||||
"s_acctbal" DESC,
|
"s_acctbal" DESC,
|
||||||
|
@ -224,36 +190,15 @@ SELECT
|
||||||
)) AS "revenue",
|
)) AS "revenue",
|
||||||
CAST("orders"."o_orderdate" AS TEXT) AS "o_orderdate",
|
CAST("orders"."o_orderdate" AS TEXT) AS "o_orderdate",
|
||||||
"orders"."o_shippriority" AS "o_shippriority"
|
"orders"."o_shippriority" AS "o_shippriority"
|
||||||
FROM (
|
FROM "customer" AS "customer"
|
||||||
SELECT
|
JOIN "orders" AS "orders"
|
||||||
"customer"."c_custkey" AS "c_custkey",
|
|
||||||
"customer"."c_mktsegment" AS "c_mktsegment"
|
|
||||||
FROM "customer" AS "customer"
|
|
||||||
WHERE
|
|
||||||
"customer"."c_mktsegment" = 'BUILDING'
|
|
||||||
) AS "customer"
|
|
||||||
JOIN (
|
|
||||||
SELECT
|
|
||||||
"orders"."o_orderkey" AS "o_orderkey",
|
|
||||||
"orders"."o_custkey" AS "o_custkey",
|
|
||||||
"orders"."o_orderdate" AS "o_orderdate",
|
|
||||||
"orders"."o_shippriority" AS "o_shippriority"
|
|
||||||
FROM "orders" AS "orders"
|
|
||||||
WHERE
|
|
||||||
"orders"."o_orderdate" < '1995-03-15'
|
|
||||||
) AS "orders"
|
|
||||||
ON "customer"."c_custkey" = "orders"."o_custkey"
|
ON "customer"."c_custkey" = "orders"."o_custkey"
|
||||||
JOIN (
|
JOIN "lineitem" AS "lineitem"
|
||||||
SELECT
|
|
||||||
"lineitem"."l_orderkey" AS "l_orderkey",
|
|
||||||
"lineitem"."l_extendedprice" AS "l_extendedprice",
|
|
||||||
"lineitem"."l_discount" AS "l_discount",
|
|
||||||
"lineitem"."l_shipdate" AS "l_shipdate"
|
|
||||||
FROM "lineitem" AS "lineitem"
|
|
||||||
WHERE
|
|
||||||
"lineitem"."l_shipdate" > '1995-03-15'
|
|
||||||
) AS "lineitem"
|
|
||||||
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
|
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
|
||||||
|
WHERE
|
||||||
|
"customer"."c_mktsegment" = 'BUILDING'
|
||||||
|
AND "lineitem"."l_shipdate" > '1995-03-15'
|
||||||
|
AND "orders"."o_orderdate" < '1995-03-15'
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"lineitem"."l_orderkey",
|
"lineitem"."l_orderkey",
|
||||||
"orders"."o_orderdate",
|
"orders"."o_orderdate",
|
||||||
|
@ -342,57 +287,22 @@ SELECT
|
||||||
SUM("lineitem"."l_extendedprice" * (
|
SUM("lineitem"."l_extendedprice" * (
|
||||||
1 - "lineitem"."l_discount"
|
1 - "lineitem"."l_discount"
|
||||||
)) AS "revenue"
|
)) AS "revenue"
|
||||||
FROM (
|
FROM "customer" AS "customer"
|
||||||
SELECT
|
JOIN "orders" AS "orders"
|
||||||
"customer"."c_custkey" AS "c_custkey",
|
|
||||||
"customer"."c_nationkey" AS "c_nationkey"
|
|
||||||
FROM "customer" AS "customer"
|
|
||||||
) AS "customer"
|
|
||||||
JOIN (
|
|
||||||
SELECT
|
|
||||||
"orders"."o_orderkey" AS "o_orderkey",
|
|
||||||
"orders"."o_custkey" AS "o_custkey",
|
|
||||||
"orders"."o_orderdate" AS "o_orderdate"
|
|
||||||
FROM "orders" AS "orders"
|
|
||||||
WHERE
|
|
||||||
"orders"."o_orderdate" < CAST('1995-01-01' AS DATE)
|
|
||||||
AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE)
|
|
||||||
) AS "orders"
|
|
||||||
ON "customer"."c_custkey" = "orders"."o_custkey"
|
ON "customer"."c_custkey" = "orders"."o_custkey"
|
||||||
CROSS JOIN (
|
CROSS JOIN "region" AS "region"
|
||||||
SELECT
|
JOIN "nation" AS "nation"
|
||||||
"region"."r_regionkey" AS "r_regionkey",
|
|
||||||
"region"."r_name" AS "r_name"
|
|
||||||
FROM "region" AS "region"
|
|
||||||
WHERE
|
|
||||||
"region"."r_name" = 'ASIA'
|
|
||||||
) AS "region"
|
|
||||||
JOIN (
|
|
||||||
SELECT
|
|
||||||
"nation"."n_nationkey" AS "n_nationkey",
|
|
||||||
"nation"."n_name" AS "n_name",
|
|
||||||
"nation"."n_regionkey" AS "n_regionkey"
|
|
||||||
FROM "nation" AS "nation"
|
|
||||||
) AS "nation"
|
|
||||||
ON "nation"."n_regionkey" = "region"."r_regionkey"
|
ON "nation"."n_regionkey" = "region"."r_regionkey"
|
||||||
JOIN (
|
JOIN "supplier" AS "supplier"
|
||||||
SELECT
|
|
||||||
"supplier"."s_suppkey" AS "s_suppkey",
|
|
||||||
"supplier"."s_nationkey" AS "s_nationkey"
|
|
||||||
FROM "supplier" AS "supplier"
|
|
||||||
) AS "supplier"
|
|
||||||
ON "customer"."c_nationkey" = "supplier"."s_nationkey"
|
ON "customer"."c_nationkey" = "supplier"."s_nationkey"
|
||||||
AND "supplier"."s_nationkey" = "nation"."n_nationkey"
|
AND "supplier"."s_nationkey" = "nation"."n_nationkey"
|
||||||
JOIN (
|
JOIN "lineitem" AS "lineitem"
|
||||||
SELECT
|
|
||||||
"lineitem"."l_orderkey" AS "l_orderkey",
|
|
||||||
"lineitem"."l_suppkey" AS "l_suppkey",
|
|
||||||
"lineitem"."l_extendedprice" AS "l_extendedprice",
|
|
||||||
"lineitem"."l_discount" AS "l_discount"
|
|
||||||
FROM "lineitem" AS "lineitem"
|
|
||||||
) AS "lineitem"
|
|
||||||
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
|
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
|
||||||
AND "lineitem"."l_suppkey" = "supplier"."s_suppkey"
|
AND "lineitem"."l_suppkey" = "supplier"."s_suppkey"
|
||||||
|
WHERE
|
||||||
|
"orders"."o_orderdate" < CAST('1995-01-01' AS DATE)
|
||||||
|
AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE)
|
||||||
|
AND "region"."r_name" = 'ASIA'
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"nation"."n_name"
|
"nation"."n_name"
|
||||||
ORDER BY
|
ORDER BY
|
||||||
|
@ -471,53 +381,22 @@ WITH "_e_0" AS (
|
||||||
OR "nation"."n_name" = 'GERMANY'
|
OR "nation"."n_name" = 'GERMANY'
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
"shipping"."supp_nation" AS "supp_nation",
|
|
||||||
"shipping"."cust_nation" AS "cust_nation",
|
|
||||||
"shipping"."l_year" AS "l_year",
|
|
||||||
SUM("shipping"."volume") AS "revenue"
|
|
||||||
FROM (
|
|
||||||
SELECT
|
|
||||||
"n1"."n_name" AS "supp_nation",
|
"n1"."n_name" AS "supp_nation",
|
||||||
"n2"."n_name" AS "cust_nation",
|
"n2"."n_name" AS "cust_nation",
|
||||||
EXTRACT(year FROM "lineitem"."l_shipdate") AS "l_year",
|
EXTRACT(year FROM "lineitem"."l_shipdate") AS "l_year",
|
||||||
"lineitem"."l_extendedprice" * (
|
SUM("lineitem"."l_extendedprice" * (
|
||||||
1 - "lineitem"."l_discount"
|
1 - "lineitem"."l_discount"
|
||||||
) AS "volume"
|
)) AS "revenue"
|
||||||
FROM (
|
FROM "supplier" AS "supplier"
|
||||||
SELECT
|
JOIN "lineitem" AS "lineitem"
|
||||||
"supplier"."s_suppkey" AS "s_suppkey",
|
|
||||||
"supplier"."s_nationkey" AS "s_nationkey"
|
|
||||||
FROM "supplier" AS "supplier"
|
|
||||||
) AS "supplier"
|
|
||||||
JOIN (
|
|
||||||
SELECT
|
|
||||||
"lineitem"."l_orderkey" AS "l_orderkey",
|
|
||||||
"lineitem"."l_suppkey" AS "l_suppkey",
|
|
||||||
"lineitem"."l_extendedprice" AS "l_extendedprice",
|
|
||||||
"lineitem"."l_discount" AS "l_discount",
|
|
||||||
"lineitem"."l_shipdate" AS "l_shipdate"
|
|
||||||
FROM "lineitem" AS "lineitem"
|
|
||||||
WHERE
|
|
||||||
"lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
|
|
||||||
) AS "lineitem"
|
|
||||||
ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
|
ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
|
||||||
JOIN (
|
JOIN "orders" AS "orders"
|
||||||
SELECT
|
|
||||||
"orders"."o_orderkey" AS "o_orderkey",
|
|
||||||
"orders"."o_custkey" AS "o_custkey"
|
|
||||||
FROM "orders" AS "orders"
|
|
||||||
) AS "orders"
|
|
||||||
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
|
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
|
||||||
JOIN (
|
JOIN "customer" AS "customer"
|
||||||
SELECT
|
|
||||||
"customer"."c_custkey" AS "c_custkey",
|
|
||||||
"customer"."c_nationkey" AS "c_nationkey"
|
|
||||||
FROM "customer" AS "customer"
|
|
||||||
) AS "customer"
|
|
||||||
ON "customer"."c_custkey" = "orders"."o_custkey"
|
ON "customer"."c_custkey" = "orders"."o_custkey"
|
||||||
JOIN "_e_0" AS "n1"
|
JOIN "_e_0" AS "n1"
|
||||||
ON "supplier"."s_nationkey" = "n1"."n_nationkey"
|
ON "supplier"."s_nationkey" = "n1"."n_nationkey"
|
||||||
JOIN "_e_0" AS "n2"
|
JOIN "_e_0" AS "n2"
|
||||||
ON "customer"."c_nationkey" = "n2"."n_nationkey"
|
ON "customer"."c_nationkey" = "n2"."n_nationkey"
|
||||||
AND (
|
AND (
|
||||||
"n1"."n_name" = 'FRANCE'
|
"n1"."n_name" = 'FRANCE'
|
||||||
|
@ -527,11 +406,12 @@ FROM (
|
||||||
"n1"."n_name" = 'GERMANY'
|
"n1"."n_name" = 'GERMANY'
|
||||||
OR "n2"."n_name" = 'GERMANY'
|
OR "n2"."n_name" = 'GERMANY'
|
||||||
)
|
)
|
||||||
) AS "shipping"
|
WHERE
|
||||||
|
"lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"shipping"."supp_nation",
|
"n1"."n_name",
|
||||||
"shipping"."cust_nation",
|
"n2"."n_name",
|
||||||
"shipping"."l_year"
|
EXTRACT(year FROM "lineitem"."l_shipdate")
|
||||||
ORDER BY
|
ORDER BY
|
||||||
"supp_nation",
|
"supp_nation",
|
||||||
"cust_nation",
|
"cust_nation",
|
||||||
|
@ -578,87 +458,37 @@ group by
|
||||||
order by
|
order by
|
||||||
o_year;
|
o_year;
|
||||||
SELECT
|
SELECT
|
||||||
"all_nations"."o_year" AS "o_year",
|
|
||||||
SUM(CASE
|
|
||||||
WHEN "all_nations"."nation" = 'BRAZIL'
|
|
||||||
THEN "all_nations"."volume"
|
|
||||||
ELSE 0
|
|
||||||
END) / SUM("all_nations"."volume") AS "mkt_share"
|
|
||||||
FROM (
|
|
||||||
SELECT
|
|
||||||
EXTRACT(year FROM "orders"."o_orderdate") AS "o_year",
|
EXTRACT(year FROM "orders"."o_orderdate") AS "o_year",
|
||||||
"lineitem"."l_extendedprice" * (
|
SUM(CASE
|
||||||
|
WHEN "nation_2"."n_name" = 'BRAZIL'
|
||||||
|
THEN "lineitem"."l_extendedprice" * (
|
||||||
1 - "lineitem"."l_discount"
|
1 - "lineitem"."l_discount"
|
||||||
) AS "volume",
|
)
|
||||||
"n2"."n_name" AS "nation"
|
ELSE 0
|
||||||
FROM (
|
END) / SUM("lineitem"."l_extendedprice" * (
|
||||||
SELECT
|
1 - "lineitem"."l_discount"
|
||||||
"part"."p_partkey" AS "p_partkey",
|
)) AS "mkt_share"
|
||||||
"part"."p_type" AS "p_type"
|
FROM "part" AS "part"
|
||||||
FROM "part" AS "part"
|
CROSS JOIN "region" AS "region"
|
||||||
WHERE
|
JOIN "nation" AS "nation"
|
||||||
"part"."p_type" = 'ECONOMY ANODIZED STEEL'
|
ON "nation"."n_regionkey" = "region"."r_regionkey"
|
||||||
) AS "part"
|
JOIN "customer" AS "customer"
|
||||||
CROSS JOIN (
|
ON "customer"."c_nationkey" = "nation"."n_nationkey"
|
||||||
SELECT
|
JOIN "orders" AS "orders"
|
||||||
"region"."r_regionkey" AS "r_regionkey",
|
|
||||||
"region"."r_name" AS "r_name"
|
|
||||||
FROM "region" AS "region"
|
|
||||||
WHERE
|
|
||||||
"region"."r_name" = 'AMERICA'
|
|
||||||
) AS "region"
|
|
||||||
JOIN (
|
|
||||||
SELECT
|
|
||||||
"nation"."n_nationkey" AS "n_nationkey",
|
|
||||||
"nation"."n_regionkey" AS "n_regionkey"
|
|
||||||
FROM "nation" AS "nation"
|
|
||||||
) AS "n1"
|
|
||||||
ON "n1"."n_regionkey" = "region"."r_regionkey"
|
|
||||||
JOIN (
|
|
||||||
SELECT
|
|
||||||
"customer"."c_custkey" AS "c_custkey",
|
|
||||||
"customer"."c_nationkey" AS "c_nationkey"
|
|
||||||
FROM "customer" AS "customer"
|
|
||||||
) AS "customer"
|
|
||||||
ON "customer"."c_nationkey" = "n1"."n_nationkey"
|
|
||||||
JOIN (
|
|
||||||
SELECT
|
|
||||||
"orders"."o_orderkey" AS "o_orderkey",
|
|
||||||
"orders"."o_custkey" AS "o_custkey",
|
|
||||||
"orders"."o_orderdate" AS "o_orderdate"
|
|
||||||
FROM "orders" AS "orders"
|
|
||||||
WHERE
|
|
||||||
"orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
|
|
||||||
) AS "orders"
|
|
||||||
ON "orders"."o_custkey" = "customer"."c_custkey"
|
ON "orders"."o_custkey" = "customer"."c_custkey"
|
||||||
JOIN (
|
JOIN "lineitem" AS "lineitem"
|
||||||
SELECT
|
|
||||||
"lineitem"."l_orderkey" AS "l_orderkey",
|
|
||||||
"lineitem"."l_partkey" AS "l_partkey",
|
|
||||||
"lineitem"."l_suppkey" AS "l_suppkey",
|
|
||||||
"lineitem"."l_extendedprice" AS "l_extendedprice",
|
|
||||||
"lineitem"."l_discount" AS "l_discount"
|
|
||||||
FROM "lineitem" AS "lineitem"
|
|
||||||
) AS "lineitem"
|
|
||||||
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
|
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
|
||||||
AND "part"."p_partkey" = "lineitem"."l_partkey"
|
AND "part"."p_partkey" = "lineitem"."l_partkey"
|
||||||
JOIN (
|
JOIN "supplier" AS "supplier"
|
||||||
SELECT
|
|
||||||
"supplier"."s_suppkey" AS "s_suppkey",
|
|
||||||
"supplier"."s_nationkey" AS "s_nationkey"
|
|
||||||
FROM "supplier" AS "supplier"
|
|
||||||
) AS "supplier"
|
|
||||||
ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
|
ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
|
||||||
JOIN (
|
JOIN "nation" AS "nation_2"
|
||||||
SELECT
|
ON "supplier"."s_nationkey" = "nation_2"."n_nationkey"
|
||||||
"nation"."n_nationkey" AS "n_nationkey",
|
WHERE
|
||||||
"nation"."n_name" AS "n_name"
|
"orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
|
||||||
FROM "nation" AS "nation"
|
AND "part"."p_type" = 'ECONOMY ANODIZED STEEL'
|
||||||
) AS "n2"
|
AND "region"."r_name" = 'AMERICA'
|
||||||
ON "supplier"."s_nationkey" = "n2"."n_nationkey"
|
|
||||||
) AS "all_nations"
|
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"all_nations"."o_year"
|
EXTRACT(year FROM "orders"."o_orderdate")
|
||||||
ORDER BY
|
ORDER BY
|
||||||
"o_year";
|
"o_year";
|
||||||
|
|
||||||
|
@ -698,69 +528,28 @@ order by
|
||||||
nation,
|
nation,
|
||||||
o_year desc;
|
o_year desc;
|
||||||
SELECT
|
SELECT
|
||||||
"profit"."nation" AS "nation",
|
|
||||||
"profit"."o_year" AS "o_year",
|
|
||||||
SUM("profit"."amount") AS "sum_profit"
|
|
||||||
FROM (
|
|
||||||
SELECT
|
|
||||||
"nation"."n_name" AS "nation",
|
"nation"."n_name" AS "nation",
|
||||||
EXTRACT(year FROM "orders"."o_orderdate") AS "o_year",
|
EXTRACT(year FROM "orders"."o_orderdate") AS "o_year",
|
||||||
"lineitem"."l_extendedprice" * (
|
SUM("lineitem"."l_extendedprice" * (
|
||||||
1 - "lineitem"."l_discount"
|
1 - "lineitem"."l_discount"
|
||||||
) - "partsupp"."ps_supplycost" * "lineitem"."l_quantity" AS "amount"
|
) - "partsupp"."ps_supplycost" * "lineitem"."l_quantity") AS "sum_profit"
|
||||||
FROM (
|
FROM "part" AS "part"
|
||||||
SELECT
|
JOIN "lineitem" AS "lineitem"
|
||||||
"part"."p_partkey" AS "p_partkey",
|
|
||||||
"part"."p_name" AS "p_name"
|
|
||||||
FROM "part" AS "part"
|
|
||||||
WHERE
|
|
||||||
"part"."p_name" LIKE '%green%'
|
|
||||||
) AS "part"
|
|
||||||
JOIN (
|
|
||||||
SELECT
|
|
||||||
"lineitem"."l_orderkey" AS "l_orderkey",
|
|
||||||
"lineitem"."l_partkey" AS "l_partkey",
|
|
||||||
"lineitem"."l_suppkey" AS "l_suppkey",
|
|
||||||
"lineitem"."l_quantity" AS "l_quantity",
|
|
||||||
"lineitem"."l_extendedprice" AS "l_extendedprice",
|
|
||||||
"lineitem"."l_discount" AS "l_discount"
|
|
||||||
FROM "lineitem" AS "lineitem"
|
|
||||||
) AS "lineitem"
|
|
||||||
ON "part"."p_partkey" = "lineitem"."l_partkey"
|
ON "part"."p_partkey" = "lineitem"."l_partkey"
|
||||||
JOIN (
|
JOIN "supplier" AS "supplier"
|
||||||
SELECT
|
|
||||||
"supplier"."s_suppkey" AS "s_suppkey",
|
|
||||||
"supplier"."s_nationkey" AS "s_nationkey"
|
|
||||||
FROM "supplier" AS "supplier"
|
|
||||||
) AS "supplier"
|
|
||||||
ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
|
ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
|
||||||
JOIN (
|
JOIN "partsupp" AS "partsupp"
|
||||||
SELECT
|
|
||||||
"partsupp"."ps_partkey" AS "ps_partkey",
|
|
||||||
"partsupp"."ps_suppkey" AS "ps_suppkey",
|
|
||||||
"partsupp"."ps_supplycost" AS "ps_supplycost"
|
|
||||||
FROM "partsupp" AS "partsupp"
|
|
||||||
) AS "partsupp"
|
|
||||||
ON "partsupp"."ps_partkey" = "lineitem"."l_partkey"
|
ON "partsupp"."ps_partkey" = "lineitem"."l_partkey"
|
||||||
AND "partsupp"."ps_suppkey" = "lineitem"."l_suppkey"
|
AND "partsupp"."ps_suppkey" = "lineitem"."l_suppkey"
|
||||||
JOIN (
|
JOIN "orders" AS "orders"
|
||||||
SELECT
|
|
||||||
"orders"."o_orderkey" AS "o_orderkey",
|
|
||||||
"orders"."o_orderdate" AS "o_orderdate"
|
|
||||||
FROM "orders" AS "orders"
|
|
||||||
) AS "orders"
|
|
||||||
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
|
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
|
||||||
JOIN (
|
JOIN "nation" AS "nation"
|
||||||
SELECT
|
|
||||||
"nation"."n_nationkey" AS "n_nationkey",
|
|
||||||
"nation"."n_name" AS "n_name"
|
|
||||||
FROM "nation" AS "nation"
|
|
||||||
) AS "nation"
|
|
||||||
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
|
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
|
||||||
) AS "profit"
|
WHERE
|
||||||
|
"part"."p_name" LIKE '%green%'
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"profit"."nation",
|
"nation"."n_name",
|
||||||
"profit"."o_year"
|
EXTRACT(year FROM "orders"."o_orderdate")
|
||||||
ORDER BY
|
ORDER BY
|
||||||
"nation",
|
"nation",
|
||||||
"o_year" DESC;
|
"o_year" DESC;
|
||||||
|
@ -812,46 +601,17 @@ SELECT
|
||||||
"customer"."c_address" AS "c_address",
|
"customer"."c_address" AS "c_address",
|
||||||
"customer"."c_phone" AS "c_phone",
|
"customer"."c_phone" AS "c_phone",
|
||||||
"customer"."c_comment" AS "c_comment"
|
"customer"."c_comment" AS "c_comment"
|
||||||
FROM (
|
FROM "customer" AS "customer"
|
||||||
SELECT
|
JOIN "orders" AS "orders"
|
||||||
"customer"."c_custkey" AS "c_custkey",
|
|
||||||
"customer"."c_name" AS "c_name",
|
|
||||||
"customer"."c_address" AS "c_address",
|
|
||||||
"customer"."c_nationkey" AS "c_nationkey",
|
|
||||||
"customer"."c_phone" AS "c_phone",
|
|
||||||
"customer"."c_acctbal" AS "c_acctbal",
|
|
||||||
"customer"."c_comment" AS "c_comment"
|
|
||||||
FROM "customer" AS "customer"
|
|
||||||
) AS "customer"
|
|
||||||
JOIN (
|
|
||||||
SELECT
|
|
||||||
"orders"."o_orderkey" AS "o_orderkey",
|
|
||||||
"orders"."o_custkey" AS "o_custkey",
|
|
||||||
"orders"."o_orderdate" AS "o_orderdate"
|
|
||||||
FROM "orders" AS "orders"
|
|
||||||
WHERE
|
|
||||||
"orders"."o_orderdate" < CAST('1994-01-01' AS DATE)
|
|
||||||
AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE)
|
|
||||||
) AS "orders"
|
|
||||||
ON "customer"."c_custkey" = "orders"."o_custkey"
|
ON "customer"."c_custkey" = "orders"."o_custkey"
|
||||||
JOIN (
|
JOIN "lineitem" AS "lineitem"
|
||||||
SELECT
|
|
||||||
"lineitem"."l_orderkey" AS "l_orderkey",
|
|
||||||
"lineitem"."l_extendedprice" AS "l_extendedprice",
|
|
||||||
"lineitem"."l_discount" AS "l_discount",
|
|
||||||
"lineitem"."l_returnflag" AS "l_returnflag"
|
|
||||||
FROM "lineitem" AS "lineitem"
|
|
||||||
WHERE
|
|
||||||
"lineitem"."l_returnflag" = 'R'
|
|
||||||
) AS "lineitem"
|
|
||||||
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
|
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
|
||||||
JOIN (
|
JOIN "nation" AS "nation"
|
||||||
SELECT
|
|
||||||
"nation"."n_nationkey" AS "n_nationkey",
|
|
||||||
"nation"."n_name" AS "n_name"
|
|
||||||
FROM "nation" AS "nation"
|
|
||||||
) AS "nation"
|
|
||||||
ON "customer"."c_nationkey" = "nation"."n_nationkey"
|
ON "customer"."c_nationkey" = "nation"."n_nationkey"
|
||||||
|
WHERE
|
||||||
|
"lineitem"."l_returnflag" = 'R'
|
||||||
|
AND "orders"."o_orderdate" < CAST('1994-01-01' AS DATE)
|
||||||
|
AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE)
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"customer"."c_custkey",
|
"customer"."c_custkey",
|
||||||
"customer"."c_name",
|
"customer"."c_name",
|
||||||
|
@ -910,14 +670,7 @@ WITH "_e_0" AS (
|
||||||
SELECT
|
SELECT
|
||||||
"partsupp"."ps_partkey" AS "ps_partkey",
|
"partsupp"."ps_partkey" AS "ps_partkey",
|
||||||
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value"
|
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value"
|
||||||
FROM (
|
FROM "partsupp" AS "partsupp"
|
||||||
SELECT
|
|
||||||
"partsupp"."ps_partkey" AS "ps_partkey",
|
|
||||||
"partsupp"."ps_suppkey" AS "ps_suppkey",
|
|
||||||
"partsupp"."ps_availqty" AS "ps_availqty",
|
|
||||||
"partsupp"."ps_supplycost" AS "ps_supplycost"
|
|
||||||
FROM "partsupp" AS "partsupp"
|
|
||||||
) AS "partsupp"
|
|
||||||
JOIN "_e_0" AS "supplier"
|
JOIN "_e_0" AS "supplier"
|
||||||
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
|
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
|
||||||
JOIN "_e_1" AS "nation"
|
JOIN "_e_1" AS "nation"
|
||||||
|
@ -928,13 +681,7 @@ HAVING
|
||||||
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > (
|
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > (
|
||||||
SELECT
|
SELECT
|
||||||
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0"
|
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0"
|
||||||
FROM (
|
|
||||||
SELECT
|
|
||||||
"partsupp"."ps_suppkey" AS "ps_suppkey",
|
|
||||||
"partsupp"."ps_availqty" AS "ps_availqty",
|
|
||||||
"partsupp"."ps_supplycost" AS "ps_supplycost"
|
|
||||||
FROM "partsupp" AS "partsupp"
|
FROM "partsupp" AS "partsupp"
|
||||||
) AS "partsupp"
|
|
||||||
JOIN "_e_0" AS "supplier"
|
JOIN "_e_0" AS "supplier"
|
||||||
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
|
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
|
||||||
JOIN "_e_1" AS "nation"
|
JOIN "_e_1" AS "nation"
|
||||||
|
@ -988,28 +735,15 @@ SELECT
|
||||||
THEN 1
|
THEN 1
|
||||||
ELSE 0
|
ELSE 0
|
||||||
END) AS "low_line_count"
|
END) AS "low_line_count"
|
||||||
FROM (
|
FROM "orders" AS "orders"
|
||||||
SELECT
|
JOIN "lineitem" AS "lineitem"
|
||||||
"orders"."o_orderkey" AS "o_orderkey",
|
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
|
||||||
"orders"."o_orderpriority" AS "o_orderpriority"
|
WHERE
|
||||||
FROM "orders" AS "orders"
|
|
||||||
) AS "orders"
|
|
||||||
JOIN (
|
|
||||||
SELECT
|
|
||||||
"lineitem"."l_orderkey" AS "l_orderkey",
|
|
||||||
"lineitem"."l_shipdate" AS "l_shipdate",
|
|
||||||
"lineitem"."l_commitdate" AS "l_commitdate",
|
|
||||||
"lineitem"."l_receiptdate" AS "l_receiptdate",
|
|
||||||
"lineitem"."l_shipmode" AS "l_shipmode"
|
|
||||||
FROM "lineitem" AS "lineitem"
|
|
||||||
WHERE
|
|
||||||
"lineitem"."l_commitdate" < "lineitem"."l_receiptdate"
|
"lineitem"."l_commitdate" < "lineitem"."l_receiptdate"
|
||||||
AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE)
|
AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE)
|
||||||
AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE)
|
AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE)
|
||||||
AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate"
|
AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate"
|
||||||
AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP')
|
AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP')
|
||||||
) AS "lineitem"
|
|
||||||
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
|
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"lineitem"."l_shipmode"
|
"lineitem"."l_shipmode"
|
||||||
ORDER BY
|
ORDER BY
|
||||||
|
@ -1044,21 +778,10 @@ SELECT
|
||||||
FROM (
|
FROM (
|
||||||
SELECT
|
SELECT
|
||||||
COUNT("orders"."o_orderkey") AS "c_count"
|
COUNT("orders"."o_orderkey") AS "c_count"
|
||||||
FROM (
|
|
||||||
SELECT
|
|
||||||
"customer"."c_custkey" AS "c_custkey"
|
|
||||||
FROM "customer" AS "customer"
|
FROM "customer" AS "customer"
|
||||||
) AS "customer"
|
LEFT JOIN "orders" AS "orders"
|
||||||
LEFT JOIN (
|
|
||||||
SELECT
|
|
||||||
"orders"."o_orderkey" AS "o_orderkey",
|
|
||||||
"orders"."o_custkey" AS "o_custkey",
|
|
||||||
"orders"."o_comment" AS "o_comment"
|
|
||||||
FROM "orders" AS "orders"
|
|
||||||
WHERE
|
|
||||||
NOT "orders"."o_comment" LIKE '%special%requests%'
|
|
||||||
) AS "orders"
|
|
||||||
ON "customer"."c_custkey" = "orders"."o_custkey"
|
ON "customer"."c_custkey" = "orders"."o_custkey"
|
||||||
|
AND NOT "orders"."o_comment" LIKE '%special%requests%'
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"customer"."c_custkey"
|
"customer"."c_custkey"
|
||||||
) AS "c_orders"
|
) AS "c_orders"
|
||||||
|
@ -1094,24 +817,12 @@ SELECT
|
||||||
END) / SUM("lineitem"."l_extendedprice" * (
|
END) / SUM("lineitem"."l_extendedprice" * (
|
||||||
1 - "lineitem"."l_discount"
|
1 - "lineitem"."l_discount"
|
||||||
)) AS "promo_revenue"
|
)) AS "promo_revenue"
|
||||||
FROM (
|
FROM "lineitem" AS "lineitem"
|
||||||
SELECT
|
JOIN "part" AS "part"
|
||||||
"lineitem"."l_partkey" AS "l_partkey",
|
ON "lineitem"."l_partkey" = "part"."p_partkey"
|
||||||
"lineitem"."l_extendedprice" AS "l_extendedprice",
|
WHERE
|
||||||
"lineitem"."l_discount" AS "l_discount",
|
|
||||||
"lineitem"."l_shipdate" AS "l_shipdate"
|
|
||||||
FROM "lineitem" AS "lineitem"
|
|
||||||
WHERE
|
|
||||||
"lineitem"."l_shipdate" < CAST('1995-10-01' AS DATE)
|
"lineitem"."l_shipdate" < CAST('1995-10-01' AS DATE)
|
||||||
AND "lineitem"."l_shipdate" >= CAST('1995-09-01' AS DATE)
|
AND "lineitem"."l_shipdate" >= CAST('1995-09-01' AS DATE);
|
||||||
) AS "lineitem"
|
|
||||||
JOIN (
|
|
||||||
SELECT
|
|
||||||
"part"."p_partkey" AS "p_partkey",
|
|
||||||
"part"."p_type" AS "p_type"
|
|
||||||
FROM "part" AS "part"
|
|
||||||
) AS "part"
|
|
||||||
ON "lineitem"."l_partkey" = "part"."p_partkey";
|
|
||||||
|
|
||||||
--------------------------------------
|
--------------------------------------
|
||||||
-- TPC-H 15
|
-- TPC-H 15
|
||||||
|
@ -1165,14 +876,7 @@ SELECT
|
||||||
"supplier"."s_address" AS "s_address",
|
"supplier"."s_address" AS "s_address",
|
||||||
"supplier"."s_phone" AS "s_phone",
|
"supplier"."s_phone" AS "s_phone",
|
||||||
"revenue"."total_revenue" AS "total_revenue"
|
"revenue"."total_revenue" AS "total_revenue"
|
||||||
FROM (
|
FROM "supplier" AS "supplier"
|
||||||
SELECT
|
|
||||||
"supplier"."s_suppkey" AS "s_suppkey",
|
|
||||||
"supplier"."s_name" AS "s_name",
|
|
||||||
"supplier"."s_address" AS "s_address",
|
|
||||||
"supplier"."s_phone" AS "s_phone"
|
|
||||||
FROM "supplier" AS "supplier"
|
|
||||||
) AS "supplier"
|
|
||||||
JOIN "revenue"
|
JOIN "revenue"
|
||||||
ON "revenue"."total_revenue" = (
|
ON "revenue"."total_revenue" = (
|
||||||
SELECT
|
SELECT
|
||||||
|
@ -1221,12 +925,7 @@ SELECT
|
||||||
"part"."p_type" AS "p_type",
|
"part"."p_type" AS "p_type",
|
||||||
"part"."p_size" AS "p_size",
|
"part"."p_size" AS "p_size",
|
||||||
COUNT(DISTINCT "partsupp"."ps_suppkey") AS "supplier_cnt"
|
COUNT(DISTINCT "partsupp"."ps_suppkey") AS "supplier_cnt"
|
||||||
FROM (
|
FROM "partsupp" AS "partsupp"
|
||||||
SELECT
|
|
||||||
"partsupp"."ps_partkey" AS "ps_partkey",
|
|
||||||
"partsupp"."ps_suppkey" AS "ps_suppkey"
|
|
||||||
FROM "partsupp" AS "partsupp"
|
|
||||||
) AS "partsupp"
|
|
||||||
LEFT JOIN (
|
LEFT JOIN (
|
||||||
SELECT
|
SELECT
|
||||||
"supplier"."s_suppkey" AS "s_suppkey"
|
"supplier"."s_suppkey" AS "s_suppkey"
|
||||||
|
@ -1237,21 +936,13 @@ LEFT JOIN (
|
||||||
"supplier"."s_suppkey"
|
"supplier"."s_suppkey"
|
||||||
) AS "_u_0"
|
) AS "_u_0"
|
||||||
ON "partsupp"."ps_suppkey" = "_u_0"."s_suppkey"
|
ON "partsupp"."ps_suppkey" = "_u_0"."s_suppkey"
|
||||||
JOIN (
|
JOIN "part" AS "part"
|
||||||
SELECT
|
|
||||||
"part"."p_partkey" AS "p_partkey",
|
|
||||||
"part"."p_brand" AS "p_brand",
|
|
||||||
"part"."p_type" AS "p_type",
|
|
||||||
"part"."p_size" AS "p_size"
|
|
||||||
FROM "part" AS "part"
|
|
||||||
WHERE
|
|
||||||
"part"."p_brand" <> 'Brand#45'
|
|
||||||
AND "part"."p_size" IN (49, 14, 23, 45, 19, 3, 36, 9)
|
|
||||||
AND NOT "part"."p_type" LIKE 'MEDIUM POLISHED%'
|
|
||||||
) AS "part"
|
|
||||||
ON "part"."p_partkey" = "partsupp"."ps_partkey"
|
ON "part"."p_partkey" = "partsupp"."ps_partkey"
|
||||||
WHERE
|
WHERE
|
||||||
"_u_0"."s_suppkey" IS NULL
|
"_u_0"."s_suppkey" IS NULL
|
||||||
|
AND "part"."p_brand" <> 'Brand#45'
|
||||||
|
AND "part"."p_size" IN (49, 14, 23, 45, 19, 3, 36, 9)
|
||||||
|
AND NOT "part"."p_type" LIKE 'MEDIUM POLISHED%'
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"part"."p_brand",
|
"part"."p_brand",
|
||||||
"part"."p_type",
|
"part"."p_type",
|
||||||
|
@ -1284,23 +975,8 @@ where
|
||||||
);
|
);
|
||||||
SELECT
|
SELECT
|
||||||
SUM("lineitem"."l_extendedprice") / 7.0 AS "avg_yearly"
|
SUM("lineitem"."l_extendedprice") / 7.0 AS "avg_yearly"
|
||||||
FROM (
|
FROM "lineitem" AS "lineitem"
|
||||||
SELECT
|
JOIN "part" AS "part"
|
||||||
"lineitem"."l_partkey" AS "l_partkey",
|
|
||||||
"lineitem"."l_quantity" AS "l_quantity",
|
|
||||||
"lineitem"."l_extendedprice" AS "l_extendedprice"
|
|
||||||
FROM "lineitem" AS "lineitem"
|
|
||||||
) AS "lineitem"
|
|
||||||
JOIN (
|
|
||||||
SELECT
|
|
||||||
"part"."p_partkey" AS "p_partkey",
|
|
||||||
"part"."p_brand" AS "p_brand",
|
|
||||||
"part"."p_container" AS "p_container"
|
|
||||||
FROM "part" AS "part"
|
|
||||||
WHERE
|
|
||||||
"part"."p_brand" = 'Brand#23'
|
|
||||||
AND "part"."p_container" = 'MED BOX'
|
|
||||||
) AS "part"
|
|
||||||
ON "part"."p_partkey" = "lineitem"."l_partkey"
|
ON "part"."p_partkey" = "lineitem"."l_partkey"
|
||||||
LEFT JOIN (
|
LEFT JOIN (
|
||||||
SELECT
|
SELECT
|
||||||
|
@ -1313,6 +989,8 @@ LEFT JOIN (
|
||||||
ON "_u_0"."_u_1" = "part"."p_partkey"
|
ON "_u_0"."_u_1" = "part"."p_partkey"
|
||||||
WHERE
|
WHERE
|
||||||
"lineitem"."l_quantity" < "_u_0"."_col_0"
|
"lineitem"."l_quantity" < "_u_0"."_col_0"
|
||||||
|
AND "part"."p_brand" = 'Brand#23'
|
||||||
|
AND "part"."p_container" = 'MED BOX'
|
||||||
AND NOT "_u_0"."_u_1" IS NULL;
|
AND NOT "_u_0"."_u_1" IS NULL;
|
||||||
|
|
||||||
--------------------------------------
|
--------------------------------------
|
||||||
|
@ -1359,20 +1037,8 @@ SELECT
|
||||||
"orders"."o_orderdate" AS "o_orderdate",
|
"orders"."o_orderdate" AS "o_orderdate",
|
||||||
"orders"."o_totalprice" AS "o_totalprice",
|
"orders"."o_totalprice" AS "o_totalprice",
|
||||||
SUM("lineitem"."l_quantity") AS "_col_5"
|
SUM("lineitem"."l_quantity") AS "_col_5"
|
||||||
FROM (
|
FROM "customer" AS "customer"
|
||||||
SELECT
|
JOIN "orders" AS "orders"
|
||||||
"customer"."c_custkey" AS "c_custkey",
|
|
||||||
"customer"."c_name" AS "c_name"
|
|
||||||
FROM "customer" AS "customer"
|
|
||||||
) AS "customer"
|
|
||||||
JOIN (
|
|
||||||
SELECT
|
|
||||||
"orders"."o_orderkey" AS "o_orderkey",
|
|
||||||
"orders"."o_custkey" AS "o_custkey",
|
|
||||||
"orders"."o_totalprice" AS "o_totalprice",
|
|
||||||
"orders"."o_orderdate" AS "o_orderdate"
|
|
||||||
FROM "orders" AS "orders"
|
|
||||||
) AS "orders"
|
|
||||||
ON "customer"."c_custkey" = "orders"."o_custkey"
|
ON "customer"."c_custkey" = "orders"."o_custkey"
|
||||||
LEFT JOIN (
|
LEFT JOIN (
|
||||||
SELECT
|
SELECT
|
||||||
|
@ -1385,12 +1051,7 @@ LEFT JOIN (
|
||||||
SUM("lineitem"."l_quantity") > 300
|
SUM("lineitem"."l_quantity") > 300
|
||||||
) AS "_u_0"
|
) AS "_u_0"
|
||||||
ON "orders"."o_orderkey" = "_u_0"."l_orderkey"
|
ON "orders"."o_orderkey" = "_u_0"."l_orderkey"
|
||||||
JOIN (
|
JOIN "lineitem" AS "lineitem"
|
||||||
SELECT
|
|
||||||
"lineitem"."l_orderkey" AS "l_orderkey",
|
|
||||||
"lineitem"."l_quantity" AS "l_quantity"
|
|
||||||
FROM "lineitem" AS "lineitem"
|
|
||||||
) AS "lineitem"
|
|
||||||
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
|
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
|
||||||
WHERE
|
WHERE
|
||||||
NOT "_u_0"."l_orderkey" IS NULL
|
NOT "_u_0"."l_orderkey" IS NULL
|
||||||
|
@ -1447,24 +1108,8 @@ SELECT
|
||||||
SUM("lineitem"."l_extendedprice" * (
|
SUM("lineitem"."l_extendedprice" * (
|
||||||
1 - "lineitem"."l_discount"
|
1 - "lineitem"."l_discount"
|
||||||
)) AS "revenue"
|
)) AS "revenue"
|
||||||
FROM (
|
FROM "lineitem" AS "lineitem"
|
||||||
SELECT
|
JOIN "part" AS "part"
|
||||||
"lineitem"."l_partkey" AS "l_partkey",
|
|
||||||
"lineitem"."l_quantity" AS "l_quantity",
|
|
||||||
"lineitem"."l_extendedprice" AS "l_extendedprice",
|
|
||||||
"lineitem"."l_discount" AS "l_discount",
|
|
||||||
"lineitem"."l_shipinstruct" AS "l_shipinstruct",
|
|
||||||
"lineitem"."l_shipmode" AS "l_shipmode"
|
|
||||||
FROM "lineitem" AS "lineitem"
|
|
||||||
) AS "lineitem"
|
|
||||||
JOIN (
|
|
||||||
SELECT
|
|
||||||
"part"."p_partkey" AS "p_partkey",
|
|
||||||
"part"."p_brand" AS "p_brand",
|
|
||||||
"part"."p_size" AS "p_size",
|
|
||||||
"part"."p_container" AS "p_container"
|
|
||||||
FROM "part" AS "part"
|
|
||||||
) AS "part"
|
|
||||||
ON (
|
ON (
|
||||||
"part"."p_brand" = 'Brand#12'
|
"part"."p_brand" = 'Brand#12'
|
||||||
AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG')
|
AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG')
|
||||||
|
@ -1558,14 +1203,7 @@ order by
|
||||||
SELECT
|
SELECT
|
||||||
"supplier"."s_name" AS "s_name",
|
"supplier"."s_name" AS "s_name",
|
||||||
"supplier"."s_address" AS "s_address"
|
"supplier"."s_address" AS "s_address"
|
||||||
FROM (
|
FROM "supplier" AS "supplier"
|
||||||
SELECT
|
|
||||||
"supplier"."s_suppkey" AS "s_suppkey",
|
|
||||||
"supplier"."s_name" AS "s_name",
|
|
||||||
"supplier"."s_address" AS "s_address",
|
|
||||||
"supplier"."s_nationkey" AS "s_nationkey"
|
|
||||||
FROM "supplier" AS "supplier"
|
|
||||||
) AS "supplier"
|
|
||||||
LEFT JOIN (
|
LEFT JOIN (
|
||||||
SELECT
|
SELECT
|
||||||
"partsupp"."ps_suppkey" AS "ps_suppkey"
|
"partsupp"."ps_suppkey" AS "ps_suppkey"
|
||||||
|
@ -1604,17 +1242,11 @@ LEFT JOIN (
|
||||||
"partsupp"."ps_suppkey"
|
"partsupp"."ps_suppkey"
|
||||||
) AS "_u_4"
|
) AS "_u_4"
|
||||||
ON "supplier"."s_suppkey" = "_u_4"."ps_suppkey"
|
ON "supplier"."s_suppkey" = "_u_4"."ps_suppkey"
|
||||||
JOIN (
|
JOIN "nation" AS "nation"
|
||||||
SELECT
|
|
||||||
"nation"."n_nationkey" AS "n_nationkey",
|
|
||||||
"nation"."n_name" AS "n_name"
|
|
||||||
FROM "nation" AS "nation"
|
|
||||||
WHERE
|
|
||||||
"nation"."n_name" = 'CANADA'
|
|
||||||
) AS "nation"
|
|
||||||
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
|
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
|
||||||
WHERE
|
WHERE
|
||||||
NOT "_u_4"."ps_suppkey" IS NULL
|
"nation"."n_name" = 'CANADA'
|
||||||
|
AND NOT "_u_4"."ps_suppkey" IS NULL
|
||||||
ORDER BY
|
ORDER BY
|
||||||
"s_name";
|
"s_name";
|
||||||
|
|
||||||
|
@ -1665,24 +1297,9 @@ limit
|
||||||
SELECT
|
SELECT
|
||||||
"supplier"."s_name" AS "s_name",
|
"supplier"."s_name" AS "s_name",
|
||||||
COUNT(*) AS "numwait"
|
COUNT(*) AS "numwait"
|
||||||
FROM (
|
FROM "supplier" AS "supplier"
|
||||||
SELECT
|
JOIN "lineitem" AS "lineitem"
|
||||||
"supplier"."s_suppkey" AS "s_suppkey",
|
ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
|
||||||
"supplier"."s_name" AS "s_name",
|
|
||||||
"supplier"."s_nationkey" AS "s_nationkey"
|
|
||||||
FROM "supplier" AS "supplier"
|
|
||||||
) AS "supplier"
|
|
||||||
JOIN (
|
|
||||||
SELECT
|
|
||||||
"lineitem"."l_orderkey" AS "l_orderkey",
|
|
||||||
"lineitem"."l_suppkey" AS "l_suppkey",
|
|
||||||
"lineitem"."l_commitdate" AS "l_commitdate",
|
|
||||||
"lineitem"."l_receiptdate" AS "l_receiptdate"
|
|
||||||
FROM "lineitem" AS "lineitem"
|
|
||||||
WHERE
|
|
||||||
"lineitem"."l_receiptdate" > "lineitem"."l_commitdate"
|
|
||||||
) AS "l1"
|
|
||||||
ON "supplier"."s_suppkey" = "l1"."l_suppkey"
|
|
||||||
LEFT JOIN (
|
LEFT JOIN (
|
||||||
SELECT
|
SELECT
|
||||||
"l2"."l_orderkey" AS "l_orderkey",
|
"l2"."l_orderkey" AS "l_orderkey",
|
||||||
|
@ -1691,7 +1308,7 @@ LEFT JOIN (
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"l2"."l_orderkey"
|
"l2"."l_orderkey"
|
||||||
) AS "_u_0"
|
) AS "_u_0"
|
||||||
ON "_u_0"."l_orderkey" = "l1"."l_orderkey"
|
ON "_u_0"."l_orderkey" = "lineitem"."l_orderkey"
|
||||||
LEFT JOIN (
|
LEFT JOIN (
|
||||||
SELECT
|
SELECT
|
||||||
"l3"."l_orderkey" AS "l_orderkey",
|
"l3"."l_orderkey" AS "l_orderkey",
|
||||||
|
@ -1702,31 +1319,20 @@ LEFT JOIN (
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"l3"."l_orderkey"
|
"l3"."l_orderkey"
|
||||||
) AS "_u_2"
|
) AS "_u_2"
|
||||||
ON "_u_2"."l_orderkey" = "l1"."l_orderkey"
|
ON "_u_2"."l_orderkey" = "lineitem"."l_orderkey"
|
||||||
JOIN (
|
JOIN "orders" AS "orders"
|
||||||
SELECT
|
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
|
||||||
"orders"."o_orderkey" AS "o_orderkey",
|
JOIN "nation" AS "nation"
|
||||||
"orders"."o_orderstatus" AS "o_orderstatus"
|
|
||||||
FROM "orders" AS "orders"
|
|
||||||
WHERE
|
|
||||||
"orders"."o_orderstatus" = 'F'
|
|
||||||
) AS "orders"
|
|
||||||
ON "orders"."o_orderkey" = "l1"."l_orderkey"
|
|
||||||
JOIN (
|
|
||||||
SELECT
|
|
||||||
"nation"."n_nationkey" AS "n_nationkey",
|
|
||||||
"nation"."n_name" AS "n_name"
|
|
||||||
FROM "nation" AS "nation"
|
|
||||||
WHERE
|
|
||||||
"nation"."n_name" = 'SAUDI ARABIA'
|
|
||||||
) AS "nation"
|
|
||||||
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
|
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
|
||||||
WHERE
|
WHERE
|
||||||
(
|
(
|
||||||
"_u_2"."l_orderkey" IS NULL
|
"_u_2"."l_orderkey" IS NULL
|
||||||
OR NOT ARRAY_ANY("_u_2"."_u_3", "_x" -> "_x" <> "l1"."l_suppkey")
|
OR NOT ARRAY_ANY("_u_2"."_u_3", "_x" -> "_x" <> "lineitem"."l_suppkey")
|
||||||
)
|
)
|
||||||
AND ARRAY_ANY("_u_0"."_u_1", "_x" -> "_x" <> "l1"."l_suppkey")
|
AND "lineitem"."l_receiptdate" > "lineitem"."l_commitdate"
|
||||||
|
AND "nation"."n_name" = 'SAUDI ARABIA'
|
||||||
|
AND "orders"."o_orderstatus" = 'F'
|
||||||
|
AND ARRAY_ANY("_u_0"."_u_1", "_x" -> "_x" <> "lineitem"."l_suppkey")
|
||||||
AND NOT "_u_0"."l_orderkey" IS NULL
|
AND NOT "_u_0"."l_orderkey" IS NULL
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"supplier"."s_name"
|
"supplier"."s_name"
|
||||||
|
@ -1776,23 +1382,19 @@ group by
|
||||||
order by
|
order by
|
||||||
cntrycode;
|
cntrycode;
|
||||||
SELECT
|
SELECT
|
||||||
"custsale"."cntrycode" AS "cntrycode",
|
|
||||||
COUNT(*) AS "numcust",
|
|
||||||
SUM("custsale"."c_acctbal") AS "totacctbal"
|
|
||||||
FROM (
|
|
||||||
SELECT
|
|
||||||
SUBSTRING("customer"."c_phone", 1, 2) AS "cntrycode",
|
SUBSTRING("customer"."c_phone", 1, 2) AS "cntrycode",
|
||||||
"customer"."c_acctbal" AS "c_acctbal"
|
COUNT(*) AS "numcust",
|
||||||
FROM "customer" AS "customer"
|
SUM("customer"."c_acctbal") AS "totacctbal"
|
||||||
LEFT JOIN (
|
FROM "customer" AS "customer"
|
||||||
|
LEFT JOIN (
|
||||||
SELECT
|
SELECT
|
||||||
"orders"."o_custkey" AS "_u_1"
|
"orders"."o_custkey" AS "_u_1"
|
||||||
FROM "orders" AS "orders"
|
FROM "orders" AS "orders"
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"orders"."o_custkey"
|
"orders"."o_custkey"
|
||||||
) AS "_u_0"
|
) AS "_u_0"
|
||||||
ON "_u_0"."_u_1" = "customer"."c_custkey"
|
ON "_u_0"."_u_1" = "customer"."c_custkey"
|
||||||
WHERE
|
WHERE
|
||||||
"_u_0"."_u_1" IS NULL
|
"_u_0"."_u_1" IS NULL
|
||||||
AND "customer"."c_acctbal" > (
|
AND "customer"."c_acctbal" > (
|
||||||
SELECT
|
SELECT
|
||||||
|
@ -1803,8 +1405,7 @@ FROM (
|
||||||
AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
|
AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
|
||||||
)
|
)
|
||||||
AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
|
AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
|
||||||
) AS "custsale"
|
|
||||||
GROUP BY
|
GROUP BY
|
||||||
"custsale"."cntrycode"
|
SUBSTRING("customer"."c_phone", 1, 2)
|
||||||
ORDER BY
|
ORDER BY
|
||||||
"cntrycode";
|
"cntrycode";
|
||||||
|
|
|
@ -5,9 +5,7 @@ FIXTURES_DIR = os.path.join(FILE_DIR, "fixtures")
|
||||||
|
|
||||||
|
|
||||||
def _filter_comments(s):
|
def _filter_comments(s):
|
||||||
return "\n".join(
|
return "\n".join([line for line in s.splitlines() if line and not line.startswith("--")])
|
||||||
[line for line in s.splitlines() if line and not line.startswith("--")]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_meta(sql):
|
def _extract_meta(sql):
|
||||||
|
@ -23,9 +21,7 @@ def _extract_meta(sql):
|
||||||
|
|
||||||
|
|
||||||
def assert_logger_contains(message, logger, level="error"):
|
def assert_logger_contains(message, logger, level="error"):
|
||||||
output = "\n".join(
|
output = "\n".join(str(args[0][0]) for args in getattr(logger, level).call_args_list)
|
||||||
str(args[0][0]) for args in getattr(logger, level).call_args_list
|
|
||||||
)
|
|
||||||
assert message in output
|
assert message in output
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -46,10 +46,7 @@ class TestBuild(unittest.TestCase):
|
||||||
"SELECT x FROM tbl WHERE FALSE",
|
"SELECT x FROM tbl WHERE FALSE",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x")
|
lambda: select("x").from_("tbl").where("x > 0").where("x < 9", append=False),
|
||||||
.from_("tbl")
|
|
||||||
.where("x > 0")
|
|
||||||
.where("x < 9", append=False),
|
|
||||||
"SELECT x FROM tbl WHERE x < 9",
|
"SELECT x FROM tbl WHERE x < 9",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
@ -61,10 +58,7 @@ class TestBuild(unittest.TestCase):
|
||||||
"SELECT x, y FROM tbl GROUP BY x, y",
|
"SELECT x, y FROM tbl GROUP BY x, y",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x", "y", "z", "a")
|
lambda: select("x", "y", "z", "a").from_("tbl").group_by("x, y", "z").group_by("a"),
|
||||||
.from_("tbl")
|
|
||||||
.group_by("x, y", "z")
|
|
||||||
.group_by("a"),
|
|
||||||
"SELECT x, y, z, a FROM tbl GROUP BY x, y, z, a",
|
"SELECT x, y, z, a FROM tbl GROUP BY x, y, z, a",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
@ -85,9 +79,7 @@ class TestBuild(unittest.TestCase):
|
||||||
"SELECT x FROM tbl JOIN tbl2 ON tbl.y = tbl2.y",
|
"SELECT x FROM tbl JOIN tbl2 ON tbl.y = tbl2.y",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x")
|
lambda: select("x").from_("tbl").join("tbl2", on=["tbl.y = tbl2.y", "a = b"]),
|
||||||
.from_("tbl")
|
|
||||||
.join("tbl2", on=["tbl.y = tbl2.y", "a = b"]),
|
|
||||||
"SELECT x FROM tbl JOIN tbl2 ON tbl.y = tbl2.y AND a = b",
|
"SELECT x FROM tbl JOIN tbl2 ON tbl.y = tbl2.y AND a = b",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
@ -95,21 +87,15 @@ class TestBuild(unittest.TestCase):
|
||||||
"SELECT x FROM tbl LEFT OUTER JOIN tbl2",
|
"SELECT x FROM tbl LEFT OUTER JOIN tbl2",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x")
|
lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer"),
|
||||||
.from_("tbl")
|
|
||||||
.join(exp.Table(this="tbl2"), join_type="left outer"),
|
|
||||||
"SELECT x FROM tbl LEFT OUTER JOIN tbl2",
|
"SELECT x FROM tbl LEFT OUTER JOIN tbl2",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x")
|
lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"),
|
||||||
.from_("tbl")
|
|
||||||
.join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"),
|
|
||||||
"SELECT x FROM tbl LEFT OUTER JOIN tbl2 AS foo",
|
"SELECT x FROM tbl LEFT OUTER JOIN tbl2 AS foo",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x")
|
lambda: select("x").from_("tbl").join(select("y").from_("tbl2"), join_type="left outer"),
|
||||||
.from_("tbl")
|
|
||||||
.join(select("y").from_("tbl2"), join_type="left outer"),
|
|
||||||
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2)",
|
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2)",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
@ -132,9 +118,7 @@ class TestBuild(unittest.TestCase):
|
||||||
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2) AS aliased",
|
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2) AS aliased",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x")
|
lambda: select("x").from_("tbl").join(parse_one("left join x", into=exp.Join), on="a=b"),
|
||||||
.from_("tbl")
|
|
||||||
.join(parse_one("left join x", into=exp.Join), on="a=b"),
|
|
||||||
"SELECT x FROM tbl LEFT JOIN x ON a = b",
|
"SELECT x FROM tbl LEFT JOIN x ON a = b",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
@ -142,9 +126,7 @@ class TestBuild(unittest.TestCase):
|
||||||
"SELECT x FROM tbl LEFT JOIN x ON a = b",
|
"SELECT x FROM tbl LEFT JOIN x ON a = b",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x")
|
lambda: select("x").from_("tbl").join("select b from tbl2", on="a=b", join_type="left"),
|
||||||
.from_("tbl")
|
|
||||||
.join("select b from tbl2", on="a=b", join_type="left"),
|
|
||||||
"SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) ON a = b",
|
"SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) ON a = b",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
@ -159,10 +141,7 @@ class TestBuild(unittest.TestCase):
|
||||||
"SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) AS aliased ON a = b",
|
"SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) AS aliased ON a = b",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x", "COUNT(y)")
|
lambda: select("x", "COUNT(y)").from_("tbl").group_by("x").having("COUNT(y) > 0"),
|
||||||
.from_("tbl")
|
|
||||||
.group_by("x")
|
|
||||||
.having("COUNT(y) > 0"),
|
|
||||||
"SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 0",
|
"SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 0",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
@ -190,24 +169,15 @@ class TestBuild(unittest.TestCase):
|
||||||
"SELECT x FROM tbl SORT BY x, y DESC",
|
"SELECT x FROM tbl SORT BY x, y DESC",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x", "y", "z", "a")
|
lambda: select("x", "y", "z", "a").from_("tbl").order_by("x, y", "z").order_by("a"),
|
||||||
.from_("tbl")
|
|
||||||
.order_by("x, y", "z")
|
|
||||||
.order_by("a"),
|
|
||||||
"SELECT x, y, z, a FROM tbl ORDER BY x, y, z, a",
|
"SELECT x, y, z, a FROM tbl ORDER BY x, y, z, a",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x", "y", "z", "a")
|
lambda: select("x", "y", "z", "a").from_("tbl").cluster_by("x, y", "z").cluster_by("a"),
|
||||||
.from_("tbl")
|
|
||||||
.cluster_by("x, y", "z")
|
|
||||||
.cluster_by("a"),
|
|
||||||
"SELECT x, y, z, a FROM tbl CLUSTER BY x, y, z, a",
|
"SELECT x, y, z, a FROM tbl CLUSTER BY x, y, z, a",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x", "y", "z", "a")
|
lambda: select("x", "y", "z", "a").from_("tbl").sort_by("x, y", "z").sort_by("a"),
|
||||||
.from_("tbl")
|
|
||||||
.sort_by("x, y", "z")
|
|
||||||
.sort_by("a"),
|
|
||||||
"SELECT x, y, z, a FROM tbl SORT BY x, y, z, a",
|
"SELECT x, y, z, a FROM tbl SORT BY x, y, z, a",
|
||||||
),
|
),
|
||||||
(lambda: select("x").from_("tbl").limit(10), "SELECT x FROM tbl LIMIT 10"),
|
(lambda: select("x").from_("tbl").limit(10), "SELECT x FROM tbl LIMIT 10"),
|
||||||
|
@ -220,21 +190,15 @@ class TestBuild(unittest.TestCase):
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x")
|
lambda: select("x").from_("tbl").with_("tbl", as_="SELECT x FROM tbl2", recursive=True),
|
||||||
.from_("tbl")
|
|
||||||
.with_("tbl", as_="SELECT x FROM tbl2", recursive=True),
|
|
||||||
"WITH RECURSIVE tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
|
"WITH RECURSIVE tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x")
|
lambda: select("x").from_("tbl").with_("tbl", as_=select("x").from_("tbl2")),
|
||||||
.from_("tbl")
|
|
||||||
.with_("tbl", as_=select("x").from_("tbl2")),
|
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x")
|
lambda: select("x").from_("tbl").with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")),
|
||||||
.from_("tbl")
|
|
||||||
.with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")),
|
|
||||||
"WITH tbl(x, y) AS (SELECT x, y FROM tbl2) SELECT x FROM tbl",
|
"WITH tbl(x, y) AS (SELECT x, y FROM tbl2) SELECT x FROM tbl",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
@ -245,72 +209,43 @@ class TestBuild(unittest.TestCase):
|
||||||
"WITH tbl AS (SELECT x FROM tbl2), tbl2 AS (SELECT x FROM tbl3) SELECT x FROM tbl",
|
"WITH tbl AS (SELECT x FROM tbl2), tbl2 AS (SELECT x FROM tbl3) SELECT x FROM tbl",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x")
|
lambda: select("x").from_("tbl").with_("tbl", as_=select("x", "y").from_("tbl2")).select("y"),
|
||||||
.from_("tbl")
|
|
||||||
.with_("tbl", as_=select("x", "y").from_("tbl2"))
|
|
||||||
.select("y"),
|
|
||||||
"WITH tbl AS (SELECT x, y FROM tbl2) SELECT x, y FROM tbl",
|
"WITH tbl AS (SELECT x, y FROM tbl2) SELECT x, y FROM tbl",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x")
|
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl"),
|
||||||
.with_("tbl", as_=select("x").from_("tbl2"))
|
|
||||||
.from_("tbl"),
|
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x")
|
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").group_by("x"),
|
||||||
.with_("tbl", as_=select("x").from_("tbl2"))
|
|
||||||
.from_("tbl")
|
|
||||||
.group_by("x"),
|
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl GROUP BY x",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl GROUP BY x",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x")
|
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").order_by("x"),
|
||||||
.with_("tbl", as_=select("x").from_("tbl2"))
|
|
||||||
.from_("tbl")
|
|
||||||
.order_by("x"),
|
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl ORDER BY x",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl ORDER BY x",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x")
|
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").limit(10),
|
||||||
.with_("tbl", as_=select("x").from_("tbl2"))
|
|
||||||
.from_("tbl")
|
|
||||||
.limit(10),
|
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl LIMIT 10",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl LIMIT 10",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x")
|
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").offset(10),
|
||||||
.with_("tbl", as_=select("x").from_("tbl2"))
|
|
||||||
.from_("tbl")
|
|
||||||
.offset(10),
|
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl OFFSET 10",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl OFFSET 10",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x")
|
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").join("tbl3"),
|
||||||
.with_("tbl", as_=select("x").from_("tbl2"))
|
|
||||||
.from_("tbl")
|
|
||||||
.join("tbl3"),
|
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl JOIN tbl3",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl JOIN tbl3",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x")
|
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").distinct(),
|
||||||
.with_("tbl", as_=select("x").from_("tbl2"))
|
|
||||||
.from_("tbl")
|
|
||||||
.distinct(),
|
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT DISTINCT x FROM tbl",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT DISTINCT x FROM tbl",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x")
|
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").where("x > 10"),
|
||||||
.with_("tbl", as_=select("x").from_("tbl2"))
|
|
||||||
.from_("tbl")
|
|
||||||
.where("x > 10"),
|
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl WHERE x > 10",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl WHERE x > 10",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x")
|
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").having("x > 20"),
|
||||||
.with_("tbl", as_=select("x").from_("tbl2"))
|
|
||||||
.from_("tbl")
|
|
||||||
.having("x > 20"),
|
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl HAVING x > 20",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl HAVING x > 20",
|
||||||
),
|
),
|
||||||
(lambda: select("x").from_("tbl").subquery(), "(SELECT x FROM tbl)"),
|
(lambda: select("x").from_("tbl").subquery(), "(SELECT x FROM tbl)"),
|
||||||
|
@ -324,9 +259,7 @@ class TestBuild(unittest.TestCase):
|
||||||
),
|
),
|
||||||
(lambda: from_("tbl").select("x"), "SELECT x FROM tbl"),
|
(lambda: from_("tbl").select("x"), "SELECT x FROM tbl"),
|
||||||
(
|
(
|
||||||
lambda: parse_one("SELECT a FROM tbl")
|
lambda: parse_one("SELECT a FROM tbl").assert_is(exp.Select).select("b"),
|
||||||
.assert_is(exp.Select)
|
|
||||||
.select("b"),
|
|
||||||
"SELECT a, b FROM tbl",
|
"SELECT a, b FROM tbl",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
@ -368,15 +301,11 @@ class TestBuild(unittest.TestCase):
|
||||||
"SELECT * FROM x WHERE y = 1 AND z = 1",
|
"SELECT * FROM x WHERE y = 1 AND z = 1",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: exp.subquery("select x from tbl", "foo")
|
lambda: exp.subquery("select x from tbl", "foo").select("x").where("x > 0"),
|
||||||
.select("x")
|
|
||||||
.where("x > 0"),
|
|
||||||
"SELECT x FROM (SELECT x FROM tbl) AS foo WHERE x > 0",
|
"SELECT x FROM (SELECT x FROM tbl) AS foo WHERE x > 0",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: exp.subquery(
|
lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select("x"),
|
||||||
"select x from tbl UNION select x from bar", "unioned"
|
|
||||||
).select("x"),
|
|
||||||
"SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned",
|
"SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned",
|
||||||
),
|
),
|
||||||
]:
|
]:
|
||||||
|
|
|
@ -27,10 +27,7 @@ class TestExecutor(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
cls.cache = {}
|
cls.cache = {}
|
||||||
cls.sqls = [
|
cls.sqls = [(sql, expected) for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")]
|
||||||
(sql, expected)
|
|
||||||
for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
|
@ -50,7 +47,8 @@ class TestExecutor(unittest.TestCase):
|
||||||
self.assertEqual(Python().generate(parse_one("'x '''")), r"'x \''")
|
self.assertEqual(Python().generate(parse_one("'x '''")), r"'x \''")
|
||||||
|
|
||||||
def test_optimized_tpch(self):
|
def test_optimized_tpch(self):
|
||||||
for sql, optimized in self.sqls[0:20]:
|
for i, (sql, optimized) in enumerate(self.sqls[:20], start=1):
|
||||||
|
with self.subTest(f"{i}, {sql}"):
|
||||||
a = self.cached_execute(sql)
|
a = self.cached_execute(sql)
|
||||||
b = self.conn.execute(optimized).fetchdf()
|
b = self.conn.execute(optimized).fetchdf()
|
||||||
self.rename_anonymous(b, a)
|
self.rename_anonymous(b, a)
|
||||||
|
@ -59,9 +57,7 @@ class TestExecutor(unittest.TestCase):
|
||||||
def test_execute_tpch(self):
|
def test_execute_tpch(self):
|
||||||
def to_csv(expression):
|
def to_csv(expression):
|
||||||
if isinstance(expression, exp.Table):
|
if isinstance(expression, exp.Table):
|
||||||
return parse_one(
|
return parse_one(f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}")
|
||||||
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}"
|
|
||||||
)
|
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
for sql, _ in self.sqls[0:3]:
|
for sql, _ in self.sqls[0:3]:
|
||||||
|
|
|
@ -26,9 +26,7 @@ class TestExpressions(unittest.TestCase):
|
||||||
parse_one("ROW() OVER(Partition by y)"),
|
parse_one("ROW() OVER(Partition by y)"),
|
||||||
parse_one("ROW() OVER (partition BY y)"),
|
parse_one("ROW() OVER (partition BY y)"),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)"))
|
||||||
parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)")
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_find(self):
|
def test_find(self):
|
||||||
expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y")
|
expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y")
|
||||||
|
@ -87,9 +85,7 @@ class TestExpressions(unittest.TestCase):
|
||||||
self.assertIsNone(column.find_ancestor(exp.Join))
|
self.assertIsNone(column.find_ancestor(exp.Join))
|
||||||
|
|
||||||
def test_alias_or_name(self):
|
def test_alias_or_name(self):
|
||||||
expression = parse_one(
|
expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
|
||||||
"SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
[e.alias_or_name for e in expression.expressions],
|
[e.alias_or_name for e in expression.expressions],
|
||||||
["a", "B", "e", "*", "zz", "z"],
|
["a", "B", "e", "*", "zz", "z"],
|
||||||
|
@ -118,9 +114,7 @@ class TestExpressions(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_named_selects(self):
|
def test_named_selects(self):
|
||||||
expression = parse_one(
|
expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
|
||||||
"SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
|
|
||||||
)
|
|
||||||
self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"])
|
self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"])
|
||||||
|
|
||||||
expression = parse_one(
|
expression = parse_one(
|
||||||
|
@ -196,15 +190,9 @@ class TestExpressions(unittest.TestCase):
|
||||||
|
|
||||||
def test_sql(self):
|
def test_sql(self):
|
||||||
self.assertEqual(parse_one("x + y * 2").sql(), "x + y * 2")
|
self.assertEqual(parse_one("x + y * 2").sql(), "x + y * 2")
|
||||||
self.assertEqual(
|
self.assertEqual(parse_one('select "x"').sql(dialect="hive", pretty=True), "SELECT\n `x`")
|
||||||
parse_one('select "x"').sql(dialect="hive", pretty=True), "SELECT\n `x`"
|
self.assertEqual(parse_one("X + y").sql(identify=True, normalize=True), '"x" + "y"')
|
||||||
)
|
self.assertEqual(parse_one("SUM(X)").sql(identify=True, normalize=True), 'SUM("x")')
|
||||||
self.assertEqual(
|
|
||||||
parse_one("X + y").sql(identify=True, normalize=True), '"x" + "y"'
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
parse_one("SUM(X)").sql(identify=True, normalize=True), 'SUM("x")'
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_transform_with_arguments(self):
|
def test_transform_with_arguments(self):
|
||||||
expression = parse_one("a")
|
expression = parse_one("a")
|
||||||
|
@ -229,15 +217,11 @@ class TestExpressions(unittest.TestCase):
|
||||||
return node
|
return node
|
||||||
|
|
||||||
actual_expression_1 = expression.transform(fun)
|
actual_expression_1 = expression.transform(fun)
|
||||||
self.assertEqual(
|
self.assertEqual(actual_expression_1.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)")
|
||||||
actual_expression_1.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)"
|
|
||||||
)
|
|
||||||
self.assertIsNot(actual_expression_1, expression)
|
self.assertIsNot(actual_expression_1, expression)
|
||||||
|
|
||||||
actual_expression_2 = expression.transform(fun, copy=False)
|
actual_expression_2 = expression.transform(fun, copy=False)
|
||||||
self.assertEqual(
|
self.assertEqual(actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)")
|
||||||
actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)"
|
|
||||||
)
|
|
||||||
self.assertIs(actual_expression_2, expression)
|
self.assertIs(actual_expression_2, expression)
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
|
@ -274,12 +258,8 @@ class TestExpressions(unittest.TestCase):
|
||||||
expression = parse_one("SELECT * FROM (SELECT * FROM x)")
|
expression = parse_one("SELECT * FROM (SELECT * FROM x)")
|
||||||
self.assertEqual(len(list(expression.walk())), 9)
|
self.assertEqual(len(list(expression.walk())), 9)
|
||||||
self.assertEqual(len(list(expression.walk(bfs=False))), 9)
|
self.assertEqual(len(list(expression.walk(bfs=False))), 9)
|
||||||
self.assertTrue(
|
self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk()))
|
||||||
all(isinstance(e, exp.Expression) for e, _, _ in expression.walk())
|
self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False)))
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False))
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_functions(self):
|
def test_functions(self):
|
||||||
self.assertIsInstance(parse_one("ABS(a)"), exp.Abs)
|
self.assertIsInstance(parse_one("ABS(a)"), exp.Abs)
|
||||||
|
@ -303,9 +283,7 @@ class TestExpressions(unittest.TestCase):
|
||||||
self.assertIsInstance(parse_one("IF(a, b, c)"), exp.If)
|
self.assertIsInstance(parse_one("IF(a, b, c)"), exp.If)
|
||||||
self.assertIsInstance(parse_one("INITCAP(a)"), exp.Initcap)
|
self.assertIsInstance(parse_one("INITCAP(a)"), exp.Initcap)
|
||||||
self.assertIsInstance(parse_one("JSON_EXTRACT(a, '$.name')"), exp.JSONExtract)
|
self.assertIsInstance(parse_one("JSON_EXTRACT(a, '$.name')"), exp.JSONExtract)
|
||||||
self.assertIsInstance(
|
self.assertIsInstance(parse_one("JSON_EXTRACT_SCALAR(a, '$.name')"), exp.JSONExtractScalar)
|
||||||
parse_one("JSON_EXTRACT_SCALAR(a, '$.name')"), exp.JSONExtractScalar
|
|
||||||
)
|
|
||||||
self.assertIsInstance(parse_one("LEAST(a, b)"), exp.Least)
|
self.assertIsInstance(parse_one("LEAST(a, b)"), exp.Least)
|
||||||
self.assertIsInstance(parse_one("LN(a)"), exp.Ln)
|
self.assertIsInstance(parse_one("LN(a)"), exp.Ln)
|
||||||
self.assertIsInstance(parse_one("LOG10(a)"), exp.Log10)
|
self.assertIsInstance(parse_one("LOG10(a)"), exp.Log10)
|
||||||
|
@ -334,6 +312,7 @@ class TestExpressions(unittest.TestCase):
|
||||||
self.assertIsInstance(parse_one("TIME_STR_TO_DATE(a)"), exp.TimeStrToDate)
|
self.assertIsInstance(parse_one("TIME_STR_TO_DATE(a)"), exp.TimeStrToDate)
|
||||||
self.assertIsInstance(parse_one("TIME_STR_TO_TIME(a)"), exp.TimeStrToTime)
|
self.assertIsInstance(parse_one("TIME_STR_TO_TIME(a)"), exp.TimeStrToTime)
|
||||||
self.assertIsInstance(parse_one("TIME_STR_TO_UNIX(a)"), exp.TimeStrToUnix)
|
self.assertIsInstance(parse_one("TIME_STR_TO_UNIX(a)"), exp.TimeStrToUnix)
|
||||||
|
self.assertIsInstance(parse_one("TRIM(LEADING 'b' FROM 'bla')"), exp.Trim)
|
||||||
self.assertIsInstance(parse_one("TS_OR_DS_ADD(a, 1, 'day')"), exp.TsOrDsAdd)
|
self.assertIsInstance(parse_one("TS_OR_DS_ADD(a, 1, 'day')"), exp.TsOrDsAdd)
|
||||||
self.assertIsInstance(parse_one("TS_OR_DS_TO_DATE(a)"), exp.TsOrDsToDate)
|
self.assertIsInstance(parse_one("TS_OR_DS_TO_DATE(a)"), exp.TsOrDsToDate)
|
||||||
self.assertIsInstance(parse_one("TS_OR_DS_TO_DATE_STR(a)"), exp.Substring)
|
self.assertIsInstance(parse_one("TS_OR_DS_TO_DATE_STR(a)"), exp.Substring)
|
||||||
|
@ -404,12 +383,8 @@ class TestExpressions(unittest.TestCase):
|
||||||
self.assertFalse(exp.to_identifier("x").quoted)
|
self.assertFalse(exp.to_identifier("x").quoted)
|
||||||
|
|
||||||
def test_function_normalizer(self):
|
def test_function_normalizer(self):
|
||||||
self.assertEqual(
|
self.assertEqual(parse_one("HELLO()").sql(normalize_functions="lower"), "hello()")
|
||||||
parse_one("HELLO()").sql(normalize_functions="lower"), "hello()"
|
self.assertEqual(parse_one("hello()").sql(normalize_functions="upper"), "HELLO()")
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
parse_one("hello()").sql(normalize_functions="upper"), "HELLO()"
|
|
||||||
)
|
|
||||||
self.assertEqual(parse_one("heLLO()").sql(normalize_functions=None), "heLLO()")
|
self.assertEqual(parse_one("heLLO()").sql(normalize_functions=None), "heLLO()")
|
||||||
self.assertEqual(parse_one("SUM(x)").sql(normalize_functions="lower"), "sum(x)")
|
self.assertEqual(parse_one("SUM(x)").sql(normalize_functions="lower"), "sum(x)")
|
||||||
self.assertEqual(parse_one("sum(x)").sql(normalize_functions="upper"), "SUM(x)")
|
self.assertEqual(parse_one("sum(x)").sql(normalize_functions="upper"), "SUM(x)")
|
||||||
|
|
|
@ -31,9 +31,7 @@ class TestOptimizer(unittest.TestCase):
|
||||||
dialect = meta.get("dialect")
|
dialect = meta.get("dialect")
|
||||||
with self.subTest(sql):
|
with self.subTest(sql):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
func(parse_one(sql, read=dialect), **kwargs).sql(
|
func(parse_one(sql, read=dialect), **kwargs).sql(pretty=pretty, dialect=dialect),
|
||||||
pretty=pretty, dialect=dialect
|
|
||||||
),
|
|
||||||
expected,
|
expected,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -86,9 +84,7 @@ class TestOptimizer(unittest.TestCase):
|
||||||
for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"):
|
for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"):
|
||||||
with self.subTest(sql):
|
with self.subTest(sql):
|
||||||
with self.assertRaises(OptimizeError):
|
with self.assertRaises(OptimizeError):
|
||||||
optimizer.qualify_columns.qualify_columns(
|
optimizer.qualify_columns.qualify_columns(parse_one(sql), schema=self.schema)
|
||||||
parse_one(sql), schema=self.schema
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_quote_identities(self):
|
def test_quote_identities(self):
|
||||||
self.check_file("quote_identities", optimizer.quote_identities.quote_identities)
|
self.check_file("quote_identities", optimizer.quote_identities.quote_identities)
|
||||||
|
@ -100,9 +96,7 @@ class TestOptimizer(unittest.TestCase):
|
||||||
expression = optimizer.pushdown_projections.pushdown_projections(expression)
|
expression = optimizer.pushdown_projections.pushdown_projections(expression)
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
self.check_file(
|
self.check_file("pushdown_projections", pushdown_projections, schema=self.schema)
|
||||||
"pushdown_projections", pushdown_projections, schema=self.schema
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_simplify(self):
|
def test_simplify(self):
|
||||||
self.check_file("simplify", optimizer.simplify.simplify)
|
self.check_file("simplify", optimizer.simplify.simplify)
|
||||||
|
@ -115,9 +109,7 @@ class TestOptimizer(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_pushdown_predicates(self):
|
def test_pushdown_predicates(self):
|
||||||
self.check_file(
|
self.check_file("pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates)
|
||||||
"pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_expand_multi_table_selects(self):
|
def test_expand_multi_table_selects(self):
|
||||||
self.check_file(
|
self.check_file(
|
||||||
|
@ -138,10 +130,17 @@ class TestOptimizer(unittest.TestCase):
|
||||||
pretty=True,
|
pretty=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_merge_derived_tables(self):
|
||||||
|
def optimize(expression, **kwargs):
|
||||||
|
expression = optimizer.qualify_tables.qualify_tables(expression)
|
||||||
|
expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
|
||||||
|
expression = optimizer.merge_derived_tables.merge_derived_tables(expression)
|
||||||
|
return expression
|
||||||
|
|
||||||
|
self.check_file("merge_derived_tables", optimize, schema=self.schema)
|
||||||
|
|
||||||
def test_tpch(self):
|
def test_tpch(self):
|
||||||
self.check_file(
|
self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True)
|
||||||
"tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_schema(self):
|
def test_schema(self):
|
||||||
schema = ensure_schema(
|
schema = ensure_schema(
|
||||||
|
@ -262,9 +261,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
||||||
self.assertEqual(len(scopes), 5)
|
self.assertEqual(len(scopes), 5)
|
||||||
self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x")
|
self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x")
|
||||||
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
|
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
|
||||||
self.assertEqual(
|
self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
|
||||||
scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b"
|
|
||||||
)
|
|
||||||
self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y")
|
self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y")
|
||||||
self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql())
|
self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql())
|
||||||
|
|
||||||
|
|
|
@ -16,28 +16,23 @@ class TestParser(unittest.TestCase):
|
||||||
self.assertIsInstance(parse_one("array<int>", into=exp.DataType), exp.DataType)
|
self.assertIsInstance(parse_one("array<int>", into=exp.DataType), exp.DataType)
|
||||||
|
|
||||||
def test_column(self):
|
def test_column(self):
|
||||||
columns = parse_one("select a, ARRAY[1] b, case when 1 then 1 end").find_all(
|
columns = parse_one("select a, ARRAY[1] b, case when 1 then 1 end").find_all(exp.Column)
|
||||||
exp.Column
|
|
||||||
)
|
|
||||||
assert len(list(columns)) == 1
|
assert len(list(columns)) == 1
|
||||||
|
|
||||||
self.assertIsNotNone(parse_one("date").find(exp.Column))
|
self.assertIsNotNone(parse_one("date").find(exp.Column))
|
||||||
|
|
||||||
def test_table(self):
|
def test_table(self):
|
||||||
tables = [
|
tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)]
|
||||||
t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)
|
|
||||||
]
|
|
||||||
self.assertEqual(tables, ["a", "b.c", "d"])
|
self.assertEqual(tables, ["a", "b.c", "d"])
|
||||||
|
|
||||||
def test_select(self):
|
def test_select(self):
|
||||||
self.assertIsNotNone(
|
self.assertIsNotNone(parse_one("select 1 natural"))
|
||||||
parse_one("select * from (select 1) x order by x.y").args["order"]
|
self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"])
|
||||||
)
|
self.assertIsNotNone(parse_one("select * from x where a = (select 1) order by x.y").args["order"])
|
||||||
self.assertIsNotNone(
|
self.assertEqual(len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1)
|
||||||
parse_one("select * from x where a = (select 1) order by x.y").args["order"]
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1
|
parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(),
|
||||||
|
"""SELECT * FROM x, z LATERAL VIEW EXPLODE(y) CROSS JOIN y""",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_command(self):
|
def test_command(self):
|
||||||
|
@ -72,12 +67,8 @@ class TestParser(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(expressions) == 2
|
assert len(expressions) == 2
|
||||||
assert (
|
assert expressions[0].args["from"].expressions[0].args["this"].args["this"] == "a"
|
||||||
expressions[0].args["from"].expressions[0].args["this"].args["this"] == "a"
|
assert expressions[1].args["from"].expressions[0].args["this"].args["this"] == "b"
|
||||||
)
|
|
||||||
assert (
|
|
||||||
expressions[1].args["from"].expressions[0].args["this"].args["this"] == "b"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_expression(self):
|
def test_expression(self):
|
||||||
ignore = Parser(error_level=ErrorLevel.IGNORE)
|
ignore = Parser(error_level=ErrorLevel.IGNORE)
|
||||||
|
@ -147,13 +138,9 @@ class TestParser(unittest.TestCase):
|
||||||
def test_pretty_config_override(self):
|
def test_pretty_config_override(self):
|
||||||
self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x")
|
self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x")
|
||||||
with patch("sqlglot.pretty", True):
|
with patch("sqlglot.pretty", True):
|
||||||
self.assertEqual(
|
self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT\n col\nFROM x")
|
||||||
parse_one("SELECT col FROM x").sql(), "SELECT\n col\nFROM x"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(parse_one("SELECT col FROM x").sql(pretty=True), "SELECT\n col\nFROM x")
|
||||||
parse_one("SELECT col FROM x").sql(pretty=True), "SELECT\n col\nFROM x"
|
|
||||||
)
|
|
||||||
|
|
||||||
@patch("sqlglot.parser.logger")
|
@patch("sqlglot.parser.logger")
|
||||||
def test_comment_error_n(self, logger):
|
def test_comment_error_n(self, logger):
|
||||||
|
|
|
@ -42,6 +42,20 @@ class TestTranspile(unittest.TestCase):
|
||||||
"SELECT * FROM x WHERE a = ANY (SELECT 1)",
|
"SELECT * FROM x WHERE a = ANY (SELECT 1)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_leading_comma(self):
|
||||||
|
self.validate(
|
||||||
|
"SELECT FOO, BAR, BAZ",
|
||||||
|
"SELECT\n FOO\n , BAR\n , BAZ",
|
||||||
|
leading_comma=True,
|
||||||
|
pretty=True,
|
||||||
|
)
|
||||||
|
# without pretty, this should be a no-op
|
||||||
|
self.validate(
|
||||||
|
"SELECT FOO, BAR, BAZ",
|
||||||
|
"SELECT FOO, BAR, BAZ",
|
||||||
|
leading_comma=True,
|
||||||
|
)
|
||||||
|
|
||||||
def test_space(self):
|
def test_space(self):
|
||||||
self.validate("SELECT MIN(3)>MIN(2)", "SELECT MIN(3) > MIN(2)")
|
self.validate("SELECT MIN(3)>MIN(2)", "SELECT MIN(3) > MIN(2)")
|
||||||
self.validate("SELECT MIN(3)>=MIN(2)", "SELECT MIN(3) >= MIN(2)")
|
self.validate("SELECT MIN(3)>=MIN(2)", "SELECT MIN(3) >= MIN(2)")
|
||||||
|
@ -108,6 +122,11 @@ class TestTranspile(unittest.TestCase):
|
||||||
"extract(month from '2021-01-31'::timestamp without time zone)",
|
"extract(month from '2021-01-31'::timestamp without time zone)",
|
||||||
"EXTRACT(month FROM CAST('2021-01-31' AS TIMESTAMP))",
|
"EXTRACT(month FROM CAST('2021-01-31' AS TIMESTAMP))",
|
||||||
)
|
)
|
||||||
|
self.validate("extract(week from current_date + 2)", "EXTRACT(week FROM CURRENT_DATE + 2)")
|
||||||
|
self.validate(
|
||||||
|
"EXTRACT(minute FROM datetime1 - datetime2)",
|
||||||
|
"EXTRACT(minute FROM datetime1 - datetime2)",
|
||||||
|
)
|
||||||
|
|
||||||
def test_if(self):
|
def test_if(self):
|
||||||
self.validate(
|
self.validate(
|
||||||
|
@ -122,18 +141,14 @@ class TestTranspile(unittest.TestCase):
|
||||||
"SELECT IF a > 1 THEN b ELSE c END",
|
"SELECT IF a > 1 THEN b ELSE c END",
|
||||||
"SELECT CASE WHEN a > 1 THEN b ELSE c END",
|
"SELECT CASE WHEN a > 1 THEN b ELSE c END",
|
||||||
)
|
)
|
||||||
self.validate(
|
self.validate("SELECT IF(a > 1, 1) FROM foo", "SELECT CASE WHEN a > 1 THEN 1 END FROM foo")
|
||||||
"SELECT IF(a > 1, 1) FROM foo", "SELECT CASE WHEN a > 1 THEN 1 END FROM foo"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_ignore_nulls(self):
|
def test_ignore_nulls(self):
|
||||||
self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)")
|
self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)")
|
||||||
|
|
||||||
def test_time(self):
|
def test_time(self):
|
||||||
self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)")
|
self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)")
|
||||||
self.validate(
|
self.validate("TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)")
|
||||||
"TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)"
|
|
||||||
)
|
|
||||||
self.validate(
|
self.validate(
|
||||||
"TIMESTAMP(9) WITH TIME ZONE '2020-01-01'",
|
"TIMESTAMP(9) WITH TIME ZONE '2020-01-01'",
|
||||||
"CAST('2020-01-01' AS TIMESTAMPTZ(9))",
|
"CAST('2020-01-01' AS TIMESTAMPTZ(9))",
|
||||||
|
@ -159,9 +174,7 @@ class TestTranspile(unittest.TestCase):
|
||||||
self.validate("DATE '2020-01-01'", "CAST('2020-01-01' AS DATE)")
|
self.validate("DATE '2020-01-01'", "CAST('2020-01-01' AS DATE)")
|
||||||
self.validate("'2020-01-01'::DATE", "CAST('2020-01-01' AS DATE)")
|
self.validate("'2020-01-01'::DATE", "CAST('2020-01-01' AS DATE)")
|
||||||
self.validate("STR_TO_TIME('x', 'y')", "STRPTIME('x', 'y')", write="duckdb")
|
self.validate("STR_TO_TIME('x', 'y')", "STRPTIME('x', 'y')", write="duckdb")
|
||||||
self.validate(
|
self.validate("STR_TO_UNIX('x', 'y')", "EPOCH(STRPTIME('x', 'y'))", write="duckdb")
|
||||||
"STR_TO_UNIX('x', 'y')", "EPOCH(STRPTIME('x', 'y'))", write="duckdb"
|
|
||||||
)
|
|
||||||
self.validate("TIME_TO_STR(x, 'y')", "STRFTIME(x, 'y')", write="duckdb")
|
self.validate("TIME_TO_STR(x, 'y')", "STRFTIME(x, 'y')", write="duckdb")
|
||||||
self.validate("TIME_TO_UNIX(x)", "EPOCH(x)", write="duckdb")
|
self.validate("TIME_TO_UNIX(x)", "EPOCH(x)", write="duckdb")
|
||||||
self.validate(
|
self.validate(
|
||||||
|
@ -209,12 +222,8 @@ class TestTranspile(unittest.TestCase):
|
||||||
self.validate("TIME_STR_TO_DATE(x)", "TIME_STR_TO_DATE(x)", write=None)
|
self.validate("TIME_STR_TO_DATE(x)", "TIME_STR_TO_DATE(x)", write=None)
|
||||||
|
|
||||||
self.validate("TIME_STR_TO_DATE(x)", "TO_DATE(x)", write="hive")
|
self.validate("TIME_STR_TO_DATE(x)", "TO_DATE(x)", write="hive")
|
||||||
self.validate(
|
self.validate("UNIX_TO_STR(x, 'yyyy-MM-dd HH:mm:ss')", "FROM_UNIXTIME(x)", write="hive")
|
||||||
"UNIX_TO_STR(x, 'yyyy-MM-dd HH:mm:ss')", "FROM_UNIXTIME(x)", write="hive"
|
self.validate("STR_TO_UNIX(x, 'yyyy-MM-dd HH:mm:ss')", "UNIX_TIMESTAMP(x)", write="hive")
|
||||||
)
|
|
||||||
self.validate(
|
|
||||||
"STR_TO_UNIX(x, 'yyyy-MM-dd HH:mm:ss')", "UNIX_TIMESTAMP(x)", write="hive"
|
|
||||||
)
|
|
||||||
self.validate("IF(x > 1, x + 1)", "IF(x > 1, x + 1)", write="presto")
|
self.validate("IF(x > 1, x + 1)", "IF(x > 1, x + 1)", write="presto")
|
||||||
self.validate("IF(x > 1, 1 + 1)", "IF(x > 1, 1 + 1)", write="hive")
|
self.validate("IF(x > 1, 1 + 1)", "IF(x > 1, 1 + 1)", write="hive")
|
||||||
self.validate("IF(x > 1, 1, 0)", "IF(x > 1, 1, 0)", write="hive")
|
self.validate("IF(x > 1, 1, 0)", "IF(x > 1, 1, 0)", write="hive")
|
||||||
|
@ -232,9 +241,7 @@ class TestTranspile(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.validate("STR_TO_TIME('x', 'y')", "DATE_PARSE('x', 'y')", write="presto")
|
self.validate("STR_TO_TIME('x', 'y')", "DATE_PARSE('x', 'y')", write="presto")
|
||||||
self.validate(
|
self.validate("STR_TO_UNIX('x', 'y')", "TO_UNIXTIME(DATE_PARSE('x', 'y'))", write="presto")
|
||||||
"STR_TO_UNIX('x', 'y')", "TO_UNIXTIME(DATE_PARSE('x', 'y'))", write="presto"
|
|
||||||
)
|
|
||||||
self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="presto")
|
self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="presto")
|
||||||
self.validate("TIME_TO_UNIX(x)", "TO_UNIXTIME(x)", write="presto")
|
self.validate("TIME_TO_UNIX(x)", "TO_UNIXTIME(x)", write="presto")
|
||||||
self.validate(
|
self.validate(
|
||||||
|
@ -245,9 +252,7 @@ class TestTranspile(unittest.TestCase):
|
||||||
self.validate("UNIX_TO_TIME(123)", "FROM_UNIXTIME(123)", write="presto")
|
self.validate("UNIX_TO_TIME(123)", "FROM_UNIXTIME(123)", write="presto")
|
||||||
|
|
||||||
self.validate("STR_TO_TIME('x', 'y')", "TO_TIMESTAMP('x', 'y')", write="spark")
|
self.validate("STR_TO_TIME('x', 'y')", "TO_TIMESTAMP('x', 'y')", write="spark")
|
||||||
self.validate(
|
self.validate("STR_TO_UNIX('x', 'y')", "UNIX_TIMESTAMP('x', 'y')", write="spark")
|
||||||
"STR_TO_UNIX('x', 'y')", "UNIX_TIMESTAMP('x', 'y')", write="spark"
|
|
||||||
)
|
|
||||||
self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="spark")
|
self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="spark")
|
||||||
|
|
||||||
self.validate(
|
self.validate(
|
||||||
|
@ -283,9 +288,7 @@ class TestTranspile(unittest.TestCase):
|
||||||
def test_partial(self):
|
def test_partial(self):
|
||||||
for sql in load_sql_fixtures("partial.sql"):
|
for sql in load_sql_fixtures("partial.sql"):
|
||||||
with self.subTest(sql):
|
with self.subTest(sql):
|
||||||
self.assertEqual(
|
self.assertEqual(transpile(sql, error_level=ErrorLevel.IGNORE)[0], sql.strip())
|
||||||
transpile(sql, error_level=ErrorLevel.IGNORE)[0], sql.strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_pretty(self):
|
def test_pretty(self):
|
||||||
for _, sql, pretty in load_sql_fixture_pairs("pretty.sql"):
|
for _, sql, pretty in load_sql_fixture_pairs("pretty.sql"):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue