1
0
Fork 0

Adding upstream version 21.1.2.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:28:14 +01:00
parent 92ffd7746f
commit b01402dc30
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
103 changed files with 18237 additions and 17794 deletions

View file

@ -18,88 +18,6 @@ class TestBigQuery(Validator):
maxDiff = None
def test_bigquery(self):
with self.assertLogs(helper_logger) as cm:
statements = parse(
"""
BEGIN
DECLARE 1;
IF from_date IS NULL THEN SET x = 1;
END IF;
END
""",
read="bigquery",
)
self.assertIn("unsupported syntax", cm.output[0])
for actual, expected in zip(
statements, ("BEGIN DECLARE 1", "IF from_date IS NULL THEN SET x = 1", "END IF", "END")
):
self.assertEqual(actual.sql(dialect="bigquery"), expected)
with self.assertLogs(helper_logger) as cm:
self.validate_identity(
"SELECT * FROM t AS t(c1, c2)",
"SELECT * FROM t AS t",
)
self.assertEqual(
cm.output, ["WARNING:sqlglot:Named columns are not supported in table alias."]
)
with self.assertLogs(helper_logger) as cm:
self.validate_all(
"SELECT a[1], b[OFFSET(1)], c[ORDINAL(1)], d[SAFE_OFFSET(1)], e[SAFE_ORDINAL(1)]",
write={
"duckdb": "SELECT a[2], b[2], c[1], d[2], e[1]",
"bigquery": "SELECT a[1], b[OFFSET(1)], c[ORDINAL(1)], d[SAFE_OFFSET(1)], e[SAFE_ORDINAL(1)]",
"presto": "SELECT a[2], b[2], c[1], ELEMENT_AT(d, 2), ELEMENT_AT(e, 1)",
},
)
self.validate_all(
"a[0]",
read={
"bigquery": "a[0]",
"duckdb": "a[1]",
"presto": "a[1]",
},
)
with self.assertRaises(TokenError):
transpile("'\\'", read="bigquery")
# Reference: https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#set_operators
with self.assertRaises(UnsupportedError):
transpile(
"SELECT * FROM a INTERSECT ALL SELECT * FROM b",
write="bigquery",
unsupported_level=ErrorLevel.RAISE,
)
with self.assertRaises(UnsupportedError):
transpile(
"SELECT * FROM a EXCEPT ALL SELECT * FROM b",
write="bigquery",
unsupported_level=ErrorLevel.RAISE,
)
with self.assertRaises(ParseError):
transpile("SELECT * FROM UNNEST(x) AS x(y)", read="bigquery")
with self.assertRaises(ParseError):
transpile("DATE_ADD(x, day)", read="bigquery")
with self.assertLogs(parser_logger) as cm:
for_in_stmts = parse(
"FOR record IN (SELECT word FROM shakespeare) DO SELECT record.word; END FOR;",
read="bigquery",
)
self.assertEqual(
[s.sql(dialect="bigquery") for s in for_in_stmts],
["FOR record IN (SELECT word FROM shakespeare) DO SELECT record.word", "END FOR"],
)
assert "'END FOR'" in cm.output[0]
self.validate_identity("CREATE SCHEMA x DEFAULT COLLATE 'en'")
self.validate_identity("CREATE TABLE x (y INT64) DEFAULT COLLATE 'en'")
self.validate_identity("PARSE_JSON('{}', wide_number_mode => 'exact')")
@ -1086,6 +1004,127 @@ WHERE
pretty=True,
)
def test_errors(self):
with self.assertRaises(TokenError):
transpile("'\\'", read="bigquery")
# Reference: https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#set_operators
with self.assertRaises(UnsupportedError):
transpile(
"SELECT * FROM a INTERSECT ALL SELECT * FROM b",
write="bigquery",
unsupported_level=ErrorLevel.RAISE,
)
with self.assertRaises(UnsupportedError):
transpile(
"SELECT * FROM a EXCEPT ALL SELECT * FROM b",
write="bigquery",
unsupported_level=ErrorLevel.RAISE,
)
with self.assertRaises(ParseError):
transpile("SELECT * FROM UNNEST(x) AS x(y)", read="bigquery")
with self.assertRaises(ParseError):
transpile("DATE_ADD(x, day)", read="bigquery")
def test_warnings(self):
with self.assertLogs(helper_logger) as cm:
self.validate_identity(
"WITH cte(c) AS (SELECT * FROM t) SELECT * FROM cte",
"WITH cte AS (SELECT * FROM t) SELECT * FROM cte",
)
self.assertIn("Can't push down CTE column names for star queries.", cm.output[0])
self.assertIn("Named columns are not supported in table alias.", cm.output[1])
with self.assertLogs(helper_logger) as cm:
self.validate_identity(
"SELECT * FROM t AS t(c1, c2)",
"SELECT * FROM t AS t",
)
self.assertIn("Named columns are not supported in table alias.", cm.output[0])
with self.assertLogs(helper_logger) as cm:
statements = parse(
"""
BEGIN
DECLARE 1;
IF from_date IS NULL THEN SET x = 1;
END IF;
END
""",
read="bigquery",
)
for actual, expected in zip(
statements,
("BEGIN DECLARE 1", "IF from_date IS NULL THEN SET x = 1", "END IF", "END"),
):
self.assertEqual(actual.sql(dialect="bigquery"), expected)
self.assertIn("unsupported syntax", cm.output[0])
with self.assertLogs(helper_logger) as cm:
statements = parse(
"""
BEGIN CALL `project_id.dataset_id.stored_procedure_id`();
EXCEPTION WHEN ERROR THEN INSERT INTO `project_id.dataset_id.table_id` SELECT @@error.message, CURRENT_TIMESTAMP();
END
""",
read="bigquery",
)
expected_statements = (
"BEGIN CALL `project_id.dataset_id.stored_procedure_id`()",
"EXCEPTION WHEN ERROR THEN INSERT INTO `project_id.dataset_id.table_id` SELECT @@error.message, CURRENT_TIMESTAMP()",
"END",
)
for actual, expected in zip(statements, expected_statements):
self.assertEqual(actual.sql(dialect="bigquery"), expected)
self.assertIn("unsupported syntax", cm.output[0])
with self.assertLogs(helper_logger) as cm:
self.validate_identity(
"SELECT * FROM t AS t(c1, c2)",
"SELECT * FROM t AS t",
)
self.assertIn("Named columns are not supported in table alias.", cm.output[0])
with self.assertLogs(helper_logger):
self.validate_all(
"SELECT a[1], b[OFFSET(1)], c[ORDINAL(1)], d[SAFE_OFFSET(1)], e[SAFE_ORDINAL(1)]",
write={
"duckdb": "SELECT a[2], b[2], c[1], d[2], e[1]",
"bigquery": "SELECT a[1], b[OFFSET(1)], c[ORDINAL(1)], d[SAFE_OFFSET(1)], e[SAFE_ORDINAL(1)]",
"presto": "SELECT a[2], b[2], c[1], ELEMENT_AT(d, 2), ELEMENT_AT(e, 1)",
},
)
self.validate_all(
"a[0]",
read={
"bigquery": "a[0]",
"duckdb": "a[1]",
"presto": "a[1]",
},
)
with self.assertLogs(parser_logger) as cm:
for_in_stmts = parse(
"FOR record IN (SELECT word FROM shakespeare) DO SELECT record.word; END FOR;",
read="bigquery",
)
self.assertEqual(
[s.sql(dialect="bigquery") for s in for_in_stmts],
["FOR record IN (SELECT word FROM shakespeare) DO SELECT record.word", "END FOR"],
)
self.assertIn("'END FOR'", cm.output[0])
def test_user_defined_functions(self):
self.validate_identity(
"CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'"

View file

@ -729,7 +729,7 @@ class TestDialect(Validator):
write={
"duckdb": "TO_TIMESTAMP(x)",
"hive": "FROM_UNIXTIME(x)",
"oracle": "TO_DATE('1970-01-01','YYYY-MM-DD') + (x / 86400)",
"oracle": "TO_DATE('1970-01-01', 'YYYY-MM-DD') + (x / 86400)",
"postgres": "TO_TIMESTAMP(x)",
"presto": "FROM_UNIXTIME(x)",
"starrocks": "FROM_UNIXTIME(x)",
@ -2272,3 +2272,32 @@ SELECT
"tsql": "RAND()",
},
)
def test_array_any(self):
self.validate_all(
"ARRAY_ANY(arr, x -> pred)",
write={
"": "ARRAY_ANY(arr, x -> pred)",
"bigquery": "(ARRAY_LENGTH(arr) = 0 OR ARRAY_LENGTH(ARRAY(SELECT x FROM UNNEST(arr) AS x WHERE pred)) <> 0)",
"clickhouse": "(LENGTH(arr) = 0 OR LENGTH(arrayFilter(x -> pred, arr)) <> 0)",
"databricks": "(SIZE(arr) = 0 OR SIZE(FILTER(arr, x -> pred)) <> 0)",
"doris": UnsupportedError,
"drill": UnsupportedError,
"duckdb": "(ARRAY_LENGTH(arr) = 0 OR ARRAY_LENGTH(LIST_FILTER(arr, x -> pred)) <> 0)",
"hive": UnsupportedError,
"mysql": UnsupportedError,
"oracle": UnsupportedError,
"postgres": "(ARRAY_LENGTH(arr, 1) = 0 OR ARRAY_LENGTH(ARRAY(SELECT x FROM UNNEST(arr) AS _t(x) WHERE pred), 1) <> 0)",
"presto": "ANY_MATCH(arr, x -> pred)",
"redshift": UnsupportedError,
"snowflake": UnsupportedError,
"spark": "(SIZE(arr) = 0 OR SIZE(FILTER(arr, x -> pred)) <> 0)",
"spark2": "(SIZE(arr) = 0 OR SIZE(FILTER(arr, x -> pred)) <> 0)",
"sqlite": UnsupportedError,
"starrocks": UnsupportedError,
"tableau": UnsupportedError,
"teradata": "(CARDINALITY(arr) = 0 OR CARDINALITY(FILTER(arr, x -> pred)) <> 0)",
"trino": "ANY_MATCH(arr, x -> pred)",
"tsql": UnsupportedError,
},
)

