Merging upstream version 9.0.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
ebb36a5fc5
commit
4483b8ff47
87 changed files with 7994 additions and 421 deletions
|
@ -5,11 +5,11 @@ import duckdb
|
|||
from pandas.testing import assert_frame_equal
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import exp, optimizer, parse_one, table
|
||||
from sqlglot import exp, optimizer, parse_one
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
from sqlglot.optimizer.schema import MappingSchema, ensure_schema
|
||||
from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope
|
||||
from sqlglot.schema import MappingSchema
|
||||
from tests.helpers import (
|
||||
TPCH_SCHEMA,
|
||||
load_sql_fixture_pairs,
|
||||
|
@ -29,19 +29,19 @@ class TestOptimizer(unittest.TestCase):
|
|||
CREATE TABLE x (a INT, b INT);
|
||||
CREATE TABLE y (b INT, c INT);
|
||||
CREATE TABLE z (b INT, c INT);
|
||||
|
||||
|
||||
INSERT INTO x VALUES (1, 1);
|
||||
INSERT INTO x VALUES (2, 2);
|
||||
INSERT INTO x VALUES (2, 2);
|
||||
INSERT INTO x VALUES (3, 3);
|
||||
INSERT INTO x VALUES (null, null);
|
||||
|
||||
|
||||
INSERT INTO y VALUES (2, 2);
|
||||
INSERT INTO y VALUES (2, 2);
|
||||
INSERT INTO y VALUES (3, 3);
|
||||
INSERT INTO y VALUES (4, 4);
|
||||
INSERT INTO y VALUES (null, null);
|
||||
|
||||
|
||||
INSERT INTO y VALUES (3, 3);
|
||||
INSERT INTO y VALUES (3, 3);
|
||||
INSERT INTO y VALUES (4, 4);
|
||||
|
@ -80,8 +80,8 @@ class TestOptimizer(unittest.TestCase):
|
|||
|
||||
with self.subTest(title):
|
||||
self.assertEqual(
|
||||
optimized.sql(pretty=pretty, dialect=dialect),
|
||||
expected,
|
||||
optimized.sql(pretty=pretty, dialect=dialect),
|
||||
)
|
||||
|
||||
should_execute = meta.get("execute")
|
||||
|
@ -223,85 +223,6 @@ class TestOptimizer(unittest.TestCase):
|
|||
def test_tpch(self):
|
||||
self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True)
|
||||
|
||||
def test_schema(self):
|
||||
schema = ensure_schema(
|
||||
{
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
}
|
||||
}
|
||||
)
|
||||
self.assertEqual(
|
||||
schema.column_names(
|
||||
table(
|
||||
"x",
|
||||
)
|
||||
),
|
||||
["a"],
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db", catalog="c"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x2"))
|
||||
|
||||
schema = ensure_schema(
|
||||
{
|
||||
"db": {
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
self.assertEqual(schema.column_names(table("x", db="db")), ["a"])
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db", catalog="c"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db2"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x2", db="db"))
|
||||
|
||||
schema = ensure_schema(
|
||||
{
|
||||
"c": {
|
||||
"db": {
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"])
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db", catalog="c2"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db2"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x2", db="db"))
|
||||
|
||||
schema = ensure_schema(
|
||||
MappingSchema(
|
||||
{
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
self.assertEqual(schema.column_names(table("x")), ["a"])
|
||||
|
||||
with self.assertRaises(OptimizeError):
|
||||
ensure_schema({})
|
||||
|
||||
def test_file_schema(self):
|
||||
expression = parse_one(
|
||||
"""
|
||||
|
@ -327,6 +248,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
SELECT x.b FROM x
|
||||
), r AS (
|
||||
SELECT y.b FROM y
|
||||
), z as (
|
||||
SELECT cola, colb FROM (VALUES(1, 'test')) AS tab(cola, colb)
|
||||
)
|
||||
SELECT
|
||||
r.b,
|
||||
|
@ -340,19 +263,23 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
"""
|
||||
expression = parse_one(sql)
|
||||
for scopes in traverse_scope(expression), list(build_scope(expression).traverse()):
|
||||
self.assertEqual(len(scopes), 5)
|
||||
self.assertEqual(len(scopes), 7)
|
||||
self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x")
|
||||
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
|
||||
self.assertEqual(scopes[2].expression.sql(), "SELECT y.c AS b FROM y")
|
||||
self.assertEqual(scopes[3].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
|
||||
self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql())
|
||||
self.assertEqual(scopes[2].expression.sql(), "(VALUES (1, 'test')) AS tab(cola, colb)")
|
||||
self.assertEqual(
|
||||
scopes[3].expression.sql(), "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)"
|
||||
)
|
||||
self.assertEqual(scopes[4].expression.sql(), "SELECT y.c AS b FROM y")
|
||||
self.assertEqual(scopes[5].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
|
||||
self.assertEqual(scopes[6].expression.sql(), parse_one(sql).sql())
|
||||
|
||||
self.assertEqual(set(scopes[4].sources), {"q", "r", "s"})
|
||||
self.assertEqual(len(scopes[4].columns), 6)
|
||||
self.assertEqual(set(c.table for c in scopes[4].columns), {"r", "s"})
|
||||
self.assertEqual(scopes[4].source_columns("q"), [])
|
||||
self.assertEqual(len(scopes[4].source_columns("r")), 2)
|
||||
self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"})
|
||||
self.assertEqual(set(scopes[6].sources), {"q", "z", "r", "s"})
|
||||
self.assertEqual(len(scopes[6].columns), 6)
|
||||
self.assertEqual(set(c.table for c in scopes[6].columns), {"r", "s"})
|
||||
self.assertEqual(scopes[6].source_columns("q"), [])
|
||||
self.assertEqual(len(scopes[6].source_columns("r")), 2)
|
||||
self.assertEqual(set(c.table for c in scopes[6].source_columns("r")), {"r"})
|
||||
|
||||
self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"})
|
||||
self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue