Merging upstream version 25.1.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
7ab180cac9
commit
3b7539dcad
79 changed files with 28803 additions and 24929 deletions
|
@ -20,6 +20,14 @@ class TestBigQuery(Validator):
|
|||
maxDiff = None
|
||||
|
||||
def test_bigquery(self):
|
||||
self.validate_all(
|
||||
"EXTRACT(HOUR FROM DATETIME(2008, 12, 25, 15, 30, 00))",
|
||||
write={
|
||||
"bigquery": "EXTRACT(HOUR FROM DATETIME(2008, 12, 25, 15, 30, 00))",
|
||||
"duckdb": "EXTRACT(HOUR FROM MAKE_TIMESTAMP(2008, 12, 25, 15, 30, 00))",
|
||||
"snowflake": "DATE_PART(HOUR, TIMESTAMP_FROM_PARTS(2008, 12, 25, 15, 30, 00))",
|
||||
},
|
||||
)
|
||||
self.validate_identity(
|
||||
"""CREATE TEMPORARY FUNCTION FOO()
|
||||
RETURNS STRING
|
||||
|
@ -619,9 +627,9 @@ LANGUAGE js AS
|
|||
'SELECT TIMESTAMP_ADD(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE)',
|
||||
write={
|
||||
"bigquery": "SELECT TIMESTAMP_ADD(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL 10 MINUTE)",
|
||||
"databricks": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
|
||||
"databricks": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
|
||||
"mysql": "SELECT DATE_ADD(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL 10 MINUTE)",
|
||||
"spark": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
|
||||
"spark": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
|
@ -761,12 +769,15 @@ LANGUAGE js AS
|
|||
"clickhouse": "SHA256(x)",
|
||||
"presto": "SHA256(x)",
|
||||
"trino": "SHA256(x)",
|
||||
"postgres": "SHA256(x)",
|
||||
},
|
||||
write={
|
||||
"bigquery": "SHA256(x)",
|
||||
"spark2": "SHA2(x, 256)",
|
||||
"clickhouse": "SHA256(x)",
|
||||
"postgres": "SHA256(x)",
|
||||
"presto": "SHA256(x)",
|
||||
"redshift": "SHA2(x, 256)",
|
||||
"trino": "SHA256(x)",
|
||||
},
|
||||
)
|
||||
|
|
|
@ -18,6 +18,13 @@ class TestDuckDB(Validator):
|
|||
"WITH _data AS (SELECT [STRUCT(1 AS a, 2 AS b), STRUCT(2 AS a, 3 AS b)] AS col) SELECT col.b FROM _data, UNNEST(_data.col) AS col WHERE col.a = 1",
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"SELECT straight_join",
|
||||
write={
|
||||
"duckdb": "SELECT straight_join",
|
||||
"mysql": "SELECT `straight_join`",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT CAST('2020-01-01 12:05:01' AS TIMESTAMP)",
|
||||
read={
|
||||
|
@ -278,6 +285,7 @@ class TestDuckDB(Validator):
|
|||
self.validate_identity("FROM tbl", "SELECT * FROM tbl")
|
||||
self.validate_identity("x -> '$.family'")
|
||||
self.validate_identity("CREATE TABLE color (name ENUM('RED', 'GREEN', 'BLUE'))")
|
||||
self.validate_identity("SELECT * FROM foo WHERE bar > $baz AND bla = $bob")
|
||||
self.validate_identity(
|
||||
"SELECT * FROM x LEFT JOIN UNNEST(y)", "SELECT * FROM x LEFT JOIN UNNEST(y) ON TRUE"
|
||||
)
|
||||
|
@ -1000,6 +1008,7 @@ class TestDuckDB(Validator):
|
|||
self.validate_identity("CAST(x AS CHAR)", "CAST(x AS TEXT)")
|
||||
self.validate_identity("CAST(x AS BPCHAR)", "CAST(x AS TEXT)")
|
||||
self.validate_identity("CAST(x AS STRING)", "CAST(x AS TEXT)")
|
||||
self.validate_identity("CAST(x AS VARCHAR)", "CAST(x AS TEXT)")
|
||||
self.validate_identity("CAST(x AS INT1)", "CAST(x AS TINYINT)")
|
||||
self.validate_identity("CAST(x AS FLOAT4)", "CAST(x AS REAL)")
|
||||
self.validate_identity("CAST(x AS FLOAT)", "CAST(x AS REAL)")
|
||||
|
@ -1027,6 +1036,13 @@ class TestDuckDB(Validator):
|
|||
"CAST([{'a': 1}] AS STRUCT(a BIGINT)[])",
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"CAST(x AS VARCHAR(5))",
|
||||
write={
|
||||
"duckdb": "CAST(x AS TEXT)",
|
||||
"postgres": "CAST(x AS TEXT)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CAST(x AS DECIMAL(38, 0))",
|
||||
read={
|
||||
|
|
|
@ -21,6 +21,9 @@ class TestMySQL(Validator):
|
|||
self.validate_identity("CREATE TABLE foo (a BIGINT, FULLTEXT INDEX (b))")
|
||||
self.validate_identity("CREATE TABLE foo (a BIGINT, SPATIAL INDEX (b))")
|
||||
self.validate_identity("ALTER TABLE t1 ADD COLUMN x INT, ALGORITHM=INPLACE, LOCK=EXCLUSIVE")
|
||||
self.validate_identity("ALTER TABLE t ADD INDEX `i` (`c`)")
|
||||
self.validate_identity("ALTER TABLE t ADD UNIQUE `i` (`c`)")
|
||||
self.validate_identity("ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT")
|
||||
self.validate_identity(
|
||||
"CREATE TABLE `oauth_consumer` (`key` VARCHAR(32) NOT NULL, UNIQUE `OAUTH_CONSUMER_KEY` (`key`))"
|
||||
)
|
||||
|
@ -60,6 +63,10 @@ class TestMySQL(Validator):
|
|||
self.validate_identity(
|
||||
"CREATE OR REPLACE VIEW my_view AS SELECT column1 AS `boo`, column2 AS `foo` FROM my_table WHERE column3 = 'some_value' UNION SELECT q.* FROM fruits_table, JSON_TABLE(Fruits, '$[*]' COLUMNS(id VARCHAR(255) PATH '$.$id', value VARCHAR(255) PATH '$.value')) AS q",
|
||||
)
|
||||
self.validate_identity(
|
||||
"ALTER TABLE t ADD KEY `i` (`c`)",
|
||||
"ALTER TABLE t ADD INDEX `i` (`c`)",
|
||||
)
|
||||
self.validate_identity(
|
||||
"CREATE TABLE `foo` (`id` char(36) NOT NULL DEFAULT (uuid()), PRIMARY KEY (`id`), UNIQUE KEY `id` (`id`))",
|
||||
"CREATE TABLE `foo` (`id` CHAR(36) NOT NULL DEFAULT (UUID()), PRIMARY KEY (`id`), UNIQUE `id` (`id`))",
|
||||
|
@ -76,9 +83,6 @@ class TestMySQL(Validator):
|
|||
"ALTER TABLE test_table ALTER COLUMN test_column SET DATA TYPE LONGTEXT",
|
||||
"ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT",
|
||||
)
|
||||
self.validate_identity(
|
||||
"ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT",
|
||||
)
|
||||
self.validate_identity(
|
||||
"CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP) DEFAULT CHARSET=utf8 ROW_FORMAT=DYNAMIC",
|
||||
"CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP()) DEFAULT CHARACTER SET=utf8 ROW_FORMAT=DYNAMIC",
|
||||
|
@ -113,6 +117,7 @@ class TestMySQL(Validator):
|
|||
)
|
||||
|
||||
def test_identity(self):
|
||||
self.validate_identity("SELECT e.* FROM e STRAIGHT_JOIN p ON e.x = p.y")
|
||||
self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1")
|
||||
self.validate_identity("SELECT DATE_FORMAT(NOW(), '%Y-%m-%d %H:%i:00.0000')")
|
||||
self.validate_identity("SELECT @var1 := 1, @var2")
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.errors import UnsupportedError
|
||||
from sqlglot import exp, UnsupportedError
|
||||
from sqlglot.dialects.oracle import eliminate_join_marks
|
||||
from tests.dialects.test_dialect import Validator
|
||||
|
||||
|
||||
|
@ -43,6 +43,7 @@ class TestOracle(Validator):
|
|||
self.validate_identity("SELECT * FROM table_name SAMPLE (25) s")
|
||||
self.validate_identity("SELECT COUNT(*) * 10 FROM orders SAMPLE (10) SEED (1)")
|
||||
self.validate_identity("SELECT * FROM V$SESSION")
|
||||
self.validate_identity("SELECT TO_DATE('January 15, 1989, 11:00 A.M.')")
|
||||
self.validate_identity(
|
||||
"SELECT last_name, employee_id, manager_id, LEVEL FROM employees START WITH employee_id = 100 CONNECT BY PRIOR employee_id = manager_id ORDER SIBLINGS BY last_name"
|
||||
)
|
||||
|
@ -249,7 +250,8 @@ class TestOracle(Validator):
|
|||
self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y (+) = e2.y")
|
||||
|
||||
self.validate_all(
|
||||
"SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)", write={"": UnsupportedError}
|
||||
"SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)",
|
||||
write={"": UnsupportedError},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)",
|
||||
|
@ -413,3 +415,65 @@ WHERE
|
|||
|
||||
for query in (f"{body}{start}{connect}", f"{body}{connect}{start}"):
|
||||
self.validate_identity(query, pretty, pretty=True)
|
||||
|
||||
def test_eliminate_join_marks(self):
|
||||
test_sql = [
|
||||
(
|
||||
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) > 5",
|
||||
"SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y > 5",
|
||||
),
|
||||
(
|
||||
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) IS NULL",
|
||||
"SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y IS NULL",
|
||||
),
|
||||
(
|
||||
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y IS NULL",
|
||||
"SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T2.y IS NULL",
|
||||
),
|
||||
(
|
||||
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T1.Z > 4",
|
||||
"SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T1.Z > 4",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table1, table2 WHERE table1.column = table2.column(+)",
|
||||
"SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table1, table2, table3, table4 WHERE table1.column = table2.column(+) and table2.column >= table3.column(+) and table1.column = table4.column(+)",
|
||||
"SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column LEFT JOIN table3 ON table2.column >= table3.column LEFT JOIN table4 ON table1.column = table4.column",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table1, table2, table3 WHERE table1.column = table2.column(+) and table2.column >= table3.column(+)",
|
||||
"SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column LEFT JOIN table3 ON table2.column >= table3.column",
|
||||
),
|
||||
(
|
||||
"SELECT table1.id, table2.cloumn1, table3.id FROM table1, table2, (SELECT tableInner1.id FROM tableInner1, tableInner2 WHERE tableInner1.id = tableInner2.id(+)) AS table3 WHERE table1.id = table2.id(+) and table1.id = table3.id(+)",
|
||||
"SELECT table1.id, table2.cloumn1, table3.id FROM table1 LEFT JOIN table2 ON table1.id = table2.id LEFT JOIN (SELECT tableInner1.id FROM tableInner1 LEFT JOIN tableInner2 ON tableInner1.id = tableInner2.id) table3 ON table1.id = table3.id",
|
||||
),
|
||||
# 2 join marks on one side of predicate
|
||||
(
|
||||
"SELECT * FROM table1, table2 WHERE table1.column = table2.column1(+) + table2.column2(+)",
|
||||
"SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column1 + table2.column2",
|
||||
),
|
||||
# join mark and expression
|
||||
(
|
||||
"SELECT * FROM table1, table2 WHERE table1.column = table2.column1(+) + 25",
|
||||
"SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column1 + 25",
|
||||
),
|
||||
]
|
||||
|
||||
for original, expected in test_sql:
|
||||
with self.subTest(original):
|
||||
self.assertEqual(
|
||||
eliminate_join_marks(self.parse_one(original)).sql(dialect=self.dialect),
|
||||
expected,
|
||||
)
|
||||
|
||||
def test_query_restrictions(self):
|
||||
for restriction in ("READ ONLY", "CHECK OPTION"):
|
||||
for constraint_name in (" CONSTRAINT name", ""):
|
||||
with self.subTest(f"Restriction: {restriction}"):
|
||||
self.validate_identity(f"SELECT * FROM tbl WITH {restriction}{constraint_name}")
|
||||
self.validate_identity(
|
||||
f"CREATE VIEW view AS SELECT * FROM tbl WITH {restriction}{constraint_name}"
|
||||
)
|
||||
|
|
|
@ -8,6 +8,7 @@ class TestPostgres(Validator):
|
|||
dialect = "postgres"
|
||||
|
||||
def test_postgres(self):
|
||||
self.validate_identity("SHA384(x)")
|
||||
self.validate_identity(
|
||||
'CREATE TABLE x (a TEXT COLLATE "de_DE")', "CREATE TABLE x (a TEXT COLLATE de_DE)"
|
||||
)
|
||||
|
@ -724,6 +725,28 @@ class TestPostgres(Validator):
|
|||
self.validate_identity("cast(a as FLOAT8)", "CAST(a AS DOUBLE PRECISION)")
|
||||
self.validate_identity("cast(a as FLOAT4)", "CAST(a AS REAL)")
|
||||
|
||||
self.validate_all(
|
||||
"1 / DIV(4, 2)",
|
||||
read={
|
||||
"postgres": "1 / DIV(4, 2)",
|
||||
},
|
||||
write={
|
||||
"sqlite": "1 / CAST(CAST(CAST(4 AS REAL) / 2 AS INTEGER) AS REAL)",
|
||||
"duckdb": "1 / CAST(4 // 2 AS DECIMAL)",
|
||||
"bigquery": "1 / CAST(DIV(4, 2) AS NUMERIC)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CAST(DIV(4, 2) AS DECIMAL(5, 3))",
|
||||
read={
|
||||
"duckdb": "CAST(4 // 2 AS DECIMAL(5, 3))",
|
||||
},
|
||||
write={
|
||||
"duckdb": "CAST(CAST(4 // 2 AS DECIMAL) AS DECIMAL(5, 3))",
|
||||
"postgres": "CAST(DIV(4, 2) AS DECIMAL(5, 3))",
|
||||
},
|
||||
)
|
||||
|
||||
def test_ddl(self):
|
||||
# Checks that user-defined types are parsed into DataType instead of Identifier
|
||||
self.parse_one("CREATE TABLE t (a udt)").this.expressions[0].args["kind"].assert_is(
|
||||
|
|
|
@ -564,6 +564,7 @@ class TestPresto(Validator):
|
|||
self.validate_all(
|
||||
f"{prefix}'Hello winter \\2603 !'",
|
||||
write={
|
||||
"oracle": "U'Hello winter \\2603 !'",
|
||||
"presto": "U&'Hello winter \\2603 !'",
|
||||
"snowflake": "'Hello winter \\u2603 !'",
|
||||
"spark": "'Hello winter \\u2603 !'",
|
||||
|
@ -572,6 +573,7 @@ class TestPresto(Validator):
|
|||
self.validate_all(
|
||||
f"{prefix}'Hello winter #2603 !' UESCAPE '#'",
|
||||
write={
|
||||
"oracle": "U'Hello winter \\2603 !'",
|
||||
"presto": "U&'Hello winter #2603 !' UESCAPE '#'",
|
||||
"snowflake": "'Hello winter \\u2603 !'",
|
||||
"spark": "'Hello winter \\u2603 !'",
|
||||
|
|
|
@ -281,6 +281,9 @@ class TestRedshift(Validator):
|
|||
"redshift": "SELECT DATEADD(MONTH, 18, '2008-02-28')",
|
||||
"snowflake": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS TIMESTAMP))",
|
||||
"tsql": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS DATETIME2))",
|
||||
"spark": "SELECT DATE_ADD(MONTH, 18, '2008-02-28')",
|
||||
"spark2": "SELECT ADD_MONTHS('2008-02-28', 18)",
|
||||
"databricks": "SELECT DATE_ADD(MONTH, 18, '2008-02-28')",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
|
@ -585,3 +588,9 @@ FROM (
|
|||
self.assertEqual(
|
||||
ast.sql("redshift"), "SELECT * FROM x AS a, a.b AS c, c.d.e AS f, f.g.h.i.j.k AS l"
|
||||
)
|
||||
|
||||
def test_join_markers(self):
|
||||
self.validate_identity(
|
||||
"select a.foo, b.bar, a.baz from a, b where a.baz = b.baz (+)",
|
||||
"SELECT a.foo, b.bar, a.baz FROM a, b WHERE a.baz = b.baz (+)",
|
||||
)
|
||||
|
|
|
@ -125,6 +125,10 @@ WHERE
|
|||
"SELECT a:from::STRING, a:from || ' test' ",
|
||||
"SELECT CAST(GET_PATH(a, 'from') AS TEXT), GET_PATH(a, 'from') || ' test'",
|
||||
)
|
||||
self.validate_identity(
|
||||
"SELECT a:select",
|
||||
"SELECT GET_PATH(a, 'select')",
|
||||
)
|
||||
self.validate_identity("x:from", "GET_PATH(x, 'from')")
|
||||
self.validate_identity(
|
||||
"value:values::string::int",
|
||||
|
@ -1196,16 +1200,16 @@ WHERE
|
|||
for constraint_prefix in ("WITH ", ""):
|
||||
with self.subTest(f"Constraint prefix: {constraint_prefix}"):
|
||||
self.validate_identity(
|
||||
f"CREATE TABLE t (id INT {constraint_prefix}MASKING POLICY p)",
|
||||
"CREATE TABLE t (id INT MASKING POLICY p)",
|
||||
f"CREATE TABLE t (id INT {constraint_prefix}MASKING POLICY p.q.r)",
|
||||
"CREATE TABLE t (id INT MASKING POLICY p.q.r)",
|
||||
)
|
||||
self.validate_identity(
|
||||
f"CREATE TABLE t (id INT {constraint_prefix}MASKING POLICY p USING (c1, c2, c3))",
|
||||
"CREATE TABLE t (id INT MASKING POLICY p USING (c1, c2, c3))",
|
||||
)
|
||||
self.validate_identity(
|
||||
f"CREATE TABLE t (id INT {constraint_prefix}PROJECTION POLICY p)",
|
||||
"CREATE TABLE t (id INT PROJECTION POLICY p)",
|
||||
f"CREATE TABLE t (id INT {constraint_prefix}PROJECTION POLICY p.q.r)",
|
||||
"CREATE TABLE t (id INT PROJECTION POLICY p.q.r)",
|
||||
)
|
||||
self.validate_identity(
|
||||
f"CREATE TABLE t (id INT {constraint_prefix}TAG (key1='value_1', key2='value_2'))",
|
||||
|
|
|
@ -563,6 +563,7 @@ TBLPROPERTIES (
|
|||
"SELECT DATE_ADD(my_date_column, 1)",
|
||||
write={
|
||||
"spark": "SELECT DATE_ADD(my_date_column, 1)",
|
||||
"spark2": "SELECT DATE_ADD(my_date_column, 1)",
|
||||
"bigquery": "SELECT DATE_ADD(CAST(CAST(my_date_column AS DATETIME) AS DATE), INTERVAL 1 DAY)",
|
||||
},
|
||||
)
|
||||
|
@ -675,6 +676,16 @@ TBLPROPERTIES (
|
|||
"spark": "SELECT ARRAY_SORT(x)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT DATE_ADD(MONTH, 20, col)",
|
||||
read={
|
||||
"spark": "SELECT TIMESTAMPADD(MONTH, 20, col)",
|
||||
},
|
||||
write={
|
||||
"spark": "SELECT DATE_ADD(MONTH, 20, col)",
|
||||
"databricks": "SELECT DATE_ADD(MONTH, 20, col)",
|
||||
},
|
||||
)
|
||||
|
||||
def test_bool_or(self):
|
||||
self.validate_all(
|
||||
|
|
|
@ -202,6 +202,7 @@ class TestSQLite(Validator):
|
|||
"CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)",
|
||||
read={
|
||||
"mysql": "CREATE TABLE z (a INT UNIQUE PRIMARY KEY AUTO_INCREMENT)",
|
||||
"postgres": "CREATE TABLE z (a INT GENERATED BY DEFAULT AS IDENTITY NOT NULL UNIQUE PRIMARY KEY)",
|
||||
},
|
||||
write={
|
||||
"sqlite": "CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)",
|
||||
|
|
|
@ -1,12 +1,18 @@
|
|||
from sqlglot import exp, parse, parse_one
|
||||
from sqlglot import exp, parse
|
||||
from tests.dialects.test_dialect import Validator
|
||||
from sqlglot.errors import ParseError
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
|
||||
class TestTSQL(Validator):
|
||||
dialect = "tsql"
|
||||
|
||||
def test_tsql(self):
|
||||
self.assertEqual(
|
||||
annotate_types(self.validate_identity("SELECT 1 WHERE EXISTS(SELECT 1)")).sql("tsql"),
|
||||
"SELECT 1 WHERE EXISTS(SELECT 1)",
|
||||
)
|
||||
|
||||
self.validate_identity("CREATE view a.b.c", "CREATE VIEW b.c")
|
||||
self.validate_identity("DROP view a.b.c", "DROP VIEW b.c")
|
||||
self.validate_identity("ROUND(x, 1, 0)")
|
||||
|
@ -217,9 +223,9 @@ class TestTSQL(Validator):
|
|||
"CREATE TABLE [db].[tbl] ([a] INTEGER)",
|
||||
)
|
||||
|
||||
projection = parse_one("SELECT a = 1", read="tsql").selects[0]
|
||||
projection.assert_is(exp.Alias)
|
||||
projection.args["alias"].assert_is(exp.Identifier)
|
||||
self.validate_identity("SELECT a = 1", "SELECT 1 AS a").selects[0].assert_is(
|
||||
exp.Alias
|
||||
).args["alias"].assert_is(exp.Identifier)
|
||||
|
||||
self.validate_all(
|
||||
"IF OBJECT_ID('tempdb.dbo.#TempTableName', 'U') IS NOT NULL DROP TABLE #TempTableName",
|
||||
|
@ -756,12 +762,9 @@ class TestTSQL(Validator):
|
|||
for view_attr in ("ENCRYPTION", "SCHEMABINDING", "VIEW_METADATA"):
|
||||
self.validate_identity(f"CREATE VIEW a.b WITH {view_attr} AS SELECT * FROM x")
|
||||
|
||||
expression = parse_one("ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B", dialect="tsql")
|
||||
self.assertIsInstance(expression, exp.AlterTable)
|
||||
self.assertIsInstance(expression.args["actions"][0], exp.Drop)
|
||||
self.assertEqual(
|
||||
expression.sql(dialect="tsql"), "ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B"
|
||||
)
|
||||
self.validate_identity("ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B").assert_is(
|
||||
exp.AlterTable
|
||||
).args["actions"][0].assert_is(exp.Drop)
|
||||
|
||||
for clustered_keyword in ("CLUSTERED", "NONCLUSTERED"):
|
||||
self.validate_identity(
|
||||
|
@ -795,10 +798,10 @@ class TestTSQL(Validator):
|
|||
)
|
||||
|
||||
self.validate_all(
|
||||
"CREATE TABLE [#temptest] (name VARCHAR)",
|
||||
"CREATE TABLE [#temptest] (name INTEGER)",
|
||||
read={
|
||||
"duckdb": "CREATE TEMPORARY TABLE 'temptest' (name VARCHAR)",
|
||||
"tsql": "CREATE TABLE [#temptest] (name VARCHAR)",
|
||||
"duckdb": "CREATE TEMPORARY TABLE 'temptest' (name INTEGER)",
|
||||
"tsql": "CREATE TABLE [#temptest] (name INTEGER)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
|
@ -1632,27 +1635,23 @@ WHERE
|
|||
)
|
||||
|
||||
def test_identifier_prefixes(self):
|
||||
expr = parse_one("#x", read="tsql")
|
||||
self.assertIsInstance(expr, exp.Column)
|
||||
self.assertIsInstance(expr.this, exp.Identifier)
|
||||
self.assertTrue(expr.this.args.get("temporary"))
|
||||
self.assertEqual(expr.sql("tsql"), "#x")
|
||||
self.assertTrue(
|
||||
self.validate_identity("#x")
|
||||
.assert_is(exp.Column)
|
||||
.this.assert_is(exp.Identifier)
|
||||
.args.get("temporary")
|
||||
)
|
||||
self.assertTrue(
|
||||
self.validate_identity("##x")
|
||||
.assert_is(exp.Column)
|
||||
.this.assert_is(exp.Identifier)
|
||||
.args.get("global")
|
||||
)
|
||||
|
||||
expr = parse_one("##x", read="tsql")
|
||||
self.assertIsInstance(expr, exp.Column)
|
||||
self.assertIsInstance(expr.this, exp.Identifier)
|
||||
self.assertTrue(expr.this.args.get("global"))
|
||||
self.assertEqual(expr.sql("tsql"), "##x")
|
||||
|
||||
expr = parse_one("@x", read="tsql")
|
||||
self.assertIsInstance(expr, exp.Parameter)
|
||||
self.assertIsInstance(expr.this, exp.Var)
|
||||
self.assertEqual(expr.sql("tsql"), "@x")
|
||||
|
||||
table = parse_one("select * from @x", read="tsql").args["from"].this
|
||||
self.assertIsInstance(table, exp.Table)
|
||||
self.assertIsInstance(table.this, exp.Parameter)
|
||||
self.assertIsInstance(table.this.this, exp.Var)
|
||||
self.validate_identity("@x").assert_is(exp.Parameter).this.assert_is(exp.Var)
|
||||
self.validate_identity("SELECT * FROM @x").args["from"].this.assert_is(
|
||||
exp.Table
|
||||
).this.assert_is(exp.Parameter).this.assert_is(exp.Var)
|
||||
|
||||
self.validate_all(
|
||||
"SELECT @x",
|
||||
|
@ -1663,8 +1662,6 @@ WHERE
|
|||
"tsql": "SELECT @x",
|
||||
},
|
||||
)
|
||||
|
||||
def test_temp_table(self):
|
||||
self.validate_all(
|
||||
"SELECT * FROM #mytemptable",
|
||||
write={
|
||||
|
|
1
tests/fixtures/identity.sql
vendored
1
tests/fixtures/identity.sql
vendored
|
@ -872,3 +872,4 @@ SELECT name
|
|||
SELECT copy
|
||||
SELECT rollup
|
||||
SELECT unnest
|
||||
SELECT * FROM a STRAIGHT_JOIN b
|
||||
|
|
3
tests/fixtures/optimizer/simplify.sql
vendored
3
tests/fixtures/optimizer/simplify.sql
vendored
|
@ -1047,6 +1047,9 @@ x < CAST('2021-01-02' AS DATE) AND x >= CAST('2021-01-01' AS DATE);
|
|||
TIMESTAMP_TRUNC(x, YEAR) = CAST(CAST('2021-01-01 01:02:03' AS DATE) AS DATETIME);
|
||||
x < CAST('2022-01-01 00:00:00' AS DATETIME) AND x >= CAST('2021-01-01 00:00:00' AS DATETIME);
|
||||
|
||||
DATE_TRUNC('day', CAST(x AS DATE)) <= CAST('2021-01-01 01:02:03' AS TIMESTAMP);
|
||||
CAST(x AS DATE) < CAST('2021-01-02 01:02:03' AS TIMESTAMP);
|
||||
|
||||
--------------------------------------
|
||||
-- EQUALITY
|
||||
--------------------------------------
|
||||
|
|
|
@ -29,7 +29,11 @@ def parse_and_optimize(func, sql, read_dialect, **kwargs):
|
|||
|
||||
def qualify_columns(expression, **kwargs):
|
||||
expression = optimizer.qualify.qualify(
|
||||
expression, infer_schema=True, validate_qualify_columns=False, identify=False, **kwargs
|
||||
expression,
|
||||
infer_schema=True,
|
||||
validate_qualify_columns=False,
|
||||
identify=False,
|
||||
**kwargs,
|
||||
)
|
||||
return expression
|
||||
|
||||
|
@ -111,7 +115,14 @@ class TestOptimizer(unittest.TestCase):
|
|||
}
|
||||
|
||||
def check_file(
|
||||
self, file, func, pretty=False, execute=False, set_dialect=False, only=None, **kwargs
|
||||
self,
|
||||
file,
|
||||
func,
|
||||
pretty=False,
|
||||
execute=False,
|
||||
set_dialect=False,
|
||||
only=None,
|
||||
**kwargs,
|
||||
):
|
||||
with ProcessPoolExecutor() as pool:
|
||||
results = {}
|
||||
|
@ -331,7 +342,11 @@ class TestOptimizer(unittest.TestCase):
|
|||
)
|
||||
|
||||
self.check_file(
|
||||
"qualify_columns", qualify_columns, execute=True, schema=self.schema, set_dialect=True
|
||||
"qualify_columns",
|
||||
qualify_columns,
|
||||
execute=True,
|
||||
schema=self.schema,
|
||||
set_dialect=True,
|
||||
)
|
||||
self.check_file(
|
||||
"qualify_columns_ddl", qualify_columns, schema=self.schema, set_dialect=True
|
||||
|
@ -343,7 +358,8 @@ class TestOptimizer(unittest.TestCase):
|
|||
|
||||
def test_pushdown_cte_alias_columns(self):
|
||||
self.check_file(
|
||||
"pushdown_cte_alias_columns", optimizer.qualify_columns.pushdown_cte_alias_columns
|
||||
"pushdown_cte_alias_columns",
|
||||
optimizer.qualify_columns.pushdown_cte_alias_columns,
|
||||
)
|
||||
|
||||
def test_qualify_columns__invalid(self):
|
||||
|
@ -405,7 +421,8 @@ class TestOptimizer(unittest.TestCase):
|
|||
self.assertEqual(optimizer.simplify.gen(query), optimizer.simplify.gen(query.copy()))
|
||||
|
||||
anon_unquoted_identifier = exp.Anonymous(
|
||||
this=exp.to_identifier("anonymous"), expressions=[exp.column("x"), exp.column("y")]
|
||||
this=exp.to_identifier("anonymous"),
|
||||
expressions=[exp.column("x"), exp.column("y")],
|
||||
)
|
||||
self.assertEqual(optimizer.simplify.gen(anon_unquoted_identifier), "ANONYMOUS(x,y)")
|
||||
|
||||
|
@ -416,7 +433,10 @@ class TestOptimizer(unittest.TestCase):
|
|||
anon_invalid = exp.Anonymous(this=5)
|
||||
optimizer.simplify.gen(anon_invalid)
|
||||
|
||||
self.assertIn("Anonymous.this expects a str or an Identifier, got 'int'.", str(e.exception))
|
||||
self.assertIn(
|
||||
"Anonymous.this expects a str or an Identifier, got 'int'.",
|
||||
str(e.exception),
|
||||
)
|
||||
|
||||
sql = parse_one(
|
||||
"""
|
||||
|
@ -906,7 +926,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
|
||||
# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
|
||||
for d, t in zip(
|
||||
cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]
|
||||
cte_select.find_all(exp.Subquery),
|
||||
[exp.DataType.Type.CHAR, exp.DataType.Type.TEXT],
|
||||
):
|
||||
self.assertEqual(d.this.expressions[0].this.type.this, t)
|
||||
|
||||
|
@ -1020,7 +1041,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
|
||||
for (func, col), target_type in tests.items():
|
||||
expression = annotate_types(
|
||||
parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), schema=schema
|
||||
parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"),
|
||||
schema=schema,
|
||||
)
|
||||
self.assertEqual(expression.expressions[0].type.this, target_type)
|
||||
|
||||
|
@ -1035,7 +1057,13 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
self.assertEqual(exp.DataType.Type.INT, expression.selects[1].type.this)
|
||||
|
||||
def test_nested_type_annotation(self):
|
||||
schema = {"order": {"customer_id": "bigint", "item_id": "bigint", "item_price": "numeric"}}
|
||||
schema = {
|
||||
"order": {
|
||||
"customer_id": "bigint",
|
||||
"item_id": "bigint",
|
||||
"item_price": "numeric",
|
||||
}
|
||||
}
|
||||
sql = """
|
||||
SELECT ARRAY_AGG(DISTINCT order.item_id) FILTER (WHERE order.item_price > 10) AS items,
|
||||
FROM order AS order
|
||||
|
@ -1057,7 +1085,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
|
||||
self.assertEqual(expression.selects[0].type.sql(dialect="bigquery"), "STRUCT<`f` STRING>")
|
||||
self.assertEqual(
|
||||
expression.selects[1].type.sql(dialect="bigquery"), "ARRAY<STRUCT<`f` STRING>>"
|
||||
expression.selects[1].type.sql(dialect="bigquery"),
|
||||
"ARRAY<STRUCT<`f` STRING>>",
|
||||
)
|
||||
|
||||
expression = annotate_types(
|
||||
|
@ -1206,7 +1235,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
|
||||
self.assertEqual(
|
||||
optimizer.optimize(
|
||||
parse_one("SELECT * FROM a"), schema=MappingSchema(schema, dialect="bigquery")
|
||||
parse_one("SELECT * FROM a"),
|
||||
schema=MappingSchema(schema, dialect="bigquery"),
|
||||
),
|
||||
parse_one('SELECT "a"."a" AS "a", "a"."b" AS "b" FROM "a" AS "a"'),
|
||||
)
|
||||
|
|
|
@ -106,6 +106,7 @@ class TestParser(unittest.TestCase):
|
|||
expr = parse_one("SELECT foo IN UNNEST(bla) AS bar")
|
||||
self.assertIsInstance(expr.selects[0], exp.Alias)
|
||||
self.assertEqual(expr.selects[0].output_name, "bar")
|
||||
self.assertIsNotNone(parse_one("select unnest(x)").find(exp.Unnest))
|
||||
|
||||
def test_unary_plus(self):
|
||||
self.assertEqual(parse_one("+15"), exp.Literal.number(15))
|
||||
|
@ -880,10 +881,12 @@ class TestParser(unittest.TestCase):
|
|||
self.assertIsInstance(parse_one("a IS DISTINCT FROM b OR c IS DISTINCT FROM d"), exp.Or)
|
||||
|
||||
def test_trailing_comments(self):
|
||||
expressions = parse("""
|
||||
expressions = parse(
|
||||
"""
|
||||
select * from x;
|
||||
-- my comment
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
";\n".join(e.sql() for e in expressions), "SELECT * FROM x;\n/* my comment */"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue