Merging upstream version 9.0.3.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
66ef36a209
commit
b1dc5c6faf
22 changed files with 742 additions and 223 deletions
|
@ -9,7 +9,6 @@ from sqlglot.errors import ErrorLevel
|
|||
|
||||
|
||||
class TestFunctions(unittest.TestCase):
|
||||
@unittest.skip("not yet fixed.")
|
||||
def test_invoke_anonymous(self):
|
||||
for name, func in inspect.getmembers(SF, inspect.isfunction):
|
||||
with self.subTest(f"{name} should not invoke anonymous_function"):
|
||||
|
@ -438,13 +437,13 @@ class TestFunctions(unittest.TestCase):
|
|||
|
||||
def test_pow(self):
|
||||
col_str = SF.pow("cola", "colb")
|
||||
self.assertEqual("POW(cola, colb)", col_str.sql())
|
||||
self.assertEqual("POWER(cola, colb)", col_str.sql())
|
||||
col = SF.pow(SF.col("cola"), SF.col("colb"))
|
||||
self.assertEqual("POW(cola, colb)", col.sql())
|
||||
self.assertEqual("POWER(cola, colb)", col.sql())
|
||||
col_float = SF.pow(10.10, "colb")
|
||||
self.assertEqual("POW(10.1, colb)", col_float.sql())
|
||||
self.assertEqual("POWER(10.1, colb)", col_float.sql())
|
||||
col_float2 = SF.pow("cola", 10.10)
|
||||
self.assertEqual("POW(cola, 10.1)", col_float2.sql())
|
||||
self.assertEqual("POWER(cola, 10.1)", col_float2.sql())
|
||||
|
||||
def test_row_number(self):
|
||||
col_str = SF.row_number()
|
||||
|
@ -493,6 +492,8 @@ class TestFunctions(unittest.TestCase):
|
|||
self.assertEqual("COALESCE(cola, colb, colc)", col_str.sql())
|
||||
col = SF.coalesce(SF.col("cola"), "colb", SF.col("colc"))
|
||||
self.assertEqual("COALESCE(cola, colb, colc)", col.sql())
|
||||
col_single = SF.coalesce("cola")
|
||||
self.assertEqual("COALESCE(cola)", col_single.sql())
|
||||
|
||||
def test_corr(self):
|
||||
col_str = SF.corr("cola", "colb")
|
||||
|
@ -843,8 +844,8 @@ class TestFunctions(unittest.TestCase):
|
|||
self.assertEqual("TO_DATE(cola)", col_str.sql())
|
||||
col = SF.to_date(SF.col("cola"))
|
||||
self.assertEqual("TO_DATE(cola)", col.sql())
|
||||
col_with_format = SF.to_date("cola", "yyyy-MM-dd")
|
||||
self.assertEqual("TO_DATE(cola, 'yyyy-MM-dd')", col_with_format.sql())
|
||||
col_with_format = SF.to_date("cola", "yy-MM-dd")
|
||||
self.assertEqual("TO_DATE(cola, 'yy-MM-dd')", col_with_format.sql())
|
||||
|
||||
def test_to_timestamp(self):
|
||||
col_str = SF.to_timestamp("cola")
|
||||
|
@ -883,16 +884,16 @@ class TestFunctions(unittest.TestCase):
|
|||
self.assertEqual("FROM_UNIXTIME(cola)", col_str.sql())
|
||||
col = SF.from_unixtime(SF.col("cola"))
|
||||
self.assertEqual("FROM_UNIXTIME(cola)", col.sql())
|
||||
col_format = SF.from_unixtime("cola", "yyyy-MM-dd HH:mm:ss")
|
||||
self.assertEqual("FROM_UNIXTIME(cola, 'yyyy-MM-dd HH:mm:ss')", col_format.sql())
|
||||
col_format = SF.from_unixtime("cola", "yyyy-MM-dd HH:mm")
|
||||
self.assertEqual("FROM_UNIXTIME(cola, 'yyyy-MM-dd HH:mm')", col_format.sql())
|
||||
|
||||
def test_unix_timestamp(self):
|
||||
col_str = SF.unix_timestamp("cola")
|
||||
self.assertEqual("UNIX_TIMESTAMP(cola)", col_str.sql())
|
||||
col = SF.unix_timestamp(SF.col("cola"))
|
||||
self.assertEqual("UNIX_TIMESTAMP(cola)", col.sql())
|
||||
col_format = SF.unix_timestamp("cola", "yyyy-MM-dd HH:mm:ss")
|
||||
self.assertEqual("UNIX_TIMESTAMP(cola, 'yyyy-MM-dd HH:mm:ss')", col_format.sql())
|
||||
col_format = SF.unix_timestamp("cola", "yyyy-MM-dd HH:mm")
|
||||
self.assertEqual("UNIX_TIMESTAMP(cola, 'yyyy-MM-dd HH:mm')", col_format.sql())
|
||||
col_current = SF.unix_timestamp()
|
||||
self.assertEqual("UNIX_TIMESTAMP()", col_current.sql())
|
||||
|
||||
|
@ -1427,6 +1428,13 @@ class TestFunctions(unittest.TestCase):
|
|||
self.assertEqual("ARRAY_SORT(cola)", col_str.sql())
|
||||
col = SF.array_sort(SF.col("cola"))
|
||||
self.assertEqual("ARRAY_SORT(cola)", col.sql())
|
||||
col_comparator = SF.array_sort(
|
||||
"cola", lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(SF.length(y) - SF.length(x))
|
||||
)
|
||||
self.assertEqual(
|
||||
"ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)",
|
||||
col_comparator.sql(),
|
||||
)
|
||||
|
||||
def test_reverse(self):
|
||||
col_str = SF.reverse("cola")
|
||||
|
@ -1514,8 +1522,6 @@ class TestFunctions(unittest.TestCase):
|
|||
SF.lit(0),
|
||||
lambda accumulator, target: accumulator + target,
|
||||
lambda accumulator: accumulator * 2,
|
||||
"accumulator",
|
||||
"target",
|
||||
)
|
||||
self.assertEqual(
|
||||
"AGGREGATE(cola, 0, (accumulator, target) -> accumulator + target, accumulator -> accumulator * 2)",
|
||||
|
@ -1527,7 +1533,7 @@ class TestFunctions(unittest.TestCase):
|
|||
self.assertEqual("TRANSFORM(cola, x -> x * 2)", col_str.sql())
|
||||
col = SF.transform(SF.col("cola"), lambda x, i: x * i)
|
||||
self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql())
|
||||
col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count, "target", "row_count")
|
||||
col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count)
|
||||
|
||||
self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql())
|
||||
|
||||
|
@ -1536,7 +1542,7 @@ class TestFunctions(unittest.TestCase):
|
|||
self.assertEqual("EXISTS(cola, x -> x % 2 = 0)", col_str.sql())
|
||||
col = SF.exists(SF.col("cola"), lambda x: x % 2 == 0)
|
||||
self.assertEqual("EXISTS(cola, x -> x % 2 = 0)", col.sql())
|
||||
col_custom_name = SF.exists("cola", lambda target: target > 0, "target")
|
||||
col_custom_name = SF.exists("cola", lambda target: target > 0)
|
||||
self.assertEqual("EXISTS(cola, target -> target > 0)", col_custom_name.sql())
|
||||
|
||||
def test_forall(self):
|
||||
|
@ -1544,7 +1550,7 @@ class TestFunctions(unittest.TestCase):
|
|||
self.assertEqual("FORALL(cola, x -> x RLIKE 'foo')", col_str.sql())
|
||||
col = SF.forall(SF.col("cola"), lambda x: x.rlike("foo"))
|
||||
self.assertEqual("FORALL(cola, x -> x RLIKE 'foo')", col.sql())
|
||||
col_custom_name = SF.forall("cola", lambda target: target.rlike("foo"), "target")
|
||||
col_custom_name = SF.forall("cola", lambda target: target.rlike("foo"))
|
||||
self.assertEqual("FORALL(cola, target -> target RLIKE 'foo')", col_custom_name.sql())
|
||||
|
||||
def test_filter(self):
|
||||
|
@ -1552,9 +1558,7 @@ class TestFunctions(unittest.TestCase):
|
|||
self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql())
|
||||
col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i))
|
||||
self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql())
|
||||
col_custom_names = SF.filter(
|
||||
"cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count), "target", "row_count"
|
||||
)
|
||||
col_custom_names = SF.filter("cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count))
|
||||
|
||||
self.assertEqual(
|
||||
"FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql()
|
||||
|
@ -1565,7 +1569,7 @@ class TestFunctions(unittest.TestCase):
|
|||
self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col_str.sql())
|
||||
col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y))
|
||||
self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql())
|
||||
col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r), "l", "r")
|
||||
col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r))
|
||||
self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql())
|
||||
|
||||
def test_transform_keys(self):
|
||||
|
@ -1573,7 +1577,7 @@ class TestFunctions(unittest.TestCase):
|
|||
self.assertEqual("TRANSFORM_KEYS(cola, (k, v) -> UPPER(k))", col_str.sql())
|
||||
col = SF.transform_keys(SF.col("cola"), lambda k, v: SF.upper(k))
|
||||
self.assertEqual("TRANSFORM_KEYS(cola, (k, v) -> UPPER(k))", col.sql())
|
||||
col_custom_names = SF.transform_keys("cola", lambda key, _: SF.upper(key), "key", "_")
|
||||
col_custom_names = SF.transform_keys("cola", lambda key, _: SF.upper(key))
|
||||
self.assertEqual("TRANSFORM_KEYS(cola, (key, _) -> UPPER(key))", col_custom_names.sql())
|
||||
|
||||
def test_transform_values(self):
|
||||
|
@ -1581,7 +1585,7 @@ class TestFunctions(unittest.TestCase):
|
|||
self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col_str.sql())
|
||||
col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v))
|
||||
self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql())
|
||||
col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value), "_", "value")
|
||||
col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value))
|
||||
self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql())
|
||||
|
||||
def test_map_filter(self):
|
||||
|
@ -1589,5 +1593,9 @@ class TestFunctions(unittest.TestCase):
|
|||
self.assertEqual("MAP_FILTER(cola, (k, v) -> k > v)", col_str.sql())
|
||||
col = SF.map_filter(SF.col("cola"), lambda k, v: k > v)
|
||||
self.assertEqual("MAP_FILTER(cola, (k, v) -> k > v)", col.sql())
|
||||
col_custom_names = SF.map_filter("cola", lambda key, value: key > value, "key", "value")
|
||||
col_custom_names = SF.map_filter("cola", lambda key, value: key > value)
|
||||
self.assertEqual("MAP_FILTER(cola, (key, value) -> key > value)", col_custom_names.sql())
|
||||
|
||||
def test_map_zip_with(self):
|
||||
col = SF.map_zip_with("base", "ratio", lambda k, v1, v2: SF.round(v1 * v2, 2))
|
||||
self.assertEqual("MAP_ZIP_WITH(base, ratio, (k, v1, v2) -> ROUND(v1 * v2, 2))", col.sql())
|
||||
|
|
|
@ -105,6 +105,15 @@ class TestBigQuery(Validator):
|
|||
"spark": "x IS NULL",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CURRENT_DATE",
|
||||
read={
|
||||
"tsql": "GETDATE()",
|
||||
},
|
||||
write={
|
||||
"tsql": "GETDATE()",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"current_datetime",
|
||||
write={
|
||||
|
|
|
@ -434,12 +434,7 @@ class TestDialect(Validator):
|
|||
"presto": "DATE_ADD('day', 1, x)",
|
||||
"spark": "DATE_ADD(x, 1)",
|
||||
"starrocks": "DATE_ADD(x, INTERVAL 1 DAY)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"DATE_ADD(x, y, 'day')",
|
||||
write={
|
||||
"postgres": UnsupportedError,
|
||||
"tsql": "DATEADD(day, 1, x)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
|
@ -634,11 +629,13 @@ class TestDialect(Validator):
|
|||
read={
|
||||
"postgres": "x->'y'",
|
||||
"presto": "JSON_EXTRACT(x, 'y')",
|
||||
"starrocks": "x->'y'",
|
||||
},
|
||||
write={
|
||||
"oracle": "JSON_EXTRACT(x, 'y')",
|
||||
"postgres": "x->'y'",
|
||||
"presto": "JSON_EXTRACT(x, 'y')",
|
||||
"starrocks": "x->'y'",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
|
@ -983,6 +980,7 @@ class TestDialect(Validator):
|
|||
)
|
||||
|
||||
def test_limit(self):
|
||||
self.validate_all("SELECT * FROM data LIMIT 10, 20", write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"})
|
||||
self.validate_all(
|
||||
"SELECT x FROM y LIMIT 10",
|
||||
write={
|
||||
|
|
|
@ -282,3 +282,6 @@ TBLPROPERTIES (
|
|||
"spark": "SELECT ARRAY_SORT(x)",
|
||||
},
|
||||
)
|
||||
|
||||
def test_iif(self):
|
||||
self.validate_all("SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"})
|
||||
|
|
|
@ -71,3 +71,226 @@ class TestTSQL(Validator):
|
|||
"spark": "LOCATE('sub', 'testsubstring')",
|
||||
},
|
||||
)
|
||||
|
||||
def test_len(self):
|
||||
self.validate_all("LEN(x)", write={"spark": "LENGTH(x)"})
|
||||
|
||||
def test_replicate(self):
|
||||
self.validate_all("REPLICATE('x', 2)", write={"spark": "REPEAT('x', 2)"})
|
||||
|
||||
def test_isnull(self):
|
||||
self.validate_all("ISNULL(x, y)", write={"spark": "COALESCE(x, y)"})
|
||||
|
||||
def test_jsonvalue(self):
|
||||
self.validate_all(
|
||||
"JSON_VALUE(r.JSON, '$.Attr_INT')",
|
||||
write={"spark": "GET_JSON_OBJECT(r.JSON, '$.Attr_INT')"},
|
||||
)
|
||||
|
||||
def test_datefromparts(self):
|
||||
self.validate_all(
|
||||
"SELECT DATEFROMPARTS('2020', 10, 01)",
|
||||
write={"spark": "SELECT MAKE_DATE('2020', 10, 01)"},
|
||||
)
|
||||
|
||||
def test_datename(self):
|
||||
self.validate_all(
|
||||
"SELECT DATENAME(mm,'01-01-1970')",
|
||||
write={"spark": "SELECT DATE_FORMAT('01-01-1970', 'MMMM')"},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT DATENAME(dw,'01-01-1970')",
|
||||
write={"spark": "SELECT DATE_FORMAT('01-01-1970', 'EEEE')"},
|
||||
)
|
||||
|
||||
def test_datepart(self):
|
||||
self.validate_all(
|
||||
"SELECT DATEPART(month,'01-01-1970')",
|
||||
write={"spark": "SELECT DATE_FORMAT('01-01-1970', 'MM')"},
|
||||
)
|
||||
|
||||
def test_convert_date_format(self):
|
||||
self.validate_all(
|
||||
"CONVERT(NVARCHAR(200), x)",
|
||||
write={
|
||||
"spark": "CAST(x AS VARCHAR(200))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CONVERT(NVARCHAR, x)",
|
||||
write={
|
||||
"spark": "CAST(x AS VARCHAR(30))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CONVERT(NVARCHAR(MAX), x)",
|
||||
write={
|
||||
"spark": "CAST(x AS STRING)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CONVERT(VARCHAR(200), x)",
|
||||
write={
|
||||
"spark": "CAST(x AS VARCHAR(200))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CONVERT(VARCHAR, x)",
|
||||
write={
|
||||
"spark": "CAST(x AS VARCHAR(30))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CONVERT(VARCHAR(MAX), x)",
|
||||
write={
|
||||
"spark": "CAST(x AS STRING)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CONVERT(CHAR(40), x)",
|
||||
write={
|
||||
"spark": "CAST(x AS CHAR(40))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CONVERT(CHAR, x)",
|
||||
write={
|
||||
"spark": "CAST(x AS CHAR(30))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CONVERT(NCHAR(40), x)",
|
||||
write={
|
||||
"spark": "CAST(x AS CHAR(40))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CONVERT(NCHAR, x)",
|
||||
write={
|
||||
"spark": "CAST(x AS CHAR(30))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CONVERT(VARCHAR, x, 121)",
|
||||
write={
|
||||
"spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(30))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CONVERT(VARCHAR(40), x, 121)",
|
||||
write={
|
||||
"spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(40))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CONVERT(VARCHAR(MAX), x, 121)",
|
||||
write={
|
||||
"spark": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CONVERT(NVARCHAR, x, 121)",
|
||||
write={
|
||||
"spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(30))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CONVERT(NVARCHAR(40), x, 121)",
|
||||
write={
|
||||
"spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(40))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CONVERT(NVARCHAR(MAX), x, 121)",
|
||||
write={
|
||||
"spark": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CONVERT(DATE, x, 121)",
|
||||
write={
|
||||
"spark": "TO_DATE(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CONVERT(DATETIME, x, 121)",
|
||||
write={
|
||||
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CONVERT(DATETIME2, x, 121)",
|
||||
write={
|
||||
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CONVERT(INT, x)",
|
||||
write={
|
||||
"spark": "CAST(x AS INT)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CONVERT(INT, x, 121)",
|
||||
write={
|
||||
"spark": "CAST(x AS INT)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"TRY_CONVERT(NVARCHAR, x, 121)",
|
||||
write={
|
||||
"spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(30))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"TRY_CONVERT(INT, x)",
|
||||
write={
|
||||
"spark": "CAST(x AS INT)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"TRY_CAST(x AS INT)",
|
||||
write={
|
||||
"spark": "CAST(x AS INT)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CAST(x AS INT)",
|
||||
write={
|
||||
"spark": "CAST(x AS INT)",
|
||||
},
|
||||
)
|
||||
|
||||
def test_add_date(self):
|
||||
self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')")
|
||||
self.validate_all(
|
||||
"SELECT DATEADD(year, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"}
|
||||
)
|
||||
self.validate_all("SELECT DATEADD(qq, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"})
|
||||
self.validate_all("SELECT DATEADD(wk, 1, '2017/08/25')", write={"spark": "SELECT DATE_ADD('2017/08/25', 7)"})
|
||||
|
||||
def test_date_diff(self):
|
||||
self.validate_identity("SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')")
|
||||
self.validate_all(
|
||||
"SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')",
|
||||
write={
|
||||
"tsql": "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')",
|
||||
"spark": "SELECT MONTHS_BETWEEN('2021/01/01', '2020/01/01') / 12",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT DATEDIFF(month, 'start','end')",
|
||||
write={"spark": "SELECT MONTHS_BETWEEN('end', 'start')", "tsql": "SELECT DATEDIFF(month, 'start', 'end')"},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT DATEDIFF(quarter, 'start', 'end')", write={"spark": "SELECT MONTHS_BETWEEN('end', 'start') / 3"}
|
||||
)
|
||||
|
||||
def test_iif(self):
|
||||
self.validate_identity("SELECT IIF(cond, 'True', 'False')")
|
||||
self.validate_all(
|
||||
"SELECT IIF(cond, 'True', 'False');",
|
||||
write={
|
||||
"spark": "SELECT IF(cond, 'True', 'False')",
|
||||
},
|
||||
)
|
||||
|
|
1
tests/fixtures/identity.sql
vendored
1
tests/fixtures/identity.sql
vendored
|
@ -149,7 +149,6 @@ SELECT 1 AS count FROM test
|
|||
SELECT 1 AS comment FROM test
|
||||
SELECT 1 AS numeric FROM test
|
||||
SELECT 1 AS number FROM test
|
||||
SELECT 1 AS number # annotation
|
||||
SELECT t.count
|
||||
SELECT DISTINCT x FROM test
|
||||
SELECT DISTINCT x, y FROM test
|
||||
|
|
|
@ -329,6 +329,10 @@ class TestBuild(unittest.TestCase):
|
|||
lambda: exp.update("tbl", {"x": 1}, where="y > 0"),
|
||||
"UPDATE tbl SET x = 1 WHERE y > 0",
|
||||
),
|
||||
(
|
||||
lambda: exp.update("tbl", {"x": 1}, where=exp.condition("y > 0")),
|
||||
"UPDATE tbl SET x = 1 WHERE y > 0",
|
||||
),
|
||||
(
|
||||
lambda: exp.update("tbl", {"x": 1}, from_="tbl2"),
|
||||
"UPDATE tbl SET x = 1 FROM tbl2",
|
||||
|
|
|
@ -135,6 +135,53 @@ class TestExpressions(unittest.TestCase):
|
|||
"SELECT * FROM a1 AS a JOIN b.a JOIN c.a2 JOIN d2 JOIN e.a",
|
||||
)
|
||||
|
||||
def test_replace_placeholders(self):
|
||||
self.assertEqual(
|
||||
exp.replace_placeholders(
|
||||
parse_one("select * from :tbl1 JOIN :tbl2 ON :col1 = :col2 WHERE :col3 > 100"),
|
||||
tbl1="foo",
|
||||
tbl2="bar",
|
||||
col1="a",
|
||||
col2="b",
|
||||
col3="c",
|
||||
).sql(),
|
||||
"SELECT * FROM foo JOIN bar ON a = b WHERE c > 100",
|
||||
)
|
||||
self.assertEqual(
|
||||
exp.replace_placeholders(
|
||||
parse_one("select * from ? JOIN ? ON ? = ? WHERE ? > 100"),
|
||||
"foo",
|
||||
"bar",
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
).sql(),
|
||||
"SELECT * FROM foo JOIN bar ON a = b WHERE c > 100",
|
||||
)
|
||||
self.assertEqual(
|
||||
exp.replace_placeholders(
|
||||
parse_one("select * from ? WHERE ? > 100"),
|
||||
"foo",
|
||||
).sql(),
|
||||
"SELECT * FROM foo WHERE ? > 100",
|
||||
)
|
||||
self.assertEqual(
|
||||
exp.replace_placeholders(parse_one("select * from :name WHERE ? > 100"), another_name="bla").sql(),
|
||||
"SELECT * FROM :name WHERE ? > 100",
|
||||
)
|
||||
self.assertEqual(
|
||||
exp.replace_placeholders(
|
||||
parse_one("select * from (SELECT :col1 FROM ?) WHERE :col2 > 100"),
|
||||
"tbl1",
|
||||
"tbl2",
|
||||
"tbl3",
|
||||
col1="a",
|
||||
col2="b",
|
||||
col3="c",
|
||||
).sql(),
|
||||
"SELECT * FROM (SELECT a FROM tbl1) WHERE b > 100",
|
||||
)
|
||||
|
||||
def test_named_selects(self):
|
||||
expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
|
||||
self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"])
|
||||
|
@ -504,9 +551,24 @@ class TestExpressions(unittest.TestCase):
|
|||
[e.alias_or_name for e in expression.expressions],
|
||||
["a", "B", "c", "D"],
|
||||
)
|
||||
self.assertEqual(expression.sql(), sql)
|
||||
self.assertEqual(expression.sql(), "SELECT a, b AS B, c, d AS D")
|
||||
self.assertEqual(expression.expressions[2].name, "comment")
|
||||
self.assertEqual(expression.sql(annotations=False), "SELECT a, b AS B, c, d AS D")
|
||||
self.assertEqual(
|
||||
expression.sql(pretty=True, annotations=False),
|
||||
"""SELECT
|
||||
a,
|
||||
b AS B,
|
||||
c,
|
||||
d AS D""",
|
||||
)
|
||||
self.assertEqual(
|
||||
expression.sql(pretty=True),
|
||||
"""SELECT
|
||||
a,
|
||||
b AS B,
|
||||
c # comment,
|
||||
d AS D # another_comment FROM foo""",
|
||||
)
|
||||
|
||||
def test_to_table(self):
|
||||
table_only = exp.to_table("table_name")
|
||||
|
|
|
@ -5,7 +5,7 @@ from sqlglot.time import format_time
|
|||
|
||||
class TestTime(unittest.TestCase):
|
||||
def test_format_time(self):
|
||||
self.assertEqual(format_time("", {}), "")
|
||||
self.assertEqual(format_time("", {}), None)
|
||||
self.assertEqual(format_time(" ", {}), " ")
|
||||
mapping = {"a": "b", "aa": "c"}
|
||||
self.assertEqual(format_time("a", mapping), "b")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue