1
0
Fork 0

Merging upstream version 10.5.2.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:03:38 +01:00
parent 77197f1e44
commit e0f3bbb5f3
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
58 changed files with 1480 additions and 383 deletions

View file

@ -1152,17 +1152,17 @@ class TestFunctions(unittest.TestCase):
def test_regexp_extract(self):
col_str = SF.regexp_extract("cola", r"(\d+)-(\d+)", 1)
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)', 1)", col_str.sql())
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)', 1)", col_str.sql())
col = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)", 1)
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)', 1)", col.sql())
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)', 1)", col.sql())
col_no_idx = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)")
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)')", col_no_idx.sql())
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)')", col_no_idx.sql())
def test_regexp_replace(self):
col_str = SF.regexp_replace("cola", r"(\d+)", "--")
self.assertEqual("REGEXP_REPLACE(cola, '(\\\d+)', '--')", col_str.sql())
self.assertEqual("REGEXP_REPLACE(cola, '(\\\\d+)', '--')", col_str.sql())
col = SF.regexp_replace(SF.col("cola"), r"(\d+)", "--")
self.assertEqual("REGEXP_REPLACE(cola, '(\\\d+)', '--')", col.sql())
self.assertEqual("REGEXP_REPLACE(cola, '(\\\\d+)', '--')", col.sql())
def test_initcap(self):
col_str = SF.initcap("cola")

View file

@ -15,11 +15,11 @@ class TestDataframeWindow(unittest.TestCase):
def test_window_spec_rows_between(self):
rows_between = WindowSpec().rowsBetween(3, 5)
self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
self.assertEqual("OVER (ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
def test_window_spec_range_between(self):
range_between = WindowSpec().rangeBetween(3, 5)
self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
self.assertEqual("OVER (RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
def test_window_partition_by(self):
partition_by = Window.partitionBy(F.col("cola"), F.col("colb"))
@ -31,46 +31,46 @@ class TestDataframeWindow(unittest.TestCase):
def test_window_rows_between(self):
rows_between = Window.rowsBetween(3, 5)
self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
self.assertEqual("OVER (ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
def test_window_range_between(self):
range_between = Window.rangeBetween(3, 5)
self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
self.assertEqual("OVER (RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
def test_window_rows_unbounded(self):
rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
self.assertEqual(
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
"OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
rows_between_unbounded_start.sql(),
)
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
self.assertEqual(
"OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
"OVER (ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
rows_between_unbounded_end.sql(),
)
rows_between_unbounded_both = Window.rowsBetween(
Window.unboundedPreceding, Window.unboundedFollowing
)
self.assertEqual(
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
"OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
rows_between_unbounded_both.sql(),
)
def test_window_range_unbounded(self):
range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
self.assertEqual(
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
"OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
range_between_unbounded_start.sql(),
)
range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
self.assertEqual(
"OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
"OVER (RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
range_between_unbounded_end.sql(),
)
range_between_unbounded_both = Window.rangeBetween(
Window.unboundedPreceding, Window.unboundedFollowing
)
self.assertEqual(
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
"OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
range_between_unbounded_both.sql(),
)

View file

@ -125,7 +125,7 @@ class TestBigQuery(Validator):
},
)
self.validate_all(
"CURRENT_DATE",
"CURRENT_TIMESTAMP()",
read={
"tsql": "GETDATE()",
},
@ -299,6 +299,14 @@ class TestBigQuery(Validator):
"snowflake": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
},
)
self.validate_all(
"SELECT cola, colb, colc FROM (VALUES (1, 'test', NULL)) AS tab(cola, colb, colc)",
write={
"spark": "SELECT cola, colb, colc FROM VALUES (1, 'test', NULL) AS tab(cola, colb, colc)",
"bigquery": "SELECT cola, colb, colc FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb, NULL AS colc)])",
"snowflake": "SELECT cola, colb, colc FROM (VALUES (1, 'test', NULL)) AS tab(cola, colb, colc)",
},
)
self.validate_all(
"SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) d, COUNT(*) e FOR c IN ('x', 'y'))",
write={
@ -324,3 +332,35 @@ class TestBigQuery(Validator):
"SELECT a, GROUP_CONCAT(b) FROM table GROUP BY a",
write={"bigquery": "SELECT a, STRING_AGG(b) FROM table GROUP BY a"},
)
def test_remove_precision_parameterized_types(self):
self.validate_all(
"SELECT CAST(1 AS NUMERIC(10, 2))",
write={
"bigquery": "SELECT CAST(1 AS NUMERIC)",
},
)
self.validate_all(
"CREATE TABLE test (a NUMERIC(10, 2))",
write={
"bigquery": "CREATE TABLE test (a NUMERIC(10, 2))",
},
)
self.validate_all(
"SELECT CAST('1' AS STRING(10)) UNION ALL SELECT CAST('2' AS STRING(10))",
write={
"bigquery": "SELECT CAST('1' AS STRING) UNION ALL SELECT CAST('2' AS STRING)",
},
)
self.validate_all(
"SELECT cola FROM (SELECT CAST('1' AS STRING(10)) AS cola UNION ALL SELECT CAST('2' AS STRING(10)) AS cola)",
write={
"bigquery": "SELECT cola FROM (SELECT CAST('1' AS STRING) AS cola UNION ALL SELECT CAST('2' AS STRING) AS cola)",
},
)
self.validate_all(
"INSERT INTO test (cola, colb) VALUES (CAST(7 AS STRING(10)), CAST(14 AS STRING(10)))",
write={
"bigquery": "INSERT INTO test (cola, colb) VALUES (CAST(7 AS STRING), CAST(14 AS STRING))",
},
)

View file

@ -14,6 +14,9 @@ class TestClickhouse(Validator):
self.validate_identity("SELECT * FROM foo LEFT ASOF JOIN bla")
self.validate_identity("SELECT * FROM foo ASOF JOIN bla")
self.validate_identity("SELECT * FROM foo ANY JOIN bla")
self.validate_identity("SELECT quantile(0.5)(a)")
self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t")
self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)")
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
@ -38,3 +41,9 @@ class TestClickhouse(Validator):
"SELECT x #! comment",
write={"": "SELECT x /* comment */"},
)
self.validate_all(
"SELECT quantileIf(0.5)(a, true)",
write={
"clickhouse": "SELECT quantileIf(0.5)(a, TRUE)",
},
)

View file

@ -85,7 +85,7 @@ class TestDialect(Validator):
self.validate_all(
"CAST(a AS BINARY(4))",
write={
"bigquery": "CAST(a AS BINARY(4))",
"bigquery": "CAST(a AS BINARY)",
"clickhouse": "CAST(a AS BINARY(4))",
"drill": "CAST(a AS VARBINARY(4))",
"duckdb": "CAST(a AS BINARY(4))",
@ -104,7 +104,7 @@ class TestDialect(Validator):
self.validate_all(
"CAST(a AS VARBINARY(4))",
write={
"bigquery": "CAST(a AS VARBINARY(4))",
"bigquery": "CAST(a AS VARBINARY)",
"clickhouse": "CAST(a AS VARBINARY(4))",
"duckdb": "CAST(a AS VARBINARY(4))",
"mysql": "CAST(a AS VARBINARY(4))",
@ -181,7 +181,7 @@ class TestDialect(Validator):
self.validate_all(
"CAST(a AS VARCHAR(3))",
write={
"bigquery": "CAST(a AS STRING(3))",
"bigquery": "CAST(a AS STRING)",
"drill": "CAST(a AS VARCHAR(3))",
"duckdb": "CAST(a AS TEXT(3))",
"mysql": "CAST(a AS VARCHAR(3))",

View file

@ -338,6 +338,24 @@ class TestHive(Validator):
)
def test_hive(self):
self.validate_all(
"SELECT A.1a AS b FROM test_a AS A",
write={
"spark": "SELECT A.1a AS b FROM test_a AS A",
},
)
self.validate_all(
"SELECT 1_a AS a FROM test_table",
write={
"spark": "SELECT 1_a AS a FROM test_table",
},
)
self.validate_all(
"SELECT a_b AS 1_a FROM test_table",
write={
"spark": "SELECT a_b AS 1_a FROM test_table",
},
)
self.validate_all(
"PERCENTILE(x, 0.5)",
write={
@ -411,7 +429,7 @@ class TestHive(Validator):
"INITCAP('new york')",
write={
"duckdb": "INITCAP('new york')",
"presto": "REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2]))",
"presto": r"REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2]))",
"hive": "INITCAP('new york')",
"spark": "INITCAP('new york')",
},

View file

@ -122,6 +122,10 @@ class TestPostgres(Validator):
"TO_TIMESTAMP(123::DOUBLE PRECISION)",
write={"postgres": "TO_TIMESTAMP(CAST(123 AS DOUBLE PRECISION))"},
)
self.validate_all(
"SELECT to_timestamp(123)::time without time zone",
write={"postgres": "SELECT CAST(TO_TIMESTAMP(123) AS TIME)"},
)
self.validate_identity(
"CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)"

View file

@ -60,11 +60,11 @@ class TestPresto(Validator):
self.validate_all(
"CAST(x AS TIMESTAMP(9) WITH TIME ZONE)",
write={
"bigquery": "CAST(x AS TIMESTAMPTZ(9))",
"bigquery": "CAST(x AS TIMESTAMPTZ)",
"duckdb": "CAST(x AS TIMESTAMPTZ(9))",
"presto": "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)",
"hive": "CAST(x AS TIMESTAMPTZ(9))",
"spark": "CAST(x AS TIMESTAMPTZ(9))",
"hive": "CAST(x AS TIMESTAMPTZ)",
"spark": "CAST(x AS TIMESTAMPTZ)",
},
)

View file

@ -523,3 +523,33 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f, LATERA
"spark": "SELECT `c0`, `c1` FROM (VALUES (1, 2), (3, 4)) AS `t0`(`c0`, `c1`)",
},
)
def test_describe_table(self):
self.validate_all(
"DESCRIBE TABLE db.table",
write={
"snowflake": "DESCRIBE TABLE db.table",
"spark": "DESCRIBE db.table",
},
)
self.validate_all(
"DESCRIBE db.table",
write={
"snowflake": "DESCRIBE TABLE db.table",
"spark": "DESCRIBE db.table",
},
)
self.validate_all(
"DESC TABLE db.table",
write={
"snowflake": "DESCRIBE TABLE db.table",
"spark": "DESCRIBE db.table",
},
)
self.validate_all(
"DESC VIEW db.table",
write={
"snowflake": "DESCRIBE VIEW db.table",
"spark": "DESCRIBE db.table",
},
)

View file

@ -207,6 +207,7 @@ TBLPROPERTIES (
)
def test_spark(self):
self.validate_identity("SELECT UNIX_TIMESTAMP()")
self.validate_all(
"ARRAY_SORT(x, (left, right) -> -1)",
write={

View file

@ -6,6 +6,8 @@ class TestTSQL(Validator):
def test_tsql(self):
self.validate_identity('SELECT "x"."y" FROM foo')
self.validate_identity("SELECT * FROM #foo")
self.validate_identity("SELECT * FROM ##foo")
self.validate_identity(
"SELECT DISTINCT DepartmentName, PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY BaseRate) OVER (PARTITION BY DepartmentName) AS MedianCont FROM dbo.DimEmployee"
)
@ -71,6 +73,12 @@ class TestTSQL(Validator):
"tsql": "CAST(x AS DATETIME2)",
},
)
self.validate_all(
"CAST(x AS DATETIME2(6))",
write={
"hive": "CAST(x AS TIMESTAMP)",
},
)
def test_charindex(self):
self.validate_all(
@ -300,6 +308,12 @@ class TestTSQL(Validator):
"spark": "SELECT CAST(y.x AS VARCHAR(10)) AS z FROM testdb.dbo.test AS y",
},
)
self.validate_all(
"SELECT CAST((SELECT x FROM y) AS VARCHAR) AS test",
write={
"spark": "SELECT CAST((SELECT x FROM y) AS STRING) AS test",
},
)
def test_add_date(self):
self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')")
@ -441,3 +455,13 @@ class TestTSQL(Validator):
"SELECT '''test'''",
write={"spark": r"SELECT '\'test\''"},
)
def test_eomonth(self):
self.validate_all(
"EOMONTH(GETDATE())",
write={"spark": "LAST_DAY(CURRENT_TIMESTAMP())"},
)
self.validate_all(
"EOMONTH(GETDATE(), -1)",
write={"spark": "LAST_DAY(ADD_MONTHS(CURRENT_TIMESTAMP(), -1))"},
)

View file

@ -89,6 +89,7 @@ POSEXPLODE("x") AS ("a", "b")
POSEXPLODE("x") AS ("a", "b", "c")
STR_POSITION(x, 'a')
STR_POSITION(x, 'a', 3)
LEVENSHTEIN('gumbo', 'gambol', 2, 1, 1)
SPLIT(SPLIT(referrer, 'utm_source=')[OFFSET(1)], "&")[OFFSET(0)]
x[ORDINAL(1)][SAFE_OFFSET(2)]
x LIKE SUBSTR('abc', 1, 1)
@ -425,6 +426,7 @@ SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT
SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 AND 3)
SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND 3)
SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLOWING)
SELECT AVG(x) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM t
SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x) AS y
SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x DESC)
SELECT SUM(x) FILTER(WHERE x > 1)
@ -450,14 +452,24 @@ SELECT student, score FROM tests CROSS JOIN UNNEST(scores) AS t(a, b)
SELECT student, score FROM tests CROSS JOIN UNNEST(scores) WITH ORDINALITY AS t(a, b)
SELECT student, score FROM tests CROSS JOIN UNNEST(x.scores) AS t(score)
SELECT student, score FROM tests CROSS JOIN UNNEST(ARRAY(x.scores)) AS t(score)
SELECT * FROM t WITH (TABLOCK, INDEX(myindex))
SELECT * FROM t WITH (NOWAIT)
CREATE TABLE foo AS (SELECT 1) UNION ALL (SELECT 2)
CREATE TABLE foo (id INT PRIMARY KEY ASC)
CREATE TABLE a.b AS SELECT 1
CREATE TABLE a.b AS SELECT 1 WITH DATA AND STATISTICS
CREATE TABLE a.b AS SELECT 1 WITH NO DATA AND NO STATISTICS
CREATE TABLE a.b AS (SELECT 1) NO PRIMARY INDEX
CREATE TABLE a.b AS (SELECT 1) UNIQUE PRIMARY INDEX index1 (a) UNIQUE INDEX index2 (b)
CREATE TABLE a.b AS (SELECT 1) PRIMARY AMP INDEX index1 (a) UNIQUE INDEX index2 (b)
CREATE TABLE a.b AS SELECT a FROM a.c
CREATE TABLE IF NOT EXISTS x AS SELECT a FROM d
CREATE TEMPORARY TABLE x AS SELECT a FROM d
CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d
CREATE VIEW x AS SELECT a FROM b
CREATE VIEW IF NOT EXISTS x AS SELECT a FROM b
CREATE VIEW z (a, b COMMENT 'b', c COMMENT 'c') AS SELECT a, b, c FROM d
CREATE VIEW IF NOT EXISTS z (a, b COMMENT 'b', c COMMENT 'c') AS SELECT a, b, c FROM d
CREATE OR REPLACE VIEW x AS SELECT *
CREATE OR REPLACE TEMPORARY VIEW x AS SELECT *
CREATE TEMPORARY VIEW x AS SELECT a FROM d
@ -490,6 +502,8 @@ CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT)
CREATE TABLE z (a INT REFERENCES parent(b, c))
CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id))
CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c))
CREATE VIEW z (a, b)
CREATE VIEW z (a, b COMMENT 'b', c COMMENT 'c')
CREATE TEMPORARY FUNCTION f
CREATE TEMPORARY FUNCTION f AS 'g'
CREATE FUNCTION f
@ -559,6 +573,7 @@ INSERT INTO x.z IF EXISTS SELECT * FROM y
INSERT INTO x VALUES (1, 'a', 2.0)
INSERT INTO x VALUES (1, 'a', 2.0), (1, 'a', 3.0), (X(), y[1], z.x)
INSERT INTO y (a, b, c) SELECT a, b, c FROM x
INSERT INTO y (SELECT 1) UNION (SELECT 2)
INSERT OVERWRITE TABLE x IF EXISTS SELECT * FROM y
INSERT OVERWRITE TABLE a.b IF EXISTS SELECT * FROM y
INSERT OVERWRITE DIRECTORY 'x' SELECT 1
@ -627,3 +642,4 @@ ALTER TABLE integers ALTER COLUMN i SET DEFAULT 10
ALTER TABLE integers ALTER COLUMN i DROP DEFAULT
ALTER TABLE mydataset.mytable DROP COLUMN A, DROP COLUMN IF EXISTS B
ALTER TABLE mydataset.mytable ADD COLUMN A TEXT, ADD COLUMN IF NOT EXISTS B INT
SELECT div.a FROM test_table AS div

View file

@ -311,3 +311,42 @@ FROM
ON
t1.cola = t2.cola;
SELECT /*+ BROADCAST(a2) */ a1.cola AS cola, a2.cola AS cola FROM VALUES (1) AS a1(cola) JOIN VALUES (1) AS a2(cola) ON a1.cola = a2.cola;
# title: Nested subquery selects from same table as another subquery
WITH i AS (
SELECT
x.a AS a
FROM x AS x
), j AS (
SELECT
x.a,
x.b
FROM x AS x
), k AS (
SELECT
j.a,
j.b
FROM j AS j
)
SELECT
i.a,
k.b
FROM i AS i
LEFT JOIN k AS k
ON i.a = k.a;
SELECT x.a AS a, x_2.b AS b FROM x AS x LEFT JOIN x AS x_2 ON x.a = x_2.a;
# title: Outer select joins on inner select join
WITH i AS (
SELECT
x.a AS a
FROM y AS y
JOIN x AS x
ON y.b = x.b
)
SELECT
x.a AS a
FROM x AS x
LEFT JOIN i AS i
ON x.a = i.a;
WITH i AS (SELECT x.a AS a FROM y AS y JOIN x AS x ON y.b = x.b) SELECT x.a AS a FROM x AS x LEFT JOIN i AS i ON x.a = i.a;

View file

@ -105,7 +105,7 @@ LEFT JOIN "_u_0" AS "_u_0"
JOIN "y" AS "y"
ON "x"."b" = "y"."b"
WHERE
"_u_0"."_col_0" >= 0 AND "x"."a" > 1 AND NOT "_u_0"."_u_1" IS NULL
"_u_0"."_col_0" >= 0 AND "x"."a" > 1
GROUP BY
"x"."a";

View file

@ -54,3 +54,6 @@ WITH t1 AS (SELECT q.cola AS cola FROM UNNEST(ARRAY(STRUCT(1 AS cola, 'test' AS
SELECT x FROM VALUES(1, 2) AS q(x, y);
SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y);
SELECT i.a FROM x AS i LEFT JOIN (SELECT a, b FROM (SELECT a, b FROM x)) AS j ON i.a = j.a;
SELECT i.a AS a FROM x AS i LEFT JOIN (SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0) AS j ON i.a = j.a;

View file

@ -375,6 +375,18 @@ CAST('1998-12-01' AS DATE) - INTERVAL '90' foo;
date '1998-12-01' + interval '90' foo;
CAST('1998-12-01' AS DATE) + INTERVAL '90' foo;
CAST(x AS DATE) + interval '1' week;
CAST(x AS DATE) + INTERVAL '1' week;
CAST('2008-11-11' AS DATETIME) + INTERVAL '5' MONTH;
CAST('2009-04-11 00:00:00' AS DATETIME);
datetime '1998-12-01' - interval '90' day;
CAST('1998-09-02 00:00:00' AS DATETIME);
CAST(x AS DATETIME) + interval '1' week;
CAST(x AS DATETIME) + INTERVAL '1' week;
--------------------------------------
-- Comparisons
--------------------------------------

View file

@ -150,7 +150,6 @@ WHERE
"part"."p_size" = 15
AND "part"."p_type" LIKE '%BRASS'
AND "partsupp"."ps_supplycost" = "_u_0"."_col_0"
AND NOT "_u_0"."_u_1" IS NULL
ORDER BY
"s_acctbal" DESC,
"n_name",
@ -1008,7 +1007,7 @@ JOIN "part" AS "part"
LEFT JOIN "_u_0" AS "_u_0"
ON "_u_0"."_u_1" = "part"."p_partkey"
WHERE
"lineitem"."l_quantity" < "_u_0"."_col_0" AND NOT "_u_0"."_u_1" IS NULL;
"lineitem"."l_quantity" < "_u_0"."_col_0";
--------------------------------------
-- TPC-H 18
@ -1253,10 +1252,7 @@ WITH "_u_0" AS (
LEFT JOIN "_u_3" AS "_u_3"
ON "partsupp"."ps_partkey" = "_u_3"."p_partkey"
WHERE
"partsupp"."ps_availqty" > "_u_0"."_col_0"
AND NOT "_u_0"."_u_1" IS NULL
AND NOT "_u_0"."_u_2" IS NULL
AND NOT "_u_3"."p_partkey" IS NULL
"partsupp"."ps_availqty" > "_u_0"."_col_0" AND NOT "_u_3"."p_partkey" IS NULL
GROUP BY
"partsupp"."ps_suppkey"
)

View file

@ -22,6 +22,8 @@ WHERE
AND x.a > ANY (SELECT y.a FROM y)
AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a LIMIT 10)
AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a OFFSET 10)
AND x.a > ALL (SELECT y.c FROM y WHERE y.a = x.a)
AND x.a > (SELECT COUNT(*) as d FROM y WHERE y.a = x.a)
;
SELECT
*
@ -130,37 +132,42 @@ LEFT JOIN (
y.a
) AS _u_15
ON x.a = _u_15.a
LEFT JOIN (
SELECT
ARRAY_AGG(c),
y.a AS _u_20
FROM y
WHERE
TRUE
GROUP BY
y.a
) AS _u_19
ON _u_19._u_20 = x.a
LEFT JOIN (
SELECT
COUNT(*) AS d,
y.a AS _u_22
FROM y
WHERE
TRUE
GROUP BY
y.a
) AS _u_21
ON _u_21._u_22 = x.a
WHERE
x.a = _u_0.a
AND NOT "_u_1"."a" IS NULL
AND NOT "_u_2"."b" IS NULL
AND NOT "_u_3"."a" IS NULL
AND x.a = _u_4.b
AND x.a > _u_6.b
AND x.a = _u_8.a
AND NOT x.a = _u_9.a
AND ARRAY_ANY(_u_10.a, _x -> _x = x.a)
AND (
x.a = _u_4.b AND NOT _u_4._u_5 IS NULL
)
AND (
x.a > _u_6.b AND NOT _u_6._u_7 IS NULL
)
AND (
None = _u_8.a AND NOT _u_8.a IS NULL
)
AND NOT (
x.a = _u_9.a AND NOT _u_9.a IS NULL
)
AND (
ARRAY_ANY(_u_10.a, _x -> _x = x.a) AND NOT _u_10._u_11 IS NULL
)
AND (
(
(
x.a < _u_12.a AND NOT _u_12._u_13 IS NULL
) AND NOT _u_12._u_13 IS NULL
)
AND ARRAY_ANY(_u_12._u_14, "_x" -> _x <> x.d)
)
AND (
NOT _u_15.a IS NULL AND NOT _u_15.a IS NULL
x.a < _u_12.a AND ARRAY_ANY(_u_12._u_14, "_x" -> _x <> x.d)
)
AND NOT _u_15.a IS NULL
AND x.a IN (
SELECT
y.a AS a
@ -199,4 +206,6 @@ WHERE
WHERE
y.a = x.a
OFFSET 10
);
)
AND ARRAY_ALL(_u_19."", _x -> _x = x.a)
AND x.a > COALESCE(_u_21.d, 0);

