Merging upstream version 11.4.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
ecb42ec17f
commit
63746a3e92
89 changed files with 35352 additions and 33081 deletions
|
@ -8,6 +8,9 @@ class TestBigQuery(Validator):
|
|||
def test_bigquery(self):
|
||||
self.validate_identity("SELECT STRUCT<ARRAY<STRING>>(['2023-01-17'])")
|
||||
self.validate_identity("SELECT * FROM q UNPIVOT(values FOR quarter IN (b, c))")
|
||||
self.validate_identity(
|
||||
"SELECT * FROM (SELECT * FROM `t`) AS a UNPIVOT((c) FOR c_name IN (v1, v2))"
|
||||
)
|
||||
|
||||
self.validate_all("LEAST(x, y)", read={"sqlite": "MIN(x, y)"})
|
||||
self.validate_all(
|
||||
|
@ -280,7 +283,7 @@ class TestBigQuery(Validator):
|
|||
"duckdb": "CURRENT_DATE + INTERVAL 1 DAY",
|
||||
"mysql": "DATE_ADD(CURRENT_DATE, INTERVAL 1 DAY)",
|
||||
"postgres": "CURRENT_DATE + INTERVAL '1' DAY",
|
||||
"presto": "DATE_ADD(DAY, 1, CURRENT_DATE)",
|
||||
"presto": "DATE_ADD('DAY', 1, CURRENT_DATE)",
|
||||
"hive": "DATE_ADD(CURRENT_DATE, 1)",
|
||||
"spark": "DATE_ADD(CURRENT_DATE, 1)",
|
||||
},
|
||||
|
|
|
@ -16,6 +16,10 @@ class TestClickhouse(Validator):
|
|||
self.validate_identity("SELECT * FROM foo ANY JOIN bla")
|
||||
self.validate_identity("SELECT quantile(0.5)(a)")
|
||||
self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t")
|
||||
self.validate_identity("SELECT quantiles(0.1, 0.2, 0.3)(a)")
|
||||
self.validate_identity("SELECT histogram(5)(a)")
|
||||
self.validate_identity("SELECT groupUniqArray(2)(a)")
|
||||
self.validate_identity("SELECT exponentialTimeDecayedAvg(60)(a, b)")
|
||||
self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)")
|
||||
self.validate_identity("position(haystack, needle)")
|
||||
self.validate_identity("position(haystack, needle, position)")
|
||||
|
|
|
@ -519,7 +519,7 @@ class TestDialect(Validator):
|
|||
"duckdb": "x + INTERVAL 1 day",
|
||||
"hive": "DATE_ADD(x, 1)",
|
||||
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
|
||||
"postgres": "x + INTERVAL '1' 'day'",
|
||||
"postgres": "x + INTERVAL '1' day",
|
||||
"presto": "DATE_ADD('day', 1, x)",
|
||||
"snowflake": "DATEADD(day, 1, x)",
|
||||
"spark": "DATE_ADD(x, 1)",
|
||||
|
@ -543,11 +543,48 @@ class TestDialect(Validator):
|
|||
)
|
||||
self.validate_all(
|
||||
"DATE_TRUNC('day', x)",
|
||||
read={
|
||||
"bigquery": "DATE_TRUNC(x, day)",
|
||||
"duckdb": "DATE_TRUNC('day', x)",
|
||||
"spark": "TRUNC(x, 'day')",
|
||||
},
|
||||
write={
|
||||
"bigquery": "DATE_TRUNC(x, day)",
|
||||
"duckdb": "DATE_TRUNC('day', x)",
|
||||
"mysql": "DATE(x)",
|
||||
"presto": "DATE_TRUNC('day', x)",
|
||||
"postgres": "DATE_TRUNC('day', x)",
|
||||
"snowflake": "DATE_TRUNC('day', x)",
|
||||
"starrocks": "DATE_TRUNC('day', x)",
|
||||
"spark": "TRUNC(x, 'day')",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"TIMESTAMP_TRUNC(x, day)",
|
||||
read={
|
||||
"bigquery": "TIMESTAMP_TRUNC(x, day)",
|
||||
"presto": "DATE_TRUNC('day', x)",
|
||||
"postgres": "DATE_TRUNC('day', x)",
|
||||
"snowflake": "DATE_TRUNC('day', x)",
|
||||
"starrocks": "DATE_TRUNC('day', x)",
|
||||
"spark": "DATE_TRUNC('day', x)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"DATE_TRUNC('day', CAST(x AS DATE))",
|
||||
read={
|
||||
"presto": "DATE_TRUNC('day', x::DATE)",
|
||||
"snowflake": "DATE_TRUNC('day', x::DATE)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"TIMESTAMP_TRUNC(CAST(x AS DATE), day)",
|
||||
read={
|
||||
"postgres": "DATE_TRUNC('day', x::DATE)",
|
||||
"starrocks": "DATE_TRUNC('day', x::DATE)",
|
||||
},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"DATE_TRUNC('week', x)",
|
||||
write={
|
||||
|
@ -582,8 +619,6 @@ class TestDialect(Validator):
|
|||
"DATE_TRUNC('year', x)",
|
||||
read={
|
||||
"bigquery": "DATE_TRUNC(x, year)",
|
||||
"snowflake": "DATE_TRUNC(year, x)",
|
||||
"starrocks": "DATE_TRUNC('year', x)",
|
||||
"spark": "TRUNC(x, 'year')",
|
||||
},
|
||||
write={
|
||||
|
@ -599,7 +634,10 @@ class TestDialect(Validator):
|
|||
"TIMESTAMP_TRUNC(x, year)",
|
||||
read={
|
||||
"bigquery": "TIMESTAMP_TRUNC(x, year)",
|
||||
"postgres": "DATE_TRUNC(year, x)",
|
||||
"spark": "DATE_TRUNC('year', x)",
|
||||
"snowflake": "DATE_TRUNC(year, x)",
|
||||
"starrocks": "DATE_TRUNC('year', x)",
|
||||
},
|
||||
write={
|
||||
"bigquery": "TIMESTAMP_TRUNC(x, year)",
|
||||
|
@ -752,7 +790,6 @@ class TestDialect(Validator):
|
|||
"trino": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
"duckdb": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
"hive": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
"presto": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
"spark": "AGGREGATE(x, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
"presto": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
},
|
||||
|
@ -1455,3 +1492,36 @@ SELECT
|
|||
"postgres": "SUBSTRING('123456' FROM 2 FOR 3)",
|
||||
},
|
||||
)
|
||||
|
||||
def test_count_if(self):
|
||||
self.validate_identity("COUNT_IF(DISTINCT cond)")
|
||||
|
||||
self.validate_all(
|
||||
"SELECT COUNT_IF(cond) FILTER", write={"": "SELECT COUNT_IF(cond) AS FILTER"}
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT COUNT_IF(col % 2 = 0) FROM foo",
|
||||
write={
|
||||
"": "SELECT COUNT_IF(col % 2 = 0) FROM foo",
|
||||
"databricks": "SELECT COUNT_IF(col % 2 = 0) FROM foo",
|
||||
"presto": "SELECT COUNT_IF(col % 2 = 0) FROM foo",
|
||||
"snowflake": "SELECT COUNT_IF(col % 2 = 0) FROM foo",
|
||||
"sqlite": "SELECT SUM(CASE WHEN col % 2 = 0 THEN 1 ELSE 0 END) FROM foo",
|
||||
"tsql": "SELECT COUNT_IF(col % 2 = 0) FROM foo",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo",
|
||||
read={
|
||||
"": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo",
|
||||
"databricks": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo",
|
||||
"tsql": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo",
|
||||
},
|
||||
write={
|
||||
"": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo",
|
||||
"databricks": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo",
|
||||
"presto": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo",
|
||||
"sqlite": "SELECT SUM(CASE WHEN col % 2 = 0 THEN 1 ELSE 0 END) FILTER(WHERE col < 1000) FROM foo",
|
||||
"tsql": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo",
|
||||
},
|
||||
)
|
||||
|
|
|
@ -4,6 +4,12 @@ from tests.dialects.test_dialect import Validator
|
|||
class TestDrill(Validator):
|
||||
dialect = "drill"
|
||||
|
||||
def test_drill(self):
|
||||
self.validate_all(
|
||||
"DATE_FORMAT(a, 'yyyy')",
|
||||
write={"drill": "TO_CHAR(a, 'yyyy')"},
|
||||
)
|
||||
|
||||
def test_string_literals(self):
|
||||
self.validate_all(
|
||||
"SELECT '2021-01-01' + INTERVAL 1 MONTH",
|
||||
|
|
|
@ -125,6 +125,11 @@ class TestDuckDB(Validator):
|
|||
"SELECT a['x space'] FROM (SELECT {'x space': 1, 'y': 2, 'z': 3} AS a)"
|
||||
)
|
||||
|
||||
self.validate_all("SELECT * FROM 'x.y'", write={"duckdb": 'SELECT * FROM "x.y"'})
|
||||
self.validate_all(
|
||||
"WITH 'x' AS (SELECT 1) SELECT * FROM x",
|
||||
write={"duckdb": 'WITH "x" AS (SELECT 1) SELECT * FROM x'},
|
||||
)
|
||||
self.validate_all(
|
||||
"CREATE TABLE IF NOT EXISTS table (cola INT, colb STRING) USING ICEBERG PARTITIONED BY (colb)",
|
||||
write={
|
||||
|
|
|
@ -63,8 +63,8 @@ class TestPresto(Validator):
|
|||
"bigquery": "CAST(x AS TIMESTAMPTZ)",
|
||||
"duckdb": "CAST(x AS TIMESTAMPTZ(9))",
|
||||
"presto": "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)",
|
||||
"hive": "CAST(x AS TIMESTAMPTZ)",
|
||||
"spark": "CAST(x AS TIMESTAMPTZ)",
|
||||
"hive": "CAST(x AS TIMESTAMP)",
|
||||
"spark": "CAST(x AS TIMESTAMP)",
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -189,34 +189,38 @@ class TestPresto(Validator):
|
|||
)
|
||||
|
||||
self.validate_all(
|
||||
"DAY_OF_WEEK(timestamp '2012-08-08 01:00')",
|
||||
"DAY_OF_WEEK(timestamp '2012-08-08 01:00:00')",
|
||||
write={
|
||||
"spark": "DAYOFWEEK(CAST('2012-08-08 01:00' AS TIMESTAMP))",
|
||||
"presto": "DAY_OF_WEEK(CAST('2012-08-08 01:00' AS TIMESTAMP))",
|
||||
"spark": "DAYOFWEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
|
||||
"presto": "DAY_OF_WEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
|
||||
"duckdb": "DAYOFWEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
|
||||
},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"DAY_OF_MONTH(timestamp '2012-08-08 01:00')",
|
||||
"DAY_OF_MONTH(timestamp '2012-08-08 01:00:00')",
|
||||
write={
|
||||
"spark": "DAYOFMONTH(CAST('2012-08-08 01:00' AS TIMESTAMP))",
|
||||
"presto": "DAY_OF_MONTH(CAST('2012-08-08 01:00' AS TIMESTAMP))",
|
||||
"spark": "DAYOFMONTH(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
|
||||
"presto": "DAY_OF_MONTH(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
|
||||
"duckdb": "DAYOFMONTH(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
|
||||
},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"DAY_OF_YEAR(timestamp '2012-08-08 01:00')",
|
||||
"DAY_OF_YEAR(timestamp '2012-08-08 01:00:00')",
|
||||
write={
|
||||
"spark": "DAYOFYEAR(CAST('2012-08-08 01:00' AS TIMESTAMP))",
|
||||
"presto": "DAY_OF_YEAR(CAST('2012-08-08 01:00' AS TIMESTAMP))",
|
||||
"spark": "DAYOFYEAR(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
|
||||
"presto": "DAY_OF_YEAR(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
|
||||
"duckdb": "DAYOFYEAR(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
|
||||
},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"WEEK_OF_YEAR(timestamp '2012-08-08 01:00')",
|
||||
"WEEK_OF_YEAR(timestamp '2012-08-08 01:00:00')",
|
||||
write={
|
||||
"spark": "WEEKOFYEAR(CAST('2012-08-08 01:00' AS TIMESTAMP))",
|
||||
"presto": "WEEK_OF_YEAR(CAST('2012-08-08 01:00' AS TIMESTAMP))",
|
||||
"spark": "WEEKOFYEAR(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
|
||||
"presto": "WEEK_OF_YEAR(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
|
||||
"duckdb": "WEEKOFYEAR(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -365,6 +369,15 @@ class TestPresto(Validator):
|
|||
self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ")
|
||||
self.validate_identity("APPROX_PERCENTILE(a, b, c, d)")
|
||||
|
||||
self.validate_all(
|
||||
"ARRAY_AGG(x ORDER BY y DESC)",
|
||||
write={
|
||||
"hive": "COLLECT_LIST(x)",
|
||||
"presto": "ARRAY_AGG(x ORDER BY y DESC)",
|
||||
"spark": "COLLECT_LIST(x)",
|
||||
"trino": "ARRAY_AGG(x ORDER BY y DESC)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT a FROM t GROUP BY a, ROLLUP(b), ROLLUP(c), ROLLUP(d)",
|
||||
write={
|
||||
|
|
|
@ -6,6 +6,9 @@ class TestRedshift(Validator):
|
|||
|
||||
def test_redshift(self):
|
||||
self.validate_all("CONVERT(INTEGER, x)", write={"redshift": "CAST(x AS INTEGER)"})
|
||||
self.validate_all(
|
||||
"DATEADD('day', ndays, caldate)", write={"redshift": "DATEADD(day, ndays, caldate)"}
|
||||
)
|
||||
self.validate_all(
|
||||
'create table "group" ("col" char(10))',
|
||||
write={
|
||||
|
@ -80,10 +83,10 @@ class TestRedshift(Validator):
|
|||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"DATEDIFF(d, a, b)",
|
||||
"DATEDIFF('day', a, b)",
|
||||
write={
|
||||
"redshift": "DATEDIFF(d, a, b)",
|
||||
"presto": "DATE_DIFF(d, a, b)",
|
||||
"redshift": "DATEDIFF(day, a, b)",
|
||||
"presto": "DATE_DIFF('day', a, b)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
|
|
|
@ -17,6 +17,41 @@ class TestSnowflake(Validator):
|
|||
)
|
||||
self.validate_identity("COMMENT IF EXISTS ON TABLE foo IS 'bar'")
|
||||
|
||||
self.validate_all(
|
||||
"SELECT i, p, o FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1",
|
||||
write={
|
||||
"": "SELECT i, p, o FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o NULLS LAST) = 1",
|
||||
"databricks": "SELECT i, p, o FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o NULLS LAST) = 1",
|
||||
"hive": "SELECT i, p, o FROM (SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o NULLS LAST) AS _w FROM qt) AS _t WHERE _w = 1",
|
||||
"presto": "SELECT i, p, o FROM (SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS _w FROM qt) AS _t WHERE _w = 1",
|
||||
"snowflake": "SELECT i, p, o FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1",
|
||||
"spark": "SELECT i, p, o FROM (SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o NULLS LAST) AS _w FROM qt) AS _t WHERE _w = 1",
|
||||
"sqlite": "SELECT i, p, o FROM (SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o NULLS LAST) AS _w FROM qt) AS _t WHERE _w = 1",
|
||||
"trino": "SELECT i, p, o FROM (SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS _w FROM qt) AS _t WHERE _w = 1",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT BOOLOR_AGG(c1), BOOLOR_AGG(c2) FROM test",
|
||||
write={
|
||||
"": "SELECT LOGICAL_OR(c1), LOGICAL_OR(c2) FROM test",
|
||||
"duckdb": "SELECT BOOL_OR(c1), BOOL_OR(c2) FROM test",
|
||||
"postgres": "SELECT BOOL_OR(c1), BOOL_OR(c2) FROM test",
|
||||
"snowflake": "SELECT BOOLOR_AGG(c1), BOOLOR_AGG(c2) FROM test",
|
||||
"spark": "SELECT BOOL_OR(c1), BOOL_OR(c2) FROM test",
|
||||
"sqlite": "SELECT MAX(c1), MAX(c2) FROM test",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT BOOLAND_AGG(c1), BOOLAND_AGG(c2) FROM test",
|
||||
write={
|
||||
"": "SELECT LOGICAL_AND(c1), LOGICAL_AND(c2) FROM test",
|
||||
"duckdb": "SELECT BOOL_AND(c1), BOOL_AND(c2) FROM test",
|
||||
"postgres": "SELECT BOOL_AND(c1), BOOL_AND(c2) FROM test",
|
||||
"snowflake": "SELECT BOOLAND_AGG(c1), BOOLAND_AGG(c2) FROM test",
|
||||
"spark": "SELECT BOOL_AND(c1), BOOL_AND(c2) FROM test",
|
||||
"sqlite": "SELECT MIN(c1), MIN(c2) FROM test",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"TO_CHAR(x, y)",
|
||||
read={
|
||||
|
|
|
@ -213,6 +213,9 @@ TBLPROPERTIES (
|
|||
self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')")
|
||||
self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')")
|
||||
|
||||
self.validate_all(
|
||||
"CAST(x AS TIMESTAMP)", read={"trino": "CAST(x AS TIMESTAMP(6) WITH TIME ZONE)"}
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT DATE_ADD(my_date_column, 1)",
|
||||
write={
|
||||
|
|
|
@ -56,6 +56,33 @@ class TestSQLite(Validator):
|
|||
)
|
||||
|
||||
def test_sqlite(self):
|
||||
self.validate_all(
|
||||
"CURRENT_DATE",
|
||||
read={
|
||||
"": "CURRENT_DATE",
|
||||
"snowflake": "CURRENT_DATE()",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CURRENT_TIME",
|
||||
read={
|
||||
"": "CURRENT_TIME",
|
||||
"snowflake": "CURRENT_TIME()",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CURRENT_TIMESTAMP",
|
||||
read={
|
||||
"": "CURRENT_TIMESTAMP",
|
||||
"snowflake": "CURRENT_TIMESTAMP()",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT DATE('2020-01-01 16:03:05')",
|
||||
read={
|
||||
"snowflake": "SELECT CAST('2020-01-01 16:03:05' AS DATE)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT CAST([a].[b] AS SMALLINT) FROM foo",
|
||||
write={
|
||||
|
|
6
tests/fixtures/identity.sql
vendored
6
tests/fixtures/identity.sql
vendored
|
@ -132,6 +132,8 @@ INTERVAL '-31' CAST(GETDATE() AS DATE)
|
|||
INTERVAL 2 months
|
||||
INTERVAL (1 + 3) DAYS
|
||||
CAST('45' AS INTERVAL DAYS)
|
||||
FILTER(a, x -> x.a.b.c.d.e.f.g)
|
||||
FILTER(a, x -> FOO(x.a.b.c.d.e.f.g) + x.a.b.c.d.e.f.g)
|
||||
TIMESTAMP_DIFF(CURRENT_TIMESTAMP(), 1, DAY)
|
||||
DATETIME_DIFF(CURRENT_DATE, 1, DAY)
|
||||
QUANTILE(x, 0.5)
|
||||
|
@ -161,6 +163,10 @@ CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo'
|
|||
SET x = 1
|
||||
SET -v
|
||||
SET x = ';'
|
||||
SET variable = value
|
||||
SET GLOBAL variable = value
|
||||
SET LOCAL variable = value
|
||||
SET @user OFF
|
||||
COMMIT
|
||||
USE db
|
||||
USE role x
|
||||
|
|
16
tests/fixtures/optimizer/canonicalize.sql
vendored
16
tests/fixtures/optimizer/canonicalize.sql
vendored
|
@ -9,3 +9,19 @@ SELECT CAST(1 AS VARCHAR) AS "a" FROM "w" AS "w";
|
|||
|
||||
SELECT CAST(1 + 3.2 AS DOUBLE) AS a FROM w AS w;
|
||||
SELECT 1 + 3.2 AS "a" FROM "w" AS "w";
|
||||
|
||||
--------------------------------------
|
||||
-- Ensure boolean predicates
|
||||
--------------------------------------
|
||||
|
||||
SELECT a FROM x WHERE b;
|
||||
SELECT "x"."a" AS "a" FROM "x" AS "x" WHERE "x"."b" <> 0;
|
||||
|
||||
SELECT a FROM x GROUP BY a HAVING SUM(b);
|
||||
SELECT "x"."a" AS "a" FROM "x" AS "x" GROUP BY "x"."a" HAVING SUM("x"."b") <> 0;
|
||||
|
||||
SELECT a FROM x GROUP BY a HAVING SUM(b) AND TRUE;
|
||||
SELECT "x"."a" AS "a" FROM "x" AS "x" GROUP BY "x"."a" HAVING SUM("x"."b") <> 0 AND TRUE;
|
||||
|
||||
SELECT a FROM x WHERE 1;
|
||||
SELECT "x"."a" AS "a" FROM "x" AS "x" WHERE 1 <> 0;
|
||||
|
|
31
tests/fixtures/optimizer/optimizer.sql
vendored
31
tests/fixtures/optimizer/optimizer.sql
vendored
|
@ -386,6 +386,29 @@ SELECT
|
|||
"x"."b" + 1 AS "c"
|
||||
FROM "x" AS "x";
|
||||
|
||||
# title: unqualified struct element is selected in the outer query
|
||||
# execute: false
|
||||
WITH "cte" AS (
|
||||
SELECT
|
||||
FROM_JSON("value", 'STRUCT<f1: STRUCT<f2: STRUCT<f3: STRUCT<f4: STRING>>>>') AS "struct"
|
||||
FROM "tbl"
|
||||
) SELECT "struct"."f1"."f2"."f3"."f4" AS "f4" FROM "cte";
|
||||
SELECT
|
||||
FROM_JSON("tbl"."value", 'STRUCT<f1: STRUCT<f2: STRUCT<f3: STRUCT<f4: STRING>>>>')."f1"."f2"."f3"."f4" AS "f4"
|
||||
FROM "tbl" AS "tbl";
|
||||
|
||||
# title: qualified struct element is selected in the outer query
|
||||
# execute: false
|
||||
WITH "cte" AS (
|
||||
SELECT
|
||||
FROM_JSON("value", 'STRUCT<f1: STRUCT<f2: INTEGER>, STRUCT<f3: STRING>>') AS "struct"
|
||||
FROM "tbl"
|
||||
) SELECT "cte"."struct"."f1"."f2" AS "f2", "cte"."struct"."f1"."f3" AS "f3" FROM "cte";
|
||||
SELECT
|
||||
FROM_JSON("tbl"."value", 'STRUCT<f1: STRUCT<f2: INTEGER>, STRUCT<f3: STRING>>')."f1"."f2" AS "f2",
|
||||
FROM_JSON("tbl"."value", 'STRUCT<f1: STRUCT<f2: INTEGER>, STRUCT<f3: STRING>>')."f1"."f3" AS "f3"
|
||||
FROM "tbl" AS "tbl";
|
||||
|
||||
# title: left join doesnt push down predicate to join in merge subqueries
|
||||
# execute: false
|
||||
SELECT
|
||||
|
@ -430,3 +453,11 @@ LEFT JOIN "unlocked" AS "unlocked"
|
|||
WHERE
|
||||
CASE WHEN "unlocked"."company_id" IS NULL THEN 0 ELSE 1 END = FALSE
|
||||
AND NOT "company_table_2"."id" IS NULL;
|
||||
|
||||
# title: db.table alias clash
|
||||
# execute: false
|
||||
select * from db1.tbl, db2.tbl;
|
||||
SELECT
|
||||
*
|
||||
FROM "db1"."tbl" AS "tbl"
|
||||
CROSS JOIN "db2"."tbl" AS "tbl_2";
|
||||
|
|
|
@ -4,6 +4,9 @@ SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0;
|
|||
SELECT 1 FROM (SELECT * FROM x) WHERE b = 2;
|
||||
SELECT 1 AS "1" FROM (SELECT x.b AS b FROM x AS x) AS _q_0 WHERE _q_0.b = 2;
|
||||
|
||||
SELECT a, b, a from x;
|
||||
SELECT x.a AS a, x.b AS b, x.a AS a FROM x AS x;
|
||||
|
||||
SELECT (SELECT c FROM y WHERE q.b = y.b) FROM (SELECT * FROM x) AS q;
|
||||
SELECT (SELECT y.c AS c FROM y AS y WHERE q.b = y.b) AS _col_0 FROM (SELECT x.b AS b FROM x AS x) AS q;
|
||||
|
||||
|
|
3
tests/fixtures/pretty.sql
vendored
3
tests/fixtures/pretty.sql
vendored
|
@ -1,3 +1,6 @@
|
|||
SET x TO 1;
|
||||
SET x = 1;
|
||||
|
||||
SELECT * FROM test;
|
||||
SELECT
|
||||
*
|
||||
|
|
|
@ -510,6 +510,27 @@ class TestBuild(unittest.TestCase):
|
|||
.qualify("row_number() OVER (PARTITION BY a ORDER BY b) = 1"),
|
||||
"SELECT * FROM table QUALIFY ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) = 1",
|
||||
),
|
||||
(lambda: exp.delete("tbl1", "x = 1").delete("tbl2"), "DELETE FROM tbl2 WHERE x = 1"),
|
||||
(lambda: exp.delete("tbl").where("x = 1"), "DELETE FROM tbl WHERE x = 1"),
|
||||
(lambda: exp.delete(exp.table_("tbl")), "DELETE FROM tbl"),
|
||||
(
|
||||
lambda: exp.delete("tbl", "x = 1").where("y = 2"),
|
||||
"DELETE FROM tbl WHERE x = 1 AND y = 2",
|
||||
),
|
||||
(
|
||||
lambda: exp.delete("tbl", "x = 1").where(exp.condition("y = 2").or_("z = 3")),
|
||||
"DELETE FROM tbl WHERE x = 1 AND (y = 2 OR z = 3)",
|
||||
),
|
||||
(
|
||||
lambda: exp.delete("tbl").where("x = 1").returning("*", dialect="postgres"),
|
||||
"DELETE FROM tbl WHERE x = 1 RETURNING *",
|
||||
"postgres",
|
||||
),
|
||||
(
|
||||
lambda: exp.delete("tbl", where="x = 1", returning="*", dialect="postgres"),
|
||||
"DELETE FROM tbl WHERE x = 1 RETURNING *",
|
||||
"postgres",
|
||||
),
|
||||
]:
|
||||
with self.subTest(sql):
|
||||
self.assertEqual(expression().sql(dialect[0] if dialect else None), sql)
|
||||
|
|
|
@ -6,6 +6,8 @@ from sqlglot import alias, exp, parse_one
|
|||
|
||||
|
||||
class TestExpressions(unittest.TestCase):
|
||||
maxDiff = None
|
||||
|
||||
def test_arg_key(self):
|
||||
self.assertEqual(parse_one("sum(1)").find(exp.Literal).arg_key, "this")
|
||||
|
||||
|
@ -91,6 +93,32 @@ class TestExpressions(unittest.TestCase):
|
|||
self.assertIsInstance(column.parent_select, exp.Select)
|
||||
self.assertIsNone(column.find_ancestor(exp.Join))
|
||||
|
||||
def test_to_dot(self):
|
||||
column = parse_one('a.b.c."d".e.f').find(exp.Column)
|
||||
dot = column.to_dot()
|
||||
|
||||
self.assertEqual(dot.sql(), 'a.b.c."d".e.f')
|
||||
|
||||
self.assertEqual(
|
||||
dot,
|
||||
exp.Dot(
|
||||
this=exp.Dot(
|
||||
this=exp.Dot(
|
||||
this=exp.Dot(
|
||||
this=exp.Dot(
|
||||
this=exp.to_identifier("a"),
|
||||
expression=exp.to_identifier("b"),
|
||||
),
|
||||
expression=exp.to_identifier("c"),
|
||||
),
|
||||
expression=exp.to_identifier("d", quoted=True),
|
||||
),
|
||||
expression=exp.to_identifier("e"),
|
||||
),
|
||||
expression=exp.to_identifier("f"),
|
||||
),
|
||||
)
|
||||
|
||||
def test_root(self):
|
||||
ast = parse_one("select * from (select a from x)")
|
||||
self.assertIs(ast, ast.root())
|
||||
|
@ -480,6 +508,7 @@ class TestExpressions(unittest.TestCase):
|
|||
self.assertIsInstance(parse_one("COMMIT"), exp.Commit)
|
||||
self.assertIsInstance(parse_one("ROLLBACK"), exp.Rollback)
|
||||
self.assertIsInstance(parse_one("GENERATE_SERIES(a, b, c)"), exp.GenerateSeries)
|
||||
self.assertIsInstance(parse_one("COUNT_IF(a > 0)"), exp.CountIf)
|
||||
|
||||
def test_column(self):
|
||||
column = parse_one("a.b.c.d")
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import unittest
|
||||
|
||||
from sqlglot import parse_one
|
||||
from sqlglot.expressions import Func
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.tokens import Tokenizer
|
||||
|
@ -28,3 +29,12 @@ class TestGenerator(unittest.TestCase):
|
|||
tokens = Tokenizer().tokenize("SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x")
|
||||
expression = NewParser().parse(tokens)[0]
|
||||
self.assertEqual(expression.sql(), "SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x")
|
||||
|
||||
def test_identify(self):
|
||||
assert parse_one("x").sql(identify=True) == '"x"'
|
||||
assert parse_one("x").sql(identify="always") == '"x"'
|
||||
assert parse_one("X").sql(identify="always") == '"X"'
|
||||
assert parse_one("x").sql(identify="safe") == '"x"'
|
||||
assert parse_one("X").sql(identify="safe") == "X"
|
||||
assert parse_one("x as 1").sql(identify="safe") == '"x" AS "1"'
|
||||
assert parse_one("X as 1").sql(identify="safe") == 'X AS "1"'
|
||||
|
|
|
@ -102,6 +102,13 @@ class TestParser(unittest.TestCase):
|
|||
self.assertEqual(expressions[1].sql(), "ADD JAR s3://a")
|
||||
self.assertEqual(expressions[2].sql(), "SELECT 1")
|
||||
|
||||
def test_lambda_struct(self):
|
||||
expression = parse_one("FILTER(a.b, x -> x.id = id)")
|
||||
lambda_expr = expression.expression
|
||||
|
||||
self.assertIsInstance(lambda_expr.this.this, exp.Dot)
|
||||
self.assertEqual(lambda_expr.sql(), "x -> x.id = id")
|
||||
|
||||
def test_transactions(self):
|
||||
expression = parse_one("BEGIN TRANSACTION")
|
||||
self.assertIsNone(expression.this)
|
||||
|
@ -280,6 +287,39 @@ class TestParser(unittest.TestCase):
|
|||
self.assertIsInstance(parse_one("TIMESTAMP()"), exp.Func)
|
||||
self.assertIsInstance(parse_one("map.x"), exp.Column)
|
||||
|
||||
def test_set_expression(self):
|
||||
set_ = parse_one("SET")
|
||||
|
||||
self.assertEqual(set_.sql(), "SET")
|
||||
self.assertIsInstance(set_, exp.Set)
|
||||
|
||||
set_session = parse_one("SET SESSION x = 1")
|
||||
|
||||
self.assertEqual(set_session.sql(), "SET SESSION x = 1")
|
||||
self.assertIsInstance(set_session, exp.Set)
|
||||
|
||||
set_item = set_session.expressions[0]
|
||||
|
||||
self.assertIsInstance(set_item, exp.SetItem)
|
||||
self.assertIsInstance(set_item.this, exp.EQ)
|
||||
self.assertIsInstance(set_item.this.this, exp.Identifier)
|
||||
self.assertIsInstance(set_item.this.expression, exp.Literal)
|
||||
|
||||
self.assertEqual(set_item.args.get("kind"), "SESSION")
|
||||
|
||||
set_to = parse_one("SET x TO 1")
|
||||
|
||||
self.assertEqual(set_to.sql(), "SET x = 1")
|
||||
self.assertIsInstance(set_to, exp.Set)
|
||||
|
||||
set_as_command = parse_one("SET DEFAULT ROLE ALL TO USER")
|
||||
|
||||
self.assertEqual(set_as_command.sql(), "SET DEFAULT ROLE ALL TO USER")
|
||||
|
||||
self.assertIsInstance(set_as_command, exp.Command)
|
||||
self.assertEqual(set_as_command.this, "SET")
|
||||
self.assertEqual(set_as_command.expression, " DEFAULT ROLE ALL TO USER")
|
||||
|
||||
def test_pretty_config_override(self):
|
||||
self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x")
|
||||
with patch("sqlglot.pretty", True):
|
||||
|
|
|
@ -30,6 +30,21 @@ class TestTokens(unittest.TestCase):
|
|||
|
||||
self.assertEqual(tokens[-1].line, 6)
|
||||
|
||||
def test_command(self):
|
||||
tokens = Tokenizer().tokenize("SHOW;")
|
||||
self.assertEqual(tokens[0].token_type, TokenType.SHOW)
|
||||
self.assertEqual(tokens[1].token_type, TokenType.SEMICOLON)
|
||||
|
||||
tokens = Tokenizer().tokenize("EXECUTE")
|
||||
self.assertEqual(tokens[0].token_type, TokenType.EXECUTE)
|
||||
self.assertEqual(len(tokens), 1)
|
||||
|
||||
tokens = Tokenizer().tokenize("FETCH;SHOW;")
|
||||
self.assertEqual(tokens[0].token_type, TokenType.FETCH)
|
||||
self.assertEqual(tokens[1].token_type, TokenType.SEMICOLON)
|
||||
self.assertEqual(tokens[2].token_type, TokenType.SHOW)
|
||||
self.assertEqual(tokens[3].token_type, TokenType.SEMICOLON)
|
||||
|
||||
def test_jinja(self):
|
||||
tokenizer = Tokenizer()
|
||||
|
||||
|
|
|
@ -3,12 +3,13 @@ import unittest
|
|||
from sqlglot import parse_one
|
||||
from sqlglot.transforms import (
|
||||
eliminate_distinct_on,
|
||||
eliminate_qualify,
|
||||
remove_precision_parameterized_types,
|
||||
unalias_group,
|
||||
)
|
||||
|
||||
|
||||
class TestTime(unittest.TestCase):
|
||||
class TestTransforms(unittest.TestCase):
|
||||
maxDiff = None
|
||||
|
||||
def validate(self, transform, sql, target):
|
||||
|
@ -74,6 +75,38 @@ class TestTime(unittest.TestCase):
|
|||
'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) WHERE "_row_number_2" = 1',
|
||||
)
|
||||
|
||||
def test_eliminate_qualify(self):
|
||||
self.validate(
|
||||
eliminate_qualify,
|
||||
"SELECT i, a + 1 FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p) = 1",
|
||||
"SELECT i, _c FROM (SELECT i, a + 1 AS _c, ROW_NUMBER() OVER (PARTITION BY p) AS _w, p FROM qt) AS _t WHERE _w = 1",
|
||||
)
|
||||
self.validate(
|
||||
eliminate_qualify,
|
||||
"SELECT i FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1 AND p = 0",
|
||||
"SELECT i FROM (SELECT i, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS _w, p, o FROM qt) AS _t WHERE _w = 1 AND p = 0",
|
||||
)
|
||||
self.validate(
|
||||
eliminate_qualify,
|
||||
"SELECT i, p, o FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1",
|
||||
"SELECT i, p, o FROM (SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS _w FROM qt) AS _t WHERE _w = 1",
|
||||
)
|
||||
self.validate(
|
||||
eliminate_qualify,
|
||||
"SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS row_num FROM qt QUALIFY row_num = 1",
|
||||
"SELECT i, p, o, row_num FROM (SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS row_num FROM qt) AS _t WHERE row_num = 1",
|
||||
)
|
||||
self.validate(
|
||||
eliminate_qualify,
|
||||
"SELECT * FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1",
|
||||
"SELECT * FROM (SELECT *, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS _w, p, o FROM qt) AS _t WHERE _w = 1",
|
||||
)
|
||||
self.validate(
|
||||
eliminate_qualify,
|
||||
"SELECT c2, SUM(c3) OVER (PARTITION BY c2) AS r FROM t1 WHERE c3 < 4 GROUP BY c2, c3 HAVING SUM(c1) > 3 QUALIFY r IN (SELECT MIN(c1) FROM test GROUP BY c2 HAVING MIN(c1) > 3)",
|
||||
"SELECT c2, r FROM (SELECT c2, SUM(c3) OVER (PARTITION BY c2) AS r, c1 FROM t1 WHERE c3 < 4 GROUP BY c2, c3 HAVING SUM(c1) > 3) AS _t WHERE r IN (SELECT MIN(c1) FROM test GROUP BY c2 HAVING MIN(c1) > 3)",
|
||||
)
|
||||
|
||||
def test_remove_precision_parameterized_types(self):
|
||||
self.validate(
|
||||
remove_precision_parameterized_types,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue