Merging upstream version 10.0.8.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
407314e8d2
commit
efc1e37108
67 changed files with 2461 additions and 840 deletions
|
@ -1,12 +1,15 @@
|
|||
import unittest
|
||||
from datetime import date
|
||||
|
||||
import duckdb
|
||||
import pandas as pd
|
||||
from pandas.testing import assert_frame_equal
|
||||
|
||||
from sqlglot import exp, parse_one
|
||||
from sqlglot.errors import ExecuteError
|
||||
from sqlglot.executor import execute
|
||||
from sqlglot.executor.python import Python
|
||||
from sqlglot.executor.table import Table, ensure_tables
|
||||
from tests.helpers import (
|
||||
FIXTURES_DIR,
|
||||
SKIP_INTEGRATION,
|
||||
|
@ -67,13 +70,399 @@ class TestExecutor(unittest.TestCase):
|
|||
def to_csv(expression):
|
||||
if isinstance(expression, exp.Table):
|
||||
return parse_one(
|
||||
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}"
|
||||
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}"
|
||||
)
|
||||
return expression
|
||||
|
||||
for sql, _ in self.sqls[0:3]:
|
||||
a = self.cached_execute(sql)
|
||||
sql = parse_one(sql).transform(to_csv).sql(pretty=True)
|
||||
table = execute(sql, TPCH_SCHEMA)
|
||||
b = pd.DataFrame(table.rows, columns=table.columns)
|
||||
assert_frame_equal(a, b, check_dtype=False)
|
||||
for i, (sql, _) in enumerate(self.sqls[0:7]):
|
||||
with self.subTest(f"tpch-h {i + 1}"):
|
||||
a = self.cached_execute(sql)
|
||||
sql = parse_one(sql).transform(to_csv).sql(pretty=True)
|
||||
table = execute(sql, TPCH_SCHEMA)
|
||||
b = pd.DataFrame(table.rows, columns=table.columns)
|
||||
assert_frame_equal(a, b, check_dtype=False)
|
||||
|
||||
def test_execute_callable(self):
|
||||
tables = {
|
||||
"x": [
|
||||
{"a": "a", "b": "d"},
|
||||
{"a": "b", "b": "e"},
|
||||
{"a": "c", "b": "f"},
|
||||
],
|
||||
"y": [
|
||||
{"b": "d", "c": "g"},
|
||||
{"b": "e", "c": "h"},
|
||||
{"b": "f", "c": "i"},
|
||||
],
|
||||
"z": [],
|
||||
}
|
||||
schema = {
|
||||
"x": {
|
||||
"a": "VARCHAR",
|
||||
"b": "VARCHAR",
|
||||
},
|
||||
"y": {
|
||||
"b": "VARCHAR",
|
||||
"c": "VARCHAR",
|
||||
},
|
||||
"z": {"d": "VARCHAR"},
|
||||
}
|
||||
|
||||
for sql, cols, rows in [
|
||||
("SELECT * FROM x", ["a", "b"], [("a", "d"), ("b", "e"), ("c", "f")]),
|
||||
(
|
||||
"SELECT * FROM x JOIN y ON x.b = y.b",
|
||||
["a", "b", "b", "c"],
|
||||
[("a", "d", "d", "g"), ("b", "e", "e", "h"), ("c", "f", "f", "i")],
|
||||
),
|
||||
(
|
||||
"SELECT j.c AS d FROM x AS i JOIN y AS j ON i.b = j.b",
|
||||
["d"],
|
||||
[("g",), ("h",), ("i",)],
|
||||
),
|
||||
(
|
||||
"SELECT CONCAT(x.a, y.c) FROM x JOIN y ON x.b = y.b WHERE y.b = 'e'",
|
||||
["_col_0"],
|
||||
[("bh",)],
|
||||
),
|
||||
(
|
||||
"SELECT * FROM x JOIN y ON x.b = y.b WHERE y.b = 'e'",
|
||||
["a", "b", "b", "c"],
|
||||
[("b", "e", "e", "h")],
|
||||
),
|
||||
(
|
||||
"SELECT * FROM z",
|
||||
["d"],
|
||||
[],
|
||||
),
|
||||
(
|
||||
"SELECT d FROM z ORDER BY d",
|
||||
["d"],
|
||||
[],
|
||||
),
|
||||
(
|
||||
"SELECT a FROM x WHERE x.a <> 'b'",
|
||||
["a"],
|
||||
[("a",), ("c",)],
|
||||
),
|
||||
(
|
||||
"SELECT a AS i FROM x ORDER BY a",
|
||||
["i"],
|
||||
[("a",), ("b",), ("c",)],
|
||||
),
|
||||
(
|
||||
"SELECT a AS i FROM x ORDER BY i",
|
||||
["i"],
|
||||
[("a",), ("b",), ("c",)],
|
||||
),
|
||||
(
|
||||
"SELECT 100 - ORD(a) AS a, a AS i FROM x ORDER BY a",
|
||||
["a", "i"],
|
||||
[(1, "c"), (2, "b"), (3, "a")],
|
||||
),
|
||||
(
|
||||
"SELECT a /* test */ FROM x LIMIT 1",
|
||||
["a"],
|
||||
[("a",)],
|
||||
),
|
||||
]:
|
||||
with self.subTest(sql):
|
||||
result = execute(sql, schema=schema, tables=tables)
|
||||
self.assertEqual(result.columns, tuple(cols))
|
||||
self.assertEqual(result.rows, rows)
|
||||
|
||||
def test_set_operations(self):
|
||||
tables = {
|
||||
"x": [
|
||||
{"a": "a"},
|
||||
{"a": "b"},
|
||||
{"a": "c"},
|
||||
],
|
||||
"y": [
|
||||
{"a": "b"},
|
||||
{"a": "c"},
|
||||
{"a": "d"},
|
||||
],
|
||||
}
|
||||
schema = {
|
||||
"x": {
|
||||
"a": "VARCHAR",
|
||||
},
|
||||
"y": {
|
||||
"a": "VARCHAR",
|
||||
},
|
||||
}
|
||||
|
||||
for sql, cols, rows in [
|
||||
(
|
||||
"SELECT a FROM x UNION ALL SELECT a FROM y",
|
||||
["a"],
|
||||
[("a",), ("b",), ("c",), ("b",), ("c",), ("d",)],
|
||||
),
|
||||
(
|
||||
"SELECT a FROM x UNION SELECT a FROM y",
|
||||
["a"],
|
||||
[("a",), ("b",), ("c",), ("d",)],
|
||||
),
|
||||
(
|
||||
"SELECT a FROM x EXCEPT SELECT a FROM y",
|
||||
["a"],
|
||||
[("a",)],
|
||||
),
|
||||
(
|
||||
"SELECT a FROM x INTERSECT SELECT a FROM y",
|
||||
["a"],
|
||||
[("b",), ("c",)],
|
||||
),
|
||||
(
|
||||
"""SELECT i.a
|
||||
FROM (
|
||||
SELECT a FROM x UNION SELECT a FROM y
|
||||
) AS i
|
||||
JOIN (
|
||||
SELECT a FROM x UNION SELECT a FROM y
|
||||
) AS j
|
||||
ON i.a = j.a""",
|
||||
["a"],
|
||||
[("a",), ("b",), ("c",), ("d",)],
|
||||
),
|
||||
(
|
||||
"SELECT 1 AS a UNION SELECT 2 AS a UNION SELECT 3 AS a",
|
||||
["a"],
|
||||
[(1,), (2,), (3,)],
|
||||
),
|
||||
]:
|
||||
with self.subTest(sql):
|
||||
result = execute(sql, schema=schema, tables=tables)
|
||||
self.assertEqual(result.columns, tuple(cols))
|
||||
self.assertEqual(set(result.rows), set(rows))
|
||||
|
||||
def test_execute_catalog_db_table(self):
|
||||
tables = {
|
||||
"catalog": {
|
||||
"db": {
|
||||
"x": [
|
||||
{"a": "a"},
|
||||
{"a": "b"},
|
||||
{"a": "c"},
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
schema = {
|
||||
"catalog": {
|
||||
"db": {
|
||||
"x": {
|
||||
"a": "VARCHAR",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result1 = execute("SELECT * FROM x", schema=schema, tables=tables)
|
||||
result2 = execute("SELECT * FROM catalog.db.x", schema=schema, tables=tables)
|
||||
assert result1.columns == result2.columns
|
||||
assert result1.rows == result2.rows
|
||||
|
||||
def test_execute_tables(self):
|
||||
tables = {
|
||||
"sushi": [
|
||||
{"id": 1, "price": 1.0},
|
||||
{"id": 2, "price": 2.0},
|
||||
{"id": 3, "price": 3.0},
|
||||
],
|
||||
"order_items": [
|
||||
{"sushi_id": 1, "order_id": 1},
|
||||
{"sushi_id": 1, "order_id": 1},
|
||||
{"sushi_id": 2, "order_id": 1},
|
||||
{"sushi_id": 3, "order_id": 2},
|
||||
],
|
||||
"orders": [
|
||||
{"id": 1, "user_id": 1},
|
||||
{"id": 2, "user_id": 2},
|
||||
],
|
||||
}
|
||||
|
||||
self.assertEqual(
|
||||
execute(
|
||||
"""
|
||||
SELECT
|
||||
o.user_id,
|
||||
SUM(s.price) AS price
|
||||
FROM orders o
|
||||
JOIN order_items i
|
||||
ON o.id = i.order_id
|
||||
JOIN sushi s
|
||||
ON i.sushi_id = s.id
|
||||
GROUP BY o.user_id
|
||||
""",
|
||||
tables=tables,
|
||||
).rows,
|
||||
[
|
||||
(1, 4.0),
|
||||
(2, 3.0),
|
||||
],
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
execute(
|
||||
"""
|
||||
SELECT
|
||||
o.id, x.*
|
||||
FROM orders o
|
||||
LEFT JOIN (
|
||||
SELECT
|
||||
1 AS id, 'b' AS x
|
||||
UNION ALL
|
||||
SELECT
|
||||
3 AS id, 'c' AS x
|
||||
) x
|
||||
ON o.id = x.id
|
||||
""",
|
||||
tables=tables,
|
||||
).rows,
|
||||
[(1, 1, "b"), (2, None, None)],
|
||||
)
|
||||
self.assertEqual(
|
||||
execute(
|
||||
"""
|
||||
SELECT
|
||||
o.id, x.*
|
||||
FROM orders o
|
||||
RIGHT JOIN (
|
||||
SELECT
|
||||
1 AS id,
|
||||
'b' AS x
|
||||
UNION ALL
|
||||
SELECT
|
||||
3 AS id, 'c' AS x
|
||||
) x
|
||||
ON o.id = x.id
|
||||
""",
|
||||
tables=tables,
|
||||
).rows,
|
||||
[
|
||||
(1, 1, "b"),
|
||||
(None, 3, "c"),
|
||||
],
|
||||
)
|
||||
|
||||
def test_table_depth_mismatch(self):
|
||||
tables = {"table": []}
|
||||
schema = {"db": {"table": {"col": "VARCHAR"}}}
|
||||
with self.assertRaises(ExecuteError):
|
||||
execute("SELECT * FROM table", schema=schema, tables=tables)
|
||||
|
||||
def test_tables(self):
|
||||
tables = ensure_tables(
|
||||
{
|
||||
"catalog1": {
|
||||
"db1": {
|
||||
"t1": [
|
||||
{"a": 1},
|
||||
],
|
||||
"t2": [
|
||||
{"a": 1},
|
||||
],
|
||||
},
|
||||
"db2": {
|
||||
"t3": [
|
||||
{"a": 1},
|
||||
],
|
||||
"t4": [
|
||||
{"a": 1},
|
||||
],
|
||||
},
|
||||
},
|
||||
"catalog2": {
|
||||
"db3": {
|
||||
"t5": Table(columns=("a",), rows=[(1,)]),
|
||||
"t6": Table(columns=("a",), rows=[(1,)]),
|
||||
},
|
||||
"db4": {
|
||||
"t7": Table(columns=("a",), rows=[(1,)]),
|
||||
"t8": Table(columns=("a",), rows=[(1,)]),
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
t1 = tables.find(exp.table_(table="t1", db="db1", catalog="catalog1"))
|
||||
self.assertEqual(t1.columns, ("a",))
|
||||
self.assertEqual(t1.rows, [(1,)])
|
||||
|
||||
t8 = tables.find(exp.table_(table="t8"))
|
||||
self.assertEqual(t1.columns, t8.columns)
|
||||
self.assertEqual(t1.rows, t8.rows)
|
||||
|
||||
def test_static_queries(self):
|
||||
for sql, cols, rows in [
|
||||
("SELECT 1", ["_col_0"], [(1,)]),
|
||||
("SELECT 1 + 2 AS x", ["x"], [(3,)]),
|
||||
("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
|
||||
("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
|
||||
("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]),
|
||||
]:
|
||||
result = execute(sql)
|
||||
self.assertEqual(result.columns, tuple(cols))
|
||||
self.assertEqual(result.rows, rows)
|
||||
|
||||
def test_aggregate_without_group_by(self):
|
||||
result = execute("SELECT SUM(x) FROM t", tables={"t": [{"x": 1}, {"x": 2}]})
|
||||
self.assertEqual(result.columns, ("_col_0",))
|
||||
self.assertEqual(result.rows, [(3,)])
|
||||
|
||||
def test_scalar_functions(self):
|
||||
for sql, expected in [
|
||||
("CONCAT('a', 'b')", "ab"),
|
||||
("CONCAT('a', NULL)", None),
|
||||
("CONCAT_WS('_', 'a', 'b')", "a_b"),
|
||||
("STR_POSITION('bar', 'foobarbar')", 4),
|
||||
("STR_POSITION('bar', 'foobarbar', 5)", 7),
|
||||
("STR_POSITION(NULL, 'foobarbar')", None),
|
||||
("STR_POSITION('bar', NULL)", None),
|
||||
("UPPER('foo')", "FOO"),
|
||||
("UPPER(NULL)", None),
|
||||
("LOWER('FOO')", "foo"),
|
||||
("LOWER(NULL)", None),
|
||||
("IFNULL('a', 'b')", "a"),
|
||||
("IFNULL(NULL, 'b')", "b"),
|
||||
("IFNULL(NULL, NULL)", None),
|
||||
("SUBSTRING('12345')", "12345"),
|
||||
("SUBSTRING('12345', 3)", "345"),
|
||||
("SUBSTRING('12345', 3, 0)", ""),
|
||||
("SUBSTRING('12345', 3, 1)", "3"),
|
||||
("SUBSTRING('12345', 3, 2)", "34"),
|
||||
("SUBSTRING('12345', 3, 3)", "345"),
|
||||
("SUBSTRING('12345', 3, 4)", "345"),
|
||||
("SUBSTRING('12345', -3)", "345"),
|
||||
("SUBSTRING('12345', -3, 0)", ""),
|
||||
("SUBSTRING('12345', -3, 1)", "3"),
|
||||
("SUBSTRING('12345', -3, 2)", "34"),
|
||||
("SUBSTRING('12345', 0)", ""),
|
||||
("SUBSTRING('12345', 0, 1)", ""),
|
||||
("SUBSTRING(NULL)", None),
|
||||
("SUBSTRING(NULL, 1)", None),
|
||||
("CAST(1 AS TEXT)", "1"),
|
||||
("CAST('1' AS LONG)", 1),
|
||||
("CAST('1.1' AS FLOAT)", 1.1),
|
||||
("COALESCE(NULL)", None),
|
||||
("COALESCE(NULL, NULL)", None),
|
||||
("COALESCE(NULL, 'b')", "b"),
|
||||
("COALESCE('a', 'b')", "a"),
|
||||
("1 << 1", 2),
|
||||
("1 >> 1", 0),
|
||||
("1 & 1", 1),
|
||||
("1 | 1", 1),
|
||||
("1 < 1", False),
|
||||
("1 <= 1", True),
|
||||
("1 > 1", False),
|
||||
("1 >= 1", True),
|
||||
("1 + NULL", None),
|
||||
("IF(true, 1, 0)", 1),
|
||||
("IF(false, 1, 0)", 0),
|
||||
("CASE WHEN 0 = 1 THEN 'foo' ELSE 'bar' END", "bar"),
|
||||
("CAST('2022-01-01' AS DATE) + INTERVAL '1' DAY", date(2022, 1, 2)),
|
||||
]:
|
||||
with self.subTest(sql):
|
||||
result = execute(f"SELECT {sql}")
|
||||
self.assertEqual(result.rows, [(expected,)])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue