Merging upstream version 6.3.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
81e6900b0a
commit
393757f998
41 changed files with 1558 additions and 267 deletions
|
@ -1,17 +1,55 @@
|
|||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import duckdb
|
||||
from pandas.testing import assert_frame_equal
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import exp, optimizer, parse_one, table
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
from sqlglot.optimizer.schema import MappingSchema, ensure_schema
|
||||
from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope
|
||||
from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures
|
||||
from tests.helpers import (
|
||||
TPCH_SCHEMA,
|
||||
load_sql_fixture_pairs,
|
||||
load_sql_fixtures,
|
||||
string_to_bool,
|
||||
)
|
||||
|
||||
|
||||
class TestOptimizer(unittest.TestCase):
|
||||
maxDiff = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.conn = duckdb.connect()
|
||||
cls.conn.execute(
|
||||
"""
|
||||
CREATE TABLE x (a INT, b INT);
|
||||
CREATE TABLE y (b INT, c INT);
|
||||
CREATE TABLE z (b INT, c INT);
|
||||
|
||||
INSERT INTO x VALUES (1, 1);
|
||||
INSERT INTO x VALUES (2, 2);
|
||||
INSERT INTO x VALUES (2, 2);
|
||||
INSERT INTO x VALUES (3, 3);
|
||||
INSERT INTO x VALUES (null, null);
|
||||
|
||||
INSERT INTO y VALUES (2, 2);
|
||||
INSERT INTO y VALUES (2, 2);
|
||||
INSERT INTO y VALUES (3, 3);
|
||||
INSERT INTO y VALUES (4, 4);
|
||||
INSERT INTO y VALUES (null, null);
|
||||
|
||||
INSERT INTO y VALUES (3, 3);
|
||||
INSERT INTO y VALUES (3, 3);
|
||||
INSERT INTO y VALUES (4, 4);
|
||||
INSERT INTO y VALUES (5, 5);
|
||||
INSERT INTO y VALUES (null, null);
|
||||
"""
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.schema = {
|
||||
"x": {
|
||||
|
@ -28,29 +66,42 @@ class TestOptimizer(unittest.TestCase):
|
|||
},
|
||||
}
|
||||
|
||||
def check_file(self, file, func, pretty=False, **kwargs):
|
||||
def check_file(self, file, func, pretty=False, execute=False, **kwargs):
|
||||
for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1):
|
||||
title = meta.get("title") or f"{i}, {sql}"
|
||||
dialect = meta.get("dialect")
|
||||
leave_tables_isolated = meta.get("leave_tables_isolated")
|
||||
|
||||
func_kwargs = {**kwargs}
|
||||
if leave_tables_isolated is not None:
|
||||
func_kwargs["leave_tables_isolated"] = leave_tables_isolated.lower() in ("true", "1")
|
||||
func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated)
|
||||
|
||||
with self.subTest(f"{i}, {sql}"):
|
||||
optimized = func(parse_one(sql, read=dialect), **func_kwargs)
|
||||
|
||||
with self.subTest(title):
|
||||
self.assertEqual(
|
||||
func(parse_one(sql, read=dialect), **func_kwargs).sql(pretty=pretty, dialect=dialect),
|
||||
optimized.sql(pretty=pretty, dialect=dialect),
|
||||
expected,
|
||||
)
|
||||
|
||||
should_execute = meta.get("execute")
|
||||
if should_execute is None:
|
||||
should_execute = execute
|
||||
|
||||
if string_to_bool(should_execute):
|
||||
with self.subTest(f"(execute) {title}"):
|
||||
df1 = self.conn.execute(sqlglot.transpile(sql, read=dialect, write="duckdb")[0]).df()
|
||||
df2 = self.conn.execute(optimized.sql(pretty=pretty, dialect="duckdb")).df()
|
||||
assert_frame_equal(df1, df2)
|
||||
|
||||
def test_optimize(self):
|
||||
schema = {
|
||||
"x": {"a": "INT", "b": "INT"},
|
||||
"y": {"a": "INT", "b": "INT"},
|
||||
"y": {"b": "INT", "c": "INT"},
|
||||
"z": {"a": "INT", "c": "INT"},
|
||||
}
|
||||
|
||||
self.check_file("optimizer", optimizer.optimize, pretty=True, schema=schema)
|
||||
self.check_file("optimizer", optimizer.optimize, pretty=True, execute=True, schema=schema)
|
||||
|
||||
def test_isolate_table_selects(self):
|
||||
self.check_file(
|
||||
|
@ -86,7 +137,16 @@ class TestOptimizer(unittest.TestCase):
|
|||
expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
|
||||
return expression
|
||||
|
||||
self.check_file("qualify_columns", qualify_columns, schema=self.schema)
|
||||
self.check_file("qualify_columns", qualify_columns, execute=True, schema=self.schema)
|
||||
|
||||
def test_qualify_columns__with_invisible(self):
|
||||
def qualify_columns(expression, **kwargs):
|
||||
expression = optimizer.qualify_tables.qualify_tables(expression)
|
||||
expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
|
||||
return expression
|
||||
|
||||
schema = MappingSchema(self.schema, {"x": {"a"}, "y": {"b"}, "z": {"b"}})
|
||||
self.check_file("qualify_columns__with_invisible", qualify_columns, schema=schema)
|
||||
|
||||
def test_qualify_columns__invalid(self):
|
||||
for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"):
|
||||
|
@ -141,7 +201,7 @@ class TestOptimizer(unittest.TestCase):
|
|||
],
|
||||
)
|
||||
|
||||
self.check_file("merge_subqueries", optimize, schema=self.schema)
|
||||
self.check_file("merge_subqueries", optimize, execute=True, schema=self.schema)
|
||||
|
||||
def test_eliminate_subqueries(self):
|
||||
self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries)
|
||||
|
@ -301,10 +361,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
}
|
||||
|
||||
for sql, target_type in tests.items():
|
||||
expression = parse_one(sql)
|
||||
annotated_expression = annotate_types(expression)
|
||||
|
||||
self.assertEqual(annotated_expression.find(exp.Literal).type, target_type)
|
||||
expression = annotate_types(parse_one(sql))
|
||||
self.assertEqual(expression.find(exp.Literal).type, target_type)
|
||||
|
||||
def test_boolean_type_annotation(self):
|
||||
tests = {
|
||||
|
@ -313,14 +371,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
}
|
||||
|
||||
for sql, target_type in tests.items():
|
||||
expression = parse_one(sql)
|
||||
annotated_expression = annotate_types(expression)
|
||||
|
||||
self.assertEqual(annotated_expression.find(exp.Boolean).type, target_type)
|
||||
expression = annotate_types(parse_one(sql))
|
||||
self.assertEqual(expression.find(exp.Boolean).type, target_type)
|
||||
|
||||
def test_cast_type_annotation(self):
|
||||
expression = parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))")
|
||||
annotate_types(expression)
|
||||
expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))"))
|
||||
|
||||
self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ)
|
||||
self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR)
|
||||
|
@ -328,16 +383,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT)
|
||||
|
||||
def test_cache_annotation(self):
|
||||
expression = parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")
|
||||
annotated_expression = annotate_types(expression)
|
||||
|
||||
self.assertEqual(annotated_expression.expression.expressions[0].type, exp.DataType.Type.INT)
|
||||
expression = annotate_types(parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1"))
|
||||
self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT)
|
||||
|
||||
def test_binary_annotation(self):
|
||||
expression = parse_one("SELECT 0.0 + (2 + 3)")
|
||||
annotate_types(expression)
|
||||
|
||||
expression = expression.expressions[0]
|
||||
expression = annotate_types(parse_one("SELECT 0.0 + (2 + 3)")).expressions[0]
|
||||
|
||||
self.assertEqual(expression.type, exp.DataType.Type.DOUBLE)
|
||||
self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE)
|
||||
|
@ -345,3 +395,124 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
self.assertEqual(expression.right.this.type, exp.DataType.Type.INT)
|
||||
self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT)
|
||||
self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT)
|
||||
|
||||
def test_derived_tables_column_annotation(self):
|
||||
schema = {"x": {"cola": "INT"}, "y": {"cola": "FLOAT"}}
|
||||
sql = """
|
||||
SELECT a.cola AS cola
|
||||
FROM (
|
||||
SELECT x.cola + y.cola AS cola
|
||||
FROM (
|
||||
SELECT x.cola AS cola
|
||||
FROM x AS x
|
||||
) AS x
|
||||
JOIN (
|
||||
SELECT y.cola AS cola
|
||||
FROM y AS y
|
||||
) AS y
|
||||
) AS a
|
||||
"""
|
||||
|
||||
expression = annotate_types(parse_one(sql), schema=schema)
|
||||
self.assertEqual(expression.expressions[0].type, exp.DataType.Type.FLOAT) # a.cola AS cola
|
||||
|
||||
addition_alias = expression.args["from"].expressions[0].this.expressions[0]
|
||||
self.assertEqual(addition_alias.type, exp.DataType.Type.FLOAT) # x.cola + y.cola AS cola
|
||||
|
||||
addition = addition_alias.this
|
||||
self.assertEqual(addition.type, exp.DataType.Type.FLOAT)
|
||||
self.assertEqual(addition.this.type, exp.DataType.Type.INT)
|
||||
self.assertEqual(addition.expression.type, exp.DataType.Type.FLOAT)
|
||||
|
||||
def test_cte_column_annotation(self):
|
||||
schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT"}}
|
||||
sql = """
|
||||
WITH tbl AS (
|
||||
SELECT x.cola + 'bla' AS cola, y.colb AS colb
|
||||
FROM (
|
||||
SELECT x.cola AS cola
|
||||
FROM x AS x
|
||||
) AS x
|
||||
JOIN (
|
||||
SELECT y.colb AS colb
|
||||
FROM y AS y
|
||||
) AS y
|
||||
)
|
||||
SELECT tbl.cola + tbl.colb + 'foo' AS col
|
||||
FROM tbl AS tbl
|
||||
"""
|
||||
|
||||
expression = annotate_types(parse_one(sql), schema=schema)
|
||||
self.assertEqual(expression.expressions[0].type, exp.DataType.Type.TEXT) # tbl.cola + tbl.colb + 'foo' AS col
|
||||
|
||||
outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo'
|
||||
self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT)
|
||||
self.assertEqual(outer_addition.left.type, exp.DataType.Type.TEXT)
|
||||
self.assertEqual(outer_addition.right.type, exp.DataType.Type.VARCHAR)
|
||||
|
||||
inner_addition = expression.expressions[0].this.left # tbl.cola + tbl.colb
|
||||
self.assertEqual(inner_addition.left.type, exp.DataType.Type.VARCHAR)
|
||||
self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT)
|
||||
|
||||
cte_select = expression.args["with"].expressions[0].this
|
||||
self.assertEqual(cte_select.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola + 'bla' AS cola
|
||||
self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb
|
||||
|
||||
cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla'
|
||||
self.assertEqual(cte_select_addition.type, exp.DataType.Type.VARCHAR)
|
||||
self.assertEqual(cte_select_addition.left.type, exp.DataType.Type.CHAR)
|
||||
self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR)
|
||||
|
||||
# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
|
||||
for d, t in zip(cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]):
|
||||
self.assertEqual(d.this.expressions[0].this.type, t)
|
||||
|
||||
def test_function_annotation(self):
|
||||
schema = {"x": {"cola": "VARCHAR", "colb": "CHAR"}}
|
||||
sql = "SELECT x.cola || TRIM(x.colb) AS col FROM x AS x"
|
||||
|
||||
concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
|
||||
self.assertEqual(concat_expr_alias.type, exp.DataType.Type.VARCHAR)
|
||||
|
||||
concat_expr = concat_expr_alias.this
|
||||
self.assertEqual(concat_expr.type, exp.DataType.Type.VARCHAR)
|
||||
self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
|
||||
self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
|
||||
self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb
|
||||
|
||||
def test_unknown_annotation(self):
|
||||
schema = {"x": {"cola": "VARCHAR"}}
|
||||
sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x"
|
||||
|
||||
concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
|
||||
self.assertEqual(concat_expr_alias.type, exp.DataType.Type.UNKNOWN)
|
||||
|
||||
concat_expr = concat_expr_alias.this
|
||||
self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN)
|
||||
self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
|
||||
self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) # SOME_ANONYMOUS_FUNC(x.cola)
|
||||
self.assertEqual(concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola (arg)
|
||||
|
||||
def test_null_annotation(self):
|
||||
expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this
|
||||
self.assertEqual(expression.left.type, exp.DataType.Type.NULL)
|
||||
self.assertEqual(expression.right.type, exp.DataType.Type.INT)
|
||||
|
||||
# NULL <op> UNKNOWN should yield NULL
|
||||
sql = "SELECT NULL || SOME_ANONYMOUS_FUNC() AS result"
|
||||
|
||||
concat_expr_alias = annotate_types(parse_one(sql)).expressions[0]
|
||||
self.assertEqual(concat_expr_alias.type, exp.DataType.Type.NULL)
|
||||
|
||||
concat_expr = concat_expr_alias.this
|
||||
self.assertEqual(concat_expr.type, exp.DataType.Type.NULL)
|
||||
self.assertEqual(concat_expr.left.type, exp.DataType.Type.NULL)
|
||||
self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN)
|
||||
|
||||
def test_nullable_annotation(self):
|
||||
nullable = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
|
||||
expression = annotate_types(parse_one("NULL AND FALSE"))
|
||||
|
||||
self.assertEqual(expression.type, nullable)
|
||||
self.assertEqual(expression.left.type, exp.DataType.Type.NULL)
|
||||
self.assertEqual(expression.right.type, exp.DataType.Type.BOOLEAN)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue