1
0
Fork 0

Merging upstream version 10.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:53:05 +01:00
parent 528822bfd4
commit b7d21c45b7
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
98 changed files with 4080 additions and 1666 deletions

View file

@ -157,6 +157,14 @@ class TestBigQuery(Validator):
},
)
self.validate_all(
"DIV(x, y)",
write={
"bigquery": "DIV(x, y)",
"duckdb": "CAST(x / y AS INT)",
},
)
self.validate_identity(
"SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)"
)
@ -284,4 +292,6 @@ class TestBigQuery(Validator):
"CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'"
)
self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)")
self.validate_identity("CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t")
self.validate_identity(
"CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t"
)

View file

@ -18,7 +18,6 @@ class TestClickhouse(Validator):
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
},
)
self.validate_all(
"CAST(1 AS NULLABLE(Int64))",
write={
@ -31,3 +30,7 @@ class TestClickhouse(Validator):
"clickhouse": "CAST(1 AS Nullable(DateTime64(6, 'UTC')))",
},
)
self.validate_all(
"SELECT x #! comment",
write={"": "SELECT x /* comment */"},
)

View file

@ -22,7 +22,8 @@ class TestDatabricks(Validator):
},
)
self.validate_all(
"SELECT DATEDIFF('end', 'start')", write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"}
"SELECT DATEDIFF('end', 'start')",
write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"},
)
self.validate_all(
"SELECT DATE_ADD('2020-01-01', 1)",

View file

@ -1,20 +1,18 @@
import unittest
from sqlglot import (
Dialect,
Dialects,
ErrorLevel,
UnsupportedError,
parse_one,
transpile,
)
from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one
class Validator(unittest.TestCase):
dialect = None
def validate_identity(self, sql):
self.assertEqual(transpile(sql, read=self.dialect, write=self.dialect)[0], sql)
def parse_one(self, sql):
return parse_one(sql, read=self.dialect)
def validate_identity(self, sql, write_sql=None):
expression = self.parse_one(sql)
self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect))
return expression
def validate_all(self, sql, read=None, write=None, pretty=False):
"""
@ -28,12 +26,14 @@ class Validator(unittest.TestCase):
read (dict): Mapping of dialect -> SQL
write (dict): Mapping of dialect -> SQL
"""
expression = parse_one(sql, read=self.dialect)
expression = self.parse_one(sql)
for read_dialect, read_sql in (read or {}).items():
with self.subTest(f"{read_dialect} -> {sql}"):
self.assertEqual(
parse_one(read_sql, read_dialect).sql(self.dialect, unsupported_level=ErrorLevel.IGNORE),
parse_one(read_sql, read_dialect).sql(
self.dialect, unsupported_level=ErrorLevel.IGNORE, pretty=pretty
),
sql,
)
@ -83,10 +83,6 @@ class TestDialect(Validator):
)
self.validate_all(
"CAST(a AS BINARY(4))",
read={
"presto": "CAST(a AS VARBINARY(4))",
"sqlite": "CAST(a AS VARBINARY(4))",
},
write={
"bigquery": "CAST(a AS BINARY(4))",
"clickhouse": "CAST(a AS BINARY(4))",
@ -103,6 +99,24 @@ class TestDialect(Validator):
"starrocks": "CAST(a AS BINARY(4))",
},
)
self.validate_all(
"CAST(a AS VARBINARY(4))",
write={
"bigquery": "CAST(a AS VARBINARY(4))",
"clickhouse": "CAST(a AS VARBINARY(4))",
"duckdb": "CAST(a AS VARBINARY(4))",
"mysql": "CAST(a AS VARBINARY(4))",
"hive": "CAST(a AS BINARY(4))",
"oracle": "CAST(a AS BLOB(4))",
"postgres": "CAST(a AS BYTEA(4))",
"presto": "CAST(a AS VARBINARY(4))",
"redshift": "CAST(a AS VARBYTE(4))",
"snowflake": "CAST(a AS VARBINARY(4))",
"sqlite": "CAST(a AS BLOB(4))",
"spark": "CAST(a AS BINARY(4))",
"starrocks": "CAST(a AS VARBINARY(4))",
},
)
self.validate_all(
"CAST(MAP('a', '1') AS MAP(TEXT, TEXT))",
write={
@ -472,45 +486,57 @@ class TestDialect(Validator):
},
)
self.validate_all(
"DATE_TRUNC(x, 'day')",
"DATE_TRUNC('day', x)",
write={
"mysql": "DATE(x)",
"starrocks": "DATE(x)",
},
)
self.validate_all(
"DATE_TRUNC(x, 'week')",
"DATE_TRUNC('week', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')",
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')",
},
)
self.validate_all(
"DATE_TRUNC(x, 'month')",
"DATE_TRUNC('month', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')",
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')",
},
)
self.validate_all(
"DATE_TRUNC(x, 'quarter')",
"DATE_TRUNC('quarter', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')",
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')",
},
)
self.validate_all(
"DATE_TRUNC(x, 'year')",
"DATE_TRUNC('year', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
},
)
self.validate_all(
"DATE_TRUNC(x, 'millenium')",
"DATE_TRUNC('millenium', x)",
write={
"mysql": UnsupportedError,
"starrocks": UnsupportedError,
},
)
self.validate_all(
"DATE_TRUNC('year', x)",
read={
"starrocks": "DATE_TRUNC('year', x)",
},
write={
"starrocks": "DATE_TRUNC('year', x)",
},
)
self.validate_all(
"DATE_TRUNC(x, year)",
read={
"bigquery": "DATE_TRUNC(x, year)",
},
write={
"bigquery": "DATE_TRUNC(x, year)",
},
)
self.validate_all(
@ -564,6 +590,22 @@ class TestDialect(Validator):
"spark": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
},
)
self.validate_all(
"TIMESTAMP '2022-01-01'",
write={
"mysql": "CAST('2022-01-01' AS TIMESTAMP)",
"starrocks": "CAST('2022-01-01' AS DATETIME)",
"hive": "CAST('2022-01-01' AS TIMESTAMP)",
},
)
self.validate_all(
"TIMESTAMP('2022-01-01')",
write={
"mysql": "TIMESTAMP('2022-01-01')",
"starrocks": "TIMESTAMP('2022-01-01')",
"hive": "TIMESTAMP('2022-01-01')",
},
)
for unit in ("DAY", "MONTH", "YEAR"):
self.validate_all(
@ -1002,7 +1044,10 @@ class TestDialect(Validator):
)
def test_limit(self):
self.validate_all("SELECT * FROM data LIMIT 10, 20", write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"})
self.validate_all(
"SELECT * FROM data LIMIT 10, 20",
write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"},
)
self.validate_all(
"SELECT x FROM y LIMIT 10",
write={
@ -1132,3 +1177,56 @@ class TestDialect(Validator):
"sqlite": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
},
)
def test_nullsafe_eq(self):
self.validate_all(
"SELECT a IS NOT DISTINCT FROM b",
read={
"mysql": "SELECT a <=> b",
"postgres": "SELECT a IS NOT DISTINCT FROM b",
},
write={
"mysql": "SELECT a <=> b",
"postgres": "SELECT a IS NOT DISTINCT FROM b",
},
)
def test_nullsafe_neq(self):
self.validate_all(
"SELECT a IS DISTINCT FROM b",
read={
"postgres": "SELECT a IS DISTINCT FROM b",
},
write={
"mysql": "SELECT NOT a <=> b",
"postgres": "SELECT a IS DISTINCT FROM b",
},
)
def test_hash_comments(self):
self.validate_all(
"SELECT 1 /* arbitrary content,,, until end-of-line */",
read={
"mysql": "SELECT 1 # arbitrary content,,, until end-of-line",
"bigquery": "SELECT 1 # arbitrary content,,, until end-of-line",
"clickhouse": "SELECT 1 #! arbitrary content,,, until end-of-line",
},
)
self.validate_all(
"""/* comment1 */
SELECT
x, -- comment2
y -- comment3""",
read={
"mysql": """SELECT # comment1
x, # comment2
y # comment3""",
"bigquery": """SELECT # comment1
x, # comment2
y # comment3""",
"clickhouse": """SELECT # comment1
x, # comment2
y # comment3""",
},
pretty=True,
)

View file

@ -1,3 +1,4 @@
from sqlglot import expressions as exp
from tests.dialects.test_dialect import Validator
@ -20,6 +21,52 @@ class TestMySQL(Validator):
self.validate_identity("SELECT TRIM(TRAILING 'bla' FROM ' XXX ')")
self.validate_identity("SELECT TRIM(BOTH 'bla' FROM ' XXX ')")
self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')")
self.validate_identity("@@GLOBAL.max_connections")
# SET Commands
self.validate_identity("SET @var_name = expr")
self.validate_identity("SET @name = 43")
self.validate_identity("SET @total_tax = (SELECT SUM(tax) FROM taxable_transactions)")
self.validate_identity("SET GLOBAL max_connections = 1000")
self.validate_identity("SET @@GLOBAL.max_connections = 1000")
self.validate_identity("SET SESSION sql_mode = 'TRADITIONAL'")
self.validate_identity("SET LOCAL sql_mode = 'TRADITIONAL'")
self.validate_identity("SET @@SESSION.sql_mode = 'TRADITIONAL'")
self.validate_identity("SET @@LOCAL.sql_mode = 'TRADITIONAL'")
self.validate_identity("SET @@sql_mode = 'TRADITIONAL'")
self.validate_identity("SET sql_mode = 'TRADITIONAL'")
self.validate_identity("SET PERSIST max_connections = 1000")
self.validate_identity("SET @@PERSIST.max_connections = 1000")
self.validate_identity("SET PERSIST_ONLY back_log = 100")
self.validate_identity("SET @@PERSIST_ONLY.back_log = 100")
self.validate_identity("SET @@SESSION.max_join_size = DEFAULT")
self.validate_identity("SET @@SESSION.max_join_size = @@GLOBAL.max_join_size")
self.validate_identity("SET @x = 1, SESSION sql_mode = ''")
self.validate_identity(
"SET GLOBAL sort_buffer_size = 1000000, SESSION sort_buffer_size = 1000000"
)
self.validate_identity(
"SET @@GLOBAL.sort_buffer_size = 1000000, @@LOCAL.sort_buffer_size = 1000000"
)
self.validate_identity("SET GLOBAL max_connections = 1000, sort_buffer_size = 1000000")
self.validate_identity("SET @@GLOBAL.sort_buffer_size = 50000, sort_buffer_size = 1000000")
self.validate_identity("SET CHARACTER SET 'utf8'")
self.validate_identity("SET CHARACTER SET utf8")
self.validate_identity("SET CHARACTER SET DEFAULT")
self.validate_identity("SET NAMES 'utf8'")
self.validate_identity("SET NAMES DEFAULT")
self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'")
self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci")
self.validate_identity("SET autocommit = ON")
def test_escape(self):
self.validate_all(
r"'a \' b '' '",
write={
"mysql": r"'a '' b '' '",
"spark": r"'a \' b \' '",
},
)
def test_introducers(self):
self.validate_all(
@ -115,14 +162,6 @@ class TestMySQL(Validator):
},
)
def test_hash_comments(self):
self.validate_all(
"SELECT 1 # arbitrary content,,, until end-of-line",
write={
"mysql": "SELECT 1",
},
)
def test_mysql(self):
self.validate_all(
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
@ -174,3 +213,242 @@ COMMENT='客户账户表'"""
},
pretty=True,
)
def test_show_simple(self):
for key, write_key in [
("BINARY LOGS", "BINARY LOGS"),
("MASTER LOGS", "BINARY LOGS"),
("STORAGE ENGINES", "ENGINES"),
("ENGINES", "ENGINES"),
("EVENTS", "EVENTS"),
("MASTER STATUS", "MASTER STATUS"),
("PLUGINS", "PLUGINS"),
("PRIVILEGES", "PRIVILEGES"),
("PROFILES", "PROFILES"),
("REPLICAS", "REPLICAS"),
("SLAVE HOSTS", "REPLICAS"),
]:
show = self.validate_identity(f"SHOW {key}", f"SHOW {write_key}")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, write_key)
def test_show_events(self):
for key in ["BINLOG", "RELAYLOG"]:
show = self.validate_identity(f"SHOW {key} EVENTS")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, f"{key} EVENTS")
show = self.validate_identity(f"SHOW {key} EVENTS IN 'log' FROM 1 LIMIT 2, 3")
self.assertEqual(show.text("log"), "log")
self.assertEqual(show.text("position"), "1")
self.assertEqual(show.text("limit"), "3")
self.assertEqual(show.text("offset"), "2")
show = self.validate_identity(f"SHOW {key} EVENTS LIMIT 1")
self.assertEqual(show.text("limit"), "1")
self.assertIsNone(show.args.get("offset"))
def test_show_like_or_where(self):
for key, write_key in [
("CHARSET", "CHARACTER SET"),
("CHARACTER SET", "CHARACTER SET"),
("COLLATION", "COLLATION"),
("DATABASES", "DATABASES"),
("FUNCTION STATUS", "FUNCTION STATUS"),
("PROCEDURE STATUS", "PROCEDURE STATUS"),
("GLOBAL STATUS", "GLOBAL STATUS"),
("SESSION STATUS", "STATUS"),
("STATUS", "STATUS"),
("GLOBAL VARIABLES", "GLOBAL VARIABLES"),
("SESSION VARIABLES", "VARIABLES"),
("VARIABLES", "VARIABLES"),
]:
expected_name = write_key.strip("GLOBAL").strip()
template = "SHOW {}"
show = self.validate_identity(template.format(key), template.format(write_key))
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, expected_name)
template = "SHOW {} LIKE '%foo%'"
show = self.validate_identity(template.format(key), template.format(write_key))
self.assertIsInstance(show, exp.Show)
self.assertIsInstance(show.args["like"], exp.Literal)
self.assertEqual(show.text("like"), "%foo%")
template = "SHOW {} WHERE Column_name LIKE '%foo%'"
show = self.validate_identity(template.format(key), template.format(write_key))
self.assertIsInstance(show, exp.Show)
self.assertIsInstance(show.args["where"], exp.Where)
self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'")
def test_show_columns(self):
show = self.validate_identity("SHOW COLUMNS FROM tbl_name")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "COLUMNS")
self.assertEqual(show.text("target"), "tbl_name")
self.assertFalse(show.args["full"])
show = self.validate_identity("SHOW FULL COLUMNS FROM tbl_name FROM db_name LIKE '%foo%'")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.text("target"), "tbl_name")
self.assertTrue(show.args["full"])
self.assertEqual(show.text("db"), "db_name")
self.assertIsInstance(show.args["like"], exp.Literal)
self.assertEqual(show.text("like"), "%foo%")
def test_show_name(self):
for key in [
"CREATE DATABASE",
"CREATE EVENT",
"CREATE FUNCTION",
"CREATE PROCEDURE",
"CREATE TABLE",
"CREATE TRIGGER",
"CREATE VIEW",
"FUNCTION CODE",
"PROCEDURE CODE",
]:
show = self.validate_identity(f"SHOW {key} foo")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, key)
self.assertEqual(show.text("target"), "foo")
def test_show_grants(self):
show = self.validate_identity(f"SHOW GRANTS FOR foo")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "GRANTS")
self.assertEqual(show.text("target"), "foo")
def test_show_engine(self):
show = self.validate_identity("SHOW ENGINE foo STATUS")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "ENGINE")
self.assertEqual(show.text("target"), "foo")
self.assertFalse(show.args["mutex"])
show = self.validate_identity("SHOW ENGINE foo MUTEX")
self.assertEqual(show.name, "ENGINE")
self.assertEqual(show.text("target"), "foo")
self.assertTrue(show.args["mutex"])
def test_show_errors(self):
for key in ["ERRORS", "WARNINGS"]:
show = self.validate_identity(f"SHOW {key}")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, key)
show = self.validate_identity(f"SHOW {key} LIMIT 2, 3")
self.assertEqual(show.text("limit"), "3")
self.assertEqual(show.text("offset"), "2")
def test_show_index(self):
show = self.validate_identity("SHOW INDEX FROM foo")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "INDEX")
self.assertEqual(show.text("target"), "foo")
show = self.validate_identity("SHOW INDEX FROM foo FROM bar")
self.assertEqual(show.text("db"), "bar")
def test_show_db_like_or_where_sql(self):
for key in [
"OPEN TABLES",
"TABLE STATUS",
"TRIGGERS",
]:
show = self.validate_identity(f"SHOW {key}")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, key)
show = self.validate_identity(f"SHOW {key} FROM db_name")
self.assertEqual(show.name, key)
self.assertEqual(show.text("db"), "db_name")
show = self.validate_identity(f"SHOW {key} LIKE '%foo%'")
self.assertEqual(show.name, key)
self.assertIsInstance(show.args["like"], exp.Literal)
self.assertEqual(show.text("like"), "%foo%")
show = self.validate_identity(f"SHOW {key} WHERE Column_name LIKE '%foo%'")
self.assertEqual(show.name, key)
self.assertIsInstance(show.args["where"], exp.Where)
self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'")
def test_show_processlist(self):
show = self.validate_identity("SHOW PROCESSLIST")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "PROCESSLIST")
self.assertFalse(show.args["full"])
show = self.validate_identity("SHOW FULL PROCESSLIST")
self.assertEqual(show.name, "PROCESSLIST")
self.assertTrue(show.args["full"])
def test_show_profile(self):
show = self.validate_identity("SHOW PROFILE")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "PROFILE")
show = self.validate_identity("SHOW PROFILE BLOCK IO")
self.assertEqual(show.args["types"][0].name, "BLOCK IO")
show = self.validate_identity(
"SHOW PROFILE BLOCK IO, PAGE FAULTS FOR QUERY 1 OFFSET 2 LIMIT 3"
)
self.assertEqual(show.args["types"][0].name, "BLOCK IO")
self.assertEqual(show.args["types"][1].name, "PAGE FAULTS")
self.assertEqual(show.text("query"), "1")
self.assertEqual(show.text("offset"), "2")
self.assertEqual(show.text("limit"), "3")
def test_show_replica_status(self):
show = self.validate_identity("SHOW REPLICA STATUS")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "REPLICA STATUS")
show = self.validate_identity("SHOW SLAVE STATUS", "SHOW REPLICA STATUS")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "REPLICA STATUS")
show = self.validate_identity("SHOW REPLICA STATUS FOR CHANNEL channel_name")
self.assertEqual(show.text("channel"), "channel_name")
def test_show_tables(self):
show = self.validate_identity("SHOW TABLES")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "TABLES")
show = self.validate_identity("SHOW FULL TABLES FROM db_name LIKE '%foo%'")
self.assertTrue(show.args["full"])
self.assertEqual(show.text("db"), "db_name")
self.assertIsInstance(show.args["like"], exp.Literal)
self.assertEqual(show.text("like"), "%foo%")
def test_set_variable(self):
cmd = self.parse_one("SET SESSION x = 1")
item = cmd.expressions[0]
self.assertEqual(item.text("kind"), "SESSION")
self.assertIsInstance(item.this, exp.EQ)
self.assertEqual(item.this.left.name, "x")
self.assertEqual(item.this.right.name, "1")
cmd = self.parse_one("SET @@GLOBAL.x = @@GLOBAL.y")
item = cmd.expressions[0]
self.assertEqual(item.text("kind"), "")
self.assertIsInstance(item.this, exp.EQ)
self.assertIsInstance(item.this.left, exp.SessionParameter)
self.assertIsInstance(item.this.right, exp.SessionParameter)
cmd = self.parse_one("SET NAMES 'charset_name' COLLATE 'collation_name'")
item = cmd.expressions[0]
self.assertEqual(item.text("kind"), "NAMES")
self.assertEqual(item.name, "charset_name")
self.assertEqual(item.text("collate"), "collation_name")
cmd = self.parse_one("SET CHARSET DEFAULT")
item = cmd.expressions[0]
self.assertEqual(item.text("kind"), "CHARACTER SET")
self.assertEqual(item.this.name, "DEFAULT")
cmd = self.parse_one("SET x = 1, y = 2")
self.assertEqual(len(cmd.expressions), 2)

View file

@ -8,7 +8,9 @@ class TestPostgres(Validator):
def test_ddl(self):
self.validate_all(
"CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)",
write={"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"},
write={
"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"
},
)
self.validate_all(
"CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)",
@ -59,15 +61,27 @@ class TestPostgres(Validator):
def test_postgres(self):
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END")
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END")
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END")
self.validate_identity('SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')')
self.validate_identity(
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END"
)
self.validate_identity(
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END"
)
self.validate_identity(
'SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')'
)
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')")
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')")
self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))")
self.validate_identity(
"SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')"
)
self.validate_identity(
"SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))"
)
self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')")
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
self.validate_identity("SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')")
self.validate_identity(
"SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')"
)
self.validate_identity("COMMENT ON TABLE mytable IS 'this'")
self.validate_identity("SELECT e'\\xDEADBEEF'")
self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
@ -75,7 +89,7 @@ class TestPostgres(Validator):
self.validate_all(
"CREATE TABLE x (a UUID, b BYTEA)",
write={
"duckdb": "CREATE TABLE x (a UUID, b BINARY)",
"duckdb": "CREATE TABLE x (a UUID, b VARBINARY)",
"presto": "CREATE TABLE x (a UUID, b VARBINARY)",
"hive": "CREATE TABLE x (a UUID, b BINARY)",
"spark": "CREATE TABLE x (a UUID, b BINARY)",
@ -153,7 +167,9 @@ class TestPostgres(Validator):
)
self.validate_all(
"SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss",
read={"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"},
read={
"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"
},
)
self.validate_all(
"SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL",
@ -169,11 +185,15 @@ class TestPostgres(Validator):
)
self.validate_all(
"SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL",
read={"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL"},
read={
"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL"
},
)
self.validate_all(
"SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL",
read={"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL"},
read={
"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL"
},
)
self.validate_all(
"'[1,2,3]'::json->2",
@ -184,7 +204,8 @@ class TestPostgres(Validator):
write={"postgres": """CAST('{"a":1,"b":2}' AS JSON)->'b'"""},
)
self.validate_all(
"""'{"x": {"y": 1}}'::json->'x'->'y'""", write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""}
"""'{"x": {"y": 1}}'::json->'x'->'y'""",
write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""},
)
self.validate_all(
"""'{"x": {"y": 1}}'::json->'x'::json->'y'""",

View file

@ -61,4 +61,6 @@ class TestRedshift(Validator):
"SELECT caldate + INTERVAL '1 second' AS dateplus FROM date WHERE caldate = '12-31-2008'"
)
self.validate_identity("CREATE TABLE datetable (start_date DATE, end_date DATE)")
self.validate_identity("SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'")
self.validate_identity(
"SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'"
)

View file

@ -336,7 +336,8 @@ class TestSnowflake(Validator):
def test_table_literal(self):
# All examples from https://docs.snowflake.com/en/sql-reference/literals-table.html
self.validate_all(
r"""SELECT * FROM TABLE('MYTABLE')""", write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""}
r"""SELECT * FROM TABLE('MYTABLE')""",
write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""},
)
self.validate_all(
@ -352,15 +353,123 @@ class TestSnowflake(Validator):
write={"snowflake": r"""SELECT * FROM TABLE('MYDB. "MYSCHEMA"."MYTABLE"')"""},
)
self.validate_all(r"""SELECT * FROM TABLE($MYVAR)""", write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""})
self.validate_all(r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""})
self.validate_all(
r"""SELECT * FROM TABLE($MYVAR)""",
write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""},
)
self.validate_all(
r"""SELECT * FROM TABLE(:BINDING)""", write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""}
r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""}
)
self.validate_all(
r"""SELECT * FROM TABLE(:BINDING)""",
write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""},
)
self.validate_all(
r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10""",
write={"snowflake": r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10"""},
)
def test_flatten(self):
self.validate_all(
"""
select
dag_report.acct_id,
dag_report.report_date,
dag_report.report_uuid,
dag_report.airflow_name,
dag_report.dag_id,
f.value::varchar as operator
from cs.telescope.dag_report,
table(flatten(input=>split(operators, ','))) f
""",
write={
"snowflake": """SELECT
dag_report.acct_id,
dag_report.report_date,
dag_report.report_uuid,
dag_report.airflow_name,
dag_report.dag_id,
CAST(f.value AS VARCHAR) AS operator
FROM cs.telescope.dag_report, TABLE(FLATTEN(input => SPLIT(operators, ','))) AS f"""
},
pretty=True,
)
# All examples from https://docs.snowflake.com/en/sql-reference/functions/flatten.html#syntax
self.validate_all(
"SELECT * FROM TABLE(FLATTEN(input => parse_json('[1, ,77]'))) f",
write={
"snowflake": "SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[1, ,77]'))) AS f"
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88]}'), outer => true)) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88]}'), outer => TRUE)) AS f"""
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88]}'), path => 'b')) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88]}'), path => 'b')) AS f"""
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('[]'))) f""",
write={"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[]'))) AS f"""},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('[]'), outer => true)) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[]'), outer => TRUE)) AS f"""
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'))) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'))) AS f"""
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => true)) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => TRUE)) AS f"""
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => true, mode => 'object')) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => TRUE, mode => 'object')) AS f"""
},
)
self.validate_all(
"""
SELECT id as "ID",
f.value AS "Contact",
f1.value:type AS "Type",
f1.value:content AS "Details"
FROM persons p,
lateral flatten(input => p.c, path => 'contact') f,
lateral flatten(input => f.value:business) f1
""",
write={
"snowflake": """SELECT
id AS "ID",
f.value AS "Contact",
f1.value['type'] AS "Type",
f1.value['content'] AS "Details"
FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') f, LATERAL FLATTEN(input => f.value['business']) f1""",
},
pretty=True,
)

View file

@ -284,4 +284,6 @@ TBLPROPERTIES (
)
def test_iif(self):
self.validate_all("SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"})
self.validate_all(
"SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"}
)

View file

@ -6,3 +6,6 @@ class TestMySQL(Validator):
def test_identity(self):
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")
def test_time(self):
self.validate_identity("TIMESTAMP('2022-01-01')")

View file

@ -278,12 +278,19 @@ class TestTSQL(Validator):
def test_add_date(self):
self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')")
self.validate_all(
"SELECT DATEADD(year, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"}
"SELECT DATEADD(year, 1, '2017/08/25')",
write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"},
)
self.validate_all(
"SELECT DATEADD(qq, 1, '2017/08/25')",
write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"},
)
self.validate_all("SELECT DATEADD(qq, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"})
self.validate_all(
"SELECT DATEADD(wk, 1, '2017/08/25')",
write={"spark": "SELECT DATE_ADD('2017/08/25', 7)", "databricks": "SELECT DATEADD(week, 1, '2017/08/25')"},
write={
"spark": "SELECT DATE_ADD('2017/08/25', 7)",
"databricks": "SELECT DATEADD(week, 1, '2017/08/25')",
},
)
def test_date_diff(self):
@ -370,13 +377,21 @@ class TestTSQL(Validator):
"SELECT FORMAT(1000000.01,'###,###.###')",
write={"spark": "SELECT FORMAT_NUMBER(1000000.01, '###,###.###')"},
)
self.validate_all("SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"})
self.validate_all(
"SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"}
)
self.validate_all(
"SELECT FORMAT('01-01-1991', 'dd.mm.yyyy')",
write={"spark": "SELECT DATE_FORMAT('01-01-1991', 'dd.mm.yyyy')"},
)
self.validate_all(
"SELECT FORMAT(date_col, 'dd.mm.yyyy')", write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"}
"SELECT FORMAT(date_col, 'dd.mm.yyyy')",
write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"},
)
self.validate_all(
"SELECT FORMAT(date_col, 'm')",
write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"},
)
self.validate_all(
"SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"}
)
self.validate_all("SELECT FORMAT(date_col, 'm')", write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"})
self.validate_all("SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"})