Merging upstream version 10.5.2.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
77197f1e44
commit
e0f3bbb5f3
58 changed files with 1480 additions and 383 deletions
|
@ -1152,17 +1152,17 @@ class TestFunctions(unittest.TestCase):
|
|||
|
||||
def test_regexp_extract(self):
|
||||
col_str = SF.regexp_extract("cola", r"(\d+)-(\d+)", 1)
|
||||
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)', 1)", col_str.sql())
|
||||
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)', 1)", col_str.sql())
|
||||
col = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)", 1)
|
||||
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)', 1)", col.sql())
|
||||
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)', 1)", col.sql())
|
||||
col_no_idx = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)")
|
||||
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)')", col_no_idx.sql())
|
||||
self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)')", col_no_idx.sql())
|
||||
|
||||
def test_regexp_replace(self):
|
||||
col_str = SF.regexp_replace("cola", r"(\d+)", "--")
|
||||
self.assertEqual("REGEXP_REPLACE(cola, '(\\\d+)', '--')", col_str.sql())
|
||||
self.assertEqual("REGEXP_REPLACE(cola, '(\\\\d+)', '--')", col_str.sql())
|
||||
col = SF.regexp_replace(SF.col("cola"), r"(\d+)", "--")
|
||||
self.assertEqual("REGEXP_REPLACE(cola, '(\\\d+)', '--')", col.sql())
|
||||
self.assertEqual("REGEXP_REPLACE(cola, '(\\\\d+)', '--')", col.sql())
|
||||
|
||||
def test_initcap(self):
|
||||
col_str = SF.initcap("cola")
|
||||
|
|
|
@ -15,11 +15,11 @@ class TestDataframeWindow(unittest.TestCase):
|
|||
|
||||
def test_window_spec_rows_between(self):
|
||||
rows_between = WindowSpec().rowsBetween(3, 5)
|
||||
self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
|
||||
self.assertEqual("OVER (ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
|
||||
|
||||
def test_window_spec_range_between(self):
|
||||
range_between = WindowSpec().rangeBetween(3, 5)
|
||||
self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
|
||||
self.assertEqual("OVER (RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
|
||||
|
||||
def test_window_partition_by(self):
|
||||
partition_by = Window.partitionBy(F.col("cola"), F.col("colb"))
|
||||
|
@ -31,46 +31,46 @@ class TestDataframeWindow(unittest.TestCase):
|
|||
|
||||
def test_window_rows_between(self):
|
||||
rows_between = Window.rowsBetween(3, 5)
|
||||
self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
|
||||
self.assertEqual("OVER (ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
|
||||
|
||||
def test_window_range_between(self):
|
||||
range_between = Window.rangeBetween(3, 5)
|
||||
self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
|
||||
self.assertEqual("OVER (RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
|
||||
|
||||
def test_window_rows_unbounded(self):
|
||||
rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
|
||||
self.assertEqual(
|
||||
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
|
||||
"OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
|
||||
rows_between_unbounded_start.sql(),
|
||||
)
|
||||
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
|
||||
self.assertEqual(
|
||||
"OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||
"OVER (ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||
rows_between_unbounded_end.sql(),
|
||||
)
|
||||
rows_between_unbounded_both = Window.rowsBetween(
|
||||
Window.unboundedPreceding, Window.unboundedFollowing
|
||||
)
|
||||
self.assertEqual(
|
||||
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||
"OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||
rows_between_unbounded_both.sql(),
|
||||
)
|
||||
|
||||
def test_window_range_unbounded(self):
|
||||
range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
|
||||
self.assertEqual(
|
||||
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
|
||||
"OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
|
||||
range_between_unbounded_start.sql(),
|
||||
)
|
||||
range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
|
||||
self.assertEqual(
|
||||
"OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||
"OVER (RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||
range_between_unbounded_end.sql(),
|
||||
)
|
||||
range_between_unbounded_both = Window.rangeBetween(
|
||||
Window.unboundedPreceding, Window.unboundedFollowing
|
||||
)
|
||||
self.assertEqual(
|
||||
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||
"OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||
range_between_unbounded_both.sql(),
|
||||
)
|
||||
|
|
|
@ -125,7 +125,7 @@ class TestBigQuery(Validator):
|
|||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CURRENT_DATE",
|
||||
"CURRENT_TIMESTAMP()",
|
||||
read={
|
||||
"tsql": "GETDATE()",
|
||||
},
|
||||
|
@ -299,6 +299,14 @@ class TestBigQuery(Validator):
|
|||
"snowflake": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT cola, colb, colc FROM (VALUES (1, 'test', NULL)) AS tab(cola, colb, colc)",
|
||||
write={
|
||||
"spark": "SELECT cola, colb, colc FROM VALUES (1, 'test', NULL) AS tab(cola, colb, colc)",
|
||||
"bigquery": "SELECT cola, colb, colc FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb, NULL AS colc)])",
|
||||
"snowflake": "SELECT cola, colb, colc FROM (VALUES (1, 'test', NULL)) AS tab(cola, colb, colc)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) d, COUNT(*) e FOR c IN ('x', 'y'))",
|
||||
write={
|
||||
|
@ -324,3 +332,35 @@ class TestBigQuery(Validator):
|
|||
"SELECT a, GROUP_CONCAT(b) FROM table GROUP BY a",
|
||||
write={"bigquery": "SELECT a, STRING_AGG(b) FROM table GROUP BY a"},
|
||||
)
|
||||
|
||||
def test_remove_precision_parameterized_types(self):
|
||||
self.validate_all(
|
||||
"SELECT CAST(1 AS NUMERIC(10, 2))",
|
||||
write={
|
||||
"bigquery": "SELECT CAST(1 AS NUMERIC)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CREATE TABLE test (a NUMERIC(10, 2))",
|
||||
write={
|
||||
"bigquery": "CREATE TABLE test (a NUMERIC(10, 2))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT CAST('1' AS STRING(10)) UNION ALL SELECT CAST('2' AS STRING(10))",
|
||||
write={
|
||||
"bigquery": "SELECT CAST('1' AS STRING) UNION ALL SELECT CAST('2' AS STRING)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT cola FROM (SELECT CAST('1' AS STRING(10)) AS cola UNION ALL SELECT CAST('2' AS STRING(10)) AS cola)",
|
||||
write={
|
||||
"bigquery": "SELECT cola FROM (SELECT CAST('1' AS STRING) AS cola UNION ALL SELECT CAST('2' AS STRING) AS cola)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"INSERT INTO test (cola, colb) VALUES (CAST(7 AS STRING(10)), CAST(14 AS STRING(10)))",
|
||||
write={
|
||||
"bigquery": "INSERT INTO test (cola, colb) VALUES (CAST(7 AS STRING), CAST(14 AS STRING))",
|
||||
},
|
||||
)
|
||||
|
|
|
@ -14,6 +14,9 @@ class TestClickhouse(Validator):
|
|||
self.validate_identity("SELECT * FROM foo LEFT ASOF JOIN bla")
|
||||
self.validate_identity("SELECT * FROM foo ASOF JOIN bla")
|
||||
self.validate_identity("SELECT * FROM foo ANY JOIN bla")
|
||||
self.validate_identity("SELECT quantile(0.5)(a)")
|
||||
self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t")
|
||||
self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)")
|
||||
|
||||
self.validate_all(
|
||||
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
|
||||
|
@ -38,3 +41,9 @@ class TestClickhouse(Validator):
|
|||
"SELECT x #! comment",
|
||||
write={"": "SELECT x /* comment */"},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT quantileIf(0.5)(a, true)",
|
||||
write={
|
||||
"clickhouse": "SELECT quantileIf(0.5)(a, TRUE)",
|
||||
},
|
||||
)
|
||||
|
|
|
@ -85,7 +85,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"CAST(a AS BINARY(4))",
|
||||
write={
|
||||
"bigquery": "CAST(a AS BINARY(4))",
|
||||
"bigquery": "CAST(a AS BINARY)",
|
||||
"clickhouse": "CAST(a AS BINARY(4))",
|
||||
"drill": "CAST(a AS VARBINARY(4))",
|
||||
"duckdb": "CAST(a AS BINARY(4))",
|
||||
|
@ -104,7 +104,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"CAST(a AS VARBINARY(4))",
|
||||
write={
|
||||
"bigquery": "CAST(a AS VARBINARY(4))",
|
||||
"bigquery": "CAST(a AS VARBINARY)",
|
||||
"clickhouse": "CAST(a AS VARBINARY(4))",
|
||||
"duckdb": "CAST(a AS VARBINARY(4))",
|
||||
"mysql": "CAST(a AS VARBINARY(4))",
|
||||
|
@ -181,7 +181,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"CAST(a AS VARCHAR(3))",
|
||||
write={
|
||||
"bigquery": "CAST(a AS STRING(3))",
|
||||
"bigquery": "CAST(a AS STRING)",
|
||||
"drill": "CAST(a AS VARCHAR(3))",
|
||||
"duckdb": "CAST(a AS TEXT(3))",
|
||||
"mysql": "CAST(a AS VARCHAR(3))",
|
||||
|
|
|
@ -338,6 +338,24 @@ class TestHive(Validator):
|
|||
)
|
||||
|
||||
def test_hive(self):
|
||||
self.validate_all(
|
||||
"SELECT A.1a AS b FROM test_a AS A",
|
||||
write={
|
||||
"spark": "SELECT A.1a AS b FROM test_a AS A",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT 1_a AS a FROM test_table",
|
||||
write={
|
||||
"spark": "SELECT 1_a AS a FROM test_table",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT a_b AS 1_a FROM test_table",
|
||||
write={
|
||||
"spark": "SELECT a_b AS 1_a FROM test_table",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"PERCENTILE(x, 0.5)",
|
||||
write={
|
||||
|
@ -411,7 +429,7 @@ class TestHive(Validator):
|
|||
"INITCAP('new york')",
|
||||
write={
|
||||
"duckdb": "INITCAP('new york')",
|
||||
"presto": "REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2]))",
|
||||
"presto": r"REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2]))",
|
||||
"hive": "INITCAP('new york')",
|
||||
"spark": "INITCAP('new york')",
|
||||
},
|
||||
|
|
|
@ -122,6 +122,10 @@ class TestPostgres(Validator):
|
|||
"TO_TIMESTAMP(123::DOUBLE PRECISION)",
|
||||
write={"postgres": "TO_TIMESTAMP(CAST(123 AS DOUBLE PRECISION))"},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT to_timestamp(123)::time without time zone",
|
||||
write={"postgres": "SELECT CAST(TO_TIMESTAMP(123) AS TIME)"},
|
||||
)
|
||||
|
||||
self.validate_identity(
|
||||
"CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)"
|
||||
|
|
|
@ -60,11 +60,11 @@ class TestPresto(Validator):
|
|||
self.validate_all(
|
||||
"CAST(x AS TIMESTAMP(9) WITH TIME ZONE)",
|
||||
write={
|
||||
"bigquery": "CAST(x AS TIMESTAMPTZ(9))",
|
||||
"bigquery": "CAST(x AS TIMESTAMPTZ)",
|
||||
"duckdb": "CAST(x AS TIMESTAMPTZ(9))",
|
||||
"presto": "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)",
|
||||
"hive": "CAST(x AS TIMESTAMPTZ(9))",
|
||||
"spark": "CAST(x AS TIMESTAMPTZ(9))",
|
||||
"hive": "CAST(x AS TIMESTAMPTZ)",
|
||||
"spark": "CAST(x AS TIMESTAMPTZ)",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
@ -523,3 +523,33 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f, LATERA
|
|||
"spark": "SELECT `c0`, `c1` FROM (VALUES (1, 2), (3, 4)) AS `t0`(`c0`, `c1`)",
|
||||
},
|
||||
)
|
||||
|
||||
def test_describe_table(self):
|
||||
self.validate_all(
|
||||
"DESCRIBE TABLE db.table",
|
||||
write={
|
||||
"snowflake": "DESCRIBE TABLE db.table",
|
||||
"spark": "DESCRIBE db.table",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"DESCRIBE db.table",
|
||||
write={
|
||||
"snowflake": "DESCRIBE TABLE db.table",
|
||||
"spark": "DESCRIBE db.table",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"DESC TABLE db.table",
|
||||
write={
|
||||
"snowflake": "DESCRIBE TABLE db.table",
|
||||
"spark": "DESCRIBE db.table",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"DESC VIEW db.table",
|
||||
write={
|
||||
"snowflake": "DESCRIBE VIEW db.table",
|
||||
"spark": "DESCRIBE db.table",
|
||||
},
|
||||
)
|
||||
|
|
|
@ -207,6 +207,7 @@ TBLPROPERTIES (
|
|||
)
|
||||
|
||||
def test_spark(self):
|
||||
self.validate_identity("SELECT UNIX_TIMESTAMP()")
|
||||
self.validate_all(
|
||||
"ARRAY_SORT(x, (left, right) -> -1)",
|
||||
write={
|
||||
|
|
|
@ -6,6 +6,8 @@ class TestTSQL(Validator):
|
|||
|
||||
def test_tsql(self):
|
||||
self.validate_identity('SELECT "x"."y" FROM foo')
|
||||
self.validate_identity("SELECT * FROM #foo")
|
||||
self.validate_identity("SELECT * FROM ##foo")
|
||||
self.validate_identity(
|
||||
"SELECT DISTINCT DepartmentName, PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY BaseRate) OVER (PARTITION BY DepartmentName) AS MedianCont FROM dbo.DimEmployee"
|
||||
)
|
||||
|
@ -71,6 +73,12 @@ class TestTSQL(Validator):
|
|||
"tsql": "CAST(x AS DATETIME2)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CAST(x AS DATETIME2(6))",
|
||||
write={
|
||||
"hive": "CAST(x AS TIMESTAMP)",
|
||||
},
|
||||
)
|
||||
|
||||
def test_charindex(self):
|
||||
self.validate_all(
|
||||
|
@ -300,6 +308,12 @@ class TestTSQL(Validator):
|
|||
"spark": "SELECT CAST(y.x AS VARCHAR(10)) AS z FROM testdb.dbo.test AS y",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT CAST((SELECT x FROM y) AS VARCHAR) AS test",
|
||||
write={
|
||||
"spark": "SELECT CAST((SELECT x FROM y) AS STRING) AS test",
|
||||
},
|
||||
)
|
||||
|
||||
def test_add_date(self):
|
||||
self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')")
|
||||
|
@ -441,3 +455,13 @@ class TestTSQL(Validator):
|
|||
"SELECT '''test'''",
|
||||
write={"spark": r"SELECT '\'test\''"},
|
||||
)
|
||||
|
||||
def test_eomonth(self):
|
||||
self.validate_all(
|
||||
"EOMONTH(GETDATE())",
|
||||
write={"spark": "LAST_DAY(CURRENT_TIMESTAMP())"},
|
||||
)
|
||||
self.validate_all(
|
||||
"EOMONTH(GETDATE(), -1)",
|
||||
write={"spark": "LAST_DAY(ADD_MONTHS(CURRENT_TIMESTAMP(), -1))"},
|
||||
)
|
||||
|
|
16
tests/fixtures/identity.sql
vendored
16
tests/fixtures/identity.sql
vendored
|
@ -89,6 +89,7 @@ POSEXPLODE("x") AS ("a", "b")
|
|||
POSEXPLODE("x") AS ("a", "b", "c")
|
||||
STR_POSITION(x, 'a')
|
||||
STR_POSITION(x, 'a', 3)
|
||||
LEVENSHTEIN('gumbo', 'gambol', 2, 1, 1)
|
||||
SPLIT(SPLIT(referrer, 'utm_source=')[OFFSET(1)], "&")[OFFSET(0)]
|
||||
x[ORDINAL(1)][SAFE_OFFSET(2)]
|
||||
x LIKE SUBSTR('abc', 1, 1)
|
||||
|
@ -425,6 +426,7 @@ SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT
|
|||
SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 AND 3)
|
||||
SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND 3)
|
||||
SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLOWING)
|
||||
SELECT AVG(x) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM t
|
||||
SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x) AS y
|
||||
SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x DESC)
|
||||
SELECT SUM(x) FILTER(WHERE x > 1)
|
||||
|
@ -450,14 +452,24 @@ SELECT student, score FROM tests CROSS JOIN UNNEST(scores) AS t(a, b)
|
|||
SELECT student, score FROM tests CROSS JOIN UNNEST(scores) WITH ORDINALITY AS t(a, b)
|
||||
SELECT student, score FROM tests CROSS JOIN UNNEST(x.scores) AS t(score)
|
||||
SELECT student, score FROM tests CROSS JOIN UNNEST(ARRAY(x.scores)) AS t(score)
|
||||
SELECT * FROM t WITH (TABLOCK, INDEX(myindex))
|
||||
SELECT * FROM t WITH (NOWAIT)
|
||||
CREATE TABLE foo AS (SELECT 1) UNION ALL (SELECT 2)
|
||||
CREATE TABLE foo (id INT PRIMARY KEY ASC)
|
||||
CREATE TABLE a.b AS SELECT 1
|
||||
CREATE TABLE a.b AS SELECT 1 WITH DATA AND STATISTICS
|
||||
CREATE TABLE a.b AS SELECT 1 WITH NO DATA AND NO STATISTICS
|
||||
CREATE TABLE a.b AS (SELECT 1) NO PRIMARY INDEX
|
||||
CREATE TABLE a.b AS (SELECT 1) UNIQUE PRIMARY INDEX index1 (a) UNIQUE INDEX index2 (b)
|
||||
CREATE TABLE a.b AS (SELECT 1) PRIMARY AMP INDEX index1 (a) UNIQUE INDEX index2 (b)
|
||||
CREATE TABLE a.b AS SELECT a FROM a.c
|
||||
CREATE TABLE IF NOT EXISTS x AS SELECT a FROM d
|
||||
CREATE TEMPORARY TABLE x AS SELECT a FROM d
|
||||
CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d
|
||||
CREATE VIEW x AS SELECT a FROM b
|
||||
CREATE VIEW IF NOT EXISTS x AS SELECT a FROM b
|
||||
CREATE VIEW z (a, b COMMENT 'b', c COMMENT 'c') AS SELECT a, b, c FROM d
|
||||
CREATE VIEW IF NOT EXISTS z (a, b COMMENT 'b', c COMMENT 'c') AS SELECT a, b, c FROM d
|
||||
CREATE OR REPLACE VIEW x AS SELECT *
|
||||
CREATE OR REPLACE TEMPORARY VIEW x AS SELECT *
|
||||
CREATE TEMPORARY VIEW x AS SELECT a FROM d
|
||||
|
@ -490,6 +502,8 @@ CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT)
|
|||
CREATE TABLE z (a INT REFERENCES parent(b, c))
|
||||
CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id))
|
||||
CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c))
|
||||
CREATE VIEW z (a, b)
|
||||
CREATE VIEW z (a, b COMMENT 'b', c COMMENT 'c')
|
||||
CREATE TEMPORARY FUNCTION f
|
||||
CREATE TEMPORARY FUNCTION f AS 'g'
|
||||
CREATE FUNCTION f
|
||||
|
@ -559,6 +573,7 @@ INSERT INTO x.z IF EXISTS SELECT * FROM y
|
|||
INSERT INTO x VALUES (1, 'a', 2.0)
|
||||
INSERT INTO x VALUES (1, 'a', 2.0), (1, 'a', 3.0), (X(), y[1], z.x)
|
||||
INSERT INTO y (a, b, c) SELECT a, b, c FROM x
|
||||
INSERT INTO y (SELECT 1) UNION (SELECT 2)
|
||||
INSERT OVERWRITE TABLE x IF EXISTS SELECT * FROM y
|
||||
INSERT OVERWRITE TABLE a.b IF EXISTS SELECT * FROM y
|
||||
INSERT OVERWRITE DIRECTORY 'x' SELECT 1
|
||||
|
@ -627,3 +642,4 @@ ALTER TABLE integers ALTER COLUMN i SET DEFAULT 10
|
|||
ALTER TABLE integers ALTER COLUMN i DROP DEFAULT
|
||||
ALTER TABLE mydataset.mytable DROP COLUMN A, DROP COLUMN IF EXISTS B
|
||||
ALTER TABLE mydataset.mytable ADD COLUMN A TEXT, ADD COLUMN IF NOT EXISTS B INT
|
||||
SELECT div.a FROM test_table AS div
|
||||
|
|
39
tests/fixtures/optimizer/merge_subqueries.sql
vendored
39
tests/fixtures/optimizer/merge_subqueries.sql
vendored
|
@ -311,3 +311,42 @@ FROM
|
|||
ON
|
||||
t1.cola = t2.cola;
|
||||
SELECT /*+ BROADCAST(a2) */ a1.cola AS cola, a2.cola AS cola FROM VALUES (1) AS a1(cola) JOIN VALUES (1) AS a2(cola) ON a1.cola = a2.cola;
|
||||
|
||||
# title: Nested subquery selects from same table as another subquery
|
||||
WITH i AS (
|
||||
SELECT
|
||||
x.a AS a
|
||||
FROM x AS x
|
||||
), j AS (
|
||||
SELECT
|
||||
x.a,
|
||||
x.b
|
||||
FROM x AS x
|
||||
), k AS (
|
||||
SELECT
|
||||
j.a,
|
||||
j.b
|
||||
FROM j AS j
|
||||
)
|
||||
SELECT
|
||||
i.a,
|
||||
k.b
|
||||
FROM i AS i
|
||||
LEFT JOIN k AS k
|
||||
ON i.a = k.a;
|
||||
SELECT x.a AS a, x_2.b AS b FROM x AS x LEFT JOIN x AS x_2 ON x.a = x_2.a;
|
||||
|
||||
# title: Outer select joins on inner select join
|
||||
WITH i AS (
|
||||
SELECT
|
||||
x.a AS a
|
||||
FROM y AS y
|
||||
JOIN x AS x
|
||||
ON y.b = x.b
|
||||
)
|
||||
SELECT
|
||||
x.a AS a
|
||||
FROM x AS x
|
||||
LEFT JOIN i AS i
|
||||
ON x.a = i.a;
|
||||
WITH i AS (SELECT x.a AS a FROM y AS y JOIN x AS x ON y.b = x.b) SELECT x.a AS a FROM x AS x LEFT JOIN i AS i ON x.a = i.a;
|
||||
|
|
2
tests/fixtures/optimizer/optimizer.sql
vendored
2
tests/fixtures/optimizer/optimizer.sql
vendored
|
@ -105,7 +105,7 @@ LEFT JOIN "_u_0" AS "_u_0"
|
|||
JOIN "y" AS "y"
|
||||
ON "x"."b" = "y"."b"
|
||||
WHERE
|
||||
"_u_0"."_col_0" >= 0 AND "x"."a" > 1 AND NOT "_u_0"."_u_1" IS NULL
|
||||
"_u_0"."_col_0" >= 0 AND "x"."a" > 1
|
||||
GROUP BY
|
||||
"x"."a";
|
||||
|
||||
|
|
|
@ -54,3 +54,6 @@ WITH t1 AS (SELECT q.cola AS cola FROM UNNEST(ARRAY(STRUCT(1 AS cola, 'test' AS
|
|||
|
||||
SELECT x FROM VALUES(1, 2) AS q(x, y);
|
||||
SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y);
|
||||
|
||||
SELECT i.a FROM x AS i LEFT JOIN (SELECT a, b FROM (SELECT a, b FROM x)) AS j ON i.a = j.a;
|
||||
SELECT i.a AS a FROM x AS i LEFT JOIN (SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0) AS j ON i.a = j.a;
|
||||
|
|
12
tests/fixtures/optimizer/simplify.sql
vendored
12
tests/fixtures/optimizer/simplify.sql
vendored
|
@ -375,6 +375,18 @@ CAST('1998-12-01' AS DATE) - INTERVAL '90' foo;
|
|||
date '1998-12-01' + interval '90' foo;
|
||||
CAST('1998-12-01' AS DATE) + INTERVAL '90' foo;
|
||||
|
||||
CAST(x AS DATE) + interval '1' week;
|
||||
CAST(x AS DATE) + INTERVAL '1' week;
|
||||
|
||||
CAST('2008-11-11' AS DATETIME) + INTERVAL '5' MONTH;
|
||||
CAST('2009-04-11 00:00:00' AS DATETIME);
|
||||
|
||||
datetime '1998-12-01' - interval '90' day;
|
||||
CAST('1998-09-02 00:00:00' AS DATETIME);
|
||||
|
||||
CAST(x AS DATETIME) + interval '1' week;
|
||||
CAST(x AS DATETIME) + INTERVAL '1' week;
|
||||
|
||||
--------------------------------------
|
||||
-- Comparisons
|
||||
--------------------------------------
|
||||
|
|
8
tests/fixtures/optimizer/tpc-h/tpc-h.sql
vendored
8
tests/fixtures/optimizer/tpc-h/tpc-h.sql
vendored
|
@ -150,7 +150,6 @@ WHERE
|
|||
"part"."p_size" = 15
|
||||
AND "part"."p_type" LIKE '%BRASS'
|
||||
AND "partsupp"."ps_supplycost" = "_u_0"."_col_0"
|
||||
AND NOT "_u_0"."_u_1" IS NULL
|
||||
ORDER BY
|
||||
"s_acctbal" DESC,
|
||||
"n_name",
|
||||
|
@ -1008,7 +1007,7 @@ JOIN "part" AS "part"
|
|||
LEFT JOIN "_u_0" AS "_u_0"
|
||||
ON "_u_0"."_u_1" = "part"."p_partkey"
|
||||
WHERE
|
||||
"lineitem"."l_quantity" < "_u_0"."_col_0" AND NOT "_u_0"."_u_1" IS NULL;
|
||||
"lineitem"."l_quantity" < "_u_0"."_col_0";
|
||||
|
||||
--------------------------------------
|
||||
-- TPC-H 18
|
||||
|
@ -1253,10 +1252,7 @@ WITH "_u_0" AS (
|
|||
LEFT JOIN "_u_3" AS "_u_3"
|
||||
ON "partsupp"."ps_partkey" = "_u_3"."p_partkey"
|
||||
WHERE
|
||||
"partsupp"."ps_availqty" > "_u_0"."_col_0"
|
||||
AND NOT "_u_0"."_u_1" IS NULL
|
||||
AND NOT "_u_0"."_u_2" IS NULL
|
||||
AND NOT "_u_3"."p_partkey" IS NULL
|
||||
"partsupp"."ps_availqty" > "_u_0"."_col_0" AND NOT "_u_3"."p_partkey" IS NULL
|
||||
GROUP BY
|
||||
"partsupp"."ps_suppkey"
|
||||
)
|
||||
|
|
59
tests/fixtures/optimizer/unnest_subqueries.sql
vendored
59
tests/fixtures/optimizer/unnest_subqueries.sql
vendored
|
@ -22,6 +22,8 @@ WHERE
|
|||
AND x.a > ANY (SELECT y.a FROM y)
|
||||
AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a LIMIT 10)
|
||||
AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a OFFSET 10)
|
||||
AND x.a > ALL (SELECT y.c FROM y WHERE y.a = x.a)
|
||||
AND x.a > (SELECT COUNT(*) as d FROM y WHERE y.a = x.a)
|
||||
;
|
||||
SELECT
|
||||
*
|
||||
|
@ -130,37 +132,42 @@ LEFT JOIN (
|
|||
y.a
|
||||
) AS _u_15
|
||||
ON x.a = _u_15.a
|
||||
LEFT JOIN (
|
||||
SELECT
|
||||
ARRAY_AGG(c),
|
||||
y.a AS _u_20
|
||||
FROM y
|
||||
WHERE
|
||||
TRUE
|
||||
GROUP BY
|
||||
y.a
|
||||
) AS _u_19
|
||||
ON _u_19._u_20 = x.a
|
||||
LEFT JOIN (
|
||||
SELECT
|
||||
COUNT(*) AS d,
|
||||
y.a AS _u_22
|
||||
FROM y
|
||||
WHERE
|
||||
TRUE
|
||||
GROUP BY
|
||||
y.a
|
||||
) AS _u_21
|
||||
ON _u_21._u_22 = x.a
|
||||
WHERE
|
||||
x.a = _u_0.a
|
||||
AND NOT "_u_1"."a" IS NULL
|
||||
AND NOT "_u_2"."b" IS NULL
|
||||
AND NOT "_u_3"."a" IS NULL
|
||||
AND x.a = _u_4.b
|
||||
AND x.a > _u_6.b
|
||||
AND x.a = _u_8.a
|
||||
AND NOT x.a = _u_9.a
|
||||
AND ARRAY_ANY(_u_10.a, _x -> _x = x.a)
|
||||
AND (
|
||||
x.a = _u_4.b AND NOT _u_4._u_5 IS NULL
|
||||
)
|
||||
AND (
|
||||
x.a > _u_6.b AND NOT _u_6._u_7 IS NULL
|
||||
)
|
||||
AND (
|
||||
None = _u_8.a AND NOT _u_8.a IS NULL
|
||||
)
|
||||
AND NOT (
|
||||
x.a = _u_9.a AND NOT _u_9.a IS NULL
|
||||
)
|
||||
AND (
|
||||
ARRAY_ANY(_u_10.a, _x -> _x = x.a) AND NOT _u_10._u_11 IS NULL
|
||||
)
|
||||
AND (
|
||||
(
|
||||
(
|
||||
x.a < _u_12.a AND NOT _u_12._u_13 IS NULL
|
||||
) AND NOT _u_12._u_13 IS NULL
|
||||
)
|
||||
AND ARRAY_ANY(_u_12._u_14, "_x" -> _x <> x.d)
|
||||
)
|
||||
AND (
|
||||
NOT _u_15.a IS NULL AND NOT _u_15.a IS NULL
|
||||
x.a < _u_12.a AND ARRAY_ANY(_u_12._u_14, "_x" -> _x <> x.d)
|
||||
)
|
||||
AND NOT _u_15.a IS NULL
|
||||
AND x.a IN (
|
||||
SELECT
|
||||
y.a AS a
|
||||
|
@ -199,4 +206,6 @@ WHERE
|
|||
WHERE
|
||||
y.a = x.a
|
||||
OFFSET 10
|
||||
);
|
||||
)
|
||||
AND ARRAY_ALL(_u_19."", _x -> _x = x.a)
|
||||
AND x.a > COALESCE(_u_21.d, 0);
|
||||
|
|
|
@ -27,8 +27,7 @@ def assert_logger_contains(message, logger, level="error"):
|
|||
|
||||
def load_sql_fixtures(filename):
|
||||
with open(os.path.join(FIXTURES_DIR, filename), encoding="utf-8") as f:
|
||||
for sql in _filter_comments(f.read()).splitlines():
|
||||
yield sql
|
||||
yield from _filter_comments(f.read()).splitlines()
|
||||
|
||||
|
||||
def load_sql_fixture_pairs(filename):
|
||||
|
|
|
@ -401,6 +401,36 @@ class TestExecutor(unittest.TestCase):
|
|||
],
|
||||
)
|
||||
|
||||
def test_correlated_count(self):
|
||||
tables = {
|
||||
"parts": [{"pnum": 0, "qoh": 1}],
|
||||
"supplies": [],
|
||||
}
|
||||
|
||||
schema = {
|
||||
"parts": {"pnum": "int", "qoh": "int"},
|
||||
"supplies": {"pnum": "int", "shipdate": "int"},
|
||||
}
|
||||
|
||||
self.assertEqual(
|
||||
execute(
|
||||
"""
|
||||
select *
|
||||
from parts
|
||||
where parts.qoh >= (
|
||||
select count(supplies.shipdate) + 1
|
||||
from supplies
|
||||
where supplies.pnum = parts.pnum and supplies.shipdate < 10
|
||||
)
|
||||
""",
|
||||
tables=tables,
|
||||
schema=schema,
|
||||
).rows,
|
||||
[
|
||||
(0, 1),
|
||||
],
|
||||
)
|
||||
|
||||
def test_table_depth_mismatch(self):
|
||||
tables = {"table": []}
|
||||
schema = {"db": {"table": {"col": "VARCHAR"}}}
|
||||
|
|
|
@ -646,3 +646,72 @@ FROM foo""",
|
|||
exp.Column(this=exp.to_identifier("colb")),
|
||||
],
|
||||
)
|
||||
|
||||
def test_values(self):
|
||||
self.assertEqual(
|
||||
exp.values([(1, 2), (3, 4)], "t", ["a", "b"]).sql(),
|
||||
"(VALUES (1, 2), (3, 4)) AS t(a, b)",
|
||||
)
|
||||
self.assertEqual(
|
||||
exp.values(
|
||||
[(1, 2), (3, 4)],
|
||||
"t",
|
||||
{"a": exp.DataType.build("TEXT"), "b": exp.DataType.build("TEXT")},
|
||||
).sql(),
|
||||
"(VALUES (CAST(1 AS TEXT), CAST(2 AS TEXT)), (3, 4)) AS t(a, b)",
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
exp.values([(1, 2), (3, 4)], columns=["a"])
|
||||
|
||||
def test_data_type_builder(self):
|
||||
self.assertEqual(exp.DataType.build("TEXT").sql(), "TEXT")
|
||||
self.assertEqual(exp.DataType.build("DECIMAL(10, 2)").sql(), "DECIMAL(10, 2)")
|
||||
self.assertEqual(exp.DataType.build("VARCHAR(255)").sql(), "VARCHAR(255)")
|
||||
self.assertEqual(exp.DataType.build("ARRAY<INT>").sql(), "ARRAY<INT>")
|
||||
self.assertEqual(exp.DataType.build("CHAR").sql(), "CHAR")
|
||||
self.assertEqual(exp.DataType.build("NCHAR").sql(), "CHAR")
|
||||
self.assertEqual(exp.DataType.build("VARCHAR").sql(), "VARCHAR")
|
||||
self.assertEqual(exp.DataType.build("NVARCHAR").sql(), "VARCHAR")
|
||||
self.assertEqual(exp.DataType.build("TEXT").sql(), "TEXT")
|
||||
self.assertEqual(exp.DataType.build("BINARY").sql(), "BINARY")
|
||||
self.assertEqual(exp.DataType.build("VARBINARY").sql(), "VARBINARY")
|
||||
self.assertEqual(exp.DataType.build("INT").sql(), "INT")
|
||||
self.assertEqual(exp.DataType.build("TINYINT").sql(), "TINYINT")
|
||||
self.assertEqual(exp.DataType.build("SMALLINT").sql(), "SMALLINT")
|
||||
self.assertEqual(exp.DataType.build("BIGINT").sql(), "BIGINT")
|
||||
self.assertEqual(exp.DataType.build("FLOAT").sql(), "FLOAT")
|
||||
self.assertEqual(exp.DataType.build("DOUBLE").sql(), "DOUBLE")
|
||||
self.assertEqual(exp.DataType.build("DECIMAL").sql(), "DECIMAL")
|
||||
self.assertEqual(exp.DataType.build("BOOLEAN").sql(), "BOOLEAN")
|
||||
self.assertEqual(exp.DataType.build("JSON").sql(), "JSON")
|
||||
self.assertEqual(exp.DataType.build("JSONB").sql(), "JSONB")
|
||||
self.assertEqual(exp.DataType.build("INTERVAL").sql(), "INTERVAL")
|
||||
self.assertEqual(exp.DataType.build("TIME").sql(), "TIME")
|
||||
self.assertEqual(exp.DataType.build("TIMESTAMP").sql(), "TIMESTAMP")
|
||||
self.assertEqual(exp.DataType.build("TIMESTAMPTZ").sql(), "TIMESTAMPTZ")
|
||||
self.assertEqual(exp.DataType.build("TIMESTAMPLTZ").sql(), "TIMESTAMPLTZ")
|
||||
self.assertEqual(exp.DataType.build("DATE").sql(), "DATE")
|
||||
self.assertEqual(exp.DataType.build("DATETIME").sql(), "DATETIME")
|
||||
self.assertEqual(exp.DataType.build("ARRAY").sql(), "ARRAY")
|
||||
self.assertEqual(exp.DataType.build("MAP").sql(), "MAP")
|
||||
self.assertEqual(exp.DataType.build("UUID").sql(), "UUID")
|
||||
self.assertEqual(exp.DataType.build("GEOGRAPHY").sql(), "GEOGRAPHY")
|
||||
self.assertEqual(exp.DataType.build("GEOMETRY").sql(), "GEOMETRY")
|
||||
self.assertEqual(exp.DataType.build("STRUCT").sql(), "STRUCT")
|
||||
self.assertEqual(exp.DataType.build("NULLABLE").sql(), "NULLABLE")
|
||||
self.assertEqual(exp.DataType.build("HLLSKETCH").sql(), "HLLSKETCH")
|
||||
self.assertEqual(exp.DataType.build("HSTORE").sql(), "HSTORE")
|
||||
self.assertEqual(exp.DataType.build("SUPER").sql(), "SUPER")
|
||||
self.assertEqual(exp.DataType.build("SERIAL").sql(), "SERIAL")
|
||||
self.assertEqual(exp.DataType.build("SMALLSERIAL").sql(), "SMALLSERIAL")
|
||||
self.assertEqual(exp.DataType.build("BIGSERIAL").sql(), "BIGSERIAL")
|
||||
self.assertEqual(exp.DataType.build("XML").sql(), "XML")
|
||||
self.assertEqual(exp.DataType.build("UNIQUEIDENTIFIER").sql(), "UNIQUEIDENTIFIER")
|
||||
self.assertEqual(exp.DataType.build("MONEY").sql(), "MONEY")
|
||||
self.assertEqual(exp.DataType.build("SMALLMONEY").sql(), "SMALLMONEY")
|
||||
self.assertEqual(exp.DataType.build("ROWVERSION").sql(), "ROWVERSION")
|
||||
self.assertEqual(exp.DataType.build("IMAGE").sql(), "IMAGE")
|
||||
self.assertEqual(exp.DataType.build("VARIANT").sql(), "VARIANT")
|
||||
self.assertEqual(exp.DataType.build("OBJECT").sql(), "OBJECT")
|
||||
self.assertEqual(exp.DataType.build("NULL").sql(), "NULL")
|
||||
self.assertEqual(exp.DataType.build("UNKNOWN").sql(), "UNKNOWN")
|
||||
|
|
|
@ -299,10 +299,10 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
|
||||
self.assertEqual(set(scopes[6].sources), {"q", "z", "r", "s"})
|
||||
self.assertEqual(len(scopes[6].columns), 6)
|
||||
self.assertEqual(set(c.table for c in scopes[6].columns), {"r", "s"})
|
||||
self.assertEqual({c.table for c in scopes[6].columns}, {"r", "s"})
|
||||
self.assertEqual(scopes[6].source_columns("q"), [])
|
||||
self.assertEqual(len(scopes[6].source_columns("r")), 2)
|
||||
self.assertEqual(set(c.table for c in scopes[6].source_columns("r")), {"r"})
|
||||
self.assertEqual({c.table for c in scopes[6].source_columns("r")}, {"r"})
|
||||
|
||||
self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"})
|
||||
self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b")
|
||||
|
@ -578,3 +578,16 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
scope_t, scope_y = build_scope(query).cte_scopes
|
||||
self.assertEqual(set(scope_t.cte_sources), {"t"})
|
||||
self.assertEqual(set(scope_y.cte_sources), {"t", "y"})
|
||||
|
||||
def test_schema_with_spaces(self):
|
||||
schema = {
|
||||
"a": {
|
||||
"b c": "text",
|
||||
'"d e"': "text",
|
||||
}
|
||||
}
|
||||
|
||||
self.assertEqual(
|
||||
optimizer.optimize(parse_one("SELECT * FROM a"), schema=schema),
|
||||
parse_one('SELECT "a"."b c" AS "b c", "a"."d e" AS "d e" FROM "a" AS "a"'),
|
||||
)
|
||||
|
|
|
@ -8,7 +8,8 @@ from tests.helpers import assert_logger_contains
|
|||
|
||||
class TestParser(unittest.TestCase):
|
||||
def test_parse_empty(self):
|
||||
self.assertIsNone(parse_one(""))
|
||||
with self.assertRaises(ParseError) as ctx:
|
||||
parse_one("")
|
||||
|
||||
def test_parse_into(self):
|
||||
self.assertIsInstance(parse_one("left join foo", into=exp.Join), exp.Join)
|
||||
|
@ -90,6 +91,9 @@ class TestParser(unittest.TestCase):
|
|||
parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(),
|
||||
"""SELECT * FROM x, z CROSS JOIN y LATERAL VIEW EXPLODE(y)""",
|
||||
)
|
||||
self.assertIsNone(
|
||||
parse_one("create table a as (select b from c) index").find(exp.TableAlias)
|
||||
)
|
||||
|
||||
def test_command(self):
|
||||
expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1", read="hive")
|
||||
|
@ -155,6 +159,11 @@ class TestParser(unittest.TestCase):
|
|||
assert expressions[0].args["from"].expressions[0].this.name == "a"
|
||||
assert expressions[1].args["from"].expressions[0].this.name == "b"
|
||||
|
||||
expressions = parse("SELECT 1; ; SELECT 2")
|
||||
|
||||
assert len(expressions) == 3
|
||||
assert expressions[1] is None
|
||||
|
||||
def test_expression(self):
|
||||
ignore = Parser(error_level=ErrorLevel.IGNORE)
|
||||
self.assertIsInstance(ignore.expression(exp.Hint, expressions=[""]), exp.Hint)
|
||||
|
|
|
@ -184,3 +184,19 @@ class TestSchema(unittest.TestCase):
|
|||
|
||||
schema = MappingSchema({"foo": {"bar": parse_one("INT", into=exp.DataType)}})
|
||||
self.assertEqual(schema.get_column_type("foo", "bar").this, exp.DataType.Type.INT)
|
||||
|
||||
def test_schema_normalization(self):
|
||||
schema = MappingSchema(
|
||||
schema={"x": {"`y`": {"Z": {"a": "INT", "`B`": "VARCHAR"}, "w": {"C": "INT"}}}},
|
||||
dialect="spark",
|
||||
)
|
||||
|
||||
table_z = exp.Table(this="z", db="y", catalog="x")
|
||||
table_w = exp.Table(this="w", db="y", catalog="x")
|
||||
|
||||
self.assertEqual(schema.column_names(table_z), ["a", "B"])
|
||||
self.assertEqual(schema.column_names(table_w), ["c"])
|
||||
|
||||
# Clickhouse supports both `` and "" for identifier quotes; sqlglot uses "" when generating sql
|
||||
schema = MappingSchema(schema={"x": {"`y`": "INT"}}, dialect="clickhouse")
|
||||
self.assertEqual(schema.column_names(exp.Table(this="x")), ["y"])
|
||||
|
|
33
tests/test_serde.py
Normal file
33
tests/test_serde.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
import json
|
||||
import unittest
|
||||
|
||||
from sqlglot import exp, parse_one
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
from tests.helpers import load_sql_fixtures
|
||||
|
||||
|
||||
class CustomExpression(exp.Expression):
|
||||
...
|
||||
|
||||
|
||||
class TestSerDe(unittest.TestCase):
|
||||
def dump_load(self, expression):
|
||||
return exp.Expression.load(json.loads(json.dumps(expression.dump())))
|
||||
|
||||
def test_serde(self):
|
||||
for sql in load_sql_fixtures("identity.sql"):
|
||||
with self.subTest(sql):
|
||||
before = parse_one(sql)
|
||||
after = self.dump_load(before)
|
||||
self.assertEqual(before, after)
|
||||
|
||||
def test_custom_expression(self):
|
||||
before = CustomExpression()
|
||||
after = self.dump_load(before)
|
||||
self.assertEqual(before, after)
|
||||
|
||||
def test_type_annotations(self):
|
||||
before = annotate_types(parse_one("CAST('1' AS INT)"))
|
||||
after = self.dump_load(before)
|
||||
self.assertEqual(before.type, after.type)
|
||||
self.assertEqual(before.this.type, after.this.type)
|
|
@ -1,7 +1,11 @@
|
|||
import unittest
|
||||
|
||||
from sqlglot import parse_one
|
||||
from sqlglot.transforms import eliminate_distinct_on, unalias_group
|
||||
from sqlglot.transforms import (
|
||||
eliminate_distinct_on,
|
||||
remove_precision_parameterized_types,
|
||||
unalias_group,
|
||||
)
|
||||
|
||||
|
||||
class TestTime(unittest.TestCase):
|
||||
|
@ -62,3 +66,10 @@ class TestTime(unittest.TestCase):
|
|||
"SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC",
|
||||
'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) WHERE "_row_number_2" = 1',
|
||||
)
|
||||
|
||||
def test_remove_precision_parameterized_types(self):
|
||||
self.validate(
|
||||
remove_precision_parameterized_types,
|
||||
"SELECT CAST(1 AS DECIMAL(10, 2)), CAST('13' AS VARCHAR(10))",
|
||||
"SELECT CAST(1 AS DECIMAL), CAST('13' AS VARCHAR)",
|
||||
)
|
||||
|
|
|
@ -117,6 +117,11 @@ class TestTranspile(unittest.TestCase):
|
|||
"select x from foo -- x",
|
||||
"SELECT x FROM foo /* x */",
|
||||
)
|
||||
self.validate(
|
||||
"""select x, --
|
||||
from foo""",
|
||||
"SELECT x FROM foo",
|
||||
)
|
||||
self.validate(
|
||||
"""
|
||||
-- comment 1
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue