988 lines
32 KiB
Python
988 lines
32 KiB
Python
import os
|
|
import unittest
|
|
from unittest import mock
|
|
|
|
from sqlglot import parse_one, transpile
|
|
from sqlglot.errors import ErrorLevel, ParseError, UnsupportedError
|
|
from sqlglot.helper import logger as helper_logger
|
|
from sqlglot.parser import logger as parser_logger
|
|
from tests.helpers import (
|
|
assert_logger_contains,
|
|
load_sql_fixture_pairs,
|
|
load_sql_fixtures,
|
|
)
|
|
|
|
|
|
class TestTranspile(unittest.TestCase):
|
|
file_dir = os.path.dirname(__file__)
|
|
fixtures_dir = os.path.join(file_dir, "fixtures")
|
|
maxDiff = None
|
|
|
|
def validate(self, sql, target, **kwargs):
|
|
self.assertEqual(transpile(sql, **kwargs)[0], target)
|
|
|
|
def test_weird_chars(self):
|
|
self.assertEqual(transpile("0Êß")[0], "0 AS Êß")
|
|
|
|
def test_alias(self):
|
|
self.assertEqual(transpile("SELECT SUM(y) KEEP")[0], "SELECT SUM(y) AS KEEP")
|
|
self.assertEqual(transpile("SELECT 1 overwrite")[0], "SELECT 1 AS overwrite")
|
|
self.assertEqual(transpile("SELECT 1 is")[0], "SELECT 1 AS is")
|
|
self.assertEqual(transpile("SELECT 1 current_time")[0], "SELECT 1 AS current_time")
|
|
self.assertEqual(
|
|
transpile("SELECT 1 current_timestamp")[0], "SELECT 1 AS current_timestamp"
|
|
)
|
|
self.assertEqual(transpile("SELECT 1 current_date")[0], "SELECT 1 AS current_date")
|
|
self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime")
|
|
self.assertEqual(transpile("SELECT 1 row")[0], "SELECT 1 AS row")
|
|
|
|
self.assertEqual(
|
|
transpile("SELECT 1 FROM a.b.table1 t UNPIVOT((c3) FOR c4 IN (a, b))")[0],
|
|
"SELECT 1 FROM a.b.table1 AS t UNPIVOT((c3) FOR c4 IN (a, b))",
|
|
)
|
|
|
|
for key in ("union", "over", "from", "join"):
|
|
with self.subTest(f"alias {key}"):
|
|
self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}")
|
|
self.validate(f'SELECT x "{key}"', f'SELECT x AS "{key}"')
|
|
|
|
with self.assertRaises(ParseError):
|
|
self.validate(f"SELECT x {key}", "")
|
|
|
|
def test_unary(self):
|
|
self.validate("+++1", "1")
|
|
self.validate("+-1", "-1")
|
|
self.validate("+- - -1", "- - -1")
|
|
|
|
def test_paren(self):
|
|
with self.assertRaises(ParseError):
|
|
transpile("1 + (2 + 3")
|
|
transpile("select f(")
|
|
|
|
def test_some(self):
|
|
self.validate(
|
|
"SELECT * FROM x WHERE a = SOME (SELECT 1)",
|
|
"SELECT * FROM x WHERE a = ANY(SELECT 1)",
|
|
)
|
|
|
|
def test_leading_comma(self):
|
|
self.validate(
|
|
"SELECT a, b, c FROM (SELECT a, b, c FROM t)",
|
|
"SELECT\n"
|
|
" a\n"
|
|
" , b\n"
|
|
" , c\n"
|
|
"FROM (\n"
|
|
" SELECT\n"
|
|
" a\n"
|
|
" , b\n"
|
|
" , c\n"
|
|
" FROM t\n"
|
|
")",
|
|
leading_comma=True,
|
|
pretty=True,
|
|
pad=4,
|
|
indent=4,
|
|
)
|
|
self.validate(
|
|
"SELECT FOO, BAR, BAZ",
|
|
"SELECT\n FOO\n , BAR\n , BAZ",
|
|
leading_comma=True,
|
|
pretty=True,
|
|
)
|
|
self.validate(
|
|
"SELECT FOO, /*x*/\nBAR, /*y*/\nBAZ",
|
|
"SELECT\n FOO /* x */\n , BAR /* y */\n , BAZ",
|
|
leading_comma=True,
|
|
pretty=True,
|
|
)
|
|
# without pretty, this should be a no-op
|
|
self.validate(
|
|
"SELECT FOO, BAR, BAZ",
|
|
"SELECT FOO, BAR, BAZ",
|
|
leading_comma=True,
|
|
)
|
|
|
|
def test_space(self):
|
|
self.validate("SELECT MIN(3)>MIN(2)", "SELECT MIN(3) > MIN(2)")
|
|
self.validate("SELECT MIN(3)>=MIN(2)", "SELECT MIN(3) >= MIN(2)")
|
|
self.validate("SELECT 1>0", "SELECT 1 > 0")
|
|
self.validate("SELECT 3>=3", "SELECT 3 >= 3")
|
|
self.validate("SELECT a\r\nFROM b", "SELECT a FROM b")
|
|
|
|
def test_comments(self):
|
|
self.validate(
|
|
"SELECT c /* foo */ AS alias",
|
|
"SELECT c AS alias /* foo */",
|
|
)
|
|
self.validate(
|
|
"SELECT c AS /* foo */ (a, b, c) FROM t",
|
|
"SELECT c AS (a, b, c) /* foo */ FROM t",
|
|
)
|
|
self.validate(
|
|
"SELECT * FROM t1\n/*x*/\nUNION ALL SELECT * FROM t2",
|
|
"SELECT * FROM t1 /* x */ UNION ALL SELECT * FROM t2",
|
|
)
|
|
self.validate(
|
|
"/* comment */ SELECT * FROM a UNION SELECT * FROM b",
|
|
"/* comment */ SELECT * FROM a UNION SELECT * FROM b",
|
|
)
|
|
self.validate(
|
|
"SELECT * FROM t1\n/*x*/\nINTERSECT ALL SELECT * FROM t2",
|
|
"SELECT * FROM t1 /* x */ INTERSECT ALL SELECT * FROM t2",
|
|
)
|
|
self.validate(
|
|
"SELECT\n foo\n/* comments */\n;",
|
|
"SELECT foo /* comments */",
|
|
)
|
|
self.validate(
|
|
"SELECT * FROM a INNER /* comments */ JOIN b",
|
|
"SELECT * FROM a /* comments */ INNER JOIN b",
|
|
)
|
|
self.validate(
|
|
"SELECT * FROM a LEFT /* comment 1 */ OUTER /* comment 2 */ JOIN b",
|
|
"SELECT * FROM a /* comment 1 */ /* comment 2 */ LEFT OUTER JOIN b",
|
|
)
|
|
self.validate(
|
|
"SELECT CASE /* test */ WHEN a THEN b ELSE c END",
|
|
"SELECT CASE WHEN a THEN b ELSE c END /* test */",
|
|
)
|
|
self.validate("SELECT 1 /*/2 */", "SELECT 1 /* /2 */")
|
|
self.validate("SELECT */*comment*/", "SELECT * /* comment */")
|
|
self.validate(
|
|
"SELECT * FROM table /*comment 1*/ /*comment 2*/",
|
|
"SELECT * FROM table /* comment 1 */ /* comment 2 */",
|
|
)
|
|
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */")
|
|
self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo")
|
|
self.validate("SELECT --!5\nx FROM foo", "/* !5 */ SELECT x FROM foo")
|
|
self.validate(
|
|
"SELECT 1 /* inline */ FROM foo -- comment",
|
|
"SELECT 1 /* inline */ FROM foo /* comment */",
|
|
)
|
|
self.validate(
|
|
"SELECT FUN(x) /*x*/, [1,2,3] /*y*/", "SELECT FUN(x) /* x */, ARRAY(1, 2, 3) /* y */"
|
|
)
|
|
self.validate(
|
|
"""
|
|
SELECT 1 -- comment
|
|
FROM foo -- comment
|
|
""",
|
|
"SELECT 1 /* comment */ FROM foo /* comment */",
|
|
)
|
|
self.validate(
|
|
"""
|
|
SELECT 1 /* big comment
|
|
like this */
|
|
FROM foo -- comment
|
|
""",
|
|
"""SELECT 1 /* big comment
|
|
like this */ FROM foo /* comment */""",
|
|
)
|
|
self.validate(
|
|
"select x from foo -- x",
|
|
"SELECT x FROM foo /* x */",
|
|
)
|
|
self.validate(
|
|
"""select x, --
|
|
from foo""",
|
|
"SELECT x FROM foo",
|
|
)
|
|
self.validate(
|
|
"""
|
|
-- comment 1
|
|
-- comment 2
|
|
-- comment 3
|
|
SELECT * FROM foo
|
|
""",
|
|
"/* comment 1 */ /* comment 2 */ /* comment 3 */ SELECT * FROM foo",
|
|
)
|
|
self.validate(
|
|
"""
|
|
-- comment 1
|
|
-- comment 2
|
|
-- comment 3
|
|
SELECT * FROM foo""",
|
|
"""/* comment 1 */ /* comment 2 */ /* comment 3 */
|
|
SELECT
|
|
*
|
|
FROM foo""",
|
|
pretty=True,
|
|
)
|
|
self.validate(
|
|
"""
|
|
SELECT * FROM tbl /*line1
|
|
line2
|
|
line3*/ /*another comment*/ where 1=1 -- comment at the end""",
|
|
"""SELECT * FROM tbl /* line1
|
|
line2
|
|
line3 */ /* another comment */ WHERE 1 = 1 /* comment at the end */""",
|
|
)
|
|
self.validate(
|
|
"""
|
|
SELECT * FROM tbl /*line1
|
|
line2
|
|
line3*/ /*another comment*/ where 1=1 -- comment at the end""",
|
|
"""SELECT
|
|
*
|
|
FROM tbl /* line1
|
|
line2
|
|
line3 */ /* another comment */
|
|
WHERE
|
|
1 = 1 /* comment at the end */""",
|
|
pretty=True,
|
|
)
|
|
self.validate(
|
|
"""
|
|
/* multi
|
|
line
|
|
comment
|
|
*/
|
|
SELECT
|
|
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
|
|
CAST(x AS CHAR), # comment 3
|
|
y -- comment 4
|
|
FROM
|
|
bar /* comment 5 */,
|
|
tbl # comment 6
|
|
""",
|
|
"""/* multi
|
|
line
|
|
comment
|
|
*/
|
|
SELECT
|
|
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
|
|
CAST(x AS CHAR), /* comment 3 */
|
|
y /* comment 4 */
|
|
FROM bar /* comment 5 */, tbl /* comment 6 */""",
|
|
read="mysql",
|
|
pretty=True,
|
|
)
|
|
self.validate(
|
|
"""
|
|
SELECT a FROM b
|
|
WHERE foo
|
|
-- comment 1
|
|
AND bar
|
|
-- comment 2
|
|
AND bla
|
|
-- comment 3
|
|
LIMIT 10
|
|
;
|
|
""",
|
|
"SELECT a FROM b WHERE foo AND /* comment 1 */ bar AND /* comment 2 */ bla LIMIT 10 /* comment 3 */",
|
|
)
|
|
self.validate(
|
|
"""
|
|
SELECT a FROM b WHERE foo
|
|
-- comment 1
|
|
""",
|
|
"SELECT a FROM b WHERE foo /* comment 1 */",
|
|
)
|
|
self.validate(
|
|
"""
|
|
select a
|
|
-- from
|
|
from b
|
|
-- where
|
|
where foo
|
|
-- comment 1
|
|
and bar
|
|
-- comment 2
|
|
and bla
|
|
""",
|
|
"""SELECT
|
|
a
|
|
/* from */
|
|
FROM b
|
|
/* where */
|
|
WHERE
|
|
foo AND /* comment 1 */ bar AND /* comment 2 */ bla""",
|
|
pretty=True,
|
|
)
|
|
self.validate(
|
|
"""
|
|
-- test
|
|
WITH v AS (
|
|
SELECT
|
|
1 AS literal
|
|
)
|
|
SELECT
|
|
*
|
|
FROM v
|
|
""",
|
|
"""/* test */
|
|
WITH v AS (
|
|
SELECT
|
|
1 AS literal
|
|
)
|
|
SELECT
|
|
*
|
|
FROM v""",
|
|
pretty=True,
|
|
)
|
|
self.validate(
|
|
"(/* 1 */ 1 ) /* 2 */",
|
|
"(1) /* 1 */ /* 2 */",
|
|
)
|
|
self.validate(
|
|
"select * from t where not a in (23) /*test*/ and b in (14)",
|
|
"SELECT * FROM t WHERE NOT a IN (23) /* test */ AND b IN (14)",
|
|
)
|
|
self.validate(
|
|
"select * from t where a in (23) /*test*/ and b in (14)",
|
|
"SELECT * FROM t WHERE a IN (23) /* test */ AND b IN (14)",
|
|
)
|
|
self.validate(
|
|
"select * from t where ((condition = 1)/*test*/)",
|
|
"SELECT * FROM t WHERE ((condition = 1) /* test */)",
|
|
)
|
|
self.validate(
|
|
"SELECT 1 // hi this is a comment",
|
|
"SELECT 1 /* hi this is a comment */",
|
|
read="snowflake",
|
|
)
|
|
self.validate(
|
|
"-- comment\nDROP TABLE IF EXISTS foo",
|
|
"/* comment */ DROP TABLE IF EXISTS foo",
|
|
)
|
|
self.validate(
|
|
"""
|
|
-- comment1
|
|
-- comment2
|
|
|
|
-- comment3
|
|
DROP TABLE IF EXISTS db.tba
|
|
""",
|
|
"""/* comment1 */ /* comment2 */ /* comment3 */
|
|
DROP TABLE IF EXISTS db.tba""",
|
|
pretty=True,
|
|
)
|
|
self.validate(
|
|
"""
|
|
-- comment4
|
|
CREATE TABLE db.tba AS
|
|
SELECT a, b, c
|
|
FROM tb_01
|
|
WHERE
|
|
-- comment5
|
|
a = 1 AND b = 2 --comment6
|
|
-- and c = 1
|
|
-- comment7
|
|
;
|
|
""",
|
|
"""/* comment4 */
|
|
CREATE TABLE db.tba AS
|
|
SELECT
|
|
a,
|
|
b,
|
|
c
|
|
FROM tb_01
|
|
WHERE
|
|
a /* comment5 */ = 1 AND b = 2 /* comment6 */ /* and c = 1 */ /* comment7 */""",
|
|
pretty=True,
|
|
)
|
|
self.validate(
|
|
"""
|
|
SELECT
|
|
-- This is testing comments
|
|
col,
|
|
-- 2nd testing comments
|
|
CASE WHEN a THEN b ELSE c END as d
|
|
FROM t
|
|
""",
|
|
"""SELECT
|
|
col, /* This is testing comments */
|
|
CASE WHEN a THEN b ELSE c END AS d /* 2nd testing comments */
|
|
FROM t""",
|
|
pretty=True,
|
|
)
|
|
self.validate(
|
|
"""
|
|
SELECT * FROM a
|
|
-- comments
|
|
INNER JOIN b
|
|
""",
|
|
"""SELECT
|
|
*
|
|
FROM a
|
|
/* comments */
|
|
INNER JOIN b""",
|
|
pretty=True,
|
|
)
|
|
self.validate(
|
|
"SELECT * FROM a LEFT /* comment 1 */ OUTER /* comment 2 */ JOIN b",
|
|
"""SELECT
|
|
*
|
|
FROM a
|
|
/* comment 1 */ /* comment 2 */
|
|
LEFT OUTER JOIN b""",
|
|
pretty=True,
|
|
)
|
|
self.validate(
|
|
"SELECT\n a /* sqlglot.meta case_sensitive */ -- noqa\nFROM tbl",
|
|
"""SELECT
|
|
a /* sqlglot.meta case_sensitive */ /* noqa */
|
|
FROM tbl""",
|
|
pretty=True,
|
|
)
|
|
self.validate(
|
|
"""
|
|
SELECT
|
|
'hotel1' AS hotel,
|
|
*
|
|
FROM dw_1_dw_1_1.exactonline_1.transactionlines
|
|
/*
|
|
UNION ALL
|
|
SELECT
|
|
'Thon Partner Hotel Jølster' AS hotel,
|
|
name,
|
|
date,
|
|
CAST(identifier AS VARCHAR) AS identifier,
|
|
value
|
|
FROM d2o_889_oupjr_1348.public.accountvalues_forecast
|
|
*/
|
|
UNION ALL
|
|
SELECT
|
|
'hotel2' AS hotel,
|
|
*
|
|
FROM dw_1_dw_1_1.exactonline_2.transactionlines""",
|
|
"""SELECT
|
|
'hotel1' AS hotel,
|
|
*
|
|
FROM dw_1_dw_1_1.exactonline_1.transactionlines
|
|
/*
|
|
UNION ALL
|
|
SELECT
|
|
'Thon Partner Hotel Jølster' AS hotel,
|
|
name,
|
|
date,
|
|
CAST(identifier AS VARCHAR) AS identifier,
|
|
value
|
|
FROM d2o_889_oupjr_1348.public.accountvalues_forecast
|
|
*/
|
|
UNION ALL
|
|
SELECT
|
|
'hotel2' AS hotel,
|
|
*
|
|
FROM dw_1_dw_1_1.exactonline_2.transactionlines""",
|
|
pretty=True,
|
|
)
|
|
self.validate(
|
|
"""/* The result of some calculations
|
|
*/
|
|
with
|
|
base as (
|
|
select
|
|
sum(sb.hep_amount) as hep_amount,
|
|
-- I AM REMOVED
|
|
sum(sb.hep_budget)
|
|
/* Budget defined in sharepoint */
|
|
as blub
|
|
, 1 as bla
|
|
from gold.data_budget sb
|
|
group by all
|
|
)
|
|
select
|
|
*
|
|
from base
|
|
""",
|
|
"""/* The result of some calculations
|
|
*/
|
|
WITH base AS (
|
|
SELECT
|
|
SUM(sb.hep_amount) AS hep_amount,
|
|
SUM(sb.hep_budget) /* I AM REMOVED */ AS blub, /* Budget defined in sharepoint */
|
|
1 AS bla
|
|
FROM gold.data_budget AS sb
|
|
GROUP BY ALL
|
|
)
|
|
SELECT
|
|
*
|
|
FROM base""",
|
|
pretty=True,
|
|
)
|
|
self.validate(
|
|
"""-- comment
|
|
SOME_FUNC(arg IGNORE NULLS)
|
|
OVER (PARTITION BY foo ORDER BY bla) AS col""",
|
|
"SOME_FUNC(arg IGNORE NULLS) OVER (PARTITION BY foo ORDER BY bla) AS col /* comment */",
|
|
pretty=True,
|
|
)
|
|
self.validate(
|
|
"""
|
|
SELECT *
|
|
FROM x
|
|
INNER JOIN y
|
|
-- inner join z
|
|
LEFT JOIN z using (id)
|
|
using (id)
|
|
""",
|
|
"""SELECT
|
|
*
|
|
FROM x
|
|
INNER JOIN y
|
|
/* inner join z */
|
|
LEFT JOIN z
|
|
USING (id)
|
|
USING (id)""",
|
|
pretty=True,
|
|
)
|
|
self.validate(
|
|
"""with x as (
|
|
SELECT *
|
|
/*
|
|
NOTE: LEFT JOIN because blah blah blah
|
|
*/
|
|
FROM a
|
|
)
|
|
select * from x""",
|
|
"""WITH x AS (
|
|
SELECT
|
|
*
|
|
/*
|
|
NOTE: LEFT JOIN because blah blah blah
|
|
*/
|
|
FROM a
|
|
)
|
|
SELECT
|
|
*
|
|
FROM x""",
|
|
pretty=True,
|
|
)
|
|
|
|
def test_types(self):
|
|
self.validate("INT 1", "CAST(1 AS INT)")
|
|
self.validate("VARCHAR 'x' y", "CAST('x' AS VARCHAR) AS y")
|
|
self.validate("STRING 'x' y", "CAST('x' AS TEXT) AS y")
|
|
self.validate("x::INT", "CAST(x AS INT)")
|
|
self.validate("x::INTEGER", "CAST(x AS INT)")
|
|
self.validate("x::INT y", "CAST(x AS INT) AS y")
|
|
self.validate("x::INT AS y", "CAST(x AS INT) AS y")
|
|
self.validate("x::INT::BOOLEAN", "CAST(CAST(x AS INT) AS BOOLEAN)")
|
|
self.validate("interval::int", "CAST(interval AS INT)")
|
|
self.validate("x::user_defined_type", "CAST(x AS user_defined_type)")
|
|
self.validate("CAST(x::INT AS BOOLEAN)", "CAST(CAST(x AS INT) AS BOOLEAN)")
|
|
self.validate("CAST(x AS INT)::BOOLEAN", "CAST(CAST(x AS INT) AS BOOLEAN)")
|
|
|
|
with self.assertRaises(ParseError):
|
|
transpile("x::z", read="duckdb")
|
|
|
|
def test_not_range(self):
|
|
self.validate("a NOT LIKE b", "NOT a LIKE b")
|
|
self.validate("a NOT BETWEEN b AND c", "NOT a BETWEEN b AND c")
|
|
self.validate("a NOT IN (1, 2)", "NOT a IN (1, 2)")
|
|
self.validate("a IS NOT NULL", "NOT a IS NULL")
|
|
self.validate("a LIKE TEXT 'y'", "a LIKE CAST('y' AS TEXT)")
|
|
|
|
def test_extract(self):
|
|
self.validate(
|
|
"EXTRACT(day FROM '2020-01-01'::TIMESTAMP)",
|
|
"EXTRACT(day FROM CAST('2020-01-01' AS TIMESTAMP))",
|
|
)
|
|
self.validate(
|
|
"EXTRACT(timezone FROM '2020-01-01'::TIMESTAMP)",
|
|
"EXTRACT(timezone FROM CAST('2020-01-01' AS TIMESTAMP))",
|
|
)
|
|
self.validate(
|
|
"EXTRACT(year FROM '2020-01-01'::TIMESTAMP WITH TIME ZONE)",
|
|
"EXTRACT(year FROM CAST('2020-01-01' AS TIMESTAMPTZ))",
|
|
)
|
|
self.validate(
|
|
"extract(month from '2021-01-31'::timestamp without time zone)",
|
|
"EXTRACT(month FROM CAST('2021-01-31' AS TIMESTAMP))",
|
|
)
|
|
self.validate("extract(week from current_date + 2)", "EXTRACT(week FROM CURRENT_DATE + 2)")
|
|
self.validate(
|
|
"EXTRACT(minute FROM datetime1 - datetime2)",
|
|
"EXTRACT(minute FROM datetime1 - datetime2)",
|
|
)
|
|
|
|
def test_if(self):
|
|
self.validate(
|
|
"SELECT IF(a > 1, 1, 0) FROM foo",
|
|
"SELECT CASE WHEN a > 1 THEN 1 ELSE 0 END FROM foo",
|
|
)
|
|
self.validate(
|
|
"SELECT IF a > 1 THEN b END",
|
|
"SELECT CASE WHEN a > 1 THEN b END",
|
|
)
|
|
self.validate(
|
|
"SELECT IF a > 1 THEN b ELSE c END",
|
|
"SELECT CASE WHEN a > 1 THEN b ELSE c END",
|
|
)
|
|
self.validate("SELECT IF(a > 1, 1) FROM foo", "SELECT CASE WHEN a > 1 THEN 1 END FROM foo")
|
|
|
|
def test_with(self):
|
|
self.validate(
|
|
"WITH a AS (SELECT 1) WITH b AS (SELECT 2) SELECT *",
|
|
"WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *",
|
|
)
|
|
self.validate(
|
|
"WITH a AS (SELECT 1), WITH b AS (SELECT 2) SELECT *",
|
|
"WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *",
|
|
)
|
|
self.validate(
|
|
"WITH A(filter) AS (VALUES 1, 2, 3) SELECT * FROM A WHERE filter >= 2",
|
|
"WITH A(filter) AS (VALUES (1), (2), (3)) SELECT * FROM A WHERE filter >= 2",
|
|
read="presto",
|
|
)
|
|
self.validate(
|
|
"SELECT BOOL_OR(a > 10) FROM (VALUES 1, 2, 15) AS T(a)",
|
|
"SELECT BOOL_OR(a > 10) FROM (VALUES (1), (2), (15)) AS T(a)",
|
|
read="presto",
|
|
)
|
|
|
|
def test_alter(self):
|
|
self.validate(
|
|
"ALTER TABLE integers ADD k INTEGER",
|
|
"ALTER TABLE integers ADD COLUMN k INT",
|
|
)
|
|
self.validate(
|
|
"ALTER TABLE integers ALTER i TYPE VARCHAR",
|
|
"ALTER TABLE integers ALTER COLUMN i SET DATA TYPE VARCHAR",
|
|
)
|
|
self.validate(
|
|
"ALTER TABLE integers ALTER i TYPE VARCHAR COLLATE foo USING bar",
|
|
"ALTER TABLE integers ALTER COLUMN i SET DATA TYPE VARCHAR COLLATE foo USING bar",
|
|
)
|
|
|
|
def test_time(self):
|
|
self.validate("INTERVAL '1 day'", "INTERVAL '1' DAY")
|
|
self.validate("INTERVAL '1 days' * 5", "INTERVAL '1' DAYS * 5")
|
|
self.validate("5 * INTERVAL '1 day'", "5 * INTERVAL '1' DAY")
|
|
self.validate("INTERVAL 1 day", "INTERVAL '1' DAY")
|
|
self.validate("INTERVAL 2 months", "INTERVAL '2' MONTHS")
|
|
self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)")
|
|
self.validate("TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)")
|
|
self.validate(
|
|
"TIMESTAMP(9) WITH TIME ZONE '2020-01-01'",
|
|
"CAST('2020-01-01' AS TIMESTAMPTZ(9))",
|
|
)
|
|
self.validate(
|
|
"TIMESTAMP WITHOUT TIME ZONE '2020-01-01'",
|
|
"CAST('2020-01-01' AS TIMESTAMP)",
|
|
)
|
|
self.validate("'2020-01-01'::TIMESTAMP", "CAST('2020-01-01' AS TIMESTAMP)")
|
|
self.validate(
|
|
"'2020-01-01'::TIMESTAMP WITHOUT TIME ZONE",
|
|
"CAST('2020-01-01' AS TIMESTAMP)",
|
|
)
|
|
self.validate(
|
|
"'2020-01-01'::TIMESTAMP WITH TIME ZONE",
|
|
"CAST('2020-01-01' AS TIMESTAMPTZ)",
|
|
)
|
|
self.validate(
|
|
"timestamp with time zone '2025-11-20 00:00:00+00' AT TIME ZONE 'Africa/Cairo'",
|
|
"CAST('2025-11-20 00:00:00+00' AS TIMESTAMPTZ) AT TIME ZONE 'Africa/Cairo'",
|
|
)
|
|
|
|
self.validate("DATE '2020-01-01'", "CAST('2020-01-01' AS DATE)")
|
|
self.validate("'2020-01-01'::DATE", "CAST('2020-01-01' AS DATE)")
|
|
self.validate("STR_TO_TIME('x', 'y')", "STRPTIME('x', 'y')", write="duckdb")
|
|
self.validate("STR_TO_UNIX('x', 'y')", "EPOCH(STRPTIME('x', 'y'))", write="duckdb")
|
|
self.validate("TIME_TO_STR(x, 'y')", "STRFTIME(x, 'y')", write="duckdb")
|
|
self.validate("TIME_TO_UNIX(x)", "EPOCH(x)", write="duckdb")
|
|
self.validate(
|
|
"UNIX_TO_STR(123, 'y')",
|
|
"STRFTIME(TO_TIMESTAMP(123), 'y')",
|
|
write="duckdb",
|
|
)
|
|
self.validate(
|
|
"UNIX_TO_TIME(123)",
|
|
"TO_TIMESTAMP(123)",
|
|
write="duckdb",
|
|
)
|
|
|
|
self.validate(
|
|
"STR_TO_TIME(x, 'y')",
|
|
"CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'y')) AS TIMESTAMP)",
|
|
write="hive",
|
|
)
|
|
self.validate(
|
|
"STR_TO_TIME(x, 'yyyy-MM-dd HH:mm:ss')",
|
|
"CAST(x AS TIMESTAMP)",
|
|
write="hive",
|
|
)
|
|
self.validate(
|
|
"STR_TO_TIME(x, 'yyyy-MM-dd')",
|
|
"CAST(x AS TIMESTAMP)",
|
|
write="hive",
|
|
)
|
|
|
|
self.validate(
|
|
"STR_TO_UNIX('x', 'y')",
|
|
"UNIX_TIMESTAMP('x', 'y')",
|
|
write="hive",
|
|
)
|
|
self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="hive")
|
|
|
|
self.validate("TIME_STR_TO_TIME(x)", "TIME_STR_TO_TIME(x)", write=None)
|
|
self.validate("TIME_STR_TO_UNIX(x)", "TIME_STR_TO_UNIX(x)", write=None)
|
|
self.validate("TIME_TO_TIME_STR(x)", "CAST(x AS TEXT)", write=None)
|
|
self.validate("TIME_TO_STR(x, 'y')", "TIME_TO_STR(x, 'y')", write=None)
|
|
self.validate("TIME_TO_UNIX(x)", "TIME_TO_UNIX(x)", write=None)
|
|
self.validate("UNIX_TO_STR(x, 'y')", "UNIX_TO_STR(x, 'y')", write=None)
|
|
self.validate("UNIX_TO_TIME(x)", "UNIX_TO_TIME(x)", write=None)
|
|
self.validate("UNIX_TO_TIME_STR(x)", "UNIX_TO_TIME_STR(x)", write=None)
|
|
self.validate("TIME_STR_TO_DATE(x)", "TIME_STR_TO_DATE(x)", write=None)
|
|
|
|
self.validate("TIME_STR_TO_DATE(x)", "TO_DATE(x)", write="hive")
|
|
self.validate("UNIX_TO_STR(x, 'yyyy-MM-dd HH:mm:ss')", "FROM_UNIXTIME(x)", write="hive")
|
|
self.validate("STR_TO_UNIX(x, 'yyyy-MM-dd HH:mm:ss')", "UNIX_TIMESTAMP(x)", write="hive")
|
|
self.validate("IF(x > 1, x + 1)", "IF(x > 1, x + 1)", write="presto")
|
|
self.validate("IF(x > 1, 1 + 1)", "IF(x > 1, 1 + 1)", write="hive")
|
|
self.validate("IF(x > 1, 1, 0)", "IF(x > 1, 1, 0)", write="hive")
|
|
|
|
self.validate(
|
|
"TIME_TO_UNIX(x)",
|
|
"UNIX_TIMESTAMP(x)",
|
|
write="hive",
|
|
)
|
|
self.validate("UNIX_TO_STR(123, 'y')", "FROM_UNIXTIME(123, 'y')", write="hive")
|
|
self.validate(
|
|
"UNIX_TO_TIME(123)",
|
|
"FROM_UNIXTIME(123)",
|
|
write="hive",
|
|
)
|
|
|
|
self.validate("STR_TO_TIME('x', 'y')", "DATE_PARSE('x', 'y')", write="presto")
|
|
self.validate(
|
|
"STR_TO_UNIX('x', 'y')",
|
|
"TO_UNIXTIME(COALESCE(TRY(DATE_PARSE(CAST('x' AS VARCHAR), 'y')), PARSE_DATETIME(CAST('x' AS VARCHAR), 'y')))",
|
|
write="presto",
|
|
)
|
|
self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="presto")
|
|
self.validate("TIME_TO_UNIX(x)", "TO_UNIXTIME(x)", write="presto")
|
|
self.validate(
|
|
"UNIX_TO_STR(123, 'y')",
|
|
"DATE_FORMAT(FROM_UNIXTIME(123), 'y')",
|
|
write="presto",
|
|
)
|
|
self.validate("UNIX_TO_TIME(123)", "FROM_UNIXTIME(123)", write="presto")
|
|
|
|
self.validate("STR_TO_TIME('x', 'y')", "TO_TIMESTAMP('x', 'y')", write="spark")
|
|
self.validate("STR_TO_UNIX('x', 'y')", "UNIX_TIMESTAMP('x', 'y')", write="spark")
|
|
self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="spark")
|
|
|
|
self.validate(
|
|
"TIME_TO_UNIX(x)",
|
|
"UNIX_TIMESTAMP(x)",
|
|
write="spark",
|
|
)
|
|
self.validate("UNIX_TO_STR(123, 'y')", "FROM_UNIXTIME(123, 'y')", write="spark")
|
|
self.validate(
|
|
"UNIX_TO_TIME(123)",
|
|
"CAST(FROM_UNIXTIME(123) AS TIMESTAMP)",
|
|
write="spark",
|
|
)
|
|
self.validate(
|
|
"CREATE TEMPORARY TABLE test AS SELECT 1",
|
|
"CREATE TEMPORARY VIEW test AS SELECT 1",
|
|
write="spark2",
|
|
)
|
|
|
|
def test_index_offset(self):
|
|
with self.assertLogs(helper_logger) as cm:
|
|
self.validate("x[0]", "x[1]", write="presto", identity=False)
|
|
self.validate("x[1]", "x[0]", read="presto", identity=False)
|
|
|
|
self.validate("x[x - 1]", "x[x - 1]", write="presto", identity=False)
|
|
self.validate(
|
|
"x[array_size(y) - 1]",
|
|
"x[(CARDINALITY(y) - 1) + 1]",
|
|
write="presto",
|
|
identity=False,
|
|
)
|
|
self.validate("x[3 - 1]", "x[3]", write="presto", identity=False)
|
|
self.validate("MAP(a, b)[0]", "MAP(a, b)[0]", write="presto", identity=False)
|
|
|
|
self.assertEqual(
|
|
cm.output,
|
|
[
|
|
"WARNING:sqlglot:Applying array index offset (1)",
|
|
"WARNING:sqlglot:Applying array index offset (-1)",
|
|
"WARNING:sqlglot:Applying array index offset (1)",
|
|
"WARNING:sqlglot:Applying array index offset (1)",
|
|
],
|
|
)
|
|
|
|
def test_identify_lambda(self):
|
|
self.validate("x(y -> y)", 'X("y" -> "y")', identify=True)
|
|
|
|
def test_identity(self):
|
|
self.assertEqual(transpile("")[0], "")
|
|
for sql in load_sql_fixtures("identity.sql"):
|
|
with self.subTest(sql):
|
|
self.assertEqual(transpile(sql)[0], sql.strip())
|
|
|
|
def test_command_identity(self):
|
|
for sql in (
|
|
"ALTER AGGREGATE bla(foo) OWNER TO CURRENT_USER",
|
|
"ALTER DOMAIN foo VALIDATE CONSTRAINT bla",
|
|
"ALTER ROLE CURRENT_USER WITH REPLICATION",
|
|
"ALTER RULE foo ON bla RENAME TO baz",
|
|
"ALTER SEQUENCE IF EXISTS baz RESTART WITH boo",
|
|
"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS=3",
|
|
"ALTER TABLE integers DROP PRIMARY KEY",
|
|
"ALTER TABLE table1 MODIFY COLUMN name1 SET TAG foo='bar'",
|
|
"ALTER TABLE table1 RENAME COLUMN c1 AS c2",
|
|
"ALTER TABLE table1 RENAME COLUMN c1 TO c2, c2 TO c3",
|
|
"ALTER TABLE table1 RENAME COLUMN c1 c2",
|
|
"ALTER TYPE electronic_mail RENAME TO email",
|
|
"ALTER VIEW foo ALTER COLUMN bla SET DEFAULT 'NOT SET'",
|
|
"ALTER schema doo",
|
|
"ANALYZE a.y",
|
|
"CALL catalog.system.iceberg_procedure_name(named_arg_1 => 'arg_1', named_arg_2 => 'arg_2')",
|
|
"COMMENT ON ACCESS METHOD gin IS 'GIN index access method'",
|
|
"CREATE OR REPLACE STAGE",
|
|
"EXECUTE statement",
|
|
"EXPLAIN SELECT * FROM x",
|
|
"GRANT INSERT ON foo TO bla",
|
|
"LOAD foo",
|
|
"OPTIMIZE TABLE y",
|
|
"PREPARE statement",
|
|
"SET -v",
|
|
"SET @user OFF",
|
|
"SHOW TABLES",
|
|
"VACUUM FREEZE my_table",
|
|
):
|
|
with self.subTest(sql):
|
|
with self.assertLogs(parser_logger) as cm:
|
|
self.assertEqual(transpile(sql)[0], sql)
|
|
assert f"'{sql[:100]}' contains unsupported syntax" in cm.output[0]
|
|
|
|
def test_normalize_name(self):
|
|
self.assertEqual(
|
|
transpile("cardinality(x)", read="presto", write="presto", normalize_functions="lower")[
|
|
0
|
|
],
|
|
"cardinality(x)",
|
|
)
|
|
|
|
def test_partial(self):
|
|
for sql in load_sql_fixtures("partial.sql"):
|
|
with self.subTest(sql):
|
|
self.assertEqual(transpile(sql, error_level=ErrorLevel.IGNORE)[0], sql.strip())
|
|
|
|
def test_pretty(self):
|
|
for _, sql, pretty in load_sql_fixture_pairs("pretty.sql"):
|
|
with self.subTest(sql[:100]):
|
|
generated = transpile(sql, pretty=True)[0]
|
|
self.assertEqual(generated, pretty)
|
|
self.assertEqual(parse_one(sql), parse_one(pretty))
|
|
|
|
def test_pretty_line_breaks(self):
|
|
self.assertEqual(transpile("SELECT '1\n2'", pretty=True)[0], "SELECT\n '1\n2'")
|
|
self.assertEqual(
|
|
transpile("SELECT '1\n2'", pretty=True, unsupported_level=ErrorLevel.IGNORE)[0],
|
|
"SELECT\n '1\n2'",
|
|
)
|
|
|
|
@mock.patch("sqlglot.parser.logger")
|
|
def test_error_level(self, logger):
|
|
invalid = "x + 1. ("
|
|
expected_messages = [
|
|
"Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Aliases'>. Line 1, Col: 8.\n x + 1. \033[4m(\033[0m",
|
|
"Expecting ). Line 1, Col: 8.\n x + 1. \033[4m(\033[0m",
|
|
]
|
|
expected_errors = [
|
|
{
|
|
"description": "Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Aliases'>",
|
|
"line": 1,
|
|
"col": 8,
|
|
"start_context": "x + 1. ",
|
|
"highlight": "(",
|
|
"end_context": "",
|
|
"into_expression": None,
|
|
},
|
|
{
|
|
"description": "Expecting )",
|
|
"line": 1,
|
|
"col": 8,
|
|
"start_context": "x + 1. ",
|
|
"highlight": "(",
|
|
"end_context": "",
|
|
"into_expression": None,
|
|
},
|
|
]
|
|
|
|
transpile(invalid, error_level=ErrorLevel.WARN)
|
|
for error in expected_messages:
|
|
assert_logger_contains(error, logger)
|
|
|
|
with self.assertRaises(ParseError) as ctx:
|
|
transpile(invalid, error_level=ErrorLevel.IMMEDIATE)
|
|
|
|
self.assertEqual(str(ctx.exception), expected_messages[0])
|
|
self.assertEqual(ctx.exception.errors[0], expected_errors[0])
|
|
|
|
with self.assertRaises(ParseError) as ctx:
|
|
transpile(invalid, error_level=ErrorLevel.RAISE)
|
|
|
|
self.assertEqual(str(ctx.exception), "\n\n".join(expected_messages))
|
|
self.assertEqual(ctx.exception.errors, expected_errors)
|
|
|
|
more_than_max_errors = "(((("
|
|
expected_messages = (
|
|
"Required keyword: 'this' missing for <class 'sqlglot.expressions.Paren'>. Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
|
|
"Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
|
|
"Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
|
|
"... and 2 more"
|
|
)
|
|
expected_errors = [
|
|
{
|
|
"description": "Required keyword: 'this' missing for <class 'sqlglot.expressions.Paren'>",
|
|
"line": 1,
|
|
"col": 4,
|
|
"start_context": "(((",
|
|
"highlight": "(",
|
|
"end_context": "",
|
|
"into_expression": None,
|
|
},
|
|
{
|
|
"description": "Expecting )",
|
|
"line": 1,
|
|
"col": 4,
|
|
"start_context": "(((",
|
|
"highlight": "(",
|
|
"end_context": "",
|
|
"into_expression": None,
|
|
},
|
|
]
|
|
# Also expect three trailing structured errors that match the first
|
|
expected_errors += [expected_errors[1]] * 3
|
|
|
|
with self.assertRaises(ParseError) as ctx:
|
|
transpile(more_than_max_errors, error_level=ErrorLevel.RAISE)
|
|
|
|
self.assertEqual(str(ctx.exception), expected_messages)
|
|
self.assertEqual(ctx.exception.errors, expected_errors)
|
|
|
|
@mock.patch("sqlglot.generator.logger")
|
|
def test_unsupported_level(self, logger):
|
|
def unsupported(level):
|
|
transpile(
|
|
"SELECT MAP(a, b), MAP(a, b), MAP(a, b), MAP(a, b)",
|
|
read="presto",
|
|
write="hive",
|
|
unsupported_level=level,
|
|
)
|
|
|
|
error = "Cannot convert array columns into map."
|
|
|
|
unsupported(ErrorLevel.WARN)
|
|
assert_logger_contains("\n".join([error] * 4), logger, level="warning")
|
|
|
|
with self.assertRaises(UnsupportedError) as ctx:
|
|
unsupported(ErrorLevel.RAISE)
|
|
self.assertEqual(str(ctx.exception).count(error), 3)
|
|
|
|
with self.assertRaises(UnsupportedError) as ctx:
|
|
unsupported(ErrorLevel.IMMEDIATE)
|
|
self.assertEqual(str(ctx.exception).count(error), 1)
|
|
|
|
def test_recursion(self):
|
|
sql = "1 AND 2 OR 3 AND " * 1000
|
|
sql += "4"
|
|
self.assertEqual(len(parse_one(sql).sql()), 17001)
|