View file

@ -27,8 +27,7 @@ def assert_logger_contains(message, logger, level="error"):
def load_sql_fixtures(filename):
with open(os.path.join(FIXTURES_DIR, filename), encoding="utf-8") as f:
for sql in _filter_comments(f.read()).splitlines():
yield sql
yield from _filter_comments(f.read()).splitlines()
def load_sql_fixture_pairs(filename):

View file

@ -401,6 +401,36 @@ class TestExecutor(unittest.TestCase):
],
)
def test_correlated_count(self):
tables = {
"parts": [{"pnum": 0, "qoh": 1}],
"supplies": [],
}
schema = {
"parts": {"pnum": "int", "qoh": "int"},
"supplies": {"pnum": "int", "shipdate": "int"},
}
self.assertEqual(
execute(
"""
select *
from parts
where parts.qoh >= (
select count(supplies.shipdate) + 1
from supplies
where supplies.pnum = parts.pnum and supplies.shipdate < 10
)
""",
tables=tables,
schema=schema,
).rows,
[
(0, 1),
],
)
def test_table_depth_mismatch(self):
tables = {"table": []}
schema = {"db": {"table": {"col": "VARCHAR"}}}

View file

@ -646,3 +646,72 @@ FROM foo""",
exp.Column(this=exp.to_identifier("colb")),
],
)
def test_values(self):
self.assertEqual(
exp.values([(1, 2), (3, 4)], "t", ["a", "b"]).sql(),
"(VALUES (1, 2), (3, 4)) AS t(a, b)",
)
self.assertEqual(
exp.values(
[(1, 2), (3, 4)],
"t",
{"a": exp.DataType.build("TEXT"), "b": exp.DataType.build("TEXT")},
).sql(),
"(VALUES (CAST(1 AS TEXT), CAST(2 AS TEXT)), (3, 4)) AS t(a, b)",
)
with self.assertRaises(ValueError):
exp.values([(1, 2), (3, 4)], columns=["a"])
def test_data_type_builder(self):
self.assertEqual(exp.DataType.build("TEXT").sql(), "TEXT")
self.assertEqual(exp.DataType.build("DECIMAL(10, 2)").sql(), "DECIMAL(10, 2)")
self.assertEqual(exp.DataType.build("VARCHAR(255)").sql(), "VARCHAR(255)")
self.assertEqual(exp.DataType.build("ARRAY<INT>").sql(), "ARRAY<INT>")
self.assertEqual(exp.DataType.build("CHAR").sql(), "CHAR")
self.assertEqual(exp.DataType.build("NCHAR").sql(), "CHAR")
self.assertEqual(exp.DataType.build("VARCHAR").sql(), "VARCHAR")
self.assertEqual(exp.DataType.build("NVARCHAR").sql(), "VARCHAR")
self.assertEqual(exp.DataType.build("TEXT").sql(), "TEXT")
self.assertEqual(exp.DataType.build("BINARY").sql(), "BINARY")
self.assertEqual(exp.DataType.build("VARBINARY").sql(), "VARBINARY")
self.assertEqual(exp.DataType.build("INT").sql(), "INT")
self.assertEqual(exp.DataType.build("TINYINT").sql(), "TINYINT")
self.assertEqual(exp.DataType.build("SMALLINT").sql(), "SMALLINT")
self.assertEqual(exp.DataType.build("BIGINT").sql(), "BIGINT")
self.assertEqual(exp.DataType.build("FLOAT").sql(), "FLOAT")
self.assertEqual(exp.DataType.build("DOUBLE").sql(), "DOUBLE")
self.assertEqual(exp.DataType.build("DECIMAL").sql(), "DECIMAL")
self.assertEqual(exp.DataType.build("BOOLEAN").sql(), "BOOLEAN")
self.assertEqual(exp.DataType.build("JSON").sql(), "JSON")
self.assertEqual(exp.DataType.build("JSONB").sql(), "JSONB")
self.assertEqual(exp.DataType.build("INTERVAL").sql(), "INTERVAL")
self.assertEqual(exp.DataType.build("TIME").sql(), "TIME")
self.assertEqual(exp.DataType.build("TIMESTAMP").sql(), "TIMESTAMP")
self.assertEqual(exp.DataType.build("TIMESTAMPTZ").sql(), "TIMESTAMPTZ")
self.assertEqual(exp.DataType.build("TIMESTAMPLTZ").sql(), "TIMESTAMPLTZ")
self.assertEqual(exp.DataType.build("DATE").sql(), "DATE")
self.assertEqual(exp.DataType.build("DATETIME").sql(), "DATETIME")
self.assertEqual(exp.DataType.build("ARRAY").sql(), "ARRAY")
self.assertEqual(exp.DataType.build("MAP").sql(), "MAP")
self.assertEqual(exp.DataType.build("UUID").sql(), "UUID")
self.assertEqual(exp.DataType.build("GEOGRAPHY").sql(), "GEOGRAPHY")
self.assertEqual(exp.DataType.build("GEOMETRY").sql(), "GEOMETRY")
self.assertEqual(exp.DataType.build("STRUCT").sql(), "STRUCT")
self.assertEqual(exp.DataType.build("NULLABLE").sql(), "NULLABLE")
self.assertEqual(exp.DataType.build("HLLSKETCH").sql(), "HLLSKETCH")
self.assertEqual(exp.DataType.build("HSTORE").sql(), "HSTORE")
self.assertEqual(exp.DataType.build("SUPER").sql(), "SUPER")
self.assertEqual(exp.DataType.build("SERIAL").sql(), "SERIAL")
self.assertEqual(exp.DataType.build("SMALLSERIAL").sql(), "SMALLSERIAL")
self.assertEqual(exp.DataType.build("BIGSERIAL").sql(), "BIGSERIAL")
self.assertEqual(exp.DataType.build("XML").sql(), "XML")
self.assertEqual(exp.DataType.build("UNIQUEIDENTIFIER").sql(), "UNIQUEIDENTIFIER")
self.assertEqual(exp.DataType.build("MONEY").sql(), "MONEY")
self.assertEqual(exp.DataType.build("SMALLMONEY").sql(), "SMALLMONEY")
self.assertEqual(exp.DataType.build("ROWVERSION").sql(), "ROWVERSION")
self.assertEqual(exp.DataType.build("IMAGE").sql(), "IMAGE")
self.assertEqual(exp.DataType.build("VARIANT").sql(), "VARIANT")
self.assertEqual(exp.DataType.build("OBJECT").sql(), "OBJECT")
self.assertEqual(exp.DataType.build("NULL").sql(), "NULL")
self.assertEqual(exp.DataType.build("UNKNOWN").sql(), "UNKNOWN")

View file

@ -299,10 +299,10 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(set(scopes[6].sources), {"q", "z", "r", "s"})
self.assertEqual(len(scopes[6].columns), 6)
self.assertEqual(set(c.table for c in scopes[6].columns), {"r", "s"})
self.assertEqual({c.table for c in scopes[6].columns}, {"r", "s"})
self.assertEqual(scopes[6].source_columns("q"), [])
self.assertEqual(len(scopes[6].source_columns("r")), 2)
self.assertEqual(set(c.table for c in scopes[6].source_columns("r")), {"r"})
self.assertEqual({c.table for c in scopes[6].source_columns("r")}, {"r"})
self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"})
self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b")
@ -578,3 +578,16 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
scope_t, scope_y = build_scope(query).cte_scopes
self.assertEqual(set(scope_t.cte_sources), {"t"})
self.assertEqual(set(scope_y.cte_sources), {"t", "y"})
def test_schema_with_spaces(self):
schema = {
"a": {
"b c": "text",
'"d e"': "text",
}
}
self.assertEqual(
optimizer.optimize(parse_one("SELECT * FROM a"), schema=schema),
parse_one('SELECT "a"."b c" AS "b c", "a"."d e" AS "d e" FROM "a" AS "a"'),
)

View file

@ -8,7 +8,8 @@ from tests.helpers import assert_logger_contains
class TestParser(unittest.TestCase):
def test_parse_empty(self):
self.assertIsNone(parse_one(""))
with self.assertRaises(ParseError) as ctx:
parse_one("")
def test_parse_into(self):
self.assertIsInstance(parse_one("left join foo", into=exp.Join), exp.Join)
@ -90,6 +91,9 @@ class TestParser(unittest.TestCase):
parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(),
"""SELECT * FROM x, z CROSS JOIN y LATERAL VIEW EXPLODE(y)""",
)
self.assertIsNone(
parse_one("create table a as (select b from c) index").find(exp.TableAlias)
)
def test_command(self):
expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1", read="hive")
@ -155,6 +159,11 @@ class TestParser(unittest.TestCase):
assert expressions[0].args["from"].expressions[0].this.name == "a"
assert expressions[1].args["from"].expressions[0].this.name == "b"
expressions = parse("SELECT 1; ; SELECT 2")
assert len(expressions) == 3
assert expressions[1] is None
def test_expression(self):
ignore = Parser(error_level=ErrorLevel.IGNORE)
self.assertIsInstance(ignore.expression(exp.Hint, expressions=[""]), exp.Hint)

View file

@ -184,3 +184,19 @@ class TestSchema(unittest.TestCase):
schema = MappingSchema({"foo": {"bar": parse_one("INT", into=exp.DataType)}})
self.assertEqual(schema.get_column_type("foo", "bar").this, exp.DataType.Type.INT)
def test_schema_normalization(self):
schema = MappingSchema(
schema={"x": {"`y`": {"Z": {"a": "INT", "`B`": "VARCHAR"}, "w": {"C": "INT"}}}},
dialect="spark",
)
table_z = exp.Table(this="z", db="y", catalog="x")
table_w = exp.Table(this="w", db="y", catalog="x")
self.assertEqual(schema.column_names(table_z), ["a", "B"])
self.assertEqual(schema.column_names(table_w), ["c"])
# Clickhouse supports both `` and "" for identifier quotes; sqlglot uses "" when generating sql
schema = MappingSchema(schema={"x": {"`y`": "INT"}}, dialect="clickhouse")
self.assertEqual(schema.column_names(exp.Table(this="x")), ["y"])

33
tests/test_serde.py Normal file
View file

@ -0,0 +1,33 @@
import json
import unittest
from sqlglot import exp, parse_one
from sqlglot.optimizer.annotate_types import annotate_types
from tests.helpers import load_sql_fixtures
class CustomExpression(exp.Expression):
...
class TestSerDe(unittest.TestCase):
def dump_load(self, expression):
return exp.Expression.load(json.loads(json.dumps(expression.dump())))
def test_serde(self):
for sql in load_sql_fixtures("identity.sql"):
with self.subTest(sql):
before = parse_one(sql)
after = self.dump_load(before)
self.assertEqual(before, after)
def test_custom_expression(self):
before = CustomExpression()
after = self.dump_load(before)
self.assertEqual(before, after)
def test_type_annotations(self):
before = annotate_types(parse_one("CAST('1' AS INT)"))
after = self.dump_load(before)
self.assertEqual(before.type, after.type)
self.assertEqual(before.this.type, after.this.type)

View file

@ -1,7 +1,11 @@
import unittest
from sqlglot import parse_one
from sqlglot.transforms import eliminate_distinct_on, unalias_group
from sqlglot.transforms import (
eliminate_distinct_on,
remove_precision_parameterized_types,
unalias_group,
)
class TestTime(unittest.TestCase):
@ -62,3 +66,10 @@ class TestTime(unittest.TestCase):
"SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC",
'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) WHERE "_row_number_2" = 1',
)
def test_remove_precision_parameterized_types(self):
self.validate(
remove_precision_parameterized_types,
"SELECT CAST(1 AS DECIMAL(10, 2)), CAST('13' AS VARCHAR(10))",
"SELECT CAST(1 AS DECIMAL), CAST('13' AS VARCHAR)",
)

View file

@ -117,6 +117,11 @@ class TestTranspile(unittest.TestCase):
"select x from foo -- x",
"SELECT x FROM foo /* x */",
)
self.validate(
"""select x, --
from foo""",
"SELECT x FROM foo",
)
self.validate(
"""
-- comment 1