1
0
Fork 0

Merging upstream version 20.1.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:17:09 +01:00
parent d4fe7bdb16
commit 90988d8258
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
127 changed files with 73384 additions and 73067 deletions

View file

@ -7,9 +7,10 @@ from sqlglot import (
ParseError,
TokenError,
UnsupportedError,
exp,
parse_one,
)
from sqlglot.dialects import Hive
from sqlglot.dialects import BigQuery, Hive, Snowflake
class Validator(unittest.TestCase):
@ -78,9 +79,56 @@ class TestDialect(Validator):
self.assertIsNotNone(Dialect[dialect.value])
def test_get_or_raise(self):
self.assertEqual(Dialect.get_or_raise(Hive), Hive)
self.assertEqual(Dialect.get_or_raise(Hive()), Hive)
self.assertEqual(Dialect.get_or_raise("hive"), Hive)
self.assertIsInstance(Dialect.get_or_raise(Hive), Hive)
self.assertIsInstance(Dialect.get_or_raise(Hive()), Hive)
self.assertIsInstance(Dialect.get_or_raise("hive"), Hive)
with self.assertRaises(ValueError):
Dialect.get_or_raise(1)
default_mysql = Dialect.get_or_raise("mysql")
self.assertEqual(default_mysql.normalization_strategy, "CASE_SENSITIVE")
lowercase_mysql = Dialect.get_or_raise("mysql,normalization_strategy=lowercase")
self.assertEqual(lowercase_mysql.normalization_strategy, "LOWERCASE")
lowercase_mysql = Dialect.get_or_raise("mysql, normalization_strategy = lowercase")
self.assertEqual(lowercase_mysql.normalization_strategy.value, "LOWERCASE")
with self.assertRaises(ValueError) as cm:
Dialect.get_or_raise("mysql, normalization_strategy")
self.assertEqual(
str(cm.exception),
"Invalid dialect format: 'mysql, normalization_strategy'. "
"Please use the correct format: 'dialect [, k1 = v2 [, ...]]'.",
)
def test_compare_dialects(self):
bigquery_class = Dialect["bigquery"]
bigquery_object = BigQuery()
bigquery_string = "bigquery"
snowflake_class = Dialect["snowflake"]
snowflake_object = Snowflake()
snowflake_string = "snowflake"
self.assertEqual(snowflake_class, snowflake_class)
self.assertEqual(snowflake_class, snowflake_object)
self.assertEqual(snowflake_class, snowflake_string)
self.assertEqual(snowflake_object, snowflake_object)
self.assertEqual(snowflake_object, snowflake_string)
self.assertNotEqual(snowflake_class, bigquery_class)
self.assertNotEqual(snowflake_class, bigquery_object)
self.assertNotEqual(snowflake_class, bigquery_string)
self.assertNotEqual(snowflake_object, bigquery_object)
self.assertNotEqual(snowflake_object, bigquery_string)
self.assertTrue(snowflake_class in {"snowflake", "bigquery"})
self.assertTrue(snowflake_object in {"snowflake", "bigquery"})
self.assertFalse(snowflake_class in {"bigquery", "redshift"})
self.assertFalse(snowflake_object in {"bigquery", "redshift"})
def test_cast(self):
self.validate_all(
@ -561,6 +609,7 @@ class TestDialect(Validator):
self.validate_all(
"TIME_TO_STR(x, '%Y-%m-%d')",
write={
"bigquery": "FORMAT_DATE('%Y-%m-%d', x)",
"drill": "TO_CHAR(x, 'yyyy-MM-dd')",
"duckdb": "STRFTIME(x, '%Y-%m-%d')",
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd')",
@ -866,9 +915,9 @@ class TestDialect(Validator):
write={
"drill": "CAST(x AS DATE)",
"duckdb": "CAST(x AS DATE)",
"hive": "TO_DATE(x)",
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)",
"spark": "TO_DATE(x)",
"hive": "CAST(x AS DATE)",
"presto": "CAST(x AS DATE)",
"spark": "CAST(x AS DATE)",
"sqlite": "x",
},
)
@ -893,7 +942,7 @@ class TestDialect(Validator):
self.validate_all(
"TS_OR_DS_ADD(CURRENT_DATE, 1, 'DAY')",
write={
"presto": "DATE_ADD('DAY', 1, CURRENT_DATE)",
"presto": "DATE_ADD('DAY', 1, CAST(CAST(CURRENT_DATE AS TIMESTAMP) AS DATE))",
"hive": "DATE_ADD(CURRENT_DATE, 1)",
},
)
@ -1268,13 +1317,6 @@ class TestDialect(Validator):
"doris": "LOWER(x) LIKE '%y'",
},
)
self.validate_all(
"SELECT * FROM a ORDER BY col_a NULLS LAST",
write={
"mysql": UnsupportedError,
"starrocks": UnsupportedError,
},
)
self.validate_all(
"POSITION(needle in haystack)",
write={
@ -1315,35 +1357,37 @@ class TestDialect(Validator):
self.validate_all(
"CONCAT_WS('-', 'a', 'b')",
write={
"clickhouse": "CONCAT_WS('-', 'a', 'b')",
"duckdb": "CONCAT_WS('-', 'a', 'b')",
"presto": "CONCAT_WS('-', 'a', 'b')",
"presto": "CONCAT_WS('-', CAST('a' AS VARCHAR), CAST('b' AS VARCHAR))",
"hive": "CONCAT_WS('-', 'a', 'b')",
"spark": "CONCAT_WS('-', 'a', 'b')",
"trino": "CONCAT_WS('-', 'a', 'b')",
"trino": "CONCAT_WS('-', CAST('a' AS VARCHAR), CAST('b' AS VARCHAR))",
},
)
self.validate_all(
"CONCAT_WS('-', x)",
write={
"clickhouse": "CONCAT_WS('-', x)",
"duckdb": "CONCAT_WS('-', x)",
"hive": "CONCAT_WS('-', x)",
"presto": "CONCAT_WS('-', x)",
"presto": "CONCAT_WS('-', CAST(x AS VARCHAR))",
"spark": "CONCAT_WS('-', x)",
"trino": "CONCAT_WS('-', x)",
"trino": "CONCAT_WS('-', CAST(x AS VARCHAR))",
},
)
self.validate_all(
"CONCAT(a)",
write={
"clickhouse": "a",
"presto": "a",
"trino": "a",
"clickhouse": "CONCAT(a)",
"presto": "CAST(a AS VARCHAR)",
"trino": "CAST(a AS VARCHAR)",
"tsql": "a",
},
)
self.validate_all(
"COALESCE(CAST(a AS TEXT), '')",
"CONCAT(COALESCE(a, ''))",
read={
"drill": "CONCAT(a)",
"duckdb": "CONCAT(a)",
@ -1442,6 +1486,76 @@ class TestDialect(Validator):
"spark": "FILTER(the_array, x -> x > 0)",
},
)
self.validate_all(
"a / b",
write={
"bigquery": "a / b",
"clickhouse": "a / b",
"databricks": "a / b",
"duckdb": "a / b",
"hive": "a / b",
"mysql": "a / b",
"oracle": "a / b",
"snowflake": "a / b",
"spark": "a / b",
"starrocks": "a / b",
"drill": "CAST(a AS DOUBLE) / b",
"postgres": "CAST(a AS DOUBLE PRECISION) / b",
"presto": "CAST(a AS DOUBLE) / b",
"redshift": "CAST(a AS DOUBLE PRECISION) / b",
"sqlite": "CAST(a AS REAL) / b",
"teradata": "CAST(a AS DOUBLE) / b",
"trino": "CAST(a AS DOUBLE) / b",
"tsql": "CAST(a AS FLOAT) / b",
},
)
def test_typeddiv(self):
typed_div = exp.Div(this=exp.column("a"), expression=exp.column("b"), typed=True)
div = exp.Div(this=exp.column("a"), expression=exp.column("b"))
typed_div_dialect = "presto"
div_dialect = "hive"
INT = exp.DataType.Type.INT
FLOAT = exp.DataType.Type.FLOAT
for expression, types, dialect, expected in [
(typed_div, (None, None), typed_div_dialect, "a / b"),
(typed_div, (None, None), div_dialect, "a / b"),
(div, (None, None), typed_div_dialect, "CAST(a AS DOUBLE) / b"),
(div, (None, None), div_dialect, "a / b"),
(typed_div, (INT, INT), typed_div_dialect, "a / b"),
(typed_div, (INT, INT), div_dialect, "CAST(a / b AS BIGINT)"),
(div, (INT, INT), typed_div_dialect, "CAST(a AS DOUBLE) / b"),
(div, (INT, INT), div_dialect, "a / b"),
(typed_div, (FLOAT, FLOAT), typed_div_dialect, "a / b"),
(typed_div, (FLOAT, FLOAT), div_dialect, "a / b"),
(div, (FLOAT, FLOAT), typed_div_dialect, "a / b"),
(div, (FLOAT, FLOAT), div_dialect, "a / b"),
(typed_div, (INT, FLOAT), typed_div_dialect, "a / b"),
(typed_div, (INT, FLOAT), div_dialect, "a / b"),
(div, (INT, FLOAT), typed_div_dialect, "a / b"),
(div, (INT, FLOAT), div_dialect, "a / b"),
]:
with self.subTest(f"{expression.__class__.__name__} {types} {dialect} -> {expected}"):
expression = expression.copy()
expression.left.type = types[0]
expression.right.type = types[1]
self.assertEqual(expected, expression.sql(dialect=dialect))
def test_safediv(self):
safe_div = exp.Div(this=exp.column("a"), expression=exp.column("b"), safe=True)
div = exp.Div(this=exp.column("a"), expression=exp.column("b"))
safe_div_dialect = "mysql"
div_dialect = "snowflake"
for expression, dialect, expected in [
(safe_div, safe_div_dialect, "a / b"),
(safe_div, div_dialect, "a / NULLIF(b, 0)"),
(div, safe_div_dialect, "a / b"),
(div, div_dialect, "a / b"),
]:
with self.subTest(f"{expression.__class__.__name__} {dialect} -> {expected}"):
self.assertEqual(expected, expression.sql(dialect=dialect))
def test_limit(self):
self.validate_all(
@ -1547,7 +1661,7 @@ class TestDialect(Validator):
"CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 TEXT, c2 TEXT(1024))",
write={
"duckdb": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))",
"hive": "CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 STRING, c2 STRING(1024))",
"hive": "CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 STRING, c2 VARCHAR(1024))",
"oracle": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 CLOB, c2 CLOB(1024))",
"postgres": "CREATE TABLE t (b1 BYTEA, b2 BYTEA(1024), c1 TEXT, c2 TEXT(1024))",
"sqlite": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))",
@ -1864,7 +1978,7 @@ SELECT
write={
"bigquery": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
"clickhouse": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
"databricks": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
"databricks": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq",
"duckdb": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
"hive": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq",
"mysql": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
@ -1872,11 +1986,11 @@ SELECT
"presto": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
"redshift": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
"snowflake": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
"spark": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
"spark": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq",
"spark2": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq",
"sqlite": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
"trino": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
"tsql": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq",
"tsql": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c AS c FROM t) AS subq",
},
)
self.validate_all(
@ -1885,13 +1999,60 @@ SELECT
"bigquery": "SELECT * FROM (SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq1) AS subq2",
"duckdb": "SELECT * FROM (SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq1) AS subq2",
"hive": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT * FROM (SELECT c FROM t) AS subq1) AS subq2",
"tsql": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT * FROM (SELECT c FROM t) AS subq1) AS subq2",
"tsql": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT * FROM (SELECT c AS c FROM t) AS subq1) AS subq2",
},
)
self.validate_all(
"WITH t1(x) AS (SELECT 1) SELECT * FROM (WITH t2(y) AS (SELECT 2) SELECT y FROM t2) AS subq",
write={
"duckdb": "WITH t1(x) AS (SELECT 1) SELECT * FROM (WITH t2(y) AS (SELECT 2) SELECT y FROM t2) AS subq",
"tsql": "WITH t1(x) AS (SELECT 1), t2(y) AS (SELECT 2) SELECT * FROM (SELECT y FROM t2) AS subq",
"tsql": "WITH t1(x) AS (SELECT 1), t2(y) AS (SELECT 2) SELECT * FROM (SELECT y AS y FROM t2) AS subq",
},
)
def test_unsupported_null_ordering(self):
# We'll transpile a portable query from the following dialects to MySQL / T-SQL, which
# both treat NULLs as small values, so the expected output queries should be equivalent
with_last_nulls = "duckdb"
with_small_nulls = "spark"
with_large_nulls = "postgres"
sql = "SELECT * FROM t ORDER BY c"
sql_nulls_last = "SELECT * FROM t ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END, c"
sql_nulls_first = "SELECT * FROM t ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c"
for read_dialect, desc, nulls_first, expected_sql in (
(with_last_nulls, False, None, sql_nulls_last),
(with_last_nulls, True, None, sql),
(with_last_nulls, False, True, sql),
(with_last_nulls, True, True, sql_nulls_first),
(with_last_nulls, False, False, sql_nulls_last),
(with_last_nulls, True, False, sql),
(with_small_nulls, False, None, sql),
(with_small_nulls, True, None, sql),
(with_small_nulls, False, True, sql),
(with_small_nulls, True, True, sql_nulls_first),
(with_small_nulls, False, False, sql_nulls_last),
(with_small_nulls, True, False, sql),
(with_large_nulls, False, None, sql_nulls_last),
(with_large_nulls, True, None, sql_nulls_first),
(with_large_nulls, False, True, sql),
(with_large_nulls, True, True, sql_nulls_first),
(with_large_nulls, False, False, sql_nulls_last),
(with_large_nulls, True, False, sql),
):
with self.subTest(
f"read: {read_dialect}, descending: {desc}, nulls first: {nulls_first}"
):
sort_order = " DESC" if desc else ""
null_order = (
" NULLS FIRST"
if nulls_first
else (" NULLS LAST" if nulls_first is not None else "")
)
expected_sql = f"{expected_sql}{sort_order}"
expression = parse_one(f"{sql}{sort_order}{null_order}", read=read_dialect)
self.assertEqual(expression.sql(dialect="mysql"), expected_sql)
self.assertEqual(expression.sql(dialect="tsql"), expected_sql)