1
0
Fork 0

Merging upstream version 12.2.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:53:39 +01:00
parent fffa0d5761
commit 62b2b24d3b
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
100 changed files with 35022 additions and 30936 deletions

View file

@ -3,17 +3,13 @@ import unittest
import warnings
import sqlglot
from sqlglot.helper import PYTHON_VERSION
from tests.helpers import SKIP_INTEGRATION
if t.TYPE_CHECKING:
from pyspark.sql import DataFrame as SparkDataFrame
@unittest.skipIf(
SKIP_INTEGRATION or PYTHON_VERSION > (3, 10),
"Skipping Integration Tests since `SKIP_INTEGRATION` is set",
)
@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set")
class DataFrameValidator(unittest.TestCase):
spark = None
sqlglot = None

View file

@ -6,10 +6,19 @@ class TestBigQuery(Validator):
dialect = "bigquery"
def test_bigquery(self):
self.validate_identity("DATE_TRUNC(col, WEEK(MONDAY))")
self.validate_identity("SELECT b'abc'")
self.validate_identity("""SELECT * FROM UNNEST(ARRAY<STRUCT<x INT64>>[1, 2])""")
self.validate_identity("SELECT AS STRUCT 1 AS a, 2 AS b")
self.validate_identity("SELECT DISTINCT AS STRUCT 1 AS a, 2 AS b")
self.validate_identity("SELECT AS VALUE STRUCT(1 AS a, 2 AS b)")
self.validate_identity("SELECT STRUCT<ARRAY<STRING>>(['2023-01-17'])")
self.validate_identity("SELECT * FROM q UNPIVOT(values FOR quarter IN (b, c))")
self.validate_identity("""CREATE TABLE x (a STRUCT<values ARRAY<INT64>>)""")
self.validate_identity("""CREATE TABLE x (a STRUCT<b STRING OPTIONS (description='b')>)""")
self.validate_identity(
"""CREATE TABLE x (a STRING OPTIONS (description='x')) OPTIONS (table_expiration_days=1)"""
)
self.validate_identity(
"SELECT * FROM (SELECT * FROM `t`) AS a UNPIVOT((c) FOR c_name IN (v1, v2))"
)
@ -97,6 +106,16 @@ class TestBigQuery(Validator):
"spark": "CAST(a AS LONG)",
},
)
self.validate_all(
"CAST(a AS BYTES)",
write={
"bigquery": "CAST(a AS BYTES)",
"duckdb": "CAST(a AS BLOB)",
"presto": "CAST(a AS VARBINARY)",
"hive": "CAST(a AS BINARY)",
"spark": "CAST(a AS BINARY)",
},
)
self.validate_all(
"CAST(a AS NUMERIC)",
write={
@ -173,7 +192,6 @@ class TestBigQuery(Validator):
"current_datetime",
write={
"bigquery": "CURRENT_DATETIME()",
"duckdb": "CURRENT_DATETIME()",
"presto": "CURRENT_DATETIME()",
"hive": "CURRENT_DATETIME()",
"spark": "CURRENT_DATETIME()",
@ -183,7 +201,7 @@ class TestBigQuery(Validator):
"current_time",
write={
"bigquery": "CURRENT_TIME()",
"duckdb": "CURRENT_TIME()",
"duckdb": "CURRENT_TIME",
"presto": "CURRENT_TIME()",
"hive": "CURRENT_TIME()",
"spark": "CURRENT_TIME()",
@ -193,7 +211,7 @@ class TestBigQuery(Validator):
"current_timestamp",
write={
"bigquery": "CURRENT_TIMESTAMP()",
"duckdb": "CURRENT_TIMESTAMP()",
"duckdb": "CURRENT_TIMESTAMP",
"postgres": "CURRENT_TIMESTAMP",
"presto": "CURRENT_TIMESTAMP",
"hive": "CURRENT_TIMESTAMP()",
@ -204,7 +222,7 @@ class TestBigQuery(Validator):
"current_timestamp()",
write={
"bigquery": "CURRENT_TIMESTAMP()",
"duckdb": "CURRENT_TIMESTAMP()",
"duckdb": "CURRENT_TIMESTAMP",
"postgres": "CURRENT_TIMESTAMP",
"presto": "CURRENT_TIMESTAMP",
"hive": "CURRENT_TIMESTAMP()",
@ -342,6 +360,18 @@ class TestBigQuery(Validator):
"snowflake": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
},
)
self.validate_all(
"SELECT cola, colb FROM (VALUES (1, 'test')) AS tab",
write={
"bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS _c0, 'test' AS _c1)])",
},
)
self.validate_all(
"SELECT cola, colb FROM (VALUES (1, 'test'))",
write={
"bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS _c0, 'test' AS _c1)])",
},
)
self.validate_all(
"SELECT cola, colb, colc FROM (VALUES (1, 'test', NULL)) AS tab(cola, colb, colc)",
write={

View file

@ -65,3 +65,24 @@ class TestClickhouse(Validator):
self.validate_identity("WITH SUM(bytes) AS foo SELECT foo FROM system.parts")
self.validate_identity("WITH (SELECT foo) AS bar SELECT bar + 5")
self.validate_identity("WITH test1 AS (SELECT i + 1, j + 1 FROM test1) SELECT * FROM test1")
def test_signed_and_unsigned_types(self):
data_types = [
"UInt8",
"UInt16",
"UInt32",
"UInt64",
"UInt128",
"UInt256",
"Int8",
"Int16",
"Int32",
"Int64",
"Int128",
"Int256",
]
for data_type in data_types:
self.validate_all(
f"pow(2, 32)::{data_type}",
write={"clickhouse": f"CAST(POWER(2, 32) AS {data_type})"},
)

View file

@ -95,7 +95,7 @@ class TestDialect(Validator):
self.validate_all(
"CAST(a AS BINARY(4))",
write={
"bigquery": "CAST(a AS BINARY)",
"bigquery": "CAST(a AS BYTES)",
"clickhouse": "CAST(a AS BINARY(4))",
"drill": "CAST(a AS VARBINARY(4))",
"duckdb": "CAST(a AS BLOB(4))",
@ -114,7 +114,7 @@ class TestDialect(Validator):
self.validate_all(
"CAST(a AS VARBINARY(4))",
write={
"bigquery": "CAST(a AS VARBINARY)",
"bigquery": "CAST(a AS BYTES)",
"clickhouse": "CAST(a AS VARBINARY(4))",
"duckdb": "CAST(a AS BLOB(4))",
"mysql": "CAST(a AS VARBINARY(4))",

View file

@ -6,6 +6,9 @@ class TestDuckDB(Validator):
dialect = "duckdb"
def test_time(self):
self.validate_identity("SELECT CURRENT_DATE")
self.validate_identity("SELECT CURRENT_TIMESTAMP")
self.validate_all(
"EPOCH(x)",
read={
@ -24,7 +27,7 @@ class TestDuckDB(Validator):
"bigquery": "UNIX_TO_TIME(x / 1000)",
"duckdb": "TO_TIMESTAMP(x / 1000)",
"presto": "FROM_UNIXTIME(x / 1000)",
"spark": "FROM_UNIXTIME(x / 1000)",
"spark": "CAST(FROM_UNIXTIME(x / 1000) AS TIMESTAMP)",
},
)
self.validate_all(
@ -124,18 +127,34 @@ class TestDuckDB(Validator):
self.validate_identity("SELECT {'a': 1} AS x")
self.validate_identity("SELECT {'a': {'b': {'c': 1}}, 'd': {'e': 2}} AS x")
self.validate_identity("SELECT {'x': 1, 'y': 2, 'z': 3}")
self.validate_identity(
"SELECT {'yes': 'duck', 'maybe': 'goose', 'huh': NULL, 'no': 'heron'}"
)
self.validate_identity("SELECT {'key1': 'string', 'key2': 1, 'key3': 12.345}")
self.validate_identity("SELECT ROW(x, x + 1, y) FROM (SELECT 1 AS x, 'a' AS y)")
self.validate_identity("SELECT (x, x + 1, y) FROM (SELECT 1 AS x, 'a' AS y)")
self.validate_identity("SELECT a.x FROM (SELECT {'x': 1, 'y': 2, 'z': 3} AS a)")
self.validate_identity("ATTACH DATABASE ':memory:' AS new_database")
self.validate_identity(
"SELECT {'yes': 'duck', 'maybe': 'goose', 'huh': NULL, 'no': 'heron'}"
)
self.validate_identity(
"SELECT a['x space'] FROM (SELECT {'x space': 1, 'y': 2, 'z': 3} AS a)"
)
self.validate_all("0b1010", write={"": "0 AS b1010"})
self.validate_all("0x1010", write={"": "0 AS x1010"})
self.validate_all(
"""SELECT DATEDIFF('day', t1."A", t1."B") FROM "table" AS t1""",
write={
"duckdb": """SELECT DATE_DIFF('day', t1."A", t1."B") FROM "table" AS t1""",
"trino": """SELECT DATE_DIFF('day', t1."A", t1."B") FROM "table" AS t1""",
},
)
self.validate_all(
"SELECT DATE_DIFF('day', DATE '2020-01-01', DATE '2020-01-05')",
write={
"duckdb": "SELECT DATE_DIFF('day', CAST('2020-01-01' AS DATE), CAST('2020-01-05' AS DATE))",
"trino": "SELECT DATE_DIFF('day', CAST('2020-01-01' AS DATE), CAST('2020-01-05' AS DATE))",
},
)
self.validate_all("x ~ y", write={"duckdb": "REGEXP_MATCHES(x, y)"})
self.validate_all("SELECT * FROM 'x.y'", write={"duckdb": 'SELECT * FROM "x.y"'})
self.validate_all(

View file

@ -12,6 +12,7 @@ class TestMySQL(Validator):
"duckdb": "CREATE TABLE z (a INT)",
"mysql": "CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'",
"spark": "CREATE TABLE z (a INT) COMMENT 'x'",
"sqlite": "CREATE TABLE z (a INTEGER)",
},
)
self.validate_all(
@ -24,6 +25,19 @@ class TestMySQL(Validator):
"INSERT INTO x VALUES (1, 'a', 2.0) ON DUPLICATE KEY UPDATE SET x.id = 1"
)
self.validate_all(
"CREATE TABLE x (id int not null auto_increment, primary key (id))",
write={
"sqlite": "CREATE TABLE x (id INTEGER NOT NULL AUTOINCREMENT PRIMARY KEY)",
},
)
self.validate_all(
"CREATE TABLE x (id int not null auto_increment)",
write={
"sqlite": "CREATE TABLE x (id INTEGER NOT NULL)",
},
)
def test_identity(self):
self.validate_identity("SELECT CURRENT_TIMESTAMP(6)")
self.validate_identity("x ->> '$.name'")
@ -150,47 +164,81 @@ class TestMySQL(Validator):
)
def test_hexadecimal_literal(self):
self.validate_all(
"SELECT 0xCC",
write={
"mysql": "SELECT x'CC'",
"sqlite": "SELECT x'CC'",
"spark": "SELECT X'CC'",
"trino": "SELECT X'CC'",
"bigquery": "SELECT 0xCC",
"oracle": "SELECT 204",
},
)
self.validate_all(
"SELECT X'1A'",
write={
"mysql": "SELECT x'1A'",
},
)
self.validate_all(
"SELECT 0xz",
write={
"mysql": "SELECT `0xz`",
},
)
write_CC = {
"bigquery": "SELECT 0xCC",
"clickhouse": "SELECT 0xCC",
"databricks": "SELECT 204",
"drill": "SELECT 204",
"duckdb": "SELECT 204",
"hive": "SELECT 204",
"mysql": "SELECT x'CC'",
"oracle": "SELECT 204",
"postgres": "SELECT x'CC'",
"presto": "SELECT 204",
"redshift": "SELECT 204",
"snowflake": "SELECT x'CC'",
"spark": "SELECT X'CC'",
"sqlite": "SELECT x'CC'",
"starrocks": "SELECT x'CC'",
"tableau": "SELECT 204",
"teradata": "SELECT 204",
"trino": "SELECT X'CC'",
"tsql": "SELECT 0xCC",
}
write_CC_with_leading_zeros = {
"bigquery": "SELECT 0x0000CC",
"clickhouse": "SELECT 0x0000CC",
"databricks": "SELECT 204",
"drill": "SELECT 204",
"duckdb": "SELECT 204",
"hive": "SELECT 204",
"mysql": "SELECT x'0000CC'",
"oracle": "SELECT 204",
"postgres": "SELECT x'0000CC'",
"presto": "SELECT 204",
"redshift": "SELECT 204",
"snowflake": "SELECT x'0000CC'",
"spark": "SELECT X'0000CC'",
"sqlite": "SELECT x'0000CC'",
"starrocks": "SELECT x'0000CC'",
"tableau": "SELECT 204",
"teradata": "SELECT 204",
"trino": "SELECT X'0000CC'",
"tsql": "SELECT 0x0000CC",
}
self.validate_all("SELECT X'1A'", write={"mysql": "SELECT x'1A'"})
self.validate_all("SELECT 0xz", write={"mysql": "SELECT `0xz`"})
self.validate_all("SELECT 0xCC", write=write_CC)
self.validate_all("SELECT 0xCC ", write=write_CC)
self.validate_all("SELECT x'CC'", write=write_CC)
self.validate_all("SELECT 0x0000CC", write=write_CC_with_leading_zeros)
self.validate_all("SELECT x'0000CC'", write=write_CC_with_leading_zeros)
def test_bits_literal(self):
self.validate_all(
"SELECT 0b1011",
write={
"mysql": "SELECT b'1011'",
"postgres": "SELECT b'1011'",
"oracle": "SELECT 11",
},
)
self.validate_all(
"SELECT B'1011'",
write={
"mysql": "SELECT b'1011'",
"postgres": "SELECT b'1011'",
"oracle": "SELECT 11",
},
)
write_1011 = {
"bigquery": "SELECT 11",
"clickhouse": "SELECT 0b1011",
"databricks": "SELECT 11",
"drill": "SELECT 11",
"hive": "SELECT 11",
"mysql": "SELECT b'1011'",
"oracle": "SELECT 11",
"postgres": "SELECT b'1011'",
"presto": "SELECT 11",
"redshift": "SELECT 11",
"snowflake": "SELECT 11",
"spark": "SELECT 11",
"sqlite": "SELECT 11",
"mysql": "SELECT b'1011'",
"tableau": "SELECT 11",
"teradata": "SELECT 11",
"trino": "SELECT 11",
"tsql": "SELECT 11",
}
self.validate_all("SELECT 0b1011", write=write_1011)
self.validate_all("SELECT b'1011'", write=write_1011)
def test_string_literals(self):
self.validate_all(

View file

@ -5,6 +5,8 @@ class TestOracle(Validator):
dialect = "oracle"
def test_oracle(self):
self.validate_identity("SELECT * FROM table_name@dblink_name.database_link_domain")
self.validate_identity("SELECT * FROM table_name SAMPLE (25) s")
self.validate_identity("SELECT * FROM V$SESSION")
self.validate_identity(
"SELECT MIN(column_name) KEEP (DENSE_RANK FIRST ORDER BY column_name DESC) FROM table_name"
@ -17,7 +19,6 @@ class TestOracle(Validator):
"": "IFNULL(NULL, 1)",
},
)
self.validate_all(
"DATE '2022-01-01'",
write={
@ -28,6 +29,21 @@ class TestOracle(Validator):
},
)
self.validate_all(
"x::binary_double",
write={
"oracle": "CAST(x AS DOUBLE PRECISION)",
"": "CAST(x AS DOUBLE)",
},
)
self.validate_all(
"x::binary_float",
write={
"oracle": "CAST(x AS FLOAT)",
"": "CAST(x AS FLOAT)",
},
)
def test_join_marker(self):
self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y (+) = e2.y")
self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)")

View file

@ -98,6 +98,21 @@ class TestPostgres(Validator):
self.validate_identity("STRING_AGG(x, ',' ORDER BY y DESC)")
self.validate_identity("STRING_AGG(DISTINCT x, ',' ORDER BY y DESC)")
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END")
self.validate_identity("COMMENT ON TABLE mytable IS 'this'")
self.validate_identity("SELECT e'\\xDEADBEEF'")
self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')")
self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')")
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
self.validate_identity("""SELECT * FROM JSON_TO_RECORDSET(z) AS y("rank" INT)""")
self.validate_identity("x ~ 'y'")
self.validate_identity("x ~* 'y'")
self.validate_identity(
"SELECT SUM(x) OVER a, SUM(y) OVER b FROM c WINDOW a AS (PARTITION BY d), b AS (PARTITION BY e)"
)
self.validate_identity(
"CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)"
)
self.validate_identity(
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END"
)
@ -107,37 +122,31 @@ class TestPostgres(Validator):
self.validate_identity(
'SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')'
)
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')")
self.validate_identity(
"SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')"
)
self.validate_identity(
"SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))"
)
self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')")
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
self.validate_identity(
"SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')"
)
self.validate_identity("COMMENT ON TABLE mytable IS 'this'")
self.validate_identity("SELECT e'\\xDEADBEEF'")
self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
self.validate_all(
"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY amount)",
write={
"databricks": "SELECT PERCENTILE_APPROX(amount, 0.5)",
"presto": "SELECT APPROX_PERCENTILE(amount, 0.5)",
"spark": "SELECT PERCENTILE_APPROX(amount, 0.5)",
"trino": "SELECT APPROX_PERCENTILE(amount, 0.5)",
},
)
self.validate_all(
"e'x'",
write={
"mysql": "x",
},
)
self.validate_identity("""SELECT * FROM JSON_TO_RECORDSET(z) AS y("rank" INT)""")
self.validate_identity(
"SELECT SUM(x) OVER a, SUM(y) OVER b FROM c WINDOW a AS (PARTITION BY d), b AS (PARTITION BY e)"
)
self.validate_identity(
"CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)"
)
self.validate_identity("x ~ 'y'")
self.validate_identity("x ~* 'y'")
self.validate_all(
"SELECT DATE_PART('isodow'::varchar(6), current_date)",
write={
@ -197,6 +206,33 @@ class TestPostgres(Validator):
"trino": "SEQUENCE(TRY_CAST('2019-01-01' AS TIMESTAMP), CAST(CURRENT_TIMESTAMP AS TIMESTAMP), INTERVAL '1' day)",
},
)
self.validate_all(
"GENERATE_SERIES(a, b)",
write={
"postgres": "GENERATE_SERIES(a, b)",
"presto": "SEQUENCE(a, b)",
"trino": "SEQUENCE(a, b)",
"tsql": "GENERATE_SERIES(a, b)",
},
)
self.validate_all(
"GENERATE_SERIES(a, b)",
read={
"postgres": "GENERATE_SERIES(a, b)",
"presto": "SEQUENCE(a, b)",
"trino": "SEQUENCE(a, b)",
"tsql": "GENERATE_SERIES(a, b)",
},
)
self.validate_all(
"SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4)",
write={
"postgres": "SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4)",
"presto": "SELECT * FROM t CROSS JOIN UNNEST(SEQUENCE(2, 4))",
"trino": "SELECT * FROM t CROSS JOIN UNNEST(SEQUENCE(2, 4))",
"tsql": "SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4)",
},
)
self.validate_all(
"END WORK AND NO CHAIN",
write={"postgres": "COMMIT AND NO CHAIN"},
@ -464,6 +500,14 @@ class TestPostgres(Validator):
},
)
self.validate_all(
"x / y ^ z",
write={
"": "x / POWER(y, z)",
"postgres": "x / y ^ z",
},
)
self.assertIsInstance(parse_one("id::UUID", read="postgres"), exp.TryCast)
def test_bool_or(self):

View file

@ -6,6 +6,26 @@ class TestPresto(Validator):
dialect = "presto"
def test_cast(self):
self.validate_all(
"FROM_BASE64(x)",
read={
"hive": "UNBASE64(x)",
},
write={
"hive": "UNBASE64(x)",
"presto": "FROM_BASE64(x)",
},
)
self.validate_all(
"TO_BASE64(x)",
read={
"hive": "BASE64(x)",
},
write={
"hive": "BASE64(x)",
"presto": "TO_BASE64(x)",
},
)
self.validate_all(
"CAST(a AS ARRAY(INT))",
write={
@ -105,6 +125,13 @@ class TestPresto(Validator):
"spark": "SIZE(x)",
},
)
self.validate_all(
"ARRAY_JOIN(x, '-', 'a')",
write={
"hive": "CONCAT_WS('-', x)",
"spark": "ARRAY_JOIN(x, '-', 'a')",
},
)
def test_interval_plural_to_singular(self):
# Microseconds, weeks and quarters are not supported in Presto/Trino INTERVAL literals
@ -133,6 +160,14 @@ class TestPresto(Validator):
self.validate_identity("TRIM(a, b)")
self.validate_identity("VAR_POP(a)")
self.validate_all(
"SELECT FROM_UNIXTIME(col) FROM tbl",
write={
"presto": "SELECT FROM_UNIXTIME(col) FROM tbl",
"spark": "SELECT CAST(FROM_UNIXTIME(col) AS TIMESTAMP) FROM tbl",
"trino": "SELECT FROM_UNIXTIME(col) FROM tbl",
},
)
self.validate_all(
"DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')",
write={
@ -181,7 +216,7 @@ class TestPresto(Validator):
"duckdb": "TO_TIMESTAMP(x)",
"presto": "FROM_UNIXTIME(x)",
"hive": "FROM_UNIXTIME(x)",
"spark": "FROM_UNIXTIME(x)",
"spark": "CAST(FROM_UNIXTIME(x) AS TIMESTAMP)",
},
)
self.validate_all(
@ -583,6 +618,14 @@ class TestPresto(Validator):
},
)
self.validate_all(
"JSON_FORMAT(JSON 'x')",
write={
"presto": "JSON_FORMAT(CAST('x' AS JSON))",
"spark": "TO_JSON('x')",
},
)
def test_encode_decode(self):
self.validate_all(
"TO_UTF8(x)",

View file

@ -101,7 +101,22 @@ class TestRedshift(Validator):
self.validate_all(
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
write={
"bigquery": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1",
"databricks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1",
"drill": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1",
"hive": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1",
"mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE `_row_number` = 1",
"oracle": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
"presto": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
"redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1',
"snowflake": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1',
"spark": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1",
"sqlite": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
"starrocks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE `_row_number` = 1",
"tableau": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
"teradata": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
"trino": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
"tsql": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
},
)
self.validate_all(

View file

@ -227,7 +227,7 @@ class TestSnowflake(Validator):
write={
"bigquery": "SELECT UNIX_TO_TIME(1659981729)",
"snowflake": "SELECT TO_TIMESTAMP(1659981729)",
"spark": "SELECT FROM_UNIXTIME(1659981729)",
"spark": "SELECT CAST(FROM_UNIXTIME(1659981729) AS TIMESTAMP)",
},
)
self.validate_all(
@ -243,7 +243,7 @@ class TestSnowflake(Validator):
write={
"bigquery": "SELECT UNIX_TO_TIME('1659981729')",
"snowflake": "SELECT TO_TIMESTAMP('1659981729')",
"spark": "SELECT FROM_UNIXTIME('1659981729')",
"spark": "SELECT CAST(FROM_UNIXTIME('1659981729') AS TIMESTAMP)",
},
)
self.validate_all(
@ -401,7 +401,7 @@ class TestSnowflake(Validator):
self.validate_all(
r"SELECT FIRST_VALUE(TABLE1.COLUMN1 RESPECT NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1",
write={
"snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
"snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1 RESPECT NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
},
)
self.validate_all(
@ -426,29 +426,18 @@ class TestSnowflake(Validator):
def test_sample(self):
self.validate_identity("SELECT * FROM testtable TABLESAMPLE BERNOULLI (20.3)")
self.validate_identity("SELECT * FROM testtable TABLESAMPLE (100)")
self.validate_identity("SELECT * FROM testtable TABLESAMPLE SYSTEM (3) SEED (82)")
self.validate_identity("SELECT * FROM testtable TABLESAMPLE (10 ROWS)")
self.validate_identity("SELECT * FROM testtable SAMPLE (10)")
self.validate_identity("SELECT * FROM testtable SAMPLE ROW (0)")
self.validate_identity("SELECT a FROM test SAMPLE BLOCK (0.5) SEED (42)")
self.validate_identity(
"SELECT i, j FROM table1 AS t1 INNER JOIN table2 AS t2 TABLESAMPLE (50) WHERE t2.j = t1.i"
)
self.validate_identity(
"SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE (1)"
)
self.validate_identity("SELECT * FROM testtable TABLESAMPLE SYSTEM (3) SEED (82)")
self.validate_identity("SELECT * FROM testtable TABLESAMPLE (10 ROWS)")
self.validate_all(
"SELECT * FROM testtable SAMPLE (10)",
write={"snowflake": "SELECT * FROM testtable TABLESAMPLE (10)"},
)
self.validate_all(
"SELECT * FROM testtable SAMPLE ROW (0)",
write={"snowflake": "SELECT * FROM testtable TABLESAMPLE ROW (0)"},
)
self.validate_all(
"SELECT a FROM test SAMPLE BLOCK (0.5) SEED (42)",
write={
"snowflake": "SELECT a FROM test TABLESAMPLE BLOCK (0.5) SEED (42)",
},
)
self.validate_all(
"""
SELECT i, j
@ -458,13 +447,13 @@ class TestSnowflake(Validator):
table2 AS t2 SAMPLE (50) -- 50% of rows in table2
WHERE t2.j = t1.i""",
write={
"snowflake": "SELECT i, j FROM table1 AS t1 TABLESAMPLE (25) /* 25% of rows in table1 */ INNER JOIN table2 AS t2 TABLESAMPLE (50) /* 50% of rows in table2 */ WHERE t2.j = t1.i",
"snowflake": "SELECT i, j FROM table1 AS t1 SAMPLE (25) /* 25% of rows in table1 */ INNER JOIN table2 AS t2 SAMPLE (50) /* 50% of rows in table2 */ WHERE t2.j = t1.i",
},
)
self.validate_all(
"SELECT * FROM testtable SAMPLE BLOCK (0.012) REPEATABLE (99992)",
write={
"snowflake": "SELECT * FROM testtable TABLESAMPLE BLOCK (0.012) SEED (99992)",
"snowflake": "SELECT * FROM testtable SAMPLE BLOCK (0.012) SEED (99992)",
},
)

View file

@ -215,19 +215,37 @@ TBLPROPERTIES (
self.validate_identity("SPLIT(str, pattern, lim)")
self.validate_all(
"BOOLEAN(x)",
write={
"": "CAST(x AS BOOLEAN)",
"spark": "CAST(x AS BOOLEAN)",
"SELECT * FROM produce PIVOT(SUM(produce.sales) FOR quarter IN ('Q1', 'Q2'))",
read={
"snowflake": "SELECT * FROM produce PIVOT (SUM(produce.sales) FOR produce.quarter IN ('Q1', 'Q2'))",
},
)
self.validate_all(
"INT(x)",
write={
"": "CAST(x AS INT)",
"spark": "CAST(x AS INT)",
"SELECT * FROM produce AS p PIVOT(SUM(p.sales) AS sales FOR quarter IN ('Q1' AS Q1, 'Q2' AS Q1))",
read={
"bigquery": "SELECT * FROM produce AS p PIVOT(SUM(p.sales) AS sales FOR p.quarter IN ('Q1' AS Q1, 'Q2' AS Q1))",
},
)
self.validate_all(
"SELECT DATEDIFF(MONTH, '2020-01-01', '2020-03-05')",
write={
"databricks": "SELECT DATEDIFF(MONTH, TO_DATE('2020-01-01'), TO_DATE('2020-03-05'))",
"hive": "SELECT MONTHS_BETWEEN(TO_DATE('2020-03-05'), TO_DATE('2020-01-01'))",
"presto": "SELECT DATE_DIFF('MONTH', CAST(SUBSTR(CAST('2020-01-01' AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST('2020-03-05' AS VARCHAR), 1, 10) AS DATE))",
"spark": "SELECT DATEDIFF(MONTH, TO_DATE('2020-01-01'), TO_DATE('2020-03-05'))",
"spark2": "SELECT MONTHS_BETWEEN(TO_DATE('2020-03-05'), TO_DATE('2020-01-01'))",
"trino": "SELECT DATE_DIFF('MONTH', CAST(SUBSTR(CAST('2020-01-01' AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST('2020-03-05' AS VARCHAR), 1, 10) AS DATE))",
},
)
for data_type in ("BOOLEAN", "DATE", "DOUBLE", "FLOAT", "INT", "TIMESTAMP"):
self.validate_all(
f"{data_type}(x)",
write={
"": f"CAST(x AS {data_type})",
"spark": f"CAST(x AS {data_type})",
},
)
self.validate_all(
"STRING(x)",
write={
@ -235,20 +253,7 @@ TBLPROPERTIES (
"spark": "CAST(x AS STRING)",
},
)
self.validate_all(
"DATE(x)",
write={
"": "CAST(x AS DATE)",
"spark": "CAST(x AS DATE)",
},
)
self.validate_all(
"TIMESTAMP(x)",
write={
"": "CAST(x AS TIMESTAMP)",
"spark": "CAST(x AS TIMESTAMP)",
},
)
self.validate_all(
"CAST(x AS TIMESTAMP)", read={"trino": "CAST(x AS TIMESTAMP(6) WITH TIME ZONE)"}
)

View file

@ -10,3 +10,11 @@ class TestMySQL(Validator):
def test_time(self):
self.validate_identity("TIMESTAMP('2022-01-01')")
def test_regex(self):
self.validate_all(
"SELECT REGEXP_LIKE(abc, '%foo%')",
write={
"starrocks": "SELECT REGEXP(abc, '%foo%')",
},
)

View file

@ -485,26 +485,30 @@ WHERE
def test_date_diff(self):
self.validate_identity("SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')")
self.validate_all(
"SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')",
write={
"tsql": "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')",
"spark": "SELECT MONTHS_BETWEEN('2021/01/01', '2020/01/01') / 12",
"spark": "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')",
"spark2": "SELECT MONTHS_BETWEEN('2021/01/01', '2020/01/01') / 12",
},
)
self.validate_all(
"SELECT DATEDIFF(mm, 'start','end')",
write={
"spark": "SELECT MONTHS_BETWEEN('end', 'start')",
"tsql": "SELECT DATEDIFF(month, 'start', 'end')",
"databricks": "SELECT DATEDIFF(month, 'start', 'end')",
"spark2": "SELECT MONTHS_BETWEEN('end', 'start')",
"tsql": "SELECT DATEDIFF(month, 'start', 'end')",
},
)
self.validate_all(
"SELECT DATEDIFF(quarter, 'start', 'end')",
write={
"spark": "SELECT MONTHS_BETWEEN('end', 'start') / 3",
"databricks": "SELECT DATEDIFF(quarter, 'start', 'end')",
"spark": "SELECT DATEDIFF(quarter, 'start', 'end')",
"spark2": "SELECT MONTHS_BETWEEN('end', 'start') / 3",
"tsql": "SELECT DATEDIFF(quarter, 'start', 'end')",
},
)

View file

@ -85,6 +85,7 @@ x IS TRUE
x IS FALSE
x IS TRUE IS TRUE
x LIKE y IS TRUE
TRIM('a' || 'b')
MAP()
GREATEST(x)
LEAST(y)
@ -104,6 +105,7 @@ ARRAY(time, foo)
ARRAY(foo, time)
ARRAY(LENGTH(waiter_name) > 0)
ARRAY_CONTAINS(x, 1)
x.EXTRACT(1)
EXTRACT(x FROM y)
EXTRACT(DATE FROM y)
EXTRACT(WEEK(monday) FROM created_at)
@ -215,6 +217,7 @@ SELECT COUNT(DISTINCT a, b)
SELECT COUNT(DISTINCT a, b + 1)
SELECT SUM(DISTINCT x)
SELECT SUM(x IGNORE NULLS) AS x
SELECT COUNT(x RESPECT NULLS)
SELECT TRUNCATE(a, b)
SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x
SELECT ARRAY_AGG(STRUCT(x, x AS y) ORDER BY z DESC) AS x
@ -820,3 +823,6 @@ JSON_OBJECT('x': 1 RETURNING VARBINARY FORMAT JSON ENCODING UTF8)
SELECT if.x
SELECT NEXT VALUE FOR db.schema.sequence_name
SELECT NEXT VALUE FOR db.schema.sequence_name OVER (ORDER BY foo), col
SELECT PERCENTILE_CONT(x, 0.5) OVER ()
SELECT PERCENTILE_CONT(x, 0.5 RESPECT NULLS) OVER ()
SELECT PERCENTILE_CONT(x, 0.5 IGNORE NULLS) OVER ()

View file

@ -4,6 +4,9 @@
SELECT a FROM x;
SELECT x.a AS a FROM x AS x;
SELECT "a" FROM x;
SELECT x."a" AS "a" FROM x AS x;
# execute: false
SELECT a FROM zz GROUP BY a ORDER BY a;
SELECT zz.a AS a FROM zz AS zz GROUP BY zz.a ORDER BY a;
@ -212,6 +215,10 @@ SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT x.b AS b FROM y AS x);
SELECT a FROM x AS i WHERE b IN (SELECT b FROM y AS j WHERE j.b IN (SELECT c FROM y AS k WHERE k.b = j.b));
SELECT i.a AS a FROM x AS i WHERE i.b IN (SELECT j.b AS b FROM y AS j WHERE j.b IN (SELECT k.c AS c FROM y AS k WHERE k.b = j.b));
# execute: false
SELECT (SELECT n.a FROM n WHERE n.id = m.id) FROM m AS m;
SELECT (SELECT n.a AS a FROM n AS n WHERE n.id = m.id) AS _col_0 FROM m AS m;
--------------------------------------
-- Expand *
--------------------------------------

View file

@ -15,3 +15,26 @@ WITH a AS (SELECT 1 FROM c.db.z AS z) SELECT 1 FROM a;
SELECT (SELECT y.c FROM y AS y) FROM x;
SELECT (SELECT y.c FROM c.db.y AS y) FROM c.db.x AS x;
-------------------------
-- Expand join constructs
-------------------------
-- This is valid in Trino, so we treat the (tbl AS tbl) as a "join construct" per postgres' terminology.
SELECT * FROM (tbl AS tbl) AS _q_0;
SELECT * FROM (SELECT * FROM c.db.tbl AS tbl) AS _q_0;
SELECT * FROM ((tbl AS tbl)) AS _q_0;
SELECT * FROM (SELECT * FROM c.db.tbl AS tbl) AS _q_0;
SELECT * FROM (((tbl AS tbl))) AS _q_0;
SELECT * FROM (SELECT * FROM c.db.tbl AS tbl) AS _q_0;
SELECT * FROM (tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2 JOIN tbl3 AS tbl3 ON id1 = id3) AS _q_0;
SELECT * FROM (SELECT * FROM c.db.tbl1 AS tbl1 JOIN c.db.tbl2 AS tbl2 ON id1 = id2 JOIN c.db.tbl3 AS tbl3 ON id1 = id3) AS _q_0;
SELECT * FROM ((tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2 JOIN tbl3 AS tbl3 ON id1 = id3)) AS _q_0;
SELECT * FROM (SELECT * FROM c.db.tbl1 AS tbl1 JOIN c.db.tbl2 AS tbl2 ON id1 = id2 JOIN c.db.tbl3 AS tbl3 ON id1 = id3) AS _q_0;
SELECT * FROM (tbl1 AS tbl1 JOIN (tbl2 AS tbl2 JOIN tbl3 AS tbl3 ON id2 = id3) AS _q_0 ON id1 = id3) AS _q_1;
SELECT * FROM (SELECT * FROM c.db.tbl1 AS tbl1 JOIN (SELECT * FROM c.db.tbl2 AS tbl2 JOIN c.db.tbl3 AS tbl3 ON id2 = id3) AS _q_0 ON id1 = id3) AS _q_1;

View file

@ -6385,6 +6385,14 @@ WITH "tmp1" AS (
"item"."i_brand" IN ('amalgimporto #1', 'edu packscholar #1', 'exportiimporto #1', 'importoamalg #1')
OR "item"."i_class" IN ('personal', 'portable', 'reference', 'self-help')
)
AND (
"item"."i_brand" IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')
OR "item"."i_category" IN ('Women', 'Music', 'Men')
)
AND (
"item"."i_brand" IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')
OR "item"."i_class" IN ('accessories', 'classical', 'fragrances', 'pants')
)
AND (
"item"."i_category" IN ('Books', 'Children', 'Electronics')
OR "item"."i_category" IN ('Women', 'Music', 'Men')
@ -7589,6 +7597,14 @@ WITH "tmp1" AS (
"item"."i_brand" IN ('amalgimporto #1', 'edu packscholar #1', 'exportiimporto #1', 'importoamalg #1')
OR "item"."i_class" IN ('personal', 'portable', 'reference', 'self-help')
)
AND (
"item"."i_brand" IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')
OR "item"."i_category" IN ('Women', 'Music', 'Men')
)
AND (
"item"."i_brand" IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')
OR "item"."i_class" IN ('accessories', 'classical', 'fragrances', 'pants')
)
AND (
"item"."i_category" IN ('Books', 'Children', 'Electronics')
OR "item"."i_category" IN ('Women', 'Music', 'Men')

View file

@ -19,6 +19,11 @@ from sqlglot import (
class TestBuild(unittest.TestCase):
def test_build(self):
x = condition("x")
x_plus_one = x + 1
# Make sure we're not mutating x by changing its parent to be x_plus_one
self.assertIsNone(x.parent)
self.assertNotEqual(id(x_plus_one.this), id(x))
for expression, sql, *dialect in [
(lambda: x + 1, "x + 1"),
@ -51,6 +56,7 @@ class TestBuild(unittest.TestCase):
(lambda: x.neq(1), "x <> 1"),
(lambda: x.isin(1, "2"), "x IN (1, '2')"),
(lambda: x.isin(query="select 1"), "x IN (SELECT 1)"),
(lambda: x.between(1, 2), "x BETWEEN 1 AND 2"),
(lambda: 1 + x + 2 + 3, "1 + x + 2 + 3"),
(lambda: 1 + x * 2 + 3, "1 + (x * 2) + 3"),
(lambda: x * 1 * 2 + 3, "(x * 1 * 2) + 3"),
@ -137,10 +143,14 @@ class TestBuild(unittest.TestCase):
"SELECT x, y, z, a FROM tbl GROUP BY x, y, z, a",
),
(
lambda: select("x").distinct(True).from_("tbl"),
lambda: select("x").distinct("a", "b").from_("tbl"),
"SELECT DISTINCT ON (a, b) x FROM tbl",
),
(
lambda: select("x").distinct(distinct=True).from_("tbl"),
"SELECT DISTINCT x FROM tbl",
),
(lambda: select("x").distinct(False).from_("tbl"), "SELECT x FROM tbl"),
(lambda: select("x").distinct(distinct=False).from_("tbl"), "SELECT x FROM tbl"),
(
lambda: select("x").lateral("OUTER explode(y) tbl2 AS z").from_("tbl"),
"SELECT x FROM tbl LATERAL VIEW OUTER EXPLODE(y) tbl2 AS z",
@ -583,6 +593,11 @@ class TestBuild(unittest.TestCase):
"DELETE FROM tbl WHERE x = 1 RETURNING *",
"postgres",
),
(
lambda: exp.convert((exp.column("x"), exp.column("y"))).isin((1, 2), (3, 4)),
"(x, y) IN ((1, 2), (3, 4))",
"postgres",
),
]:
with self.subTest(sql):
self.assertEqual(expression().sql(dialect[0] if dialect else None), sql)

View file

@ -297,6 +297,9 @@ class TestExpressions(unittest.TestCase):
expression = parse_one("SELECT a, b FROM x")
self.assertEqual([s.sql() for s in expression.selects], ["a", "b"])
expression = parse_one("(SELECT a, b FROM x)")
self.assertEqual([s.sql() for s in expression.selects], ["a", "b"])
def test_alias_column_names(self):
expression = parse_one("SELECT * FROM (SELECT * FROM x) AS y")
subquery = expression.find(exp.Subquery)
@ -761,7 +764,7 @@ FROM foo""",
"t",
{"a": exp.DataType.build("TEXT"), "b": exp.DataType.build("TEXT")},
).sql(),
"(VALUES (CAST(1 AS TEXT), CAST(2 AS TEXT)), (3, 4)) AS t(a, b)",
"(VALUES (1, 2), (3, 4)) AS t(a, b)",
)
with self.assertRaises(ValueError):
exp.values([(1, 2), (3, 4)], columns=["a"])

View file

@ -47,6 +47,7 @@ class TestOptimizer(unittest.TestCase):
@classmethod
def setUpClass(cls):
sqlglot.schema = MappingSchema()
cls.conn = duckdb.connect()
cls.conn.execute(
"""
@ -221,6 +222,12 @@ class TestOptimizer(unittest.TestCase):
self.check_file("pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates)
def test_expand_laterals(self):
# check order of lateral expansion with no schema
self.assertEqual(
optimizer.optimize("SELECT a + 1 AS d, d + 1 AS e FROM x " "").sql(),
'SELECT "x"."a" + 1 AS "d", "x"."a" + 2 AS "e" FROM "x" AS "x"',
)
self.check_file(
"expand_laterals",
optimizer.expand_laterals.expand_laterals,
@ -612,6 +619,12 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
expression = annotate_types(parse_one("CONCAT('A', 'B')"))
self.assertEqual(expression.type.this, exp.DataType.Type.VARCHAR)
def test_root_subquery_annotation(self):
expression = annotate_types(parse_one("(SELECT 1, 2 FROM x) LIMIT 0"))
self.assertIsInstance(expression, exp.Subquery)
self.assertEqual(exp.DataType.Type.INT, expression.selects[0].type.this)
self.assertEqual(exp.DataType.Type.INT, expression.selects[1].type.this)
def test_recursive_cte(self):
query = parse_one(
"""

View file

@ -102,3 +102,14 @@ x"""
(TokenType.SEMICOLON, ";"),
],
)
tokens = tokenizer.tokenize("""'{{ var('x') }}'""")
tokens = [(token.token_type, token.text) for token in tokens]
self.assertEqual(
tokens,
[
(TokenType.STRING, "{{ var("),
(TokenType.VAR, "x"),
(TokenType.STRING, ") }}"),
],
)

View file

@ -263,6 +263,18 @@ FROM v""",
"(/* 1 */ 1 ) /* 2 */",
"(1) /* 1 */ /* 2 */",
)
self.validate(
"select * from t where not a in (23) /*test*/ and b in (14)",
"SELECT * FROM t WHERE NOT a IN (23) /* test */ AND b IN (14)",
)
self.validate(
"select * from t where a in (23) /*test*/ and b in (14)",
"SELECT * FROM t WHERE a IN (23) /* test */ AND b IN (14)",
)
self.validate(
"select * from t where ((condition = 1)/*test*/)",
"SELECT * FROM t WHERE ((condition = 1) /* test */)",
)
def test_types(self):
self.validate("INT 1", "CAST(1 AS INT)")
@ -324,9 +336,6 @@ FROM v""",
)
self.validate("SELECT IF(a > 1, 1) FROM foo", "SELECT CASE WHEN a > 1 THEN 1 END FROM foo")
def test_ignore_nulls(self):
self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)")
def test_with(self):
self.validate(
"WITH a AS (SELECT 1) WITH b AS (SELECT 2) SELECT *",
@ -482,7 +491,7 @@ FROM v""",
self.validate("UNIX_TO_STR(123, 'y')", "FROM_UNIXTIME(123, 'y')", write="spark")
self.validate(
"UNIX_TO_TIME(123)",
"FROM_UNIXTIME(123)",
"CAST(FROM_UNIXTIME(123) AS TIMESTAMP)",
write="spark",
)
self.validate(