1
0
Fork 0

Adding upstream version 22.2.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:29:15 +01:00
parent b01402dc30
commit f1aa09959c
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
148 changed files with 68457 additions and 63176 deletions

View file

@ -280,9 +280,9 @@ class TestFunctions(unittest.TestCase):
def test_signum(self):
col_str = SF.signum("cola")
self.assertEqual("SIGNUM(cola)", col_str.sql())
self.assertEqual("SIGN(cola)", col_str.sql())
col = SF.signum(SF.col("cola"))
self.assertEqual("SIGNUM(cola)", col.sql())
self.assertEqual("SIGN(cola)", col.sql())
def test_sin(self):
col_str = SF.sin("cola")

View file

@ -5,7 +5,9 @@ from sqlglot import (
ParseError,
TokenError,
UnsupportedError,
exp,
parse,
parse_one,
transpile,
)
from sqlglot.helper import logger as helper_logger
@ -18,6 +20,51 @@ class TestBigQuery(Validator):
maxDiff = None
def test_bigquery(self):
self.validate_all(
"SELECT STRUCT(1, 2, 3), STRUCT(), STRUCT('abc'), STRUCT(1, t.str_col), STRUCT(1 as a, 'abc' AS b), STRUCT(str_col AS abc)",
write={
"bigquery": "SELECT STRUCT(1, 2, 3), STRUCT(), STRUCT('abc'), STRUCT(1, t.str_col), STRUCT(1 AS a, 'abc' AS b), STRUCT(str_col AS abc)",
"duckdb": "SELECT {'_0': 1, '_1': 2, '_2': 3}, {}, {'_0': 'abc'}, {'_0': 1, '_1': t.str_col}, {'a': 1, 'b': 'abc'}, {'abc': str_col}",
"hive": "SELECT STRUCT(1, 2, 3), STRUCT(), STRUCT('abc'), STRUCT(1, t.str_col), STRUCT(1, 'abc'), STRUCT(str_col)",
"spark2": "SELECT STRUCT(1, 2, 3), STRUCT(), STRUCT('abc'), STRUCT(1, t.str_col), STRUCT(1 AS a, 'abc' AS b), STRUCT(str_col AS abc)",
"spark": "SELECT STRUCT(1, 2, 3), STRUCT(), STRUCT('abc'), STRUCT(1, t.str_col), STRUCT(1 AS a, 'abc' AS b), STRUCT(str_col AS abc)",
"snowflake": "SELECT OBJECT_CONSTRUCT('_0', 1, '_1', 2, '_2', 3), OBJECT_CONSTRUCT(), OBJECT_CONSTRUCT('_0', 'abc'), OBJECT_CONSTRUCT('_0', 1, '_1', t.str_col), OBJECT_CONSTRUCT('a', 1, 'b', 'abc'), OBJECT_CONSTRUCT('abc', str_col)",
# fallback to unnamed without type inference
"trino": "SELECT ROW(1, 2, 3), ROW(), ROW('abc'), ROW(1, t.str_col), CAST(ROW(1, 'abc') AS ROW(a INTEGER, b VARCHAR)), ROW(str_col)",
},
)
self.validate_all(
"PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E6S%z', x)",
write={
"bigquery": "PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E6S%z', x)",
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S.%f%z')",
},
)
table = parse_one("x-0._y.z", dialect="bigquery", into=exp.Table)
self.assertEqual(table.catalog, "x-0")
self.assertEqual(table.db, "_y")
self.assertEqual(table.name, "z")
table = parse_one("x-0._y", dialect="bigquery", into=exp.Table)
self.assertEqual(table.db, "x-0")
self.assertEqual(table.name, "_y")
self.validate_identity("SELECT * FROM x-0.y")
self.assertEqual(exp.to_table("`x.y.z`", dialect="bigquery").sql(), '"x"."y"."z"')
self.assertEqual(exp.to_table("`x.y.z`", dialect="bigquery").sql("bigquery"), "`x.y.z`")
self.assertEqual(exp.to_table("`x`.`y`", dialect="bigquery").sql("bigquery"), "`x`.`y`")
select_with_quoted_udf = self.validate_identity("SELECT `p.d.UdF`(data) FROM `p.d.t`")
self.assertEqual(select_with_quoted_udf.selects[0].name, "p.d.UdF")
self.validate_identity("SELECT `p.d.UdF`(data).* FROM `p.d.t`")
self.validate_identity("SELECT * FROM `my-project.my-dataset.my-table`")
self.validate_identity("CREATE OR REPLACE TABLE `a.b.c` CLONE `a.b.d`")
self.validate_identity("SELECT x, 1 AS y GROUP BY 1 ORDER BY 1")
self.validate_identity("SELECT * FROM x.*")
self.validate_identity("SELECT * FROM x.y*")
self.validate_identity("CASE A WHEN 90 THEN 'red' WHEN 50 THEN 'blue' ELSE 'green' END")
self.validate_identity("CREATE SCHEMA x DEFAULT COLLATE 'en'")
self.validate_identity("CREATE TABLE x (y INT64) DEFAULT COLLATE 'en'")
self.validate_identity("PARSE_JSON('{}', wide_number_mode => 'exact')")
@ -90,6 +137,16 @@ class TestBigQuery(Validator):
self.validate_identity("LOG(n, b)")
self.validate_identity("SELECT COUNT(x RESPECT NULLS)")
self.validate_identity("SELECT LAST_VALUE(x IGNORE NULLS) OVER y AS x")
self.validate_identity("SELECT ARRAY((SELECT AS STRUCT 1 AS a, 2 AS b))")
self.validate_identity("SELECT ARRAY((SELECT AS STRUCT 1 AS a, 2 AS b) LIMIT 10)")
self.validate_identity("CAST(x AS CHAR)", "CAST(x AS STRING)")
self.validate_identity("CAST(x AS NCHAR)", "CAST(x AS STRING)")
self.validate_identity("CAST(x AS NVARCHAR)", "CAST(x AS STRING)")
self.validate_identity("CAST(x AS TIMESTAMPTZ)", "CAST(x AS TIMESTAMP)")
self.validate_identity("CAST(x AS RECORD)", "CAST(x AS STRUCT)")
self.validate_identity(
"SELECT * FROM `SOME_PROJECT_ID.SOME_DATASET_ID.INFORMATION_SCHEMA.SOME_VIEW`"
)
self.validate_identity(
"SELECT * FROM test QUALIFY a IS DISTINCT FROM b WINDOW c AS (PARTITION BY d)"
)
@ -120,6 +177,10 @@ class TestBigQuery(Validator):
self.validate_identity(
"""SELECT JSON_EXTRACT_SCALAR('5')""", """SELECT JSON_EXTRACT_SCALAR('5', '$')"""
)
self.validate_identity(
"SELECT ARRAY(SELECT AS STRUCT 1 a, 2 b)",
"SELECT ARRAY(SELECT AS STRUCT 1 AS a, 2 AS b)",
)
self.validate_identity(
"select array_contains([1, 2, 3], 1)",
"SELECT EXISTS(SELECT 1 FROM UNNEST([1, 2, 3]) AS _col WHERE _col = 1)",
@ -168,10 +229,6 @@ class TestBigQuery(Validator):
"""SELECT JSON '"foo"' AS json_data""",
"""SELECT PARSE_JSON('"foo"') AS json_data""",
)
self.validate_identity(
"CREATE OR REPLACE TABLE `a.b.c` CLONE `a.b.d`",
"CREATE OR REPLACE TABLE a.b.c CLONE a.b.d",
)
self.validate_identity(
"SELECT * FROM UNNEST(x) WITH OFFSET EXCEPT DISTINCT SELECT * FROM UNNEST(y) WITH OFFSET",
"SELECT * FROM UNNEST(x) WITH OFFSET AS offset EXCEPT DISTINCT SELECT * FROM UNNEST(y) WITH OFFSET AS offset",
@ -185,6 +242,39 @@ class TestBigQuery(Validator):
r"REGEXP_EXTRACT(svc_plugin_output, '\\\\\\((.*)')",
)
self.validate_all(
"PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E6S%z', x)",
write={
"bigquery": "PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E6S%z', x)",
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S.%f%z')",
},
)
self.validate_all(
"SELECT results FROM Coordinates, Coordinates.position AS results",
write={
"bigquery": "SELECT results FROM Coordinates, UNNEST(Coordinates.position) AS results",
"presto": "SELECT results FROM Coordinates, UNNEST(Coordinates.position) AS _t(results)",
},
)
self.validate_all(
"SELECT results FROM Coordinates, `Coordinates.position` AS results",
write={
"bigquery": "SELECT results FROM Coordinates, `Coordinates.position` AS results",
"presto": 'SELECT results FROM Coordinates, "Coordinates"."position" AS results',
},
)
self.validate_all(
"SELECT results FROM Coordinates AS c, UNNEST(c.position) AS results",
read={
"presto": "SELECT results FROM Coordinates AS c, UNNEST(c.position) AS _t(results)",
"redshift": "SELECT results FROM Coordinates AS c, c.position AS results",
},
write={
"bigquery": "SELECT results FROM Coordinates AS c, UNNEST(c.position) AS results",
"presto": "SELECT results FROM Coordinates AS c, UNNEST(c.position) AS _t(results)",
"redshift": "SELECT results FROM Coordinates AS c, c.position AS results",
},
)
self.validate_all(
"TIMESTAMP(x)",
write={
@ -434,8 +524,8 @@ class TestBigQuery(Validator):
self.validate_all(
"CREATE OR REPLACE TABLE `a.b.c` COPY `a.b.d`",
write={
"bigquery": "CREATE OR REPLACE TABLE a.b.c COPY a.b.d",
"snowflake": "CREATE OR REPLACE TABLE a.b.c CLONE a.b.d",
"bigquery": "CREATE OR REPLACE TABLE `a.b.c` COPY `a.b.d`",
"snowflake": 'CREATE OR REPLACE TABLE "a"."b"."c" CLONE "a"."b"."d"',
},
)
(
@ -475,11 +565,6 @@ class TestBigQuery(Validator):
),
)
self.validate_all("LEAST(x, y)", read={"sqlite": "MIN(x, y)"})
self.validate_all("CAST(x AS CHAR)", write={"bigquery": "CAST(x AS STRING)"})
self.validate_all("CAST(x AS NCHAR)", write={"bigquery": "CAST(x AS STRING)"})
self.validate_all("CAST(x AS NVARCHAR)", write={"bigquery": "CAST(x AS STRING)"})
self.validate_all("CAST(x AS TIMESTAMPTZ)", write={"bigquery": "CAST(x AS TIMESTAMP)"})
self.validate_all("CAST(x AS RECORD)", write={"bigquery": "CAST(x AS STRUCT)"})
self.validate_all(
'SELECT TIMESTAMP_ADD(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE)',
write={
@ -566,11 +651,11 @@ class TestBigQuery(Validator):
read={"spark": "select posexplode_outer([])"},
)
self.validate_all(
"SELECT AS STRUCT ARRAY(SELECT AS STRUCT b FROM x) AS y FROM z",
"SELECT AS STRUCT ARRAY(SELECT AS STRUCT 1 AS b FROM x) AS y FROM z",
write={
"": "SELECT AS STRUCT ARRAY(SELECT AS STRUCT b FROM x) AS y FROM z",
"bigquery": "SELECT AS STRUCT ARRAY(SELECT AS STRUCT b FROM x) AS y FROM z",
"duckdb": "SELECT {'y': ARRAY(SELECT {'b': b} FROM x)} FROM z",
"": "SELECT AS STRUCT ARRAY(SELECT AS STRUCT 1 AS b FROM x) AS y FROM z",
"bigquery": "SELECT AS STRUCT ARRAY(SELECT AS STRUCT 1 AS b FROM x) AS y FROM z",
"duckdb": "SELECT {'y': ARRAY(SELECT {'b': 1} FROM x)} FROM z",
},
)
self.validate_all(
@ -585,25 +670,9 @@ class TestBigQuery(Validator):
"bigquery": "PARSE_TIMESTAMP('%Y.%m.%d %I:%M:%S%z', x)",
},
)
self.validate_all(
self.validate_identity(
"CREATE TEMP TABLE foo AS SELECT 1",
write={"bigquery": "CREATE TEMPORARY TABLE foo AS SELECT 1"},
)
self.validate_all(
"SELECT * FROM `SOME_PROJECT_ID.SOME_DATASET_ID.INFORMATION_SCHEMA.SOME_VIEW`",
write={
"bigquery": "SELECT * FROM SOME_PROJECT_ID.SOME_DATASET_ID.INFORMATION_SCHEMA.SOME_VIEW",
},
)
self.validate_all(
"SELECT * FROM `my-project.my-dataset.my-table`",
write={"bigquery": "SELECT * FROM `my-project`.`my-dataset`.`my-table`"},
)
self.validate_all(
"SELECT ARRAY(SELECT AS STRUCT 1 a, 2 b)",
write={
"bigquery": "SELECT ARRAY(SELECT AS STRUCT 1 AS a, 2 AS b)",
},
"CREATE TEMPORARY TABLE foo AS SELECT 1",
)
self.validate_all(
"REGEXP_CONTAINS('foo', '.*')",
@ -1088,6 +1157,35 @@ WHERE
self.assertIn("unsupported syntax", cm.output[0])
with self.assertLogs(helper_logger):
statements = parse(
"""
BEGIN
DECLARE MY_VAR INT64 DEFAULT 1;
SET MY_VAR = (SELECT 0);
IF MY_VAR = 1 THEN SELECT 'TRUE';
ELSEIF MY_VAR = 0 THEN SELECT 'FALSE';
ELSE SELECT 'NULL';
END IF;
END
""",
read="bigquery",
)
expected_statements = (
"BEGIN DECLARE MY_VAR INT64 DEFAULT 1",
"SET MY_VAR = (SELECT 0)",
"IF MY_VAR = 1 THEN SELECT 'TRUE'",
"ELSEIF MY_VAR = 0 THEN SELECT 'FALSE'",
"ELSE SELECT 'NULL'",
"END IF",
"END",
)
for actual, expected in zip(statements, expected_statements):
self.assertEqual(actual.sql(dialect="bigquery"), expected)
with self.assertLogs(helper_logger) as cm:
self.validate_identity(
"SELECT * FROM t AS t(c1, c2)",

View file

@ -6,6 +6,21 @@ class TestClickhouse(Validator):
dialect = "clickhouse"
def test_clickhouse(self):
self.validate_all(
"SELECT * FROM x PREWHERE y = 1 WHERE z = 2",
write={
"": "SELECT * FROM x WHERE z = 2",
"clickhouse": "SELECT * FROM x PREWHERE y = 1 WHERE z = 2",
},
)
self.validate_all(
"SELECT * FROM x AS prewhere",
read={
"clickhouse": "SELECT * FROM x AS prewhere",
"duckdb": "SELECT * FROM x prewhere",
},
)
self.validate_identity("SELECT * FROM x LIMIT 1 UNION ALL SELECT * FROM y")
string_types = [
@ -77,6 +92,7 @@ class TestClickhouse(Validator):
self.validate_identity("""SELECT JSONExtractString('{"x": {"y": 1}}', 'x', 'y')""")
self.validate_identity("SELECT * FROM table LIMIT 1 BY a, b")
self.validate_identity("SELECT * FROM table LIMIT 2 OFFSET 1 BY a, b")
self.validate_identity(
"SELECT $1$foo$1$",
"SELECT 'foo'",
@ -134,6 +150,9 @@ class TestClickhouse(Validator):
self.validate_identity(
"CREATE MATERIALIZED VIEW test_view (id UInt8) TO db.table1 AS SELECT * FROM test_data"
)
self.validate_identity("TRUNCATE TABLE t1 ON CLUSTER test_cluster")
self.validate_identity("TRUNCATE DATABASE db")
self.validate_identity("TRUNCATE DATABASE db ON CLUSTER test_cluster")
self.validate_all(
"SELECT arrayJoin([1,2,3])",
@ -373,6 +392,7 @@ class TestClickhouse(Validator):
def test_cte(self):
self.validate_identity("WITH 'x' AS foo SELECT foo")
self.validate_identity("WITH ['c'] AS field_names SELECT field_names")
self.validate_identity("WITH SUM(bytes) AS foo SELECT foo FROM system.parts")
self.validate_identity("WITH (SELECT foo) AS bar SELECT bar + 5")
self.validate_identity("WITH test1 AS (SELECT i + 1, j + 1 FROM test1) SELECT * FROM test1")

View file

@ -38,6 +38,11 @@ class TestDatabricks(Validator):
"CREATE FUNCTION add_one(x INT) RETURNS INT LANGUAGE PYTHON AS $FOO$def add_one(x):\n return x+1$FOO$"
)
self.validate_identity("TRUNCATE TABLE t1 PARTITION(age = 10, name = 'test', address)")
self.validate_identity(
"TRUNCATE TABLE t1 PARTITION(age = 10, name = 'test', city LIKE 'LA')"
)
self.validate_all(
"CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(y)))",
write={

View file

@ -1108,6 +1108,11 @@ class TestDialect(Validator):
)
def test_order_by(self):
self.validate_identity(
"SELECT c FROM t ORDER BY a, b,",
"SELECT c FROM t ORDER BY a, b",
)
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
write={
@ -1777,7 +1782,7 @@ class TestDialect(Validator):
"CREATE TABLE t (c CHAR, nc NCHAR, v1 VARCHAR, v2 VARCHAR2, nv NVARCHAR, nv2 NVARCHAR2)",
write={
"duckdb": "CREATE TABLE t (c TEXT, nc TEXT, v1 TEXT, v2 TEXT, nv TEXT, nv2 TEXT)",
"hive": "CREATE TABLE t (c CHAR, nc CHAR, v1 STRING, v2 STRING, nv STRING, nv2 STRING)",
"hive": "CREATE TABLE t (c STRING, nc STRING, v1 STRING, v2 STRING, nv STRING, nv2 STRING)",
"oracle": "CREATE TABLE t (c CHAR, nc NCHAR, v1 VARCHAR2, v2 VARCHAR2, nv NVARCHAR2, nv2 NVARCHAR2)",
"postgres": "CREATE TABLE t (c CHAR, nc CHAR, v1 VARCHAR, v2 VARCHAR, nv VARCHAR, nv2 VARCHAR)",
"sqlite": "CREATE TABLE t (c TEXT, nc TEXT, v1 TEXT, v2 TEXT, nv TEXT, nv2 TEXT)",
@ -2301,3 +2306,9 @@ SELECT
"tsql": UnsupportedError,
},
)
def test_truncate(self):
self.validate_identity("TRUNCATE TABLE table")
self.validate_identity("TRUNCATE TABLE db.schema.test")
self.validate_identity("TRUNCATE TABLE IF EXISTS db.schema.test")
self.validate_identity("TRUNCATE TABLE t1, t2, t3")

View file

@ -26,6 +26,16 @@ class TestDoris(Validator):
"doris": "SELECT ARRAY_SUM(x -> x * x, ARRAY(2, 3))",
},
)
self.validate_all(
"MONTHS_ADD(d, n)",
read={
"oracle": "ADD_MONTHS(d, n)",
},
write={
"doris": "MONTHS_ADD(d, n)",
"oracle": "ADD_MONTHS(d, n)",
},
)
def test_identity(self):
self.validate_identity("COALECSE(a, b, c, d)")

View file

@ -7,9 +7,14 @@ class TestDuckDB(Validator):
dialect = "duckdb"
def test_duckdb(self):
struct_pack = parse_one('STRUCT_PACK("a b" := 1)', read="duckdb")
self.assertIsInstance(struct_pack.expressions[0].this, exp.Identifier)
self.assertEqual(struct_pack.sql(dialect="duckdb"), "{'a b': 1}")
self.validate_all(
'STRUCT_PACK("a b" := 1)',
write={
"duckdb": "{'a b': 1}",
"spark": "STRUCT(1 AS `a b`)",
"snowflake": "OBJECT_CONSTRUCT('a b', 1)",
},
)
self.validate_all(
"SELECT SUM(X) OVER (ORDER BY x)",
@ -52,8 +57,21 @@ class TestDuckDB(Validator):
exp.select("*").from_("t").offset(exp.select("5").subquery()).sql(dialect="duckdb"),
)
for struct_value in ("{'a': 1}", "struct_pack(a := 1)"):
self.validate_all(struct_value, write={"presto": UnsupportedError})
self.validate_all(
"{'a': 1, 'b': '2'}", write={"presto": "CAST(ROW(1, '2') AS ROW(a INTEGER, b VARCHAR))"}
)
self.validate_all(
"struct_pack(a := 1, b := 2)",
write={"presto": "CAST(ROW(1, 2) AS ROW(a INTEGER, b INTEGER))"},
)
self.validate_all(
"struct_pack(a := 1, b := x)",
write={
"duckdb": "{'a': 1, 'b': x}",
"presto": UnsupportedError,
},
)
for join_type in ("SEMI", "ANTI"):
exists = "EXISTS" if join_type == "SEMI" else "NOT EXISTS"
@ -171,7 +189,6 @@ class TestDuckDB(Validator):
},
)
self.validate_identity("SELECT i FROM RANGE(5) AS _(i) ORDER BY i ASC")
self.validate_identity("INSERT INTO x BY NAME SELECT 1 AS y")
self.validate_identity("SELECT 1 AS x UNION ALL BY NAME SELECT 2 AS x")
self.validate_identity("SELECT SUM(x) FILTER (x = 1)", "SELECT SUM(x) FILTER(WHERE x = 1)")
@ -209,6 +226,10 @@ class TestDuckDB(Validator):
self.validate_identity("FROM (FROM tbl)", "SELECT * FROM (SELECT * FROM tbl)")
self.validate_identity("FROM tbl", "SELECT * FROM tbl")
self.validate_identity("x -> '$.family'")
self.validate_identity("CREATE TABLE color (name ENUM('RED', 'GREEN', 'BLUE'))")
self.validate_identity(
"SELECT * FROM x LEFT JOIN UNNEST(y)", "SELECT * FROM x LEFT JOIN UNNEST(y) ON TRUE"
)
self.validate_identity(
"""SELECT '{"foo": [1, 2, 3]}' -> 'foo' -> 0""",
"""SELECT '{"foo": [1, 2, 3]}' -> '$.foo' -> '$[0]'""",
@ -623,6 +644,27 @@ class TestDuckDB(Validator):
},
)
self.validate_identity("SELECT * FROM RANGE(1, 5, 10)")
self.validate_identity("SELECT * FROM GENERATE_SERIES(2, 13, 4)")
self.validate_all(
"WITH t AS (SELECT i, i * i * i * i * i AS i5 FROM RANGE(1, 5) t(i)) SELECT * FROM t",
write={
"duckdb": "WITH t AS (SELECT i, i * i * i * i * i AS i5 FROM RANGE(1, 5) AS t(i)) SELECT * FROM t",
"sqlite": "WITH t AS (SELECT i, i * i * i * i * i AS i5 FROM (SELECT value AS i FROM GENERATE_SERIES(1, 5)) AS t) SELECT * FROM t",
},
)
self.validate_identity(
"""SELECT i FROM RANGE(5) AS _(i) ORDER BY i ASC""",
"""SELECT i FROM RANGE(0, 5) AS _(i) ORDER BY i ASC""",
)
self.validate_identity(
"""SELECT i FROM GENERATE_SERIES(12) AS _(i) ORDER BY i ASC""",
"""SELECT i FROM GENERATE_SERIES(0, 12) AS _(i) ORDER BY i ASC""",
)
def test_array_index(self):
with self.assertLogs(helper_logger) as cm:
self.validate_all(
@ -994,3 +1036,10 @@ class TestDuckDB(Validator):
read={"bigquery": "IS_INF(x)"},
write={"bigquery": "IS_INF(x)", "duckdb": "ISINF(x)"},
)
def test_parameter_token(self):
self.validate_all(
"SELECT $foo",
read={"bigquery": "SELECT @foo"},
write={"bigquery": "SELECT @foo", "duckdb": "SELECT $foo"},
)

View file

@ -440,6 +440,9 @@ class TestHive(Validator):
self.validate_identity(
"SELECT key, value, GROUPING__ID, COUNT(*) FROM T1 GROUP BY key, value WITH ROLLUP"
)
self.validate_identity(
"TRUNCATE TABLE t1 PARTITION(age = 10, name = 'test', address = 'abc')"
)
self.validate_all(
"SELECT ${hiveconf:some_var}",
@ -611,12 +614,6 @@ class TestHive(Validator):
"spark": "GET_JSON_OBJECT(x, '$.name')",
},
)
self.validate_all(
"STRUCT(a = b, c = d)",
read={
"snowflake": "OBJECT_CONSTRUCT(a, b, c, d)",
},
)
self.validate_all(
"MAP(a, b, c, d)",
read={

View file

@ -29,6 +29,7 @@ class TestMySQL(Validator):
self.validate_identity("CREATE TABLE foo (a BIGINT, INDEX USING BTREE (b))")
self.validate_identity("CREATE TABLE foo (a BIGINT, FULLTEXT INDEX (b))")
self.validate_identity("CREATE TABLE foo (a BIGINT, SPATIAL INDEX (b))")
self.validate_identity("ALTER TABLE t1 ADD COLUMN x INT, ALGORITHM=INPLACE, LOCK=EXCLUSIVE")
self.validate_identity(
"CREATE TABLE `oauth_consumer` (`key` VARCHAR(32) NOT NULL, UNIQUE `OAUTH_CONSUMER_KEY` (`key`))"
)
@ -68,6 +69,26 @@ class TestMySQL(Validator):
self.validate_identity(
"CREATE OR REPLACE VIEW my_view AS SELECT column1 AS `boo`, column2 AS `foo` FROM my_table WHERE column3 = 'some_value' UNION SELECT q.* FROM fruits_table, JSON_TABLE(Fruits, '$[*]' COLUMNS(id VARCHAR(255) PATH '$.$id', value VARCHAR(255) PATH '$.value')) AS q",
)
self.validate_identity(
"CREATE TABLE `foo` (`id` char(36) NOT NULL DEFAULT (uuid()), PRIMARY KEY (`id`), UNIQUE KEY `id` (`id`))",
"CREATE TABLE `foo` (`id` CHAR(36) NOT NULL DEFAULT (UUID()), PRIMARY KEY (`id`), UNIQUE `id` (`id`))",
)
self.validate_identity(
"CREATE TABLE IF NOT EXISTS industry_info (a BIGINT(20) NOT NULL AUTO_INCREMENT, b BIGINT(20) NOT NULL, c VARCHAR(1000), PRIMARY KEY (a), UNIQUE KEY d (b), KEY e (b))",
"CREATE TABLE IF NOT EXISTS industry_info (a BIGINT(20) NOT NULL AUTO_INCREMENT, b BIGINT(20) NOT NULL, c VARCHAR(1000), PRIMARY KEY (a), UNIQUE d (b), INDEX e (b))",
)
self.validate_identity(
"CREATE TABLE test (ts TIMESTAMP, ts_tz TIMESTAMPTZ, ts_ltz TIMESTAMPLTZ)",
"CREATE TABLE test (ts DATETIME, ts_tz TIMESTAMP, ts_ltz TIMESTAMP)",
)
self.validate_identity(
"ALTER TABLE test_table ALTER COLUMN test_column SET DATA TYPE LONGTEXT",
"ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT",
)
self.validate_identity(
"CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP) DEFAULT CHARSET=utf8 ROW_FORMAT=DYNAMIC",
"CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP()) DEFAULT CHARACTER SET=utf8 ROW_FORMAT=DYNAMIC",
)
self.validate_all(
"CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'",
@ -78,12 +99,6 @@ class TestMySQL(Validator):
"sqlite": "CREATE TABLE z (a INTEGER)",
},
)
self.validate_all(
"CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP) DEFAULT CHARSET=utf8 ROW_FORMAT=DYNAMIC",
write={
"mysql": "CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP()) DEFAULT CHARACTER SET=utf8 ROW_FORMAT=DYNAMIC",
},
)
self.validate_all(
"CREATE TABLE x (id int not null auto_increment, primary key (id))",
write={
@ -96,33 +111,9 @@ class TestMySQL(Validator):
"sqlite": "CREATE TABLE x (id INTEGER NOT NULL)",
},
)
self.validate_all(
"CREATE TABLE `foo` (`id` char(36) NOT NULL DEFAULT (uuid()), PRIMARY KEY (`id`), UNIQUE KEY `id` (`id`))",
write={
"mysql": "CREATE TABLE `foo` (`id` CHAR(36) NOT NULL DEFAULT (UUID()), PRIMARY KEY (`id`), UNIQUE `id` (`id`))",
},
)
self.validate_all(
"CREATE TABLE IF NOT EXISTS industry_info (a BIGINT(20) NOT NULL AUTO_INCREMENT, b BIGINT(20) NOT NULL, c VARCHAR(1000), PRIMARY KEY (a), UNIQUE KEY d (b), KEY e (b))",
write={
"mysql": "CREATE TABLE IF NOT EXISTS industry_info (a BIGINT(20) NOT NULL AUTO_INCREMENT, b BIGINT(20) NOT NULL, c VARCHAR(1000), PRIMARY KEY (a), UNIQUE d (b), INDEX e (b))",
},
)
self.validate_all(
"CREATE TABLE test (ts TIMESTAMP, ts_tz TIMESTAMPTZ, ts_ltz TIMESTAMPLTZ)",
write={
"mysql": "CREATE TABLE test (ts DATETIME, ts_tz TIMESTAMP, ts_ltz TIMESTAMP)",
},
)
self.validate_all(
"ALTER TABLE test_table ALTER COLUMN test_column SET DATA TYPE LONGTEXT",
write={
"mysql": "ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT",
},
)
self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1")
def test_identity(self):
self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1")
self.validate_identity("SELECT DATE_FORMAT(NOW(), '%Y-%m-%d %H:%i:00.0000')")
self.validate_identity("SELECT @var1 := 1, @var2")
self.validate_identity("UNLOCK TABLES")

View file

@ -1,4 +1,4 @@
from sqlglot import exp, parse_one
from sqlglot import exp
from sqlglot.errors import UnsupportedError
from tests.dialects.test_dialect import Validator
@ -7,11 +7,18 @@ class TestOracle(Validator):
dialect = "oracle"
def test_oracle(self):
self.validate_identity("REGEXP_REPLACE('source', 'search')")
parse_one("ALTER TABLE tbl_name DROP FOREIGN KEY fk_symbol", dialect="oracle").assert_is(
exp.AlterTable
self.validate_all(
"SELECT CONNECT_BY_ROOT x y",
write={
"": "SELECT CONNECT_BY_ROOT(x) AS y",
"oracle": "SELECT CONNECT_BY_ROOT x AS y",
},
)
self.parse_one("ALTER TABLE tbl_name DROP FOREIGN KEY fk_symbol").assert_is(exp.AlterTable)
self.validate_identity("CREATE GLOBAL TEMPORARY TABLE t AS SELECT * FROM orders")
self.validate_identity("CREATE PRIVATE TEMPORARY TABLE t AS SELECT * FROM orders")
self.validate_identity("REGEXP_REPLACE('source', 'search')")
self.validate_identity("TIMESTAMP(3) WITH TIME ZONE")
self.validate_identity("CURRENT_TIMESTAMP(precision)")
self.validate_identity("ALTER TABLE tbl_name DROP FOREIGN KEY fk_symbol")
@ -88,6 +95,13 @@ class TestOracle(Validator):
)
self.validate_identity("SELECT TO_CHAR(-100, 'L99', 'NL_CURRENCY = '' AusDollars '' ')")
self.validate_all(
"TO_CHAR(x)",
write={
"doris": "CAST(x AS STRING)",
"oracle": "TO_CHAR(x)",
},
)
self.validate_all(
"SELECT TO_CHAR(TIMESTAMP '1999-12-01 10:00:00')",
write={

View file

@ -8,8 +8,10 @@ class TestPostgres(Validator):
dialect = "postgres"
def test_postgres(self):
self.validate_identity("1.x", "1. AS x")
self.validate_identity("|/ x", "SQRT(x)")
self.validate_identity("||/ x", "CBRT(x)")
expr = parse_one(
"SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)", read="postgres"
)
@ -82,6 +84,7 @@ class TestPostgres(Validator):
self.validate_identity("CAST(1 AS DECIMAL) / CAST(2 AS DECIMAL) * -100")
self.validate_identity("EXEC AS myfunc @id = 123", check_command_warning=True)
self.validate_identity("SELECT CURRENT_USER")
self.validate_identity("SELECT * FROM ONLY t1")
self.validate_identity(
"""LAST_VALUE("col1") OVER (ORDER BY "col2" RANGE BETWEEN INTERVAL '1 DAY' PRECEDING AND '1 month' FOLLOWING)"""
)
@ -163,6 +166,9 @@ class TestPostgres(Validator):
"SELECT $$Dianne's horse$$",
"SELECT 'Dianne''s horse'",
)
self.validate_identity(
"COMMENT ON TABLE mytable IS $$doc this$$", "COMMENT ON TABLE mytable IS 'doc this'"
)
self.validate_identity(
"UPDATE MYTABLE T1 SET T1.COL = 13",
"UPDATE MYTABLE AS T1 SET T1.COL = 13",
@ -320,6 +326,7 @@ class TestPostgres(Validator):
"MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET x.a = y.b WHEN NOT MATCHED THEN INSERT (a, b) VALUES (y.a, y.b)",
"MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b WHEN NOT MATCHED THEN INSERT (a, b) VALUES (y.a, y.b)",
)
self.validate_identity("SELECT * FROM t1*", "SELECT * FROM t1")
self.validate_all(
"SELECT JSON_EXTRACT_PATH_TEXT(x, k1, k2, k3) FROM t",
@ -653,6 +660,12 @@ class TestPostgres(Validator):
self.validate_identity("CREATE TABLE t (c CHAR(2) UNIQUE NOT NULL) INHERITS (t1)")
self.validate_identity("CREATE TABLE s.t (c CHAR(2) UNIQUE NOT NULL) INHERITS (s.t1, s.t2)")
self.validate_identity("CREATE FUNCTION x(INT) RETURNS INT SET search_path = 'public'")
self.validate_identity("TRUNCATE TABLE t1 CONTINUE IDENTITY")
self.validate_identity("TRUNCATE TABLE t1 RESTART IDENTITY")
self.validate_identity("TRUNCATE TABLE t1 CASCADE")
self.validate_identity("TRUNCATE TABLE t1 RESTRICT")
self.validate_identity("TRUNCATE TABLE t1 CONTINUE IDENTITY CASCADE")
self.validate_identity("TRUNCATE TABLE t1 RESTART IDENTITY RESTRICT")
self.validate_identity(
"CREATE TABLE cust_part3 PARTITION OF customers FOR VALUES WITH (MODULUS 3, REMAINDER 2)"
)
@ -785,6 +798,10 @@ class TestPostgres(Validator):
self.validate_identity(
"CREATE INDEX index_ci_pipelines_on_project_idandrefandiddesc ON public.ci_pipelines USING btree(project_id, ref, id DESC)"
)
self.validate_identity(
"TRUNCATE TABLE ONLY t1, t2*, ONLY t3, t4, t5* RESTART IDENTITY CASCADE",
"TRUNCATE TABLE ONLY t1, t2, ONLY t3, t4, t5 RESTART IDENTITY CASCADE",
)
with self.assertRaises(ParseError):
transpile("CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres")
@ -911,3 +928,31 @@ class TestPostgres(Validator):
"""See https://github.com/tobymao/sqlglot/pull/2404 for details."""
self.assertIsInstance(parse_one("'thomas' ~ '.*thomas.*'", read="postgres"), exp.Binary)
self.assertIsInstance(parse_one("'thomas' ~* '.*thomas.*'", read="postgres"), exp.Binary)
def test_unnest_json_array(self):
trino_input = """
WITH t(boxcrate) AS (
SELECT JSON '[{"boxes": [{"name": "f1", "type": "plant", "color": "red"}]}]'
)
SELECT
JSON_EXTRACT_SCALAR(boxes,'$.name') AS name,
JSON_EXTRACT_SCALAR(boxes,'$.type') AS type,
JSON_EXTRACT_SCALAR(boxes,'$.color') AS color
FROM t
CROSS JOIN UNNEST(CAST(boxcrate AS array(json))) AS x(tbox)
CROSS JOIN UNNEST(CAST(json_extract(tbox, '$.boxes') AS array(json))) AS y(boxes)
"""
expected_postgres = """WITH t(boxcrate) AS (
SELECT
CAST('[{"boxes": [{"name": "f1", "type": "plant", "color": "red"}]}]' AS JSON)
)
SELECT
JSON_EXTRACT_PATH_TEXT(boxes, 'name') AS name,
JSON_EXTRACT_PATH_TEXT(boxes, 'type') AS type,
JSON_EXTRACT_PATH_TEXT(boxes, 'color') AS color
FROM t
CROSS JOIN JSON_ARRAY_ELEMENTS(CAST(boxcrate AS JSON)) AS x(tbox)
CROSS JOIN JSON_ARRAY_ELEMENTS(CAST(JSON_EXTRACT_PATH(tbox, 'boxes') AS JSON)) AS y(boxes)"""
self.validate_all(expected_postgres, read={"trino": trino_input}, pretty=True)

View file

@ -647,6 +647,7 @@ class TestPresto(Validator):
"""JSON '"foo"'""",
write={
"bigquery": """PARSE_JSON('"foo"')""",
"postgres": """CAST('"foo"' AS JSON)""",
"presto": """JSON_PARSE('"foo"')""",
"snowflake": """PARSE_JSON('"foo"')""",
},
@ -1142,3 +1143,18 @@ MATCH_RECOGNIZE (
"presto": "DATE_FORMAT(ts, '%y')",
},
)
def test_signum(self):
self.validate_all(
"SIGN(x)",
read={
"presto": "SIGN(x)",
"spark": "SIGNUM(x)",
"starrocks": "SIGN(x)",
},
write={
"presto": "SIGN(x)",
"spark": "SIGN(x)",
"starrocks": "SIGN(x)",
},
)

View file

@ -515,6 +515,11 @@ FROM (
)
def test_column_unnesting(self):
self.validate_identity("SELECT c.*, o FROM bloo AS c, c.c_orders AS o")
self.validate_identity(
"SELECT c.*, o, l FROM bloo AS c, c.c_orders AS o, o.o_lineitems AS l"
)
ast = parse_one("SELECT * FROM t.t JOIN t.c1 ON c1.c2 = t.c3", read="redshift")
ast.args["from"].this.assert_is(exp.Table)
ast.args["joins"][0].this.assert_is(exp.Table)
@ -522,7 +527,7 @@ FROM (
ast = parse_one("SELECT * FROM t AS t CROSS JOIN t.c1", read="redshift")
ast.args["from"].this.assert_is(exp.Table)
ast.args["joins"][0].this.assert_is(exp.Column)
ast.args["joins"][0].this.assert_is(exp.Unnest)
self.assertEqual(ast.sql("redshift"), "SELECT * FROM t AS t CROSS JOIN t.c1")
ast = parse_one(
@ -530,9 +535,9 @@ FROM (
)
joins = ast.args["joins"]
ast.args["from"].this.assert_is(exp.Table)
joins[0].this.this.assert_is(exp.Column)
joins[1].this.this.assert_is(exp.Column)
joins[2].this.this.assert_is(exp.Dot)
joins[0].this.assert_is(exp.Unnest)
joins[1].this.assert_is(exp.Unnest)
joins[2].this.assert_is(exp.Unnest).expressions[0].assert_is(exp.Dot)
self.assertEqual(
ast.sql("redshift"), "SELECT * FROM x AS a, a.b AS c, c.d.e AS f, f.g.h.i.j.k AS l"
)

View file

@ -40,6 +40,7 @@ WHERE
)""",
)
self.validate_identity("ALTER TABLE authors ADD CONSTRAINT c1 UNIQUE (id, email)")
self.validate_identity("RM @parquet_stage", check_command_warning=True)
self.validate_identity("REMOVE @parquet_stage", check_command_warning=True)
self.validate_identity("SELECT TIMESTAMP_FROM_PARTS(d, t)")
@ -84,6 +85,7 @@ WHERE
self.validate_identity(
"SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) AS x TABLESAMPLE (0.1)"
)
self.validate_identity("x:from", "GET_PATH(x, 'from')")
self.validate_identity(
"value:values::string",
"CAST(GET_PATH(value, 'values') AS TEXT)",
@ -371,15 +373,17 @@ WHERE
write={"snowflake": "SELECT * FROM (VALUES (0)) AS foo(bar)"},
)
self.validate_all(
"OBJECT_CONSTRUCT(a, b, c, d)",
"OBJECT_CONSTRUCT('a', b, 'c', d)",
read={
"": "STRUCT(a as b, c as d)",
"": "STRUCT(b as a, d as c)",
},
write={
"duckdb": "{'a': b, 'c': d}",
"snowflake": "OBJECT_CONSTRUCT(a, b, c, d)",
"snowflake": "OBJECT_CONSTRUCT('a', b, 'c', d)",
},
)
self.validate_identity("OBJECT_CONSTRUCT(a, b, c, d)")
self.validate_all(
"SELECT i, p, o FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1",
write={
@ -1461,26 +1465,22 @@ MATCH_RECOGNIZE (
pretty=True,
)
def test_show(self):
# Parsed as Show
self.validate_identity("SHOW PRIMARY KEYS")
self.validate_identity("SHOW PRIMARY KEYS IN ACCOUNT")
self.validate_identity("SHOW PRIMARY KEYS IN DATABASE")
self.validate_identity("SHOW PRIMARY KEYS IN DATABASE foo")
self.validate_identity("SHOW PRIMARY KEYS IN TABLE")
self.validate_identity("SHOW PRIMARY KEYS IN TABLE foo")
self.validate_identity(
'SHOW PRIMARY KEYS IN "TEST"."PUBLIC"."customers"',
'SHOW PRIMARY KEYS IN TABLE "TEST"."PUBLIC"."customers"',
)
self.validate_identity(
'SHOW TERSE PRIMARY KEYS IN "TEST"."PUBLIC"."customers"',
'SHOW PRIMARY KEYS IN TABLE "TEST"."PUBLIC"."customers"',
)
def test_show_users(self):
self.validate_identity("SHOW USERS")
self.validate_identity("SHOW TERSE USERS")
self.validate_identity("SHOW USERS LIKE '_foo%' STARTS WITH 'bar' LIMIT 5 FROM 'baz'")
def test_show_schemas(self):
self.validate_identity(
"show terse schemas in database db1 starts with 'a' limit 10 from 'b'",
"SHOW TERSE SCHEMAS IN DATABASE db1 STARTS WITH 'a' LIMIT 10 FROM 'b'",
)
ast = parse_one("SHOW SCHEMAS IN DATABASE db1", read="snowflake")
self.assertEqual(ast.args.get("scope_kind"), "DATABASE")
self.assertEqual(ast.find(exp.Table).sql(dialect="snowflake"), "db1")
def test_show_objects(self):
self.validate_identity(
"show terse objects in schema db1.schema1 starts with 'a' limit 10 from 'b'",
"SHOW TERSE OBJECTS IN SCHEMA db1.schema1 STARTS WITH 'a' LIMIT 10 FROM 'b'",
@ -1489,6 +1489,23 @@ MATCH_RECOGNIZE (
"show terse objects in db1.schema1 starts with 'a' limit 10 from 'b'",
"SHOW TERSE OBJECTS IN SCHEMA db1.schema1 STARTS WITH 'a' LIMIT 10 FROM 'b'",
)
ast = parse_one("SHOW OBJECTS IN db1.schema1", read="snowflake")
self.assertEqual(ast.args.get("scope_kind"), "SCHEMA")
self.assertEqual(ast.find(exp.Table).sql(dialect="snowflake"), "db1.schema1")
def test_show_columns(self):
self.validate_identity("SHOW COLUMNS")
self.validate_identity("SHOW COLUMNS IN TABLE dt_test")
self.validate_identity("SHOW COLUMNS LIKE '_foo%' IN TABLE dt_test")
self.validate_identity("SHOW COLUMNS IN VIEW")
self.validate_identity("SHOW COLUMNS LIKE '_foo%' IN VIEW dt_test")
ast = parse_one("SHOW COLUMNS LIKE '_testing%' IN dt_test", read="snowflake")
self.assertEqual(ast.find(exp.Table).sql(dialect="snowflake"), "dt_test")
self.assertEqual(ast.find(exp.Literal).sql(dialect="snowflake"), "'_testing%'")
def test_show_tables(self):
self.validate_identity(
"SHOW TABLES LIKE 'line%' IN tpch.public",
"SHOW TABLES LIKE 'line%' IN SCHEMA tpch.public",
@ -1506,47 +1523,97 @@ MATCH_RECOGNIZE (
"SHOW TERSE TABLES IN SCHEMA db1.schema1 STARTS WITH 'a' LIMIT 10 FROM 'b'",
)
ast = parse_one('SHOW PRIMARY KEYS IN "TEST"."PUBLIC"."customers"', read="snowflake")
table = ast.find(exp.Table)
self.assertEqual(table.sql(dialect="snowflake"), '"TEST"."PUBLIC"."customers"')
self.validate_identity("SHOW COLUMNS")
self.validate_identity("SHOW COLUMNS IN TABLE dt_test")
self.validate_identity("SHOW COLUMNS LIKE '_foo%' IN TABLE dt_test")
self.validate_identity("SHOW COLUMNS IN VIEW")
self.validate_identity("SHOW COLUMNS LIKE '_foo%' IN VIEW dt_test")
self.validate_identity("SHOW USERS")
self.validate_identity("SHOW TERSE USERS")
self.validate_identity("SHOW USERS LIKE '_foo%' STARTS WITH 'bar' LIMIT 5 FROM 'baz'")
ast = parse_one("SHOW COLUMNS LIKE '_testing%' IN dt_test", read="snowflake")
table = ast.find(exp.Table)
literal = ast.find(exp.Literal)
self.assertEqual(table.sql(dialect="snowflake"), "dt_test")
self.assertEqual(literal.sql(dialect="snowflake"), "'_testing%'")
ast = parse_one("SHOW SCHEMAS IN DATABASE db1", read="snowflake")
self.assertEqual(ast.args.get("scope_kind"), "DATABASE")
table = ast.find(exp.Table)
self.assertEqual(table.sql(dialect="snowflake"), "db1")
ast = parse_one("SHOW OBJECTS IN db1.schema1", read="snowflake")
self.assertEqual(ast.args.get("scope_kind"), "SCHEMA")
table = ast.find(exp.Table)
self.assertEqual(table.sql(dialect="snowflake"), "db1.schema1")
ast = parse_one("SHOW TABLES IN db1.schema1", read="snowflake")
self.assertEqual(ast.args.get("scope_kind"), "SCHEMA")
table = ast.find(exp.Table)
self.assertEqual(table.sql(dialect="snowflake"), "db1.schema1")
self.assertEqual(ast.find(exp.Table).sql(dialect="snowflake"), "db1.schema1")
users_exp = self.validate_identity("SHOW USERS")
self.assertTrue(isinstance(users_exp, exp.Show))
self.assertEqual(users_exp.this, "USERS")
def test_show_primary_keys(self):
self.validate_identity("SHOW PRIMARY KEYS")
self.validate_identity("SHOW PRIMARY KEYS IN ACCOUNT")
self.validate_identity("SHOW PRIMARY KEYS IN DATABASE")
self.validate_identity("SHOW PRIMARY KEYS IN DATABASE foo")
self.validate_identity("SHOW PRIMARY KEYS IN TABLE")
self.validate_identity("SHOW PRIMARY KEYS IN TABLE foo")
self.validate_identity(
'SHOW PRIMARY KEYS IN "TEST"."PUBLIC"."foo"',
'SHOW PRIMARY KEYS IN TABLE "TEST"."PUBLIC"."foo"',
)
self.validate_identity(
'SHOW TERSE PRIMARY KEYS IN "TEST"."PUBLIC"."foo"',
'SHOW PRIMARY KEYS IN TABLE "TEST"."PUBLIC"."foo"',
)
ast = parse_one('SHOW PRIMARY KEYS IN "TEST"."PUBLIC"."foo"', read="snowflake")
self.assertEqual(ast.find(exp.Table).sql(dialect="snowflake"), '"TEST"."PUBLIC"."foo"')
def test_show_views(self):
self.validate_identity("SHOW TERSE VIEWS")
self.validate_identity("SHOW VIEWS")
self.validate_identity("SHOW VIEWS LIKE 'foo%'")
self.validate_identity("SHOW VIEWS IN ACCOUNT")
self.validate_identity("SHOW VIEWS IN DATABASE")
self.validate_identity("SHOW VIEWS IN DATABASE foo")
self.validate_identity("SHOW VIEWS IN SCHEMA foo")
self.validate_identity(
"SHOW VIEWS IN foo",
"SHOW VIEWS IN SCHEMA foo",
)
ast = parse_one("SHOW VIEWS IN db1.schema1", read="snowflake")
self.assertEqual(ast.find(exp.Table).sql(dialect="snowflake"), "db1.schema1")
def test_show_unique_keys(self):
self.validate_identity("SHOW UNIQUE KEYS")
self.validate_identity("SHOW UNIQUE KEYS IN ACCOUNT")
self.validate_identity("SHOW UNIQUE KEYS IN DATABASE")
self.validate_identity("SHOW UNIQUE KEYS IN DATABASE foo")
self.validate_identity("SHOW UNIQUE KEYS IN TABLE")
self.validate_identity("SHOW UNIQUE KEYS IN TABLE foo")
self.validate_identity(
'SHOW UNIQUE KEYS IN "TEST"."PUBLIC"."foo"',
'SHOW UNIQUE KEYS IN SCHEMA "TEST"."PUBLIC"."foo"',
)
self.validate_identity(
'SHOW TERSE UNIQUE KEYS IN "TEST"."PUBLIC"."foo"',
'SHOW UNIQUE KEYS IN SCHEMA "TEST"."PUBLIC"."foo"',
)
ast = parse_one('SHOW UNIQUE KEYS IN "TEST"."PUBLIC"."foo"', read="snowflake")
self.assertEqual(ast.find(exp.Table).sql(dialect="snowflake"), '"TEST"."PUBLIC"."foo"')
def test_show_imported_keys(self):
self.validate_identity("SHOW IMPORTED KEYS")
self.validate_identity("SHOW IMPORTED KEYS IN ACCOUNT")
self.validate_identity("SHOW IMPORTED KEYS IN DATABASE")
self.validate_identity("SHOW IMPORTED KEYS IN DATABASE foo")
self.validate_identity("SHOW IMPORTED KEYS IN TABLE")
self.validate_identity("SHOW IMPORTED KEYS IN TABLE foo")
self.validate_identity(
'SHOW IMPORTED KEYS IN "TEST"."PUBLIC"."foo"',
'SHOW IMPORTED KEYS IN SCHEMA "TEST"."PUBLIC"."foo"',
)
self.validate_identity(
'SHOW TERSE IMPORTED KEYS IN "TEST"."PUBLIC"."foo"',
'SHOW IMPORTED KEYS IN SCHEMA "TEST"."PUBLIC"."foo"',
)
ast = parse_one('SHOW IMPORTED KEYS IN "TEST"."PUBLIC"."foo"', read="snowflake")
self.assertEqual(ast.find(exp.Table).sql(dialect="snowflake"), '"TEST"."PUBLIC"."foo"')
def test_show_sequences(self):
self.validate_identity("SHOW TERSE SEQUENCES")
self.validate_identity("SHOW SEQUENCES")
self.validate_identity("SHOW SEQUENCES LIKE '_foo%' IN ACCOUNT")
self.validate_identity("SHOW SEQUENCES LIKE '_foo%' IN DATABASE")
self.validate_identity("SHOW SEQUENCES LIKE '_foo%' IN DATABASE foo")
self.validate_identity("SHOW SEQUENCES LIKE '_foo%' IN SCHEMA")
self.validate_identity("SHOW SEQUENCES LIKE '_foo%' IN SCHEMA foo")
self.validate_identity(
"SHOW SEQUENCES LIKE '_foo%' IN foo",
"SHOW SEQUENCES LIKE '_foo%' IN SCHEMA foo",
)
ast = parse_one("SHOW SEQUENCES IN dt_test", read="snowflake")
self.assertEqual(ast.args.get("scope_kind"), "SCHEMA")
def test_storage_integration(self):
self.validate_identity(

View file

@ -16,6 +16,7 @@ class TestSpark(Validator):
self.validate_identity(
"CREATE TABLE foo (col STRING) CLUSTERED BY (col) SORTED BY (col) INTO 10 BUCKETS"
)
self.validate_identity("TRUNCATE TABLE t1 PARTITION(age = 10, name = 'test', address)")
self.validate_all(
"CREATE TABLE db.example_table (col_a struct<struct_col_a:int, struct_col_b:string>)",

View file

@ -1,5 +1,7 @@
from tests.dialects.test_dialect import Validator
from sqlglot.helper import logger as helper_logger
class TestSQLite(Validator):
dialect = "sqlite"
@ -76,6 +78,7 @@ class TestSQLite(Validator):
self.validate_identity(
"""SELECT item AS "item", some AS "some" FROM data WHERE (item = 'value_1' COLLATE NOCASE) AND (some = 't' COLLATE NOCASE) ORDER BY item ASC LIMIT 1 OFFSET 0"""
)
self.validate_identity("SELECT * FROM GENERATE_SERIES(1, 5)")
self.validate_all("SELECT LIKE(y, x)", write={"sqlite": "SELECT x LIKE y"})
self.validate_all("SELECT GLOB('*y*', 'xyz')", write={"sqlite": "SELECT 'xyz' GLOB '*y*'"})
@ -178,3 +181,12 @@ class TestSQLite(Validator):
"CREATE TABLE foo (bar LONGVARCHAR)",
write={"sqlite": "CREATE TABLE foo (bar TEXT)"},
)
def test_warnings(self):
with self.assertLogs(helper_logger) as cm:
self.validate_identity(
"SELECT * FROM t AS t(c1, c2)",
"SELECT * FROM t AS t",
)
self.assertIn("Named columns are not supported in table alias.", cm.output[0])

View file

@ -1,6 +1,7 @@
from sqlglot import exp, parse, parse_one
from sqlglot.parser import logger as parser_logger
from tests.dialects.test_dialect import Validator
from sqlglot.errors import ParseError
class TestTSQL(Validator):
@ -27,6 +28,7 @@ class TestTSQL(Validator):
self.validate_identity("SELECT * FROM t WHERE NOT c", "SELECT * FROM t WHERE NOT c <> 0")
self.validate_identity("1 AND true", "1 <> 0 AND (1 = 1)")
self.validate_identity("CAST(x AS int) OR y", "CAST(x AS INTEGER) <> 0 OR y <> 0")
self.validate_identity("TRUNCATE TABLE t1 WITH (PARTITIONS(1, 2 TO 5, 10 TO 20, 84))")
self.validate_all(
"SELECT IIF(cond <> 0, 'True', 'False')",
@ -142,7 +144,7 @@ class TestTSQL(Validator):
"tsql": "CREATE TABLE #mytemptable (a INTEGER)",
"snowflake": "CREATE TEMPORARY TABLE mytemptable (a INT)",
"duckdb": "CREATE TEMPORARY TABLE mytemptable (a INT)",
"oracle": "CREATE TEMPORARY TABLE mytemptable (a NUMBER)",
"oracle": "CREATE GLOBAL TEMPORARY TABLE mytemptable (a NUMBER)",
"hive": "CREATE TEMPORARY TABLE mytemptable (a INT)",
"spark2": "CREATE TEMPORARY TABLE mytemptable (a INT) USING PARQUET",
"spark": "CREATE TEMPORARY TABLE mytemptable (a INT) USING PARQUET",
@ -281,7 +283,7 @@ class TestTSQL(Validator):
"CONVERT(INT, CONVERT(NUMERIC, '444.75'))",
write={
"mysql": "CAST(CAST('444.75' AS DECIMAL) AS SIGNED)",
"tsql": "CAST(CAST('444.75' AS NUMERIC) AS INTEGER)",
"tsql": "CONVERT(INTEGER, CONVERT(NUMERIC, '444.75'))",
},
)
self.validate_all(
@ -356,6 +358,76 @@ class TestTSQL(Validator):
self.validate_identity("HASHBYTES('MD2', 'x')")
self.validate_identity("LOG(n, b)")
def test_option(self):
possible_options = [
"HASH GROUP",
"ORDER GROUP",
"CONCAT UNION",
"HASH UNION",
"MERGE UNION",
"LOOP JOIN",
"MERGE JOIN",
"HASH JOIN",
"DISABLE_OPTIMIZED_PLAN_FORCING",
"EXPAND VIEWS",
"FAST 15",
"FORCE ORDER",
"FORCE EXTERNALPUSHDOWN",
"DISABLE EXTERNALPUSHDOWN",
"FORCE SCALEOUTEXECUTION",
"DISABLE SCALEOUTEXECUTION",
"IGNORE_NONCLUSTERED_COLUMNSTORE_INDEX",
"KEEP PLAN",
"KEEPFIXED PLAN",
"MAX_GRANT_PERCENT = 5",
"MIN_GRANT_PERCENT = 10",
"MAXDOP 13",
"MAXRECURSION 8",
"NO_PERFORMANCE_SPOOL",
"OPTIMIZE FOR UNKNOWN",
"PARAMETERIZATION SIMPLE",
"PARAMETERIZATION FORCED",
"QUERYTRACEON 99",
"RECOMPILE",
"ROBUST PLAN",
"USE PLAN N'<xml_plan>'",
"LABEL = 'MyLabel'",
]
possible_statements = [
# These should be un-commented once support for the OPTION clause is added for DELETE, MERGE and UPDATE
# "DELETE FROM Table1",
# "MERGE INTO Locations AS T USING locations_stage AS S ON T.LocationID = S.LocationID WHEN MATCHED THEN UPDATE SET LocationName = S.LocationName",
# "UPDATE Customers SET ContactName = 'Alfred Schmidt', City = 'Frankfurt' WHERE CustomerID = 1",
"SELECT * FROM Table1",
"SELECT * FROM Table1 WHERE id = 2",
]
for statement in possible_statements:
for option in possible_options:
query = f"{statement} OPTION({option})"
result = self.validate_identity(query)
options = result.args.get("options")
self.assertIsInstance(options, list, f"When parsing query {query}")
is_query_options = map(lambda o: isinstance(o, exp.QueryOption), options)
self.assertTrue(all(is_query_options), f"When parsing query {query}")
self.validate_identity(
f"{statement} OPTION(RECOMPILE, USE PLAN N'<xml_plan>', MAX_GRANT_PERCENT = 5)"
)
raising_queries = [
# Missing parentheses
"SELECT * FROM Table1 OPTION HASH GROUP",
# Must be followed by 'PLAN"
"SELECT * FROM Table1 OPTION(KEEPFIXED)",
# Missing commas
"SELECT * FROM Table1 OPTION(HASH GROUP HASH GROUP)",
]
for query in raising_queries:
with self.assertRaises(ParseError, msg=f"When running '{query}'"):
self.parse_one(query)
def test_types(self):
self.validate_identity("CAST(x AS XML)")
self.validate_identity("CAST(x AS UNIQUEIDENTIFIER)")
@ -525,7 +597,7 @@ class TestTSQL(Validator):
"CAST(x as NCHAR(1))",
write={
"spark": "CAST(x AS CHAR(1))",
"tsql": "CAST(x AS CHAR(1))",
"tsql": "CAST(x AS NCHAR(1))",
},
)
@ -533,7 +605,7 @@ class TestTSQL(Validator):
"CAST(x as NVARCHAR(2))",
write={
"spark": "CAST(x AS VARCHAR(2))",
"tsql": "CAST(x AS VARCHAR(2))",
"tsql": "CAST(x AS NVARCHAR(2))",
},
)
@ -692,12 +764,7 @@ class TestTSQL(Validator):
"SELECT * INTO foo.bar.baz FROM (SELECT * FROM a.b.c) AS temp",
read={
"": "CREATE TABLE foo.bar.baz AS SELECT * FROM a.b.c",
},
)
self.validate_all(
"SELECT * INTO foo.bar.baz FROM (SELECT * FROM a.b.c) AS temp",
read={
"": "CREATE TABLE foo.bar.baz AS (SELECT * FROM a.b.c)",
"duckdb": "CREATE TABLE foo.bar.baz AS (SELECT * FROM a.b.c)",
},
)
self.validate_all(
@ -759,11 +826,6 @@ class TestTSQL(Validator):
)
def test_transaction(self):
# BEGIN { TRAN | TRANSACTION }
# [ { transaction_name | @tran_name_variable }
# [ WITH MARK [ 'description' ] ]
# ]
# [ ; ]
self.validate_identity("BEGIN TRANSACTION")
self.validate_all("BEGIN TRAN", write={"tsql": "BEGIN TRANSACTION"})
self.validate_identity("BEGIN TRANSACTION transaction_name")
@ -771,8 +833,6 @@ class TestTSQL(Validator):
self.validate_identity("BEGIN TRANSACTION transaction_name WITH MARK 'description'")
def test_commit(self):
# COMMIT [ { TRAN | TRANSACTION } [ transaction_name | @tran_name_variable ] ] [ WITH ( DELAYED_DURABILITY = { OFF | ON } ) ] [ ; ]
self.validate_all("COMMIT", write={"tsql": "COMMIT TRANSACTION"})
self.validate_all("COMMIT TRAN", write={"tsql": "COMMIT TRANSACTION"})
self.validate_identity("COMMIT TRANSACTION")
@ -787,11 +847,6 @@ class TestTSQL(Validator):
)
def test_rollback(self):
# Applies to SQL Server and Azure SQL Database
# ROLLBACK { TRAN | TRANSACTION }
# [ transaction_name | @tran_name_variable
# | savepoint_name | @savepoint_variable ]
# [ ; ]
self.validate_all("ROLLBACK", write={"tsql": "ROLLBACK TRANSACTION"})
self.validate_all("ROLLBACK TRAN", write={"tsql": "ROLLBACK TRANSACTION"})
self.validate_identity("ROLLBACK TRANSACTION")
@ -911,7 +966,7 @@ WHERE
expected_sqls = [
"CREATE PROC [dbo].[transform_proc] AS DECLARE @CurrentDate VARCHAR(20)",
"SET @CurrentDate = CAST(FORMAT(GETDATE(), 'yyyy-MM-dd HH:mm:ss') AS VARCHAR(20))",
"SET @CurrentDate = CONVERT(VARCHAR(20), GETDATE(), 120)",
"CREATE TABLE [target_schema].[target_table] (a INTEGER) WITH (DISTRIBUTION=REPLICATE, HEAP)",
]
@ -1090,155 +1145,173 @@ WHERE
},
)
def test_convert_date_format(self):
def test_convert(self):
self.validate_all(
"CONVERT(NVARCHAR(200), x)",
write={
"spark": "CAST(x AS VARCHAR(200))",
"tsql": "CONVERT(NVARCHAR(200), x)",
},
)
self.validate_all(
"CONVERT(NVARCHAR, x)",
write={
"spark": "CAST(x AS VARCHAR(30))",
"tsql": "CONVERT(NVARCHAR, x)",
},
)
self.validate_all(
"CONVERT(NVARCHAR(MAX), x)",
write={
"spark": "CAST(x AS STRING)",
"tsql": "CONVERT(NVARCHAR(MAX), x)",
},
)
self.validate_all(
"CONVERT(VARCHAR(200), x)",
write={
"spark": "CAST(x AS VARCHAR(200))",
"tsql": "CONVERT(VARCHAR(200), x)",
},
)
self.validate_all(
"CONVERT(VARCHAR, x)",
write={
"spark": "CAST(x AS VARCHAR(30))",
"tsql": "CONVERT(VARCHAR, x)",
},
)
self.validate_all(
"CONVERT(VARCHAR(MAX), x)",
write={
"spark": "CAST(x AS STRING)",
"tsql": "CONVERT(VARCHAR(MAX), x)",
},
)
self.validate_all(
"CONVERT(CHAR(40), x)",
write={
"spark": "CAST(x AS CHAR(40))",
"tsql": "CONVERT(CHAR(40), x)",
},
)
self.validate_all(
"CONVERT(CHAR, x)",
write={
"spark": "CAST(x AS CHAR(30))",
"tsql": "CONVERT(CHAR, x)",
},
)
self.validate_all(
"CONVERT(NCHAR(40), x)",
write={
"spark": "CAST(x AS CHAR(40))",
"tsql": "CONVERT(NCHAR(40), x)",
},
)
self.validate_all(
"CONVERT(NCHAR, x)",
write={
"spark": "CAST(x AS CHAR(30))",
"tsql": "CONVERT(NCHAR, x)",
},
)
self.validate_all(
"CONVERT(VARCHAR, x, 121)",
write={
"spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(30))",
"tsql": "CONVERT(VARCHAR, x, 121)",
},
)
self.validate_all(
"CONVERT(VARCHAR(40), x, 121)",
write={
"spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(40))",
"tsql": "CONVERT(VARCHAR(40), x, 121)",
},
)
self.validate_all(
"CONVERT(VARCHAR(MAX), x, 121)",
write={
"spark": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')",
"spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS STRING)",
"tsql": "CONVERT(VARCHAR(MAX), x, 121)",
},
)
self.validate_all(
"CONVERT(NVARCHAR, x, 121)",
write={
"spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(30))",
"tsql": "CONVERT(NVARCHAR, x, 121)",
},
)
self.validate_all(
"CONVERT(NVARCHAR(40), x, 121)",
write={
"spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(40))",
"tsql": "CONVERT(NVARCHAR(40), x, 121)",
},
)
self.validate_all(
"CONVERT(NVARCHAR(MAX), x, 121)",
write={
"spark": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')",
"spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS STRING)",
"tsql": "CONVERT(NVARCHAR(MAX), x, 121)",
},
)
self.validate_all(
"CONVERT(DATE, x, 121)",
write={
"spark": "TO_DATE(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')",
"tsql": "CONVERT(DATE, x, 121)",
},
)
self.validate_all(
"CONVERT(DATETIME, x, 121)",
write={
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')",
"tsql": "CONVERT(DATETIME2, x, 121)",
},
)
self.validate_all(
"CONVERT(DATETIME2, x, 121)",
write={
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')",
"tsql": "CONVERT(DATETIME2, x, 121)",
},
)
self.validate_all(
"CONVERT(INT, x)",
write={
"spark": "CAST(x AS INT)",
"tsql": "CONVERT(INTEGER, x)",
},
)
self.validate_all(
"CONVERT(INT, x, 121)",
write={
"spark": "CAST(x AS INT)",
"tsql": "CONVERT(INTEGER, x, 121)",
},
)
self.validate_all(
"TRY_CONVERT(NVARCHAR, x, 121)",
write={
"spark": "TRY_CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(30))",
"tsql": "TRY_CONVERT(NVARCHAR, x, 121)",
},
)
self.validate_all(
"TRY_CONVERT(INT, x)",
write={
"spark": "TRY_CAST(x AS INT)",
"tsql": "TRY_CONVERT(INTEGER, x)",
},
)
self.validate_all(
"TRY_CAST(x AS INT)",
write={
"spark": "TRY_CAST(x AS INT)",
},
)
self.validate_all(
"CAST(x AS INT)",
write={
"spark": "CAST(x AS INT)",
"tsql": "TRY_CAST(x AS INTEGER)",
},
)
self.validate_all(
@ -1246,6 +1319,7 @@ WHERE
write={
"mysql": "SELECT CAST(DATE_FORMAT(testdb.dbo.test.x, '%Y-%m-%d %T') AS CHAR(10)) AS y FROM testdb.dbo.test",
"spark": "SELECT CAST(DATE_FORMAT(testdb.dbo.test.x, 'yyyy-MM-dd HH:mm:ss') AS VARCHAR(10)) AS y FROM testdb.dbo.test",
"tsql": "SELECT CONVERT(VARCHAR(10), testdb.dbo.test.x, 120) AS y FROM testdb.dbo.test",
},
)
self.validate_all(
@ -1253,12 +1327,14 @@ WHERE
write={
"mysql": "SELECT CAST(y.x AS CHAR(10)) AS z FROM testdb.dbo.test AS y",
"spark": "SELECT CAST(y.x AS VARCHAR(10)) AS z FROM testdb.dbo.test AS y",
"tsql": "SELECT CONVERT(VARCHAR(10), y.x) AS z FROM testdb.dbo.test AS y",
},
)
self.validate_all(
"SELECT CAST((SELECT x FROM y) AS VARCHAR) AS test",
write={
"spark": "SELECT CAST((SELECT x FROM y) AS STRING) AS test",
"tsql": "SELECT CAST((SELECT x FROM y) AS VARCHAR) AS test",
},
)
@ -1654,7 +1730,7 @@ FROM OPENJSON(@json) WITH (
Date DATETIME2 '$.Order.Date',
Customer VARCHAR(200) '$.AccountNumber',
Quantity INTEGER '$.Item.Quantity',
[Order] VARCHAR(MAX) AS JSON
[Order] NVARCHAR(MAX) AS JSON
)"""
},
pretty=True,

View file

@ -196,10 +196,10 @@ SET LOCAL variable = value
@"x"
COMMIT
USE db
USE role x
USE warehouse x
USE database x
USE schema x.y
USE ROLE x
USE WAREHOUSE x
USE DATABASE x
USE SCHEMA x.y
NOT 1
NOT NOT 1
SELECT * FROM test
@ -643,6 +643,7 @@ DROP MATERIALIZED VIEW x.y.z
CACHE TABLE x
CACHE LAZY TABLE x
CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value')
CACHE LAZY TABLE x OPTIONS(N'storageLevel' = 'value')
CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1
CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS WITH a AS (SELECT 1) SELECT a.* FROM a
CACHE LAZY TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a
@ -708,6 +709,7 @@ COMMENT ON COLUMN my_schema.my_table.my_column IS 'Employee ID number'
COMMENT ON DATABASE my_database IS 'Development Database'
COMMENT ON PROCEDURE my_proc(integer, integer) IS 'Runs a report'
COMMENT ON TABLE my_schema.my_table IS 'Employee Information'
COMMENT ON TABLE my_schema.my_table IS N'National String'
WITH a AS (SELECT 1) INSERT INTO b SELECT * FROM a
WITH a AS (SELECT * FROM b) UPDATE a SET col = 1
WITH a AS (SELECT * FROM b) CREATE TABLE b AS SELECT * FROM a
@ -785,6 +787,7 @@ ALTER TABLE baa ADD CONSTRAINT boo PRIMARY KEY (x, y) NOT ENFORCED DEFERRABLE IN
ALTER TABLE baa ADD CONSTRAINT boo FOREIGN KEY (x, y) REFERENCES persons ON UPDATE NO ACTION ON DELETE NO ACTION MATCH FULL
ALTER TABLE a ADD PRIMARY KEY (x, y) NOT ENFORCED
ALTER TABLE a ADD FOREIGN KEY (x, y) REFERENCES bla
ALTER TABLE s_ut ADD CONSTRAINT s_ut_uq UNIQUE hajo
SELECT partition FROM a
SELECT end FROM a
SELECT id FROM b.a AS a QUALIFY ROW_NUMBER() OVER (PARTITION BY br ORDER BY sadf DESC) = 1
@ -850,3 +853,7 @@ CAST(foo AS BPCHAR)
values
SELECT values
SELECT values AS values FROM t WHERE values + 1 > 3
SELECT truncate
SELECT only
TRUNCATE(a, b)
SELECT enum

View file

@ -820,7 +820,7 @@ SELECT
`TOp_TeRmS`.`refresh_date` AS `day`,
`TOp_TeRmS`.`term` AS `top_term`,
`TOp_TeRmS`.`rank` AS `rank`
FROM `bigquery-public-data`.`GooGle_tReNDs`.`TOp_TeRmS` AS `TOp_TeRmS`
FROM `bigquery-public-data.GooGle_tReNDs.TOp_TeRmS` AS `TOp_TeRmS`
WHERE
`TOp_TeRmS`.`rank` = 1
AND CAST(`TOp_TeRmS`.`refresh_date` AS DATE) >= DATE_SUB(CURRENT_DATE, INTERVAL 2 WEEK)

View file

@ -172,6 +172,10 @@ SELECT _q_0._col_0 AS _col_0, _q_0._col_1 AS _col_1 FROM (VALUES (1, 2)) AS _q_0
select * from (values (1, 2)) x;
SELECT x._col_0 AS _col_0, x._col_1 AS _col_1 FROM (VALUES (1, 2)) AS x(_col_0, _col_1);
# execute: false
SELECT SOME_UDF(data).* FROM t;
SELECT SOME_UDF(t.data).* FROM t AS t;
--------------------------------------
-- Derived tables
--------------------------------------
@ -333,6 +337,10 @@ WITH cte AS (SELECT 1 AS x) SELECT cte.a AS a FROM cte AS cte(a);
WITH cte(x, y) AS (SELECT 1, 2) SELECT cte.* FROM cte AS cte(a);
WITH cte AS (SELECT 1 AS x, 2 AS y) SELECT cte.a AS a, cte.y AS y FROM cte AS cte(a);
-- Cannot pop table column aliases for recursive ctes (redshift).
WITH RECURSIVE cte(x) AS (SELECT 1), cte2(y) AS (SELECT 2) SELECT * FROM cte, cte2;
WITH RECURSIVE cte(x) AS (SELECT 1 AS x), cte2(y) AS (SELECT 2 AS y) SELECT cte.x AS x, cte2.y AS y FROM cte AS cte, cte2 AS cte2;
# execute: false
WITH player AS (SELECT player.name, player.asset.info FROM players) SELECT * FROM player;
WITH player AS (SELECT players.player.name AS name, players.player.asset.info AS info FROM players AS players) SELECT player.name AS name, player.info AS info FROM player AS player;
@ -549,6 +557,10 @@ SELECT x.a + x.b AS f, (x.a + x.b) * x.b AS _col_1 FROM x AS x;
SELECT x.a + x.b AS f, f, f + 5 FROM x;
SELECT x.a + x.b AS f, x.a + x.b AS _col_1, x.a + x.b + 5 AS _col_2 FROM x AS x;
# title: expand double agg if window func
SELECT a, SUM(b) AS c, SUM(c) OVER(PARTITION BY a) AS d from x group by 1 ORDER BY a;
SELECT x.a AS a, SUM(x.b) AS c, SUM(SUM(x.b)) OVER (PARTITION BY x.a) AS d FROM x AS x GROUP BY x.a ORDER BY a;
--------------------------------------
-- Wrapped tables / join constructs
--------------------------------------

View file

@ -19,6 +19,21 @@ SELECT 1 FROM x.y.z AS z;
SELECT 1 FROM y.z AS z, z.a;
SELECT 1 FROM c.y.z AS z, z.a;
# title: bigquery implicit unnest syntax, coordinates.position should be a column, not a table
# dialect: bigquery
SELECT results FROM Coordinates, coordinates.position AS results;
SELECT results FROM c.db.Coordinates AS Coordinates, UNNEST(coordinates.position) AS results;
# title: bigquery implicit unnest syntax, table is already qualified
# dialect: bigquery
SELECT results FROM db.coordinates, Coordinates.position AS results;
SELECT results FROM c.db.coordinates AS coordinates, UNNEST(Coordinates.position) AS results;
# title: bigquery schema name clashes with CTE name - this is a join, not an implicit unnest
# dialect: bigquery
WITH Coordinates AS (SELECT [1, 2] AS position) SELECT results FROM Coordinates, `Coordinates.position` AS results;
WITH Coordinates AS (SELECT [1, 2] AS position) SELECT results FROM Coordinates AS Coordinates, `c.Coordinates.position` AS results;
# title: single cte
WITH a AS (SELECT 1 FROM z) SELECT 1 FROM a;
WITH a AS (SELECT 1 FROM c.db.z AS z) SELECT 1 FROM a AS a;
@ -83,7 +98,7 @@ SELECT * FROM ((c.db.a AS foo CROSS JOIN c.db.b AS bar) CROSS JOIN c.db.c AS baz
SELECT * FROM (tbl1 CROSS JOIN (SELECT * FROM tbl2) AS t1);
SELECT * FROM (c.db.tbl1 AS tbl1 CROSS JOIN (SELECT * FROM c.db.tbl2 AS tbl2) AS t1);
# title: wrapped join with subquery with alias, parentheses can't be omitted because of alias
# title: wrapped join with subquery with alias, parentheses cant be omitted because of alias
SELECT * FROM (tbl1 CROSS JOIN (SELECT * FROM tbl2) AS t1) AS t2;
SELECT * FROM (SELECT * FROM c.db.tbl1 AS tbl1 CROSS JOIN (SELECT * FROM c.db.tbl2 AS tbl2) AS t1) AS t2;
@ -95,7 +110,7 @@ SELECT * FROM c.db.a AS a LEFT JOIN (c.db.b AS b INNER JOIN c.db.c AS c ON c.id
SELECT * FROM a LEFT JOIN b INNER JOIN c ON c.id = b.id ON b.id = a.id;
SELECT * FROM c.db.a AS a LEFT JOIN c.db.b AS b INNER JOIN c.db.c AS c ON c.id = b.id ON b.id = a.id;
# title: parentheses can't be omitted because alias shadows inner table names
# title: parentheses cant be omitted because alias shadows inner table names
SELECT t.a FROM (tbl AS tbl) AS t;
SELECT t.a FROM (SELECT * FROM c.db.tbl AS tbl) AS t;
@ -146,3 +161,7 @@ CREATE TABLE c.db.t1 AS (WITH cte AS (SELECT x FROM c.db.t2 AS t2) SELECT * FROM
# title: insert statement with cte
WITH cte AS (SELECT b FROM y) INSERT INTO s SELECT * FROM cte;
WITH cte AS (SELECT b FROM c.db.y AS y) INSERT INTO c.db.s SELECT * FROM cte AS cte;
# title: qualify wrapped query
(SELECT x FROM t);
(SELECT x FROM c.db.t AS t);

View file

@ -29,3 +29,7 @@ SELECT "dual" FROM "t";
# dialect: snowflake
SELECT * FROM t AS dual;
SELECT * FROM "t" AS "dual";
# dialect: bigquery
SELECT `p.d.udf`(data).* FROM `p.d.t`;
SELECT `p.d.udf`(`data`).* FROM `p.d.t`;

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -1,6 +1,7 @@
--------------------------------------
-- TPC-DS 1
--------------------------------------
# execute: true
WITH customer_total_return
AS (SELECT sr_customer_sk AS ctr_customer_sk,
sr_store_sk AS ctr_store_sk,
@ -219,6 +220,7 @@ ORDER BY
--------------------------------------
-- TPC-DS 3
--------------------------------------
# execute: true
SELECT dt.d_year,
item.i_brand_id brand_id,
item.i_brand brand,
@ -859,6 +861,7 @@ LIMIT 100;
--------------------------------------
-- TPC-DS 6
--------------------------------------
# execute: true
SELECT a.ca_state state,
Count(*) cnt
FROM customer_address a,
@ -924,6 +927,7 @@ LIMIT 100;
--------------------------------------
-- TPC-DS 7
--------------------------------------
# execute: true
SELECT i_item_id,
Avg(ss_quantity) agg1,
Avg(ss_list_price) agg2,
@ -1247,6 +1251,7 @@ LIMIT 100;
--------------------------------------
-- TPC-DS 9
--------------------------------------
# execute: true
SELECT CASE
WHEN (SELECT Count(*)
FROM store_sales
@ -1448,6 +1453,7 @@ WHERE
--------------------------------------
-- TPC-DS 10
--------------------------------------
# execute: true
SELECT cd_gender,
cd_marital_status,
cd_education_status,
@ -3056,6 +3062,7 @@ LIMIT 100;
--------------------------------------
-- TPC-DS 24
--------------------------------------
# execute: true
WITH ssales
AS (SELECT c_last_name,
c_first_name,
@ -3158,6 +3165,7 @@ HAVING
--------------------------------------
-- TPC-DS 25
--------------------------------------
# execute: true
SELECT i_item_id,
i_item_desc,
s_store_id,
@ -3247,6 +3255,7 @@ LIMIT 100;
--------------------------------------
-- TPC-DS 26
--------------------------------------
# execute: true
SELECT i_item_id,
Avg(cs_quantity) agg1,
Avg(cs_list_price) agg2,
@ -3527,6 +3536,7 @@ LIMIT 100;
--------------------------------------
-- TPC-DS 29
--------------------------------------
# execute: true
SELECT i_item_id,
i_item_desc,
s_store_id,
@ -3726,6 +3736,7 @@ LIMIT 100;
--------------------------------------
-- TPC-DS 31
--------------------------------------
# execute: true
WITH ss
AS (SELECT ca_county,
d_qoy,
@ -3948,6 +3959,7 @@ LIMIT 100;
--------------------------------------
-- TPC-DS 33
--------------------------------------
# execute: true
WITH ss
AS (SELECT i_manufact_id,
Sum(ss_ext_sales_price) total_sales
@ -5014,6 +5026,7 @@ LIMIT 100;
--------------------------------------
-- TPC-DS 43
--------------------------------------
# execute: true
SELECT s_store_name,
s_store_id,
Sum(CASE
@ -6194,6 +6207,7 @@ LIMIT 100;
--------------------------------------
-- TPC-DS 52
--------------------------------------
# execute: true
SELECT dt.d_year,
item.i_brand_id brand_id,
item.i_brand brand,
@ -6357,6 +6371,7 @@ LIMIT 100;
--------------------------------------
-- TPC-DS 54
--------------------------------------
# execute: true
WITH my_customers
AS (SELECT DISTINCT c_customer_sk,
c_current_addr_sk
@ -6493,6 +6508,7 @@ LIMIT 100;
--------------------------------------
-- TPC-DS 55
--------------------------------------
# execute: true
SELECT i_brand_id brand_id,
i_brand brand,
Sum(ss_ext_sales_price) ext_price
@ -6531,6 +6547,7 @@ LIMIT 100;
--------------------------------------
-- TPC-DS 56
--------------------------------------
# execute: true
WITH ss
AS (SELECT i_item_id,
Sum(ss_ext_sales_price) total_sales
@ -7231,6 +7248,7 @@ LIMIT 100;
--------------------------------------
-- TPC-DS 60
--------------------------------------
# execute: true
WITH ss
AS (SELECT i_item_id,
Sum(ss_ext_sales_price) total_sales
@ -8012,6 +8030,7 @@ ORDER BY
--------------------------------------
-- TPC-DS 65
--------------------------------------
# execute: true
SELECT s_store_name,
i_item_desc,
sc.revenue,
@ -9113,6 +9132,7 @@ LIMIT 100;
--------------------------------------
-- TPC-DS 69
--------------------------------------
# execute: true
SELECT cd_gender,
cd_marital_status,
cd_education_status,
@ -9355,6 +9375,7 @@ LIMIT 100;
--------------------------------------
-- TPC-DS 71
--------------------------------------
# execute: true
SELECT i_brand_id brand_id,
i_brand brand,
t_hour,
@ -11064,6 +11085,7 @@ LIMIT 100;
--------------------------------------
-- TPC-DS 83
--------------------------------------
# execute: true
WITH sr_items
AS (SELECT i_item_id item_id,
Sum(sr_return_quantity) sr_item_qty
@ -11262,6 +11284,7 @@ LIMIT 100;
--------------------------------------
-- TPC-DS 84
--------------------------------------
# execute: true
SELECT c_customer_id AS customer_id,
c_last_name
|| ', '
@ -11563,6 +11586,7 @@ FROM "cool_cust" AS "cool_cust";
--------------------------------------
-- TPC-DS 88
--------------------------------------
# execute: true
select *
from
(select count(*) h8_30_to_9
@ -12140,6 +12164,7 @@ LIMIT 100;
--------------------------------------
-- TPC-DS 93
--------------------------------------
# execute: true
SELECT ss_customer_sk,
Sum(act_sales) sumsales
FROM (SELECT ss_item_sk,

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -1047,7 +1047,6 @@ WITH "_u_0" AS (
"lineitem"."l_orderkey" AS "l_orderkey"
FROM "lineitem" AS "lineitem"
GROUP BY
"lineitem"."l_orderkey",
"lineitem"."l_orderkey"
HAVING
SUM("lineitem"."l_quantity") > 300

View file

@ -25,6 +25,7 @@ WHERE
AND x.a > ALL (SELECT y.c FROM y WHERE y.a = x.a)
AND x.a > (SELECT COUNT(*) as d FROM y WHERE y.a = x.a)
AND x.a = SUM(SELECT 1) -- invalid statement left alone
AND x.a IN (SELECT max(y.b) AS b FROM y GROUP BY y.a)
;
SELECT
*
@ -155,6 +156,20 @@ LEFT JOIN (
y.a
) AS _u_21
ON _u_21._u_22 = x.a
LEFT JOIN (
SELECT
_q.b
FROM (
SELECT
MAX(y.b) AS b
FROM y
GROUP BY
y.a
) AS _q
GROUP BY
_q.b
) AS _u_24
ON x.a = _u_24.b
WHERE
x.a = _u_0.a
AND NOT _u_1.a IS NULL
@ -212,6 +227,7 @@ WHERE
AND x.a > COALESCE(_u_21.d, 0)
AND x.a = SUM(SELECT
1) /* invalid statement left alone */
AND NOT _u_24.b IS NULL
;
SELECT
CAST((

View file

@ -94,6 +94,7 @@ class TestBuild(unittest.TestCase):
(lambda: select("x").from_("tbl"), "SELECT x FROM tbl"),
(lambda: select("x", "y").from_("tbl"), "SELECT x, y FROM tbl"),
(lambda: select("x").select("y").from_("tbl"), "SELECT x, y FROM tbl"),
(lambda: select("comment", "begin"), "SELECT comment, begin"),
(
lambda: select("x").select("y", append=False).from_("tbl"),
"SELECT y FROM tbl",
@ -501,6 +502,25 @@ class TestBuild(unittest.TestCase):
),
"SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned",
),
(lambda: parse_one("(SELECT 1)").select("2"), "(SELECT 1, 2)"),
(
lambda: parse_one("(SELECT 1)").limit(1),
"SELECT * FROM ((SELECT 1)) AS _l_0 LIMIT 1",
),
(
lambda: parse_one("WITH t AS (SELECT 1) (SELECT 1)").limit(1),
"SELECT * FROM (WITH t AS (SELECT 1) (SELECT 1)) AS _l_0 LIMIT 1",
),
(
lambda: parse_one("(SELECT 1 LIMIT 2)").limit(1),
"SELECT * FROM ((SELECT 1 LIMIT 2)) AS _l_0 LIMIT 1",
),
(lambda: parse_one("(SELECT 1)").subquery(), "((SELECT 1))"),
(lambda: parse_one("(SELECT 1)").subquery("alias"), "((SELECT 1)) AS alias"),
(
lambda: parse_one("(select * from foo)").with_("foo", "select 1 as c"),
"WITH foo AS (SELECT 1 AS c) (SELECT * FROM foo)",
),
(
lambda: exp.update("tbl", {"x": None, "y": {"x": 1}}),
"UPDATE tbl SET x = NULL, y = MAP(ARRAY('x'), ARRAY(1))",

View file

@ -2,7 +2,7 @@ import unittest
from sqlglot import exp, parse_one
from sqlglot.diff import Insert, Keep, Move, Remove, Update, diff
from sqlglot.expressions import Join, to_identifier
from sqlglot.expressions import Join, to_table
class TestDiff(unittest.TestCase):
@ -18,7 +18,6 @@ class TestDiff(unittest.TestCase):
self._validate_delta_only(
diff(parse_one("SELECT a, b, c"), parse_one("SELECT a, c")),
[
Remove(to_identifier("b", quoted=False)), # the Identifier node
Remove(parse_one("b")), # the Column node
],
)
@ -26,7 +25,6 @@ class TestDiff(unittest.TestCase):
self._validate_delta_only(
diff(parse_one("SELECT a, b"), parse_one("SELECT a, b, c")),
[
Insert(to_identifier("c", quoted=False)), # the Identifier node
Insert(parse_one("c")), # the Column node
],
)
@ -38,9 +36,39 @@ class TestDiff(unittest.TestCase):
),
[
Update(
to_identifier("table_one", quoted=False),
to_identifier("table_two", quoted=False),
), # the Identifier node
to_table("table_one", quoted=False),
to_table("table_two", quoted=False),
), # the Table node
],
)
def test_lambda(self):
self._validate_delta_only(
diff(parse_one("SELECT a, b, c, x(a -> a)"), parse_one("SELECT a, b, c, x(b -> b)")),
[
Update(
exp.Lambda(this=exp.to_identifier("a"), expressions=[exp.to_identifier("a")]),
exp.Lambda(this=exp.to_identifier("b"), expressions=[exp.to_identifier("b")]),
),
],
)
def test_udf(self):
self._validate_delta_only(
diff(parse_one('SELECT a, b, "my.udf1"()'), parse_one('SELECT a, b, "my.udf2"()')),
[
Insert(parse_one('"my.udf2"()')),
Remove(parse_one('"my.udf1"()')),
],
)
self._validate_delta_only(
diff(
parse_one('SELECT a, b, "my.udf"(x, y, z)'),
parse_one('SELECT a, b, "my.udf"(x, y, w)'),
),
[
Insert(exp.column("w")),
Remove(exp.column("z")),
],
)
@ -95,7 +123,6 @@ class TestDiff(unittest.TestCase):
diff(parse_one(expr_src), parse_one(expr_tgt)),
[
Remove(parse_one("LOWER(c) AS c")), # the Alias node
Remove(to_identifier("c", quoted=False)), # the Identifier node
Remove(parse_one("LOWER(c)")), # the Lower node
Remove(parse_one("'filter'")), # the Literal node
Insert(parse_one("'different_filter'")), # the Literal node
@ -162,9 +189,7 @@ class TestDiff(unittest.TestCase):
self._validate_delta_only(
diff(expr_src, expr_tgt),
[
Insert(expression=exp.to_identifier("b")),
Insert(expression=exp.to_column("tbl.b")),
Insert(expression=exp.to_identifier("tbl")),
],
)

View file

@ -1,3 +1,4 @@
import os
import datetime
import unittest
from datetime import date
@ -17,40 +18,53 @@ from tests.helpers import (
FIXTURES_DIR,
SKIP_INTEGRATION,
TPCH_SCHEMA,
TPCDS_SCHEMA,
load_sql_fixture_pairs,
string_to_bool,
)
DIR = FIXTURES_DIR + "/optimizer/tpc-h/"
DIR_TPCH = FIXTURES_DIR + "/optimizer/tpc-h/"
DIR_TPCDS = FIXTURES_DIR + "/optimizer/tpc-ds/"
@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set")
class TestExecutor(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.conn = duckdb.connect()
cls.tpch_conn = duckdb.connect()
cls.tpcds_conn = duckdb.connect()
for table, columns in TPCH_SCHEMA.items():
cls.conn.execute(
cls.tpch_conn.execute(
f"""
CREATE VIEW {table} AS
SELECT *
FROM READ_CSV('{DIR}{table}.csv.gz', delim='|', header=True, columns={columns})
FROM READ_CSV('{DIR_TPCH}{table}.csv.gz', delim='|', header=True, columns={columns})
"""
)
for table, columns in TPCDS_SCHEMA.items():
cls.tpcds_conn.execute(
f"""
CREATE VIEW {table} AS
SELECT *
FROM READ_CSV('{DIR_TPCDS}{table}.csv.gz', delim='|', header=True, columns={columns})
"""
)
cls.cache = {}
cls.sqls = [
(sql, expected)
for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")
]
cls.tpch_sqls = list(load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql"))
cls.tpcds_sqls = list(load_sql_fixture_pairs("optimizer/tpc-ds/tpc-ds.sql"))
@classmethod
def tearDownClass(cls):
cls.conn.close()
cls.tpch_conn.close()
cls.tpcds_conn.close()
def cached_execute(self, sql):
def cached_execute(self, sql, tpch=True):
conn = self.tpch_conn if tpch else self.tpcds_conn
if sql not in self.cache:
self.cache[sql] = self.conn.execute(transpile(sql, write="duckdb")[0]).fetchdf()
self.cache[sql] = conn.execute(transpile(sql, write="duckdb")[0]).fetchdf()
return self.cache[sql]
def rename_anonymous(self, source, target):
@ -66,18 +80,28 @@ class TestExecutor(unittest.TestCase):
self.assertEqual(generate(parse_one("x is null")), "scope[None][x] is None")
def test_optimized_tpch(self):
for i, (sql, optimized) in enumerate(self.sqls, start=1):
for i, (_, sql, optimized) in enumerate(self.tpch_sqls, start=1):
with self.subTest(f"{i}, {sql}"):
a = self.cached_execute(sql)
b = self.conn.execute(transpile(optimized, write="duckdb")[0]).fetchdf()
a = self.cached_execute(sql, tpch=True)
b = self.tpch_conn.execute(transpile(optimized, write="duckdb")[0]).fetchdf()
self.rename_anonymous(b, a)
assert_frame_equal(a, b)
def subtestHelper(self, i, table, tpch=True):
with self.subTest(f"{'tpc-h' if tpch else 'tpc-ds'} {i + 1}"):
_, sql, _ = self.tpch_sqls[i] if tpch else self.tpcds_sqls[i]
a = self.cached_execute(sql, tpch=tpch)
b = pd.DataFrame(
((np.nan if c is None else c for c in r) for r in table.rows),
columns=table.columns,
)
assert_frame_equal(a, b, check_dtype=False, check_index_type=False)
def test_execute_tpch(self):
def to_csv(expression):
if isinstance(expression, exp.Table) and expression.name not in ("revenue"):
return parse_one(
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}"
f"READ_CSV('{DIR_TPCH}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}"
)
return expression
@ -87,19 +111,26 @@ class TestExecutor(unittest.TestCase):
execute,
(
(parse_one(sql).transform(to_csv).sql(pretty=True), TPCH_SCHEMA)
for sql, _ in self.sqls
for _, sql, _ in self.tpch_sqls
),
)
):
with self.subTest(f"tpch-h {i + 1}"):
sql, _ = self.sqls[i]
a = self.cached_execute(sql)
b = pd.DataFrame(
((np.nan if c is None else c for c in r) for r in table.rows),
columns=table.columns,
)
self.subtestHelper(i, table, tpch=True)
assert_frame_equal(a, b, check_dtype=False, check_index_type=False)
def test_execute_tpcds(self):
def to_csv(expression):
if isinstance(expression, exp.Table) and os.path.exists(
f"{DIR_TPCDS}{expression.name}.csv.gz"
):
return parse_one(
f"READ_CSV('{DIR_TPCDS}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}"
)
return expression
for i, (meta, sql, _) in enumerate(self.tpcds_sqls):
if string_to_bool(meta.get("execute")):
table = execute(parse_one(sql).transform(to_csv).sql(pretty=True), TPCDS_SCHEMA)
self.subtestHelper(i, table, tpch=False)
def test_execute_callable(self):
tables = {

View file

@ -249,7 +249,7 @@ class TestExpressions(unittest.TestCase):
{"example.table": "`my-project.example.table`"},
dialect="bigquery",
).sql(),
'SELECT * FROM "my-project".example.table /* example.table */',
'SELECT * FROM "my-project"."example"."table" /* example.table */',
)
def test_expand(self):
@ -313,6 +313,18 @@ class TestExpressions(unittest.TestCase):
).sql(),
"SELECT * FROM (SELECT a FROM tbl1) WHERE b > 100",
)
self.assertEqual(
exp.replace_placeholders(
parse_one("select * from foo WHERE x > ? AND y IS ?"), 0, False
).sql(),
"SELECT * FROM foo WHERE x > 0 AND y IS FALSE",
)
self.assertEqual(
exp.replace_placeholders(
parse_one("select * from foo WHERE x > :int1 AND y IS :bool1"), int1=0, bool1=False
).sql(),
"SELECT * FROM foo WHERE x > 0 AND y IS FALSE",
)
def test_function_building(self):
self.assertEqual(exp.func("max", 1).sql(), "MAX(1)")
@ -645,6 +657,7 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(parse_one("TO_HEX(foo)", read="bigquery"), exp.Hex)
self.assertIsInstance(parse_one("TO_HEX(MD5(foo))", read="bigquery"), exp.MD5)
self.assertIsInstance(parse_one("TRANSFORM(a, b)", read="spark"), exp.Transform)
self.assertIsInstance(parse_one("ADD_MONTHS(a, b)"), exp.AddMonths)
def test_column(self):
column = parse_one("a.b.c.d")

View file

@ -25,21 +25,21 @@ class TestLineage(unittest.TestCase):
node.source.sql(),
"SELECT z.a AS a FROM (SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y /* source: y */) AS z /* source: z */",
)
self.assertEqual(node.alias, "")
self.assertEqual(node.source_name, "")
downstream = node.downstream[0]
self.assertEqual(
downstream.source.sql(),
"SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y /* source: y */",
)
self.assertEqual(downstream.alias, "z")
self.assertEqual(downstream.source_name, "z")
downstream = downstream.downstream[0]
self.assertEqual(
downstream.source.sql(),
"SELECT x.a AS a FROM x AS x",
)
self.assertEqual(downstream.alias, "y")
self.assertEqual(downstream.source_name, "y")
self.assertGreater(len(node.to_html()._repr_html_()), 1000)
def test_lineage_sql_with_cte(self) -> None:
@ -53,7 +53,8 @@ class TestLineage(unittest.TestCase):
node.source.sql(),
"WITH z AS (SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y /* source: y */) SELECT z.a AS a FROM z AS z",
)
self.assertEqual(node.alias, "")
self.assertEqual(node.source_name, "")
self.assertEqual(node.reference_node_name, "")
# Node containing expanded CTE expression
downstream = node.downstream[0]
@ -61,14 +62,16 @@ class TestLineage(unittest.TestCase):
downstream.source.sql(),
"SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y /* source: y */",
)
self.assertEqual(downstream.alias, "")
self.assertEqual(downstream.source_name, "")
self.assertEqual(downstream.reference_node_name, "z")
downstream = downstream.downstream[0]
self.assertEqual(
downstream.source.sql(),
"SELECT x.a AS a FROM x AS x",
)
self.assertEqual(downstream.alias, "y")
self.assertEqual(downstream.source_name, "y")
self.assertEqual(downstream.reference_node_name, "")
def test_lineage_source_with_cte(self) -> None:
node = lineage(
@ -81,21 +84,24 @@ class TestLineage(unittest.TestCase):
node.source.sql(),
"SELECT z.a AS a FROM (WITH y AS (SELECT x.a AS a FROM x AS x) SELECT y.a AS a FROM y AS y) AS z /* source: z */",
)
self.assertEqual(node.alias, "")
self.assertEqual(node.source_name, "")
self.assertEqual(node.reference_node_name, "")
downstream = node.downstream[0]
self.assertEqual(
downstream.source.sql(),
"WITH y AS (SELECT x.a AS a FROM x AS x) SELECT y.a AS a FROM y AS y",
)
self.assertEqual(downstream.alias, "z")
self.assertEqual(downstream.source_name, "z")
self.assertEqual(downstream.reference_node_name, "")
downstream = downstream.downstream[0]
self.assertEqual(
downstream.source.sql(),
"SELECT x.a AS a FROM x AS x",
)
self.assertEqual(downstream.alias, "z")
self.assertEqual(downstream.source_name, "z")
self.assertEqual(downstream.reference_node_name, "y")
def test_lineage_source_with_star(self) -> None:
node = lineage(
@ -106,14 +112,16 @@ class TestLineage(unittest.TestCase):
node.source.sql(),
"WITH y AS (SELECT * FROM x AS x) SELECT y.a AS a FROM y AS y",
)
self.assertEqual(node.alias, "")
self.assertEqual(node.source_name, "")
self.assertEqual(node.reference_node_name, "")
downstream = node.downstream[0]
self.assertEqual(
downstream.source.sql(),
"SELECT * FROM x AS x",
)
self.assertEqual(downstream.alias, "")
self.assertEqual(downstream.source_name, "")
self.assertEqual(downstream.reference_node_name, "y")
def test_lineage_external_col(self) -> None:
node = lineage(
@ -124,14 +132,16 @@ class TestLineage(unittest.TestCase):
node.source.sql(),
"WITH y AS (SELECT * FROM x AS x) SELECT a AS a FROM y AS y JOIN z AS z ON y.uid = z.uid",
)
self.assertEqual(node.alias, "")
self.assertEqual(node.source_name, "")
self.assertEqual(node.reference_node_name, "")
downstream = node.downstream[0]
self.assertEqual(
downstream.source.sql(),
"?",
)
self.assertEqual(downstream.alias, "")
self.assertEqual(downstream.source_name, "")
self.assertEqual(downstream.reference_node_name, "")
def test_lineage_values(self) -> None:
node = lineage(
@ -143,17 +153,17 @@ class TestLineage(unittest.TestCase):
node.source.sql(),
"SELECT y.a AS a FROM (SELECT t.a AS a FROM (VALUES (1), (2)) AS t(a)) AS y /* source: y */",
)
self.assertEqual(node.alias, "")
self.assertEqual(node.source_name, "")
downstream = node.downstream[0]
self.assertEqual(downstream.source.sql(), "SELECT t.a AS a FROM (VALUES (1), (2)) AS t(a)")
self.assertEqual(downstream.expression.sql(), "t.a AS a")
self.assertEqual(downstream.alias, "y")
self.assertEqual(downstream.source_name, "y")
downstream = downstream.downstream[0]
self.assertEqual(downstream.source.sql(), "(VALUES (1), (2)) AS t(a)")
self.assertEqual(downstream.expression.sql(), "a")
self.assertEqual(downstream.alias, "y")
self.assertEqual(downstream.source_name, "y")
def test_lineage_cte_name_appears_in_schema(self) -> None:
schema = {"a": {"b": {"t1": {"c1": "int"}, "t2": {"c2": "int"}}}}
@ -168,22 +178,22 @@ class TestLineage(unittest.TestCase):
node.source.sql(),
"WITH t1 AS (SELECT t2.c2 AS c2 FROM a.b.t2 AS t2), inter AS (SELECT t1.c2 AS c2 FROM t1 AS t1) SELECT inter.c2 AS c2 FROM inter AS inter",
)
self.assertEqual(node.alias, "")
self.assertEqual(node.source_name, "")
downstream = node.downstream[0]
self.assertEqual(downstream.source.sql(), "SELECT t1.c2 AS c2 FROM t1 AS t1")
self.assertEqual(downstream.expression.sql(), "t1.c2 AS c2")
self.assertEqual(downstream.alias, "")
self.assertEqual(downstream.source_name, "")
downstream = downstream.downstream[0]
self.assertEqual(downstream.source.sql(), "SELECT t2.c2 AS c2 FROM a.b.t2 AS t2")
self.assertEqual(downstream.expression.sql(), "t2.c2 AS c2")
self.assertEqual(downstream.alias, "")
self.assertEqual(downstream.source_name, "")
downstream = downstream.downstream[0]
self.assertEqual(downstream.source.sql(), "a.b.t2 AS t2")
self.assertEqual(downstream.expression.sql(), "a.b.t2 AS t2")
self.assertEqual(downstream.alias, "")
self.assertEqual(downstream.source_name, "")
self.assertEqual(downstream.downstream, [])
@ -280,9 +290,11 @@ class TestLineage(unittest.TestCase):
downstream_a = node.downstream[0]
self.assertEqual(downstream_a.name, "0")
self.assertEqual(downstream_a.source.sql(), "SELECT * FROM catalog.db.table_a AS table_a")
self.assertEqual(downstream_a.reference_node_name, "dataset")
downstream_b = node.downstream[1]
self.assertEqual(downstream_b.name, "0")
self.assertEqual(downstream_b.source.sql(), "SELECT * FROM catalog.db.table_b AS table_b")
self.assertEqual(downstream_b.reference_node_name, "dataset")
def test_lineage_source_union(self) -> None:
query = "SELECT x, created_at FROM dataset;"
@ -306,12 +318,14 @@ class TestLineage(unittest.TestCase):
downstream_a = node.downstream[0]
self.assertEqual(downstream_a.name, "0")
self.assertEqual(downstream_a.alias, "dataset")
self.assertEqual(downstream_a.source_name, "dataset")
self.assertEqual(downstream_a.source.sql(), "SELECT * FROM catalog.db.table_a AS table_a")
self.assertEqual(downstream_a.reference_node_name, "")
downstream_b = node.downstream[1]
self.assertEqual(downstream_b.name, "0")
self.assertEqual(downstream_b.alias, "dataset")
self.assertEqual(downstream_b.source_name, "dataset")
self.assertEqual(downstream_b.source.sql(), "SELECT * FROM catalog.db.table_b AS table_b")
self.assertEqual(downstream_b.reference_node_name, "")
def test_select_star(self) -> None:
node = lineage("x", "SELECT x from (SELECT * from table_a)")
@ -332,3 +346,10 @@ class TestLineage(unittest.TestCase):
"with _data as (select [struct(1 as a, 2 as b)] as col) select b from _data cross join unnest(col)",
)
self.assertEqual(node.name, "b")
def test_lineage_normalize(self) -> None:
node = lineage("a", "WITH x AS (SELECT 1 a) SELECT a FROM x", dialect="snowflake")
self.assertEqual(node.name, "A")
with self.assertRaises(sqlglot.errors.SqlglotError):
lineage('"a"', "WITH x AS (SELECT 1 a) SELECT a FROM x", dialect="snowflake")

View file

@ -205,6 +205,7 @@ class TestOptimizer(unittest.TestCase):
optimizer.qualify_tables.qualify_tables,
db="db",
catalog="c",
set_dialect=True,
)
def test_normalize(self):
@ -285,6 +286,15 @@ class TestOptimizer(unittest.TestCase):
"SELECT `test`.`bar_bazfoo_$id` AS `bar_bazfoo_$id` FROM `test` AS `test`",
)
qualified = optimizer.qualify.qualify(
parse_one("WITH t AS (SELECT 1 AS c) (SELECT c FROM t)")
)
self.assertIs(qualified.selects[0].parent, qualified.this)
self.assertEqual(
qualified.sql(),
'WITH "t" AS (SELECT 1 AS "c") (SELECT "t"."c" AS "c" FROM "t" AS "t")',
)
self.check_file(
"qualify_columns", qualify_columns, execute=True, schema=self.schema, set_dialect=True
)
@ -348,6 +358,23 @@ class TestOptimizer(unittest.TestCase):
self.assertEqual("CONCAT('a', x, 'bc')", simplified_concat.sql(dialect="presto"))
self.assertEqual("CONCAT('a', x, 'bc')", simplified_safe_concat.sql())
anon_unquoted_str = parse_one("anonymous(x, y)")
self.assertEqual(optimizer.simplify.gen(anon_unquoted_str), "ANONYMOUS x,y")
anon_unquoted_identifier = exp.Anonymous(
this=exp.to_identifier("anonymous"), expressions=[exp.column("x"), exp.column("y")]
)
self.assertEqual(optimizer.simplify.gen(anon_unquoted_identifier), "ANONYMOUS x,y")
anon_quoted = parse_one('"anonymous"(x, y)')
self.assertEqual(optimizer.simplify.gen(anon_quoted), '"anonymous" x,y')
with self.assertRaises(ValueError) as e:
anon_invalid = exp.Anonymous(this=5)
optimizer.simplify.gen(anon_invalid)
self.assertIn("Anonymous.this expects a str or an Identifier, got 'int'.", str(e.exception))
def test_unnest_subqueries(self):
self.check_file(
"unnest_subqueries",
@ -982,9 +1009,12 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expression.selects[0].type.sql(), "ARRAY<INT>")
schema = MappingSchema({"t": {"c": "STRUCT<`f` STRING>"}}, dialect="bigquery")
expression = annotate_types(parse_one("SELECT t.c FROM t"), schema=schema)
expression = annotate_types(parse_one("SELECT t.c, [t.c] FROM t"), schema=schema)
self.assertEqual(expression.selects[0].type.sql(dialect="bigquery"), "STRUCT<`f` STRING>")
self.assertEqual(
expression.selects[1].type.sql(dialect="bigquery"), "ARRAY<STRUCT<`f` STRING>>"
)
expression = annotate_types(
parse_one("SELECT unnest(t.x) FROM t AS t", dialect="postgres"),
@ -1010,6 +1040,22 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(exp.DataType.Type.USERDEFINED, expression.selects[0].type.this)
self.assertEqual(expression.selects[0].type.sql(dialect="postgres"), "IPADDRESS")
def test_unnest_annotation(self):
expression = annotate_types(
optimizer.qualify.qualify(
parse_one(
"""
SELECT a, a.b, a.b.c FROM x, UNNEST(x.a) AS a
""",
read="bigquery",
)
),
schema={"x": {"a": "ARRAY<STRUCT<b STRUCT<c int>>>"}},
)
self.assertEqual(expression.selects[0].type, exp.DataType.build("STRUCT<b STRUCT<c int>>"))
self.assertEqual(expression.selects[1].type, exp.DataType.build("STRUCT<c int>"))
self.assertEqual(expression.selects[2].type, exp.DataType.build("int"))
def test_recursive_cte(self):
query = parse_one(
"""

View file

@ -852,3 +852,6 @@ class TestParser(unittest.TestCase):
):
with self.subTest(dialect):
self.assertEqual(parse_one(sql, dialect=dialect).sql(dialect=dialect), sql)
def test_distinct_from(self):
self.assertIsInstance(parse_one("a IS DISTINCT FROM b OR c IS DISTINCT FROM d"), exp.Or)

View file

@ -6,8 +6,7 @@ from sqlglot.optimizer.annotate_types import annotate_types
from tests.helpers import load_sql_fixtures
class CustomExpression(exp.Expression):
...
class CustomExpression(exp.Expression): ...
class TestSerDe(unittest.TestCase):

View file

@ -747,7 +747,6 @@ FROM base""",
"ALTER SEQUENCE IF EXISTS baz RESTART WITH boo",
"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS=3",
"ALTER TABLE integers DROP PRIMARY KEY",
"ALTER TABLE s_ut ADD CONSTRAINT s_ut_uq UNIQUE hajo",
"ALTER TABLE table1 MODIFY COLUMN name1 SET TAG foo='bar'",
"ALTER TABLE table1 RENAME COLUMN c1 AS c2",
"ALTER TABLE table1 RENAME COLUMN c1 TO c2, c2 TO c3",
@ -769,7 +768,6 @@ FROM base""",
"SET -v",
"SET @user OFF",
"SHOW TABLES",
"TRUNCATE TABLE x",
"VACUUM FREEZE my_table",
):
with self.subTest(sql):