View file

@ -290,15 +290,15 @@ class TestPostgres(Validator):
)
self.validate_identity(
"""'{"x": {"y": 1}}'::json->'x'->'y'""",
"""JSON_EXTRACT_PATH(JSON_EXTRACT_PATH(CAST('{"x": {"y": 1}}' AS JSON), 'x'), 'y')""",
"""CAST('{"x": {"y": 1}}' AS JSON) -> 'x' -> 'y'""",
)
self.validate_identity(
"""'[1,2,3]'::json->>2""",
"JSON_EXTRACT_PATH_TEXT(CAST('[1,2,3]' AS JSON), '2')",
"CAST('[1,2,3]' AS JSON) ->> 2",
)
self.validate_identity(
"""'{"a":1,"b":2}'::json->>'b'""",
"""JSON_EXTRACT_PATH_TEXT(CAST('{"a":1,"b":2}' AS JSON), 'b')""",
"""CAST('{"a":1,"b":2}' AS JSON) ->> 'b'""",
)
self.validate_identity(
"""'{"a":[1,2,3],"b":[4,5,6]}'::json#>'{a,2}'""",
@ -310,11 +310,11 @@ class TestPostgres(Validator):
)
self.validate_identity(
"'[1,2,3]'::json->2",
"JSON_EXTRACT_PATH(CAST('[1,2,3]' AS JSON), '2')",
"CAST('[1,2,3]' AS JSON) -> 2",
)
self.validate_identity(
"""SELECT JSON_ARRAY_ELEMENTS((foo->'sections')::JSON) AS sections""",
"""SELECT JSON_ARRAY_ELEMENTS(CAST((JSON_EXTRACT_PATH(foo, 'sections')) AS JSON)) AS sections""",
"""SELECT JSON_ARRAY_ELEMENTS(CAST((foo -> 'sections') AS JSON)) AS sections""",
)
self.validate_identity(
"MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET x.a = y.b WHEN NOT MATCHED THEN INSERT (a, b) VALUES (y.a, y.b)",
@ -357,12 +357,13 @@ class TestPostgres(Validator):
"x -> 'y' -> 0 -> 'z'",
write={
"": "JSON_EXTRACT(JSON_EXTRACT(JSON_EXTRACT(x, '$.y'), '$[0]'), '$.z')",
"postgres": "JSON_EXTRACT_PATH(JSON_EXTRACT_PATH(JSON_EXTRACT_PATH(x, 'y'), '0'), 'z')",
"postgres": "x -> 'y' -> 0 -> 'z'",
},
)
self.validate_all(
"""JSON_EXTRACT_PATH('{"f2":{"f3":1},"f4":{"f5":99,"f6":"foo"}}','f4')""",
write={
"": """JSON_EXTRACT('{"f2":{"f3":1},"f4":{"f5":99,"f6":"foo"}}', '$.f4')""",
"bigquery": """JSON_EXTRACT('{"f2":{"f3":1},"f4":{"f5":99,"f6":"foo"}}', '$.f4')""",
"duckdb": """'{"f2":{"f3":1},"f4":{"f5":99,"f6":"foo"}}' -> '$.f4'""",
"mysql": """JSON_EXTRACT('{"f2":{"f3":1},"f4":{"f5":99,"f6":"foo"}}', '$.f4')""",
@ -580,7 +581,7 @@ class TestPostgres(Validator):
self.validate_all(
"""'{"a":1,"b":2}'::json->'b'""",
write={
"postgres": """JSON_EXTRACT_PATH(CAST('{"a":1,"b":2}' AS JSON), 'b')""",
"postgres": """CAST('{"a":1,"b":2}' AS JSON) -> 'b'""",
"redshift": """JSON_EXTRACT_PATH_TEXT('{"a":1,"b":2}', 'b')""",
},
)

View file

@ -594,9 +594,9 @@ WHERE
self.validate_all(
"SELECT TO_TIMESTAMP(16599817290000, 4)",
write={
"bigquery": "SELECT TIMESTAMP_SECONDS(CAST(16599817290000 / POW(10, 4) AS INT64))",
"bigquery": "SELECT TIMESTAMP_SECONDS(CAST(16599817290000 / POWER(10, 4) AS INT64))",
"snowflake": "SELECT TO_TIMESTAMP(16599817290000, 4)",
"spark": "SELECT TIMESTAMP_SECONDS(16599817290000 / POW(10, 4))",
"spark": "SELECT TIMESTAMP_SECONDS(16599817290000 / POWER(10, 4))",
},
)
self.validate_all(
@ -609,11 +609,11 @@ WHERE
self.validate_all(
"SELECT TO_TIMESTAMP(1659981729000000000, 9)",
write={
"bigquery": "SELECT TIMESTAMP_SECONDS(CAST(1659981729000000000 / POW(10, 9) AS INT64))",
"duckdb": "SELECT TO_TIMESTAMP(1659981729000000000 / POW(10, 9))",
"bigquery": "SELECT TIMESTAMP_SECONDS(CAST(1659981729000000000 / POWER(10, 9) AS INT64))",
"duckdb": "SELECT TO_TIMESTAMP(1659981729000000000 / POWER(10, 9))",
"presto": "SELECT FROM_UNIXTIME(CAST(1659981729000000000 AS DOUBLE) / POW(10, 9))",
"snowflake": "SELECT TO_TIMESTAMP(1659981729000000000, 9)",
"spark": "SELECT TIMESTAMP_SECONDS(1659981729000000000 / POW(10, 9))",
"spark": "SELECT TIMESTAMP_SECONDS(1659981729000000000 / POWER(10, 9))",
},
)
self.validate_all(
@ -1548,6 +1548,17 @@ MATCH_RECOGNIZE (
self.assertTrue(isinstance(users_exp, exp.Show))
self.assertEqual(users_exp.this, "USERS")
def test_storage_integration(self):
self.validate_identity(
"""CREATE STORAGE INTEGRATION s3_int
TYPE=EXTERNAL_STAGE
STORAGE_PROVIDER='S3'
STORAGE_AWS_ROLE_ARN='arn:aws:iam::001234567890:role/myrole'
ENABLED=TRUE
STORAGE_ALLOWED_LOCATIONS=('s3://mybucket1/path1/', 's3://mybucket2/path2/')""",
pretty=True,
)
def test_swap(self):
ast = parse_one("ALTER TABLE a SWAP WITH b", read="snowflake")
assert isinstance(ast, exp.AlterTable)

View file

@ -920,28 +920,48 @@ WHERE
self.assertEqual(expr.sql(dialect="tsql"), expected_sql)
def test_charindex(self):
self.validate_identity(
"SELECT CAST(SUBSTRING('ABCD~1234', CHARINDEX('~', 'ABCD~1234') + 1, LEN('ABCD~1234')) AS BIGINT)"
)
self.validate_all(
"CHARINDEX(x, y, 9)",
read={
"spark": "LOCATE(x, y, 9)",
},
write={
"spark": "LOCATE(x, y, 9)",
"tsql": "CHARINDEX(x, y, 9)",
},
)
self.validate_all(
"CHARINDEX(x, y)",
read={
"spark": "LOCATE(x, y)",
},
write={
"spark": "LOCATE(x, y)",
"tsql": "CHARINDEX(x, y)",
},
)
self.validate_all(
"CHARINDEX('sub', 'testsubstring', 3)",
read={
"spark": "LOCATE('sub', 'testsubstring', 3)",
},
write={
"spark": "LOCATE('sub', 'testsubstring', 3)",
"tsql": "CHARINDEX('sub', 'testsubstring', 3)",
},
)
self.validate_all(
"CHARINDEX('sub', 'testsubstring')",
read={
"spark": "LOCATE('sub', 'testsubstring')",
},
write={
"spark": "LOCATE('sub', 'testsubstring')",
"tsql": "CHARINDEX('sub', 'testsubstring')",
},
)

View file

@ -1311,3 +1311,79 @@ LEFT JOIN "_u_0" AS "_u_0"
ON "C"."EMAIL_DOMAIN" = "_u_0"."DOMAIN"
WHERE
NOT "_u_0"."DOMAIN" IS NULL;
# title: decorrelate subquery and transpile ArrayAny correctly when generating spark
# execute: false
# dialect: spark
SELECT
COUNT(DISTINCT cs1.cs_order_number) AS `order count`,
SUM(cs1.cs_ext_ship_cost) AS `total shipping cost`,
SUM(cs1.cs_net_profit) AS `total net profit`
FROM catalog_sales cs1, date_dim, customer_address, call_center
WHERE
date_dim.d_date BETWEEN '2002-02-01' AND (CAST('2002-02-01' AS DATE) + INTERVAL 60 days)
AND cs1.cs_ship_date_sk = date_dim.d_date_sk
AND cs1.cs_ship_addr_sk = customer_address.ca_address_sk
AND customer_address.ca_state = 'GA'
AND cs1.cs_call_center_sk = call_center.cc_call_center_sk
AND call_center.cc_county IN (
'Williamson County', 'Williamson County', 'Williamson County', 'Williamson County', 'Williamson County'
)
AND EXISTS(
SELECT *
FROM catalog_sales cs2
WHERE cs1.cs_order_number = cs2.cs_order_number
AND cs1.cs_warehouse_sk <> cs2.cs_warehouse_sk)
AND NOT EXISTS(
SELECT *
FROM catalog_returns cr1
WHERE cs1.cs_order_number = cr1.cr_order_number
)
ORDER BY COUNT(DISTINCT cs1.cs_order_number
)
LIMIT 100;
WITH `_u_0` AS (
SELECT
`cs2`.`cs_order_number` AS `_u_1`,
COLLECT_LIST(`cs2`.`cs_warehouse_sk`) AS `_u_2`
FROM `catalog_sales` AS `cs2`
GROUP BY
`cs2`.`cs_order_number`
), `_u_3` AS (
SELECT
`cr1`.`cr_order_number` AS `_u_4`
FROM `catalog_returns` AS `cr1`
GROUP BY
`cr1`.`cr_order_number`
)
SELECT
COUNT(DISTINCT `cs1`.`cs_order_number`) AS `order count`,
SUM(`cs1`.`cs_ext_ship_cost`) AS `total shipping cost`,
SUM(`cs1`.`cs_net_profit`) AS `total net profit`
FROM `catalog_sales` AS `cs1`
LEFT JOIN `_u_0` AS `_u_0`
ON `_u_0`.`_u_1` = `cs1`.`cs_order_number`
LEFT JOIN `_u_3` AS `_u_3`
ON `_u_3`.`_u_4` = `cs1`.`cs_order_number`
JOIN `call_center` AS `call_center`
ON `call_center`.`cc_call_center_sk` = `cs1`.`cs_call_center_sk`
AND `call_center`.`cc_county` IN ('Williamson County', 'Williamson County', 'Williamson County', 'Williamson County', 'Williamson County')
JOIN `customer_address` AS `customer_address`
ON `cs1`.`cs_ship_addr_sk` = `customer_address`.`ca_address_sk`
AND `customer_address`.`ca_state` = 'GA'
JOIN `date_dim` AS `date_dim`
ON `cs1`.`cs_ship_date_sk` = `date_dim`.`d_date_sk`
AND `date_dim`.`d_date` <= (
CAST(CAST('2002-02-01' AS DATE) AS TIMESTAMP) + INTERVAL '60' DAYS
)
AND `date_dim`.`d_date` >= '2002-02-01'
WHERE
`_u_3`.`_u_4` IS NULL
AND NOT `_u_0`.`_u_1` IS NULL
AND (
SIZE(`_u_0`.`_u_2`) = 0
OR SIZE(FILTER(`_u_0`.`_u_2`, `_x` -> `cs1`.`cs_warehouse_sk` <> `_x`)) <> 0
)
ORDER BY
COUNT(DISTINCT `cs1`.`cs_order_number`)
LIMIT 100;

View file

@ -354,10 +354,17 @@ SELECT x.b AS b, y.b AS b, y.c AS c FROM x AS x, y AS y;
SELECT * EXCEPT(a) FROM x;
SELECT x.b AS b FROM x AS x;
# execute: false
SELECT * EXCEPT(x.a) FROM x AS x;
SELECT x.b AS b FROM x AS x;
# execute: false
# note: this query would fail in the engine level because there are 0 selected columns
SELECT * EXCEPT (a, b) FROM x;
SELECT * EXCEPT (x.a, x.b) FROM x AS x;
SELECT * EXCEPT (a, b) FROM x AS x;
SELECT x.a, * EXCEPT (a) FROM x AS x LEFT JOIN x AS y USING (a);
SELECT x.a AS a, x.b AS b, y.b AS b FROM x AS x LEFT JOIN x AS y ON x.a = y.a;
SELECT COALESCE(CAST(t1.a AS VARCHAR), '') AS a, t2.* EXCEPT (a) FROM x AS t1, x AS t2;
SELECT COALESCE(CAST(t1.a AS VARCHAR), '') AS a, t2.b AS b FROM x AS t1, x AS t2;

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -166,7 +166,7 @@ WHERE
AND NOT x.a = _u_9.a
AND ARRAY_ANY(_u_10.a, _x -> _x = x.a)
AND (
x.a < _u_12.a AND ARRAY_ANY(_u_12._u_14, "_x" -> _x <> x.d)
x.a < _u_12.a AND ARRAY_ANY(_u_12._u_14, _x -> _x <> x.d)
)
AND NOT _u_15.a IS NULL
AND x.a IN (

View file

@ -8,7 +8,7 @@ import numpy as np
import pandas as pd
from pandas.testing import assert_frame_equal
from sqlglot import exp, parse_one
from sqlglot import exp, parse_one, transpile
from sqlglot.errors import ExecuteError
from sqlglot.executor import execute
from sqlglot.executor.python import Python
@ -50,7 +50,7 @@ class TestExecutor(unittest.TestCase):
def cached_execute(self, sql):
if sql not in self.cache:
self.cache[sql] = self.conn.execute(sql).fetchdf()
self.cache[sql] = self.conn.execute(transpile(sql, write="duckdb")[0]).fetchdf()
return self.cache[sql]
def rename_anonymous(self, source, target):
@ -66,10 +66,10 @@ class TestExecutor(unittest.TestCase):
self.assertEqual(generate(parse_one("x is null")), "scope[None][x] is None")
def test_optimized_tpch(self):
for i, (sql, optimized) in enumerate(self.sqls[:20], start=1):
for i, (sql, optimized) in enumerate(self.sqls, start=1):
with self.subTest(f"{i}, {sql}"):
a = self.cached_execute(sql)
b = self.conn.execute(optimized).fetchdf()
b = self.conn.execute(transpile(optimized, write="duckdb")[0]).fetchdf()
self.rename_anonymous(b, a)
assert_frame_equal(a, b)

View file

@ -156,7 +156,7 @@ class TestOptimizer(unittest.TestCase):
df1 = self.conn.execute(
sqlglot.transpile(sql, read=dialect, write="duckdb")[0]
).df()
df2 = self.conn.execute(optimized.sql(pretty=pretty, dialect="duckdb")).df()
df2 = self.conn.execute(optimized.sql(dialect="duckdb")).df()
assert_frame_equal(df1, df2)
@patch("sqlglot.generator.logger")