Merging upstream version 25.24.5.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
f2b92bd29a
commit
1763c7a4ef
80 changed files with 61531 additions and 59444 deletions
|
@ -62,8 +62,12 @@ class TestAthena(Validator):
|
|||
|
||||
# CTAS goes to the Trino engine, where the table properties cant be encased in single quotes like they can for Hive
|
||||
# ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties
|
||||
# They're also case sensitive and need to be lowercase, otherwise you get eg "Table properties [FORMAT] are not supported."
|
||||
self.validate_identity(
|
||||
"CREATE TABLE foo WITH (table_type='ICEBERG', external_location='s3://foo/') AS SELECT * FROM a"
|
||||
"CREATE TABLE foo WITH (table_type='ICEBERG', location='s3://foo/', format='orc', partitioning=ARRAY['bucket(id, 5)']) AS SELECT * FROM a"
|
||||
)
|
||||
self.validate_identity(
|
||||
"CREATE TABLE foo WITH (table_type='HIVE', external_location='s3://foo/', format='parquet', partitioned_by=ARRAY['ds']) AS SELECT * FROM a"
|
||||
)
|
||||
self.validate_identity(
|
||||
"CREATE TABLE foo AS WITH foo AS (SELECT a, b FROM bar) SELECT * FROM foo"
|
||||
|
|
|
@ -1985,3 +1985,17 @@ OPTIONS (
|
|||
self.validate_identity(
|
||||
"SELECT RANGE(CAST('2022-10-01 14:53:27 America/Los_Angeles' AS TIMESTAMP), CAST('2022-10-01 16:00:00 America/Los_Angeles' AS TIMESTAMP))"
|
||||
)
|
||||
|
||||
def test_null_ordering(self):
|
||||
# Aggregate functions allow "NULLS FIRST" only with ascending order and
|
||||
# "NULLS LAST" only with descending
|
||||
for sort_order, null_order in (("ASC", "NULLS LAST"), ("DESC", "NULLS FIRST")):
|
||||
self.validate_all(
|
||||
f"SELECT color, ARRAY_AGG(id ORDER BY id {sort_order}) AS ids FROM colors GROUP BY 1",
|
||||
read={
|
||||
"": f"SELECT color, ARRAY_AGG(id ORDER BY id {sort_order} {null_order}) AS ids FROM colors GROUP BY 1"
|
||||
},
|
||||
write={
|
||||
"bigquery": f"SELECT color, ARRAY_AGG(id ORDER BY id {sort_order}) AS ids FROM colors GROUP BY 1",
|
||||
},
|
||||
)
|
||||
|
|
|
@ -858,6 +858,28 @@ class TestDuckDB(Validator):
|
|||
self.validate_identity(
|
||||
"SELECT COALESCE(*COLUMNS(['a', 'b', 'c'])) AS result FROM (SELECT NULL AS a, 42 AS b, TRUE AS c)"
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT UNNEST(foo) AS x",
|
||||
write={
|
||||
"redshift": UnsupportedError,
|
||||
},
|
||||
)
|
||||
self.validate_identity("a ^ b", "POWER(a, b)")
|
||||
self.validate_identity("a ** b", "POWER(a, b)")
|
||||
self.validate_identity("a ~~~ b", "a GLOB b")
|
||||
self.validate_identity("a ~~ b", "a LIKE b")
|
||||
self.validate_identity("a @> b")
|
||||
self.validate_identity("a <@ b", "b @> a")
|
||||
self.validate_identity("a && b").assert_is(exp.ArrayOverlaps)
|
||||
self.validate_identity("a ^@ b", "STARTS_WITH(a, b)")
|
||||
self.validate_identity(
|
||||
"a !~~ b",
|
||||
"NOT a LIKE b",
|
||||
)
|
||||
self.validate_identity(
|
||||
"a !~~* b",
|
||||
"NOT a ILIKE b",
|
||||
)
|
||||
|
||||
def test_array_index(self):
|
||||
with self.assertLogs(helper_logger) as cm:
|
||||
|
@ -967,6 +989,15 @@ class TestDuckDB(Validator):
|
|||
"spark": "DATE_FORMAT(x, 'yy-M-ss')",
|
||||
},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"SHA1(x)",
|
||||
write={
|
||||
"duckdb": "SHA1(x)",
|
||||
"": "SHA(x)",
|
||||
},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"STRFTIME(x, '%Y-%m-%d %H:%M:%S')",
|
||||
write={
|
||||
|
@ -1086,6 +1117,7 @@ class TestDuckDB(Validator):
|
|||
self.validate_identity("CAST(x AS INT16)", "CAST(x AS SMALLINT)")
|
||||
self.validate_identity("CAST(x AS NUMERIC(1, 2))", "CAST(x AS DECIMAL(1, 2))")
|
||||
self.validate_identity("CAST(x AS HUGEINT)", "CAST(x AS INT128)")
|
||||
self.validate_identity("CAST(x AS UHUGEINT)", "CAST(x AS UINT128)")
|
||||
self.validate_identity("CAST(x AS CHAR)", "CAST(x AS TEXT)")
|
||||
self.validate_identity("CAST(x AS BPCHAR)", "CAST(x AS TEXT)")
|
||||
self.validate_identity("CAST(x AS STRING)", "CAST(x AS TEXT)")
|
||||
|
|
|
@ -747,16 +747,28 @@ class TestMySQL(Validator):
|
|||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT * FROM x LEFT JOIN y ON x.id = y.id UNION SELECT * FROM x RIGHT JOIN y ON x.id = y.id LIMIT 0",
|
||||
"SELECT * FROM x LEFT JOIN y ON x.id = y.id UNION ALL SELECT * FROM x RIGHT JOIN y ON x.id = y.id WHERE NOT EXISTS(SELECT 1 FROM x WHERE x.id = y.id) ORDER BY 1 LIMIT 0",
|
||||
read={
|
||||
"postgres": "SELECT * FROM x FULL JOIN y ON x.id = y.id LIMIT 0",
|
||||
"postgres": "SELECT * FROM x FULL JOIN y ON x.id = y.id ORDER BY 1 LIMIT 0",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
# MySQL doesn't support FULL OUTER joins
|
||||
"WITH t1 AS (SELECT 1) SELECT * FROM t1 LEFT OUTER JOIN t2 ON t1.x = t2.x UNION SELECT * FROM t1 RIGHT OUTER JOIN t2 ON t1.x = t2.x",
|
||||
"SELECT * FROM t1 LEFT OUTER JOIN t2 ON t1.x = t2.x UNION ALL SELECT * FROM t1 RIGHT OUTER JOIN t2 ON t1.x = t2.x WHERE NOT EXISTS(SELECT 1 FROM t1 WHERE t1.x = t2.x)",
|
||||
read={
|
||||
"postgres": "WITH t1 AS (SELECT 1) SELECT * FROM t1 FULL OUTER JOIN t2 ON t1.x = t2.x",
|
||||
"postgres": "SELECT * FROM t1 FULL OUTER JOIN t2 ON t1.x = t2.x",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT * FROM t1 LEFT OUTER JOIN t2 USING (x) UNION ALL SELECT * FROM t1 RIGHT OUTER JOIN t2 USING (x) WHERE NOT EXISTS(SELECT 1 FROM t1 WHERE t1.x = t2.x)",
|
||||
read={
|
||||
"postgres": "SELECT * FROM t1 FULL OUTER JOIN t2 USING (x) ",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT * FROM t1 LEFT OUTER JOIN t2 USING (x, y) UNION ALL SELECT * FROM t1 RIGHT OUTER JOIN t2 USING (x, y) WHERE NOT EXISTS(SELECT 1 FROM t1 WHERE t1.x = t2.x AND t1.y = t2.y)",
|
||||
read={
|
||||
"postgres": "SELECT * FROM t1 FULL OUTER JOIN t2 USING (x, y) ",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
|
|
|
@ -66,6 +66,15 @@ class TestOracle(Validator):
|
|||
self.validate_identity(
|
||||
"SELECT COUNT(1) INTO V_Temp FROM TABLE(CAST(somelist AS data_list)) WHERE col LIKE '%contact'"
|
||||
)
|
||||
self.validate_identity(
|
||||
"SELECT department_id INTO v_department_id FROM departments FETCH FIRST 1 ROWS ONLY"
|
||||
)
|
||||
self.validate_identity(
|
||||
"SELECT department_id BULK COLLECT INTO v_department_ids FROM departments"
|
||||
)
|
||||
self.validate_identity(
|
||||
"SELECT department_id, department_name BULK COLLECT INTO v_department_ids, v_department_names FROM departments"
|
||||
)
|
||||
self.validate_identity(
|
||||
"SELECT MIN(column_name) KEEP (DENSE_RANK FIRST ORDER BY column_name DESC) FROM table_name"
|
||||
)
|
||||
|
@ -102,6 +111,14 @@ class TestOracle(Validator):
|
|||
"SELECT * FROM t START WITH col CONNECT BY NOCYCLE PRIOR col1 = col2"
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"SELECT department_id, department_name INTO v_department_id, v_department_name FROM departments FETCH FIRST 1 ROWS ONLY",
|
||||
write={
|
||||
"oracle": "SELECT department_id, department_name INTO v_department_id, v_department_name FROM departments FETCH FIRST 1 ROWS ONLY",
|
||||
"postgres": UnsupportedError,
|
||||
"tsql": UnsupportedError,
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"TRUNC(SYSDATE, 'YEAR')",
|
||||
write={
|
||||
|
|
|
@ -354,10 +354,10 @@ class TestPostgres(Validator):
|
|||
self.validate_all(
|
||||
"SELECT ARRAY[1, 2, 3] @> ARRAY[1, 2]",
|
||||
read={
|
||||
"duckdb": "SELECT ARRAY_HAS_ALL([1, 2, 3], [1, 2])",
|
||||
"duckdb": "SELECT [1, 2, 3] @> [1, 2]",
|
||||
},
|
||||
write={
|
||||
"duckdb": "SELECT ARRAY_HAS_ALL([1, 2, 3], [1, 2])",
|
||||
"duckdb": "SELECT [1, 2, 3] @> [1, 2]",
|
||||
"mysql": UnsupportedError,
|
||||
"postgres": "SELECT ARRAY[1, 2, 3] @> ARRAY[1, 2]",
|
||||
},
|
||||
|
@ -398,13 +398,6 @@ class TestPostgres(Validator):
|
|||
"postgres": "SELECT (data ->> 'en-US') AS acat FROM my_table",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT ARRAY[1, 2, 3] && ARRAY[1, 2]",
|
||||
write={
|
||||
"": "SELECT ARRAY_OVERLAPS(ARRAY(1, 2, 3), ARRAY(1, 2))",
|
||||
"postgres": "SELECT ARRAY[1, 2, 3] && ARRAY[1, 2]",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT JSON_EXTRACT_PATH_TEXT(x, k1, k2, k3) FROM t",
|
||||
read={
|
||||
|
@ -802,6 +795,7 @@ class TestPostgres(Validator):
|
|||
)
|
||||
self.validate_identity("SELECT OVERLAY(a PLACING b FROM 1)")
|
||||
self.validate_identity("SELECT OVERLAY(a PLACING b FROM 1 FOR 1)")
|
||||
self.validate_identity("ARRAY[1, 2, 3] && ARRAY[1, 2]").assert_is(exp.ArrayOverlaps)
|
||||
|
||||
def test_ddl(self):
|
||||
# Checks that user-defined types are parsed into DataType instead of Identifier
|
||||
|
|
|
@ -213,6 +213,12 @@ class TestRedshift(Validator):
|
|||
"redshift": "SELECT CAST('abc' AS VARBYTE)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CREATE TABLE a (b BINARY VARYING(10))",
|
||||
write={
|
||||
"redshift": "CREATE TABLE a (b VARBYTE(10))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT 'abc'::CHARACTER",
|
||||
write={
|
||||
|
|
|
@ -2,7 +2,6 @@ from unittest import mock
|
|||
|
||||
from sqlglot import exp, parse_one
|
||||
from sqlglot.dialects.dialect import Dialects
|
||||
from sqlglot.helper import logger as helper_logger
|
||||
from tests.dialects.test_dialect import Validator
|
||||
|
||||
|
||||
|
@ -294,19 +293,19 @@ TBLPROPERTIES (
|
|||
"SELECT STR_TO_MAP('a:1,b:2,c:3')",
|
||||
"SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')",
|
||||
)
|
||||
|
||||
with self.assertLogs(helper_logger):
|
||||
self.validate_all(
|
||||
"SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
|
||||
read={
|
||||
"databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
|
||||
},
|
||||
write={
|
||||
"databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
|
||||
"duckdb": "SELECT ([1, 2, 3])[3]",
|
||||
"spark": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
|
||||
read={
|
||||
"databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
|
||||
"presto": "SELECT ELEMENT_AT(ARRAY[1, 2, 3], 2)",
|
||||
},
|
||||
write={
|
||||
"databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
|
||||
"spark": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
|
||||
"duckdb": "SELECT ([1, 2, 3])[2]",
|
||||
"presto": "SELECT ELEMENT_AT(ARRAY[1, 2, 3], 2)",
|
||||
},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"SELECT ARRAY_AGG(x) FILTER (WHERE x = 5) FROM (SELECT 1 UNION ALL SELECT NULL) AS t(x)",
|
||||
|
|
|
@ -26,6 +26,7 @@ class TestSQLite(Validator):
|
|||
"""SELECT item AS "item", some AS "some" FROM data WHERE (item = 'value_1' COLLATE NOCASE) AND (some = 't' COLLATE NOCASE) ORDER BY item ASC LIMIT 1 OFFSET 0"""
|
||||
)
|
||||
self.validate_identity("SELECT * FROM GENERATE_SERIES(1, 5)")
|
||||
self.validate_identity("SELECT INSTR(haystack, needle)")
|
||||
|
||||
self.validate_all("SELECT LIKE(y, x)", write={"sqlite": "SELECT x LIKE y"})
|
||||
self.validate_all("SELECT GLOB('*y*', 'xyz')", write={"sqlite": "SELECT 'xyz' GLOB '*y*'"})
|
||||
|
|
|
@ -4,6 +4,12 @@ from tests.dialects.test_dialect import Validator
|
|||
class TestTrino(Validator):
|
||||
dialect = "trino"
|
||||
|
||||
def test_trino(self):
|
||||
self.validate_identity("JSON_EXTRACT(content, json_path)")
|
||||
self.validate_identity("JSON_QUERY(content, 'lax $.HY.*')")
|
||||
self.validate_identity("JSON_QUERY(content, 'strict $.HY.*' WITH UNCONDITIONAL WRAPPER)")
|
||||
self.validate_identity("JSON_QUERY(content, 'strict $.HY.*' WITHOUT CONDITIONAL WRAPPER)")
|
||||
|
||||
def test_trim(self):
|
||||
self.validate_identity("SELECT TRIM('!' FROM '!foo!')")
|
||||
self.validate_identity("SELECT TRIM(BOTH '$' FROM '$var$')")
|
||||
|
|
|
@ -8,6 +8,11 @@ class TestTSQL(Validator):
|
|||
dialect = "tsql"
|
||||
|
||||
def test_tsql(self):
|
||||
self.validate_identity(
|
||||
"with x as (select 1) select * from x union select * from x order by 1 limit 0",
|
||||
"WITH x AS (SELECT 1 AS [1]) SELECT TOP 0 * FROM (SELECT * FROM x UNION SELECT * FROM x) AS _l_0 ORDER BY 1",
|
||||
)
|
||||
|
||||
# https://learn.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms187879(v=sql.105)?redirectedfrom=MSDN
|
||||
# tsql allows .. which means use the default schema
|
||||
self.validate_identity("SELECT * FROM a..b")
|
||||
|
@ -46,6 +51,10 @@ class TestTSQL(Validator):
|
|||
self.validate_identity(
|
||||
"COPY INTO test_1 FROM 'path' WITH (FORMAT_NAME = test, FILE_TYPE = 'CSV', CREDENTIAL = (IDENTITY='Shared Access Signature', SECRET='token'), FIELDTERMINATOR = ';', ROWTERMINATOR = '0X0A', ENCODING = 'UTF8', DATEFORMAT = 'ymd', MAXERRORS = 10, ERRORFILE = 'errorsfolder', IDENTITY_INSERT = 'ON')"
|
||||
)
|
||||
self.validate_identity(
|
||||
'SELECT 1 AS "[x]"',
|
||||
"SELECT 1 AS [[x]]]",
|
||||
)
|
||||
self.assertEqual(
|
||||
annotate_types(self.validate_identity("SELECT 1 WHERE EXISTS(SELECT 1)")).sql("tsql"),
|
||||
"SELECT 1 WHERE EXISTS(SELECT 1)",
|
||||
|
|
189
tests/fixtures/optimizer/annotate_functions.sql
vendored
Normal file
189
tests/fixtures/optimizer/annotate_functions.sql
vendored
Normal file
|
@ -0,0 +1,189 @@
|
|||
--------------------------------------
|
||||
-- Dialect
|
||||
--------------------------------------
|
||||
ABS(1);
|
||||
INT;
|
||||
|
||||
ABS(1.5);
|
||||
DOUBLE;
|
||||
|
||||
GREATEST(1, 2, 3);
|
||||
INT;
|
||||
|
||||
GREATEST(1, 2.5, 3);
|
||||
DOUBLE;
|
||||
|
||||
LEAST(1, 2, 3);
|
||||
INT;
|
||||
|
||||
LEAST(1, 2.5, 3);
|
||||
DOUBLE;
|
||||
|
||||
--------------------------------------
|
||||
-- Spark2 / Spark3 / Databricks
|
||||
--------------------------------------
|
||||
|
||||
# dialect: spark2, spark, databricks
|
||||
SUBSTRING(tbl.str_col, 0, 0);
|
||||
STRING;
|
||||
|
||||
# dialect: spark2, spark, databricks
|
||||
SUBSTRING(tbl.bin_col, 0, 0);
|
||||
BINARY;
|
||||
|
||||
# dialect: spark2, spark, databricks
|
||||
CONCAT(tbl.bin_col, tbl.bin_col);
|
||||
BINARY;
|
||||
|
||||
# dialect: spark2, spark, databricks
|
||||
CONCAT(tbl.bin_col, tbl.str_col);
|
||||
STRING;
|
||||
|
||||
# dialect: spark2, spark, databricks
|
||||
CONCAT(tbl.str_col, tbl.bin_col);
|
||||
STRING;
|
||||
|
||||
# dialect: spark2, spark, databricks
|
||||
CONCAT(tbl.str_col, tbl.str_col);
|
||||
STRING;
|
||||
|
||||
# dialect: spark2, spark, databricks
|
||||
CONCAT(tbl.str_col, unknown);
|
||||
STRING;
|
||||
|
||||
# dialect: spark2, spark, databricks
|
||||
CONCAT(tbl.bin_col, unknown);
|
||||
UNKNOWN;
|
||||
|
||||
# dialect: spark2, spark, databricks
|
||||
CONCAT(unknown, unknown);
|
||||
UNKNOWN;
|
||||
|
||||
# dialect: spark2, spark, databricks
|
||||
LPAD(tbl.bin_col, 1, tbl.bin_col);
|
||||
BINARY;
|
||||
|
||||
# dialect: spark2, spark, databricks
|
||||
RPAD(tbl.bin_col, 1, tbl.bin_col);
|
||||
BINARY;
|
||||
|
||||
# dialect: spark2, spark, databricks
|
||||
LPAD(tbl.bin_col, 1, tbl.str_col);
|
||||
STRING;
|
||||
|
||||
# dialect: spark2, spark, databricks
|
||||
RPAD(tbl.bin_col, 1, tbl.str_col);
|
||||
STRING;
|
||||
|
||||
# dialect: spark2, spark, databricks
|
||||
LPAD(tbl.str_col, 1, tbl.bin_col);
|
||||
STRING;
|
||||
|
||||
# dialect: spark2, spark, databricks
|
||||
RPAD(tbl.str_col, 1, tbl.bin_col);
|
||||
STRING;
|
||||
|
||||
# dialect: spark2, spark, databricks
|
||||
LPAD(tbl.str_col, 1, tbl.str_col);
|
||||
STRING;
|
||||
|
||||
# dialect: spark2, spark, databricks
|
||||
RPAD(tbl.str_col, 1, tbl.str_col);
|
||||
STRING;
|
||||
|
||||
|
||||
--------------------------------------
|
||||
-- BigQuery
|
||||
--------------------------------------
|
||||
|
||||
# dialect: bigquery
|
||||
SIGN(1);
|
||||
INT;
|
||||
|
||||
# dialect: bigquery
|
||||
SIGN(1.5);
|
||||
DOUBLE;
|
||||
|
||||
# dialect: bigquery
|
||||
CEIL(1);
|
||||
DOUBLE;
|
||||
|
||||
# dialect: bigquery
|
||||
CEIL(5.5);
|
||||
DOUBLE;
|
||||
|
||||
# dialect: bigquery
|
||||
CEIL(tbl.bignum_col);
|
||||
BIGDECIMAL;
|
||||
|
||||
# dialect: bigquery
|
||||
FLOOR(1);
|
||||
DOUBLE;
|
||||
|
||||
# dialect: bigquery
|
||||
FLOOR(5.5);
|
||||
DOUBLE;
|
||||
|
||||
# dialect: bigquery
|
||||
FLOOR(tbl.bignum_col);
|
||||
BIGDECIMAL;
|
||||
|
||||
# dialect: bigquery
|
||||
SQRT(1);
|
||||
DOUBLE;
|
||||
|
||||
# dialect: bigquery
|
||||
SQRT(5.5);
|
||||
DOUBLE;
|
||||
|
||||
# dialect: bigquery
|
||||
SQRT(tbl.bignum_col);
|
||||
BIGDECIMAL;
|
||||
|
||||
# dialect: bigquery
|
||||
LN(1);
|
||||
DOUBLE;
|
||||
|
||||
# dialect: bigquery
|
||||
LN(5.5);
|
||||
DOUBLE;
|
||||
|
||||
# dialect: bigquery
|
||||
LN(tbl.bignum_col);
|
||||
BIGDECIMAL;
|
||||
|
||||
# dialect: bigquery
|
||||
LOG(1);
|
||||
DOUBLE;
|
||||
|
||||
# dialect: bigquery
|
||||
LOG(5.5);
|
||||
DOUBLE;
|
||||
|
||||
# dialect: bigquery
|
||||
LOG(tbl.bignum_col);
|
||||
BIGDECIMAL;
|
||||
|
||||
# dialect: bigquery
|
||||
ROUND(1);
|
||||
DOUBLE;
|
||||
|
||||
# dialect: bigquery
|
||||
ROUND(5.5);
|
||||
DOUBLE;
|
||||
|
||||
# dialect: bigquery
|
||||
ROUND(tbl.bignum_col);
|
||||
BIGDECIMAL;
|
||||
|
||||
# dialect: bigquery
|
||||
EXP(1);
|
||||
DOUBLE;
|
||||
|
||||
# dialect: bigquery
|
||||
EXP(5.5);
|
||||
DOUBLE;
|
||||
|
||||
# dialect: bigquery
|
||||
EXP(tbl.bignum_col);
|
||||
BIGDECIMAL;
|
31
tests/fixtures/pretty.sql
vendored
31
tests/fixtures/pretty.sql
vendored
|
@ -418,3 +418,34 @@ INSERT FIRST
|
|||
SELECT
|
||||
salary
|
||||
FROM employees;
|
||||
|
||||
SELECT *
|
||||
FROM foo
|
||||
wHERE 1=1
|
||||
AND
|
||||
-- my comment
|
||||
EXISTS (
|
||||
SELECT 1
|
||||
FROM bar
|
||||
);
|
||||
SELECT
|
||||
*
|
||||
FROM foo
|
||||
WHERE
|
||||
1 = 1 AND EXISTS(
|
||||
SELECT
|
||||
1
|
||||
FROM bar
|
||||
) /* my comment */;
|
||||
|
||||
SELECT 1
|
||||
FROM foo
|
||||
WHERE 1=1
|
||||
AND -- first comment
|
||||
-- second comment
|
||||
foo.a = 1;
|
||||
SELECT
|
||||
1
|
||||
FROM foo
|
||||
WHERE
|
||||
1 = 1 AND /* first comment */ foo.a /* second comment */ = 1;
|
||||
|
|
|
@ -577,6 +577,36 @@ class TestBuild(unittest.TestCase):
|
|||
lambda: exp.update("tbl", {"x": 1}, from_="tbl2 cross join tbl3"),
|
||||
"UPDATE tbl SET x = 1 FROM tbl2 CROSS JOIN tbl3",
|
||||
),
|
||||
(
|
||||
lambda: exp.update(
|
||||
"my_table",
|
||||
{"x": 1},
|
||||
from_="baz",
|
||||
where="my_table.id = baz.id",
|
||||
with_={"baz": "SELECT id FROM foo UNION SELECT id FROM bar"},
|
||||
),
|
||||
"WITH baz AS (SELECT id FROM foo UNION SELECT id FROM bar) UPDATE my_table SET x = 1 FROM baz WHERE my_table.id = baz.id",
|
||||
),
|
||||
(
|
||||
lambda: exp.update("my_table").set_("x = 1"),
|
||||
"UPDATE my_table SET x = 1",
|
||||
),
|
||||
(
|
||||
lambda: exp.update("my_table").set_("x = 1").where("y = 2"),
|
||||
"UPDATE my_table SET x = 1 WHERE y = 2",
|
||||
),
|
||||
(
|
||||
lambda: exp.update("my_table").set_("a = 1").set_("b = 2"),
|
||||
"UPDATE my_table SET a = 1, b = 2",
|
||||
),
|
||||
(
|
||||
lambda: exp.update("my_table")
|
||||
.set_("x = 1")
|
||||
.where("my_table.id = baz.id")
|
||||
.from_("baz")
|
||||
.with_("baz", "SELECT id FROM foo"),
|
||||
"WITH baz AS (SELECT id FROM foo) UPDATE my_table SET x = 1 FROM baz WHERE my_table.id = baz.id",
|
||||
),
|
||||
(
|
||||
lambda: union("SELECT * FROM foo", "SELECT * FROM bla"),
|
||||
"SELECT * FROM foo UNION SELECT * FROM bla",
|
||||
|
|
|
@ -157,11 +157,20 @@ class TestDiff(unittest.TestCase):
|
|||
self._validate_delta_only(
|
||||
diff_delta_only(expr_src, expr_tgt),
|
||||
[
|
||||
Remove(parse_one("ROW_NUMBER()")), # the Anonymous node
|
||||
Insert(parse_one("RANK()")), # the Anonymous node
|
||||
Remove(parse_one("ROW_NUMBER()")),
|
||||
Insert(parse_one("RANK()")),
|
||||
Update(source=expr_src.selects[0], target=expr_tgt.selects[0]),
|
||||
],
|
||||
)
|
||||
|
||||
expr_src = parse_one("SELECT MAX(x) OVER (ORDER BY y) FROM z", "oracle")
|
||||
expr_tgt = parse_one("SELECT MAX(x) KEEP (DENSE_RANK LAST ORDER BY y) FROM z", "oracle")
|
||||
|
||||
self._validate_delta_only(
|
||||
diff_delta_only(expr_src, expr_tgt),
|
||||
[Update(source=expr_src.selects[0], target=expr_tgt.selects[0])],
|
||||
)
|
||||
|
||||
def test_pre_matchings(self):
|
||||
expr_src = parse_one("SELECT 1")
|
||||
expr_tgt = parse_one("SELECT 1, 2, 3, 4")
|
||||
|
@ -202,5 +211,34 @@ class TestDiff(unittest.TestCase):
|
|||
],
|
||||
)
|
||||
|
||||
expr_src = parse_one("SELECT 1 AS c1, 2 AS c2")
|
||||
expr_tgt = parse_one("SELECT 2 AS c1, 3 AS c2")
|
||||
|
||||
self._validate_delta_only(
|
||||
diff_delta_only(expr_src, expr_tgt),
|
||||
[
|
||||
Remove(expression=exp.alias_(1, "c1")),
|
||||
Remove(expression=exp.Literal.number(1)),
|
||||
Insert(expression=exp.alias_(3, "c2")),
|
||||
Insert(expression=exp.Literal.number(3)),
|
||||
Update(source=exp.alias_(2, "c2"), target=exp.alias_(2, "c1")),
|
||||
],
|
||||
)
|
||||
|
||||
def test_dialect_aware_diff(self):
|
||||
from sqlglot.generator import logger
|
||||
|
||||
with self.assertLogs(logger) as cm:
|
||||
# We want to assert there are no warnings, but the 'assertLogs' method does not support that.
|
||||
# Therefore, we are adding a dummy warning, and then we will assert it is the only warning.
|
||||
logger.warning("Dummy warning")
|
||||
|
||||
expression = parse_one("SELECT foo FROM bar FOR UPDATE", dialect="oracle")
|
||||
self._validate_delta_only(
|
||||
diff_delta_only(expression, expression.copy(), dialect="oracle"), []
|
||||
)
|
||||
|
||||
self.assertEqual(["WARNING:sqlglot:Dummy warning"], cm.output)
|
||||
|
||||
def _validate_delta_only(self, actual_delta, expected_delta):
|
||||
self.assertEqual(set(actual_delta), set(expected_delta))
|
||||
|
|
|
@ -350,6 +350,7 @@ class TestExpressions(unittest.TestCase):
|
|||
)
|
||||
|
||||
self.assertIsInstance(exp.func("instr", "x", "b", dialect="mysql"), exp.StrPosition)
|
||||
self.assertIsInstance(exp.func("instr", "x", "b", dialect="sqlite"), exp.StrPosition)
|
||||
self.assertIsInstance(exp.func("bla", 1, "foo"), exp.Anonymous)
|
||||
self.assertIsInstance(
|
||||
exp.func("cast", this=exp.Literal.number(5), to=exp.DataType.build("DOUBLE")),
|
||||
|
|
|
@ -54,6 +54,18 @@ def simplify(expression, **kwargs):
|
|||
return optimizer.simplify.simplify(expression, constant_propagation=True, **kwargs)
|
||||
|
||||
|
||||
def annotate_functions(expression, **kwargs):
|
||||
from sqlglot.dialects import Dialect
|
||||
|
||||
dialect = kwargs.get("dialect")
|
||||
schema = kwargs.get("schema")
|
||||
|
||||
annotators = Dialect.get_or_raise(dialect).ANNOTATORS
|
||||
annotated = annotate_types(expression, annotators=annotators, schema=schema)
|
||||
|
||||
return annotated.expressions[0]
|
||||
|
||||
|
||||
class TestOptimizer(unittest.TestCase):
|
||||
maxDiff = None
|
||||
|
||||
|
@ -787,6 +799,28 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
with self.subTest(title):
|
||||
self.assertEqual(result.type.sql(), exp.DataType.build(expected).sql())
|
||||
|
||||
def test_annotate_funcs(self):
|
||||
test_schema = {
|
||||
"tbl": {"bin_col": "BINARY", "str_col": "STRING", "bignum_col": "BIGNUMERIC"}
|
||||
}
|
||||
|
||||
for i, (meta, sql, expected) in enumerate(
|
||||
load_sql_fixture_pairs("optimizer/annotate_functions.sql"), start=1
|
||||
):
|
||||
title = meta.get("title") or f"{i}, {sql}"
|
||||
dialect = meta.get("dialect") or ""
|
||||
sql = f"SELECT {sql} FROM tbl"
|
||||
|
||||
for dialect in dialect.split(", "):
|
||||
result = parse_and_optimize(
|
||||
annotate_functions, sql, dialect, schema=test_schema, dialect=dialect
|
||||
)
|
||||
|
||||
with self.subTest(title):
|
||||
self.assertEqual(
|
||||
result.type.sql(dialect), exp.DataType.build(expected).sql(dialect)
|
||||
)
|
||||
|
||||
def test_cast_type_annotation(self):
|
||||
expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))"))
|
||||
self.assertEqual(expression.type.this, exp.DataType.Type.TIMESTAMPTZ)
|
||||
|
@ -1377,26 +1411,3 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
self.assertEqual(4, normalization_distance(gen_expr(2), max_=100))
|
||||
self.assertEqual(18, normalization_distance(gen_expr(3), max_=100))
|
||||
self.assertEqual(110, normalization_distance(gen_expr(10), max_=100))
|
||||
|
||||
def test_custom_annotators(self):
|
||||
# In Spark hierarchy, SUBSTRING result type is dependent on input expr type
|
||||
for dialect in ("spark2", "spark", "databricks"):
|
||||
for expr_type_pair in (
|
||||
("col", "STRING"),
|
||||
("col", "BINARY"),
|
||||
("'str_literal'", "STRING"),
|
||||
("CAST('str_literal' AS BINARY)", "BINARY"),
|
||||
):
|
||||
with self.subTest(
|
||||
f"Testing {dialect}'s SUBSTRING() result type for {expr_type_pair}"
|
||||
):
|
||||
expr, type = expr_type_pair
|
||||
ast = parse_one(f"SELECT substring({expr}, 2, 3) AS x FROM tbl", read=dialect)
|
||||
|
||||
subst_type = (
|
||||
optimizer.optimize(ast, schema={"tbl": {"col": type}}, dialect=dialect)
|
||||
.expressions[0]
|
||||
.type
|
||||
)
|
||||
|
||||
self.assertEqual(subst_type.sql(dialect), exp.DataType.build(type).sql(dialect))
|
||||
|
|
|
@ -563,7 +563,36 @@ FROM x""",
|
|||
)
|
||||
self.validate(
|
||||
"""with a as /* comment */ ( select * from b) select * from a""",
|
||||
"""WITH a AS (SELECT * FROM b) /* comment */ SELECT * FROM a""",
|
||||
"""WITH a /* comment */ AS (SELECT * FROM b) SELECT * FROM a""",
|
||||
)
|
||||
self.validate(
|
||||
"""
|
||||
-- comment at the top
|
||||
WITH
|
||||
-- comment for tbl1
|
||||
tbl1 AS (SELECT 1)
|
||||
-- comment for tbl2
|
||||
, tbl2 AS (SELECT 2)
|
||||
-- comment for tbl3
|
||||
, tbl3 AS (SELECT 3)
|
||||
-- comment for final select
|
||||
SELECT * FROM tbl1""",
|
||||
"""/* comment at the top */
|
||||
WITH tbl1 /* comment for tbl1 */ AS (
|
||||
SELECT
|
||||
1
|
||||
), tbl2 /* comment for tbl2 */ AS (
|
||||
SELECT
|
||||
2
|
||||
), tbl3 /* comment for tbl3 */ AS (
|
||||
SELECT
|
||||
3
|
||||
)
|
||||
/* comment for final select */
|
||||
SELECT
|
||||
*
|
||||
FROM tbl1""",
|
||||
pretty=True,
|
||||
)
|
||||
|
||||
def test_types(self):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue