Merging upstream version 10.0.8.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
407314e8d2
commit
efc1e37108
67 changed files with 2461 additions and 840 deletions
|
@ -4,6 +4,8 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
|
|||
|
||||
|
||||
class TestDataframe(DataFrameSQLValidator):
|
||||
maxDiff = None
|
||||
|
||||
def test_hash_select_expression(self):
|
||||
expression = exp.select("cola").from_("table")
|
||||
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
|
||||
|
@ -16,26 +18,26 @@ class TestDataframe(DataFrameSQLValidator):
|
|||
def test_cache(self):
|
||||
df = self.df_employee.select("fname").cache()
|
||||
expected_statements = [
|
||||
"DROP VIEW IF EXISTS t11623",
|
||||
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
|
||||
"DROP VIEW IF EXISTS t31563",
|
||||
"CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||
"SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
|
||||
]
|
||||
self.compare_sql(df, expected_statements)
|
||||
|
||||
def test_persist_default(self):
|
||||
df = self.df_employee.select("fname").persist()
|
||||
expected_statements = [
|
||||
"DROP VIEW IF EXISTS t11623",
|
||||
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
|
||||
"DROP VIEW IF EXISTS t31563",
|
||||
"CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||
"SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
|
||||
]
|
||||
self.compare_sql(df, expected_statements)
|
||||
|
||||
def test_persist_storagelevel(self):
|
||||
df = self.df_employee.select("fname").persist("DISK_ONLY_2")
|
||||
expected_statements = [
|
||||
"DROP VIEW IF EXISTS t11623",
|
||||
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
|
||||
"DROP VIEW IF EXISTS t31563",
|
||||
"CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||
"SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
|
||||
]
|
||||
self.compare_sql(df, expected_statements)
|
||||
|
|
|
@ -6,39 +6,41 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
|
|||
|
||||
|
||||
class TestDataFrameWriter(DataFrameSQLValidator):
|
||||
maxDiff = None
|
||||
|
||||
def test_insertInto_full_path(self):
|
||||
df = self.df_employee.write.insertInto("catalog.db.table_name")
|
||||
expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_insertInto_db_table(self):
|
||||
df = self.df_employee.write.insertInto("db.table_name")
|
||||
expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_insertInto_table(self):
|
||||
df = self.df_employee.write.insertInto("table_name")
|
||||
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_insertInto_overwrite(self):
|
||||
df = self.df_employee.write.insertInto("table_name", overwrite=True)
|
||||
expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
@mock.patch("sqlglot.schema", MappingSchema())
|
||||
def test_insertInto_byName(self):
|
||||
sqlglot.schema.add_table("table_name", {"employee_id": "INT"})
|
||||
df = self.df_employee.write.byName.insertInto("table_name")
|
||||
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_insertInto_cache(self):
|
||||
df = self.df_employee.cache().write.insertInto("table_name")
|
||||
expected_statements = [
|
||||
"DROP VIEW IF EXISTS t35612",
|
||||
"CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||
"INSERT INTO table_name SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`",
|
||||
"DROP VIEW IF EXISTS t37164",
|
||||
"CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||
"INSERT INTO table_name SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`",
|
||||
]
|
||||
self.compare_sql(df, expected_statements)
|
||||
|
||||
|
@ -48,39 +50,39 @@ class TestDataFrameWriter(DataFrameSQLValidator):
|
|||
|
||||
def test_saveAsTable_append(self):
|
||||
df = self.df_employee.write.saveAsTable("table_name", mode="append")
|
||||
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_saveAsTable_overwrite(self):
|
||||
df = self.df_employee.write.saveAsTable("table_name", mode="overwrite")
|
||||
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_saveAsTable_error(self):
|
||||
df = self.df_employee.write.saveAsTable("table_name", mode="error")
|
||||
expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_saveAsTable_ignore(self):
|
||||
df = self.df_employee.write.saveAsTable("table_name", mode="ignore")
|
||||
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_mode_standalone(self):
|
||||
df = self.df_employee.write.mode("ignore").saveAsTable("table_name")
|
||||
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_mode_override(self):
|
||||
df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite")
|
||||
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_saveAsTable_cache(self):
|
||||
df = self.df_employee.cache().write.saveAsTable("table_name")
|
||||
expected_statements = [
|
||||
"DROP VIEW IF EXISTS t35612",
|
||||
"CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||
"CREATE TABLE table_name AS SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`",
|
||||
"DROP VIEW IF EXISTS t37164",
|
||||
"CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||
"CREATE TABLE table_name AS SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`",
|
||||
]
|
||||
self.compare_sql(df, expected_statements)
|
||||
|
|
|
@ -11,32 +11,32 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
|
|||
class TestDataframeSession(DataFrameSQLValidator):
|
||||
def test_cdf_one_row(self):
|
||||
df = self.spark.createDataFrame([[1, 2]], ["cola", "colb"])
|
||||
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2)) AS `a2`(`cola`, `colb`)"
|
||||
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2) AS `a2`(`cola`, `colb`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_cdf_multiple_rows(self):
|
||||
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]], ["cola", "colb"])
|
||||
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`cola`, `colb`)"
|
||||
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`cola`, `colb`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_cdf_no_schema(self):
|
||||
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
|
||||
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
|
||||
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`_1`, `_2`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_cdf_row_mixed_primitives(self):
|
||||
df = self.spark.createDataFrame([[1, 10.1, "test", False, None]])
|
||||
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM (VALUES (1, 10.1, 'test', FALSE, NULL)) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)"
|
||||
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM VALUES (1, 10.1, 'test', FALSE, NULL) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_cdf_dict_rows(self):
|
||||
df = self.spark.createDataFrame([{"cola": 1, "colb": "test"}, {"cola": 2, "colb": "test2"}])
|
||||
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 'test'), (2, 'test2')) AS `a2`(`cola`, `colb`)"
|
||||
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 'test'), (2, 'test2') AS `a2`(`cola`, `colb`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_cdf_str_schema(self):
|
||||
df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING")
|
||||
expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)"
|
||||
expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_typed_schema_basic(self):
|
||||
|
@ -47,7 +47,7 @@ class TestDataframeSession(DataFrameSQLValidator):
|
|||
]
|
||||
)
|
||||
df = self.spark.createDataFrame([[1, "test"]], schema)
|
||||
expected = "SELECT CAST(`a2`.`cola` AS int) AS `cola`, CAST(`a2`.`colb` AS string) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)"
|
||||
expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_typed_schema_nested(self):
|
||||
|
@ -65,7 +65,8 @@ class TestDataframeSession(DataFrameSQLValidator):
|
|||
]
|
||||
)
|
||||
df = self.spark.createDataFrame([[{"sub_cola": 1, "sub_colb": "test"}]], schema)
|
||||
expected = "SELECT CAST(`a2`.`cola` AS struct<sub_cola:int, sub_colb:string>) AS `cola` FROM (VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`))) AS `a2`(`cola`)"
|
||||
expected = "SELECT CAST(`a2`.`cola` AS STRUCT<`sub_cola`: INT, `sub_colb`: STRING>) AS `cola` FROM VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`)) AS `a2`(`cola`)"
|
||||
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
@mock.patch("sqlglot.schema", MappingSchema())
|
||||
|
|
|
@ -286,6 +286,10 @@ class TestBigQuery(Validator):
|
|||
"bigquery": "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))",
|
||||
},
|
||||
)
|
||||
self.validate_identity("BEGIN A B C D E F")
|
||||
self.validate_identity("BEGIN TRANSACTION")
|
||||
self.validate_identity("COMMIT TRANSACTION")
|
||||
self.validate_identity("ROLLBACK TRANSACTION")
|
||||
|
||||
def test_user_defined_functions(self):
|
||||
self.validate_identity(
|
||||
|
|
|
@ -69,6 +69,7 @@ class TestDialect(Validator):
|
|||
write={
|
||||
"bigquery": "CAST(a AS STRING)",
|
||||
"clickhouse": "CAST(a AS TEXT)",
|
||||
"drill": "CAST(a AS VARCHAR)",
|
||||
"duckdb": "CAST(a AS TEXT)",
|
||||
"mysql": "CAST(a AS TEXT)",
|
||||
"hive": "CAST(a AS STRING)",
|
||||
|
@ -86,6 +87,7 @@ class TestDialect(Validator):
|
|||
write={
|
||||
"bigquery": "CAST(a AS BINARY(4))",
|
||||
"clickhouse": "CAST(a AS BINARY(4))",
|
||||
"drill": "CAST(a AS VARBINARY(4))",
|
||||
"duckdb": "CAST(a AS BINARY(4))",
|
||||
"mysql": "CAST(a AS BINARY(4))",
|
||||
"hive": "CAST(a AS BINARY(4))",
|
||||
|
@ -146,6 +148,7 @@ class TestDialect(Validator):
|
|||
"CAST(a AS STRING)",
|
||||
write={
|
||||
"bigquery": "CAST(a AS STRING)",
|
||||
"drill": "CAST(a AS VARCHAR)",
|
||||
"duckdb": "CAST(a AS TEXT)",
|
||||
"mysql": "CAST(a AS TEXT)",
|
||||
"hive": "CAST(a AS STRING)",
|
||||
|
@ -162,6 +165,7 @@ class TestDialect(Validator):
|
|||
"CAST(a AS VARCHAR)",
|
||||
write={
|
||||
"bigquery": "CAST(a AS STRING)",
|
||||
"drill": "CAST(a AS VARCHAR)",
|
||||
"duckdb": "CAST(a AS TEXT)",
|
||||
"mysql": "CAST(a AS VARCHAR)",
|
||||
"hive": "CAST(a AS STRING)",
|
||||
|
@ -178,6 +182,7 @@ class TestDialect(Validator):
|
|||
"CAST(a AS VARCHAR(3))",
|
||||
write={
|
||||
"bigquery": "CAST(a AS STRING(3))",
|
||||
"drill": "CAST(a AS VARCHAR(3))",
|
||||
"duckdb": "CAST(a AS TEXT(3))",
|
||||
"mysql": "CAST(a AS VARCHAR(3))",
|
||||
"hive": "CAST(a AS VARCHAR(3))",
|
||||
|
@ -194,6 +199,7 @@ class TestDialect(Validator):
|
|||
"CAST(a AS SMALLINT)",
|
||||
write={
|
||||
"bigquery": "CAST(a AS INT64)",
|
||||
"drill": "CAST(a AS INTEGER)",
|
||||
"duckdb": "CAST(a AS SMALLINT)",
|
||||
"mysql": "CAST(a AS SMALLINT)",
|
||||
"hive": "CAST(a AS SMALLINT)",
|
||||
|
@ -215,6 +221,7 @@ class TestDialect(Validator):
|
|||
},
|
||||
write={
|
||||
"duckdb": "TRY_CAST(a AS DOUBLE)",
|
||||
"drill": "CAST(a AS DOUBLE)",
|
||||
"postgres": "CAST(a AS DOUBLE PRECISION)",
|
||||
"redshift": "CAST(a AS DOUBLE PRECISION)",
|
||||
},
|
||||
|
@ -225,6 +232,7 @@ class TestDialect(Validator):
|
|||
write={
|
||||
"bigquery": "CAST(a AS FLOAT64)",
|
||||
"clickhouse": "CAST(a AS Float64)",
|
||||
"drill": "CAST(a AS DOUBLE)",
|
||||
"duckdb": "CAST(a AS DOUBLE)",
|
||||
"mysql": "CAST(a AS DOUBLE)",
|
||||
"hive": "CAST(a AS DOUBLE)",
|
||||
|
@ -279,6 +287,7 @@ class TestDialect(Validator):
|
|||
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')",
|
||||
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)",
|
||||
"presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')",
|
||||
"drill": "TO_TIMESTAMP(x, 'yyyy-MM-dd''T''HH:mm:ss')",
|
||||
"redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH:MI:SS')",
|
||||
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')",
|
||||
},
|
||||
|
@ -286,6 +295,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"STR_TO_TIME('2020-01-01', '%Y-%m-%d')",
|
||||
write={
|
||||
"drill": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')",
|
||||
"duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')",
|
||||
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
|
||||
"oracle": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
|
||||
|
@ -298,6 +308,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"STR_TO_TIME(x, '%y')",
|
||||
write={
|
||||
"drill": "TO_TIMESTAMP(x, 'yy')",
|
||||
"duckdb": "STRPTIME(x, '%y')",
|
||||
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)",
|
||||
"presto": "DATE_PARSE(x, '%y')",
|
||||
|
@ -319,6 +330,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"TIME_STR_TO_DATE('2020-01-01')",
|
||||
write={
|
||||
"drill": "CAST('2020-01-01' AS DATE)",
|
||||
"duckdb": "CAST('2020-01-01' AS DATE)",
|
||||
"hive": "TO_DATE('2020-01-01')",
|
||||
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')",
|
||||
|
@ -328,6 +340,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"TIME_STR_TO_TIME('2020-01-01')",
|
||||
write={
|
||||
"drill": "CAST('2020-01-01' AS TIMESTAMP)",
|
||||
"duckdb": "CAST('2020-01-01' AS TIMESTAMP)",
|
||||
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
|
||||
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')",
|
||||
|
@ -344,6 +357,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"TIME_TO_STR(x, '%Y-%m-%d')",
|
||||
write={
|
||||
"drill": "TO_CHAR(x, 'yyyy-MM-dd')",
|
||||
"duckdb": "STRFTIME(x, '%Y-%m-%d')",
|
||||
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd')",
|
||||
"oracle": "TO_CHAR(x, 'YYYY-MM-DD')",
|
||||
|
@ -355,6 +369,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"TIME_TO_TIME_STR(x)",
|
||||
write={
|
||||
"drill": "CAST(x AS VARCHAR)",
|
||||
"duckdb": "CAST(x AS TEXT)",
|
||||
"hive": "CAST(x AS STRING)",
|
||||
"presto": "CAST(x AS VARCHAR)",
|
||||
|
@ -364,6 +379,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"TIME_TO_UNIX(x)",
|
||||
write={
|
||||
"drill": "UNIX_TIMESTAMP(x)",
|
||||
"duckdb": "EPOCH(x)",
|
||||
"hive": "UNIX_TIMESTAMP(x)",
|
||||
"presto": "TO_UNIXTIME(x)",
|
||||
|
@ -425,6 +441,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"DATE_TO_DATE_STR(x)",
|
||||
write={
|
||||
"drill": "CAST(x AS VARCHAR)",
|
||||
"duckdb": "CAST(x AS TEXT)",
|
||||
"hive": "CAST(x AS STRING)",
|
||||
"presto": "CAST(x AS VARCHAR)",
|
||||
|
@ -433,6 +450,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"DATE_TO_DI(x)",
|
||||
write={
|
||||
"drill": "CAST(TO_DATE(x, 'yyyyMMdd') AS INT)",
|
||||
"duckdb": "CAST(STRFTIME(x, '%Y%m%d') AS INT)",
|
||||
"hive": "CAST(DATE_FORMAT(x, 'yyyyMMdd') AS INT)",
|
||||
"presto": "CAST(DATE_FORMAT(x, '%Y%m%d') AS INT)",
|
||||
|
@ -441,6 +459,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"DI_TO_DATE(x)",
|
||||
write={
|
||||
"drill": "TO_DATE(CAST(x AS VARCHAR), 'yyyyMMdd')",
|
||||
"duckdb": "CAST(STRPTIME(CAST(x AS TEXT), '%Y%m%d') AS DATE)",
|
||||
"hive": "TO_DATE(CAST(x AS STRING), 'yyyyMMdd')",
|
||||
"presto": "CAST(DATE_PARSE(CAST(x AS VARCHAR), '%Y%m%d') AS DATE)",
|
||||
|
@ -463,6 +482,7 @@ class TestDialect(Validator):
|
|||
},
|
||||
write={
|
||||
"bigquery": "DATE_ADD(x, INTERVAL 1 'day')",
|
||||
"drill": "DATE_ADD(x, INTERVAL '1' DAY)",
|
||||
"duckdb": "x + INTERVAL 1 day",
|
||||
"hive": "DATE_ADD(x, 1)",
|
||||
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
|
||||
|
@ -477,6 +497,7 @@ class TestDialect(Validator):
|
|||
"DATE_ADD(x, 1)",
|
||||
write={
|
||||
"bigquery": "DATE_ADD(x, INTERVAL 1 'day')",
|
||||
"drill": "DATE_ADD(x, INTERVAL '1' DAY)",
|
||||
"duckdb": "x + INTERVAL 1 DAY",
|
||||
"hive": "DATE_ADD(x, 1)",
|
||||
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
|
||||
|
@ -546,6 +567,7 @@ class TestDialect(Validator):
|
|||
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
|
||||
},
|
||||
write={
|
||||
"drill": "TO_DATE(x, 'yyyy-MM-dd''T''HH:mm:ss')",
|
||||
"mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
|
||||
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
|
||||
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)",
|
||||
|
@ -556,6 +578,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"STR_TO_DATE(x, '%Y-%m-%d')",
|
||||
write={
|
||||
"drill": "CAST(x AS DATE)",
|
||||
"mysql": "STR_TO_DATE(x, '%Y-%m-%d')",
|
||||
"starrocks": "STR_TO_DATE(x, '%Y-%m-%d')",
|
||||
"hive": "CAST(x AS DATE)",
|
||||
|
@ -566,6 +589,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"DATE_STR_TO_DATE(x)",
|
||||
write={
|
||||
"drill": "CAST(x AS DATE)",
|
||||
"duckdb": "CAST(x AS DATE)",
|
||||
"hive": "TO_DATE(x)",
|
||||
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)",
|
||||
|
@ -575,6 +599,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"TS_OR_DS_ADD('2021-02-01', 1, 'DAY')",
|
||||
write={
|
||||
"drill": "DATE_ADD(CAST('2021-02-01' AS DATE), INTERVAL '1' DAY)",
|
||||
"duckdb": "CAST('2021-02-01' AS DATE) + INTERVAL 1 DAY",
|
||||
"hive": "DATE_ADD('2021-02-01', 1)",
|
||||
"presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR('2021-02-01', 1, 10), '%Y-%m-%d'))",
|
||||
|
@ -584,6 +609,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
|
||||
write={
|
||||
"drill": "DATE_ADD(CAST('2020-01-01' AS DATE), INTERVAL '1' DAY)",
|
||||
"duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY",
|
||||
"hive": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
|
||||
"presto": "DATE_ADD('day', 1, CAST('2020-01-01' AS DATE))",
|
||||
|
@ -593,6 +619,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"TIMESTAMP '2022-01-01'",
|
||||
write={
|
||||
"drill": "CAST('2022-01-01' AS TIMESTAMP)",
|
||||
"mysql": "CAST('2022-01-01' AS TIMESTAMP)",
|
||||
"starrocks": "CAST('2022-01-01' AS DATETIME)",
|
||||
"hive": "CAST('2022-01-01' AS TIMESTAMP)",
|
||||
|
@ -614,6 +641,7 @@ class TestDialect(Validator):
|
|||
dialect: f"{unit}(x)"
|
||||
for dialect in (
|
||||
"bigquery",
|
||||
"drill",
|
||||
"duckdb",
|
||||
"mysql",
|
||||
"presto",
|
||||
|
@ -624,6 +652,7 @@ class TestDialect(Validator):
|
|||
dialect: f"{unit}(x)"
|
||||
for dialect in (
|
||||
"bigquery",
|
||||
"drill",
|
||||
"duckdb",
|
||||
"mysql",
|
||||
"presto",
|
||||
|
@ -649,6 +678,7 @@ class TestDialect(Validator):
|
|||
write={
|
||||
"bigquery": "ARRAY_LENGTH(x)",
|
||||
"duckdb": "ARRAY_LENGTH(x)",
|
||||
"drill": "REPEATED_COUNT(x)",
|
||||
"presto": "CARDINALITY(x)",
|
||||
"spark": "SIZE(x)",
|
||||
},
|
||||
|
@ -736,6 +766,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"SELECT a FROM x CROSS JOIN UNNEST(y) AS t (a)",
|
||||
write={
|
||||
"drill": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)",
|
||||
"presto": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)",
|
||||
"spark": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a",
|
||||
},
|
||||
|
@ -743,6 +774,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t (a, b)",
|
||||
write={
|
||||
"drill": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
|
||||
"presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
|
||||
"spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b",
|
||||
},
|
||||
|
@ -775,6 +807,7 @@ class TestDialect(Validator):
|
|||
},
|
||||
write={
|
||||
"bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b",
|
||||
"drill": "SELECT * FROM a UNION SELECT * FROM b",
|
||||
"duckdb": "SELECT * FROM a UNION SELECT * FROM b",
|
||||
"presto": "SELECT * FROM a UNION SELECT * FROM b",
|
||||
"spark": "SELECT * FROM a UNION SELECT * FROM b",
|
||||
|
@ -887,6 +920,7 @@ class TestDialect(Validator):
|
|||
write={
|
||||
"bigquery": "LOWER(x) LIKE '%y'",
|
||||
"clickhouse": "x ILIKE '%y'",
|
||||
"drill": "x `ILIKE` '%y'",
|
||||
"duckdb": "x ILIKE '%y'",
|
||||
"hive": "LOWER(x) LIKE '%y'",
|
||||
"mysql": "LOWER(x) LIKE '%y'",
|
||||
|
@ -910,32 +944,38 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"POSITION(' ' in x)",
|
||||
write={
|
||||
"drill": "STRPOS(x, ' ')",
|
||||
"duckdb": "STRPOS(x, ' ')",
|
||||
"postgres": "STRPOS(x, ' ')",
|
||||
"presto": "STRPOS(x, ' ')",
|
||||
"spark": "LOCATE(' ', x)",
|
||||
"clickhouse": "position(x, ' ')",
|
||||
"snowflake": "POSITION(' ', x)",
|
||||
"mysql": "LOCATE(' ', x)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"STR_POSITION('a', x)",
|
||||
write={
|
||||
"drill": "STRPOS(x, 'a')",
|
||||
"duckdb": "STRPOS(x, 'a')",
|
||||
"postgres": "STRPOS(x, 'a')",
|
||||
"presto": "STRPOS(x, 'a')",
|
||||
"spark": "LOCATE('a', x)",
|
||||
"clickhouse": "position(x, 'a')",
|
||||
"snowflake": "POSITION('a', x)",
|
||||
"mysql": "LOCATE('a', x)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"POSITION('a', x, 3)",
|
||||
write={
|
||||
"drill": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
|
||||
"presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
|
||||
"spark": "LOCATE('a', x, 3)",
|
||||
"clickhouse": "position(x, 'a', 3)",
|
||||
"snowflake": "POSITION('a', x, 3)",
|
||||
"mysql": "LOCATE('a', x, 3)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
|
@ -960,6 +1000,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"IF(x > 1, 1, 0)",
|
||||
write={
|
||||
"drill": "`IF`(x > 1, 1, 0)",
|
||||
"duckdb": "CASE WHEN x > 1 THEN 1 ELSE 0 END",
|
||||
"presto": "IF(x > 1, 1, 0)",
|
||||
"hive": "IF(x > 1, 1, 0)",
|
||||
|
@ -970,6 +1011,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"CASE WHEN 1 THEN x ELSE 0 END",
|
||||
write={
|
||||
"drill": "CASE WHEN 1 THEN x ELSE 0 END",
|
||||
"duckdb": "CASE WHEN 1 THEN x ELSE 0 END",
|
||||
"presto": "CASE WHEN 1 THEN x ELSE 0 END",
|
||||
"hive": "CASE WHEN 1 THEN x ELSE 0 END",
|
||||
|
@ -980,6 +1022,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"x[y]",
|
||||
write={
|
||||
"drill": "x[y]",
|
||||
"duckdb": "x[y]",
|
||||
"presto": "x[y]",
|
||||
"hive": "x[y]",
|
||||
|
@ -1000,6 +1043,7 @@ class TestDialect(Validator):
|
|||
'true or null as "foo"',
|
||||
write={
|
||||
"bigquery": "TRUE OR NULL AS `foo`",
|
||||
"drill": "TRUE OR NULL AS `foo`",
|
||||
"duckdb": 'TRUE OR NULL AS "foo"',
|
||||
"presto": 'TRUE OR NULL AS "foo"',
|
||||
"hive": "TRUE OR NULL AS `foo`",
|
||||
|
@ -1020,6 +1064,7 @@ class TestDialect(Validator):
|
|||
"LEVENSHTEIN(col1, col2)",
|
||||
write={
|
||||
"duckdb": "LEVENSHTEIN(col1, col2)",
|
||||
"drill": "LEVENSHTEIN_DISTANCE(col1, col2)",
|
||||
"presto": "LEVENSHTEIN_DISTANCE(col1, col2)",
|
||||
"hive": "LEVENSHTEIN(col1, col2)",
|
||||
"spark": "LEVENSHTEIN(col1, col2)",
|
||||
|
@ -1029,6 +1074,7 @@ class TestDialect(Validator):
|
|||
"LEVENSHTEIN(coalesce(col1, col2), coalesce(col2, col1))",
|
||||
write={
|
||||
"duckdb": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
|
||||
"drill": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
|
||||
"presto": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
|
||||
"hive": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
|
||||
"spark": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
|
||||
|
@ -1152,6 +1198,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"SELECT a AS b FROM x GROUP BY b",
|
||||
write={
|
||||
"drill": "SELECT a AS b FROM x GROUP BY b",
|
||||
"duckdb": "SELECT a AS b FROM x GROUP BY b",
|
||||
"presto": "SELECT a AS b FROM x GROUP BY 1",
|
||||
"hive": "SELECT a AS b FROM x GROUP BY 1",
|
||||
|
@ -1162,6 +1209,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"SELECT y x FROM my_table t",
|
||||
write={
|
||||
"drill": "SELECT y AS x FROM my_table AS t",
|
||||
"hive": "SELECT y AS x FROM my_table AS t",
|
||||
"oracle": "SELECT y AS x FROM my_table t",
|
||||
"postgres": "SELECT y AS x FROM my_table AS t",
|
||||
|
@ -1230,3 +1278,36 @@ SELECT
|
|||
},
|
||||
pretty=True,
|
||||
)
|
||||
|
||||
def test_transactions(self):
|
||||
self.validate_all(
|
||||
"BEGIN TRANSACTION",
|
||||
write={
|
||||
"bigquery": "BEGIN TRANSACTION",
|
||||
"mysql": "BEGIN",
|
||||
"postgres": "BEGIN",
|
||||
"presto": "START TRANSACTION",
|
||||
"trino": "START TRANSACTION",
|
||||
"redshift": "BEGIN",
|
||||
"snowflake": "BEGIN",
|
||||
"sqlite": "BEGIN TRANSACTION",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"BEGIN",
|
||||
read={
|
||||
"presto": "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE",
|
||||
"trino": "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"BEGIN",
|
||||
read={
|
||||
"presto": "START TRANSACTION ISOLATION LEVEL REPEATABLE READ",
|
||||
"trino": "START TRANSACTION ISOLATION LEVEL REPEATABLE READ",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"BEGIN IMMEDIATE TRANSACTION",
|
||||
write={"sqlite": "BEGIN IMMEDIATE TRANSACTION"},
|
||||
)
|
||||
|
|
53
tests/dialects/test_drill.py
Normal file
53
tests/dialects/test_drill.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
from tests.dialects.test_dialect import Validator
|
||||
|
||||
|
||||
class TestDrill(Validator):
|
||||
dialect = "drill"
|
||||
|
||||
def test_string_literals(self):
|
||||
self.validate_all(
|
||||
"SELECT '2021-01-01' + INTERVAL 1 MONTH",
|
||||
write={
|
||||
"mysql": "SELECT '2021-01-01' + INTERVAL 1 MONTH",
|
||||
},
|
||||
)
|
||||
|
||||
def test_quotes(self):
|
||||
self.validate_all(
|
||||
"'\\''",
|
||||
write={
|
||||
"duckdb": "''''",
|
||||
"presto": "''''",
|
||||
"hive": "'\\''",
|
||||
"spark": "'\\''",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"'\"x\"'",
|
||||
write={
|
||||
"duckdb": "'\"x\"'",
|
||||
"presto": "'\"x\"'",
|
||||
"hive": "'\"x\"'",
|
||||
"spark": "'\"x\"'",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"'\\\\a'",
|
||||
read={
|
||||
"presto": "'\\a'",
|
||||
},
|
||||
write={
|
||||
"duckdb": "'\\a'",
|
||||
"presto": "'\\a'",
|
||||
"hive": "'\\\\a'",
|
||||
"spark": "'\\\\a'",
|
||||
},
|
||||
)
|
||||
|
||||
def test_table_function(self):
|
||||
self.validate_all(
|
||||
"SELECT * FROM table( dfs.`test_data.xlsx` (type => 'excel', sheetName => 'secondSheet'))",
|
||||
write={
|
||||
"drill": "SELECT * FROM table(dfs.`test_data.xlsx`(type => 'excel', sheetName => 'secondSheet'))",
|
||||
},
|
||||
)
|
|
@ -58,6 +58,16 @@ class TestMySQL(Validator):
|
|||
self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'")
|
||||
self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci")
|
||||
self.validate_identity("SET autocommit = ON")
|
||||
self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL SERIALIZABLE")
|
||||
self.validate_identity("SET TRANSACTION READ ONLY")
|
||||
self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE")
|
||||
self.validate_identity("SELECT SCHEMA()")
|
||||
|
||||
def test_canonical_functions(self):
|
||||
self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)")
|
||||
self.validate_identity("SELECT INSTR('str', 'substr')", "SELECT LOCATE('substr', 'str')")
|
||||
self.validate_identity("SELECT UCASE('foo')", "SELECT UPPER('foo')")
|
||||
self.validate_identity("SELECT LCASE('foo')", "SELECT LOWER('foo')")
|
||||
|
||||
def test_escape(self):
|
||||
self.validate_all(
|
||||
|
|
|
@ -177,6 +177,15 @@ class TestPresto(Validator):
|
|||
"spark": "CREATE TABLE test USING PARQUET AS SELECT 1",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CREATE TABLE test STORED = 'PARQUET' AS SELECT 1",
|
||||
write={
|
||||
"duckdb": "CREATE TABLE test AS SELECT 1",
|
||||
"presto": "CREATE TABLE test WITH (FORMAT='PARQUET') AS SELECT 1",
|
||||
"hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
|
||||
"spark": "CREATE TABLE test USING PARQUET AS SELECT 1",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1",
|
||||
write={
|
||||
|
@ -427,3 +436,69 @@ class TestPresto(Validator):
|
|||
"spark": UnsupportedError,
|
||||
},
|
||||
)
|
||||
self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE")
|
||||
self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ")
|
||||
|
||||
def test_encode_decode(self):
|
||||
self.validate_all(
|
||||
"TO_UTF8(x)",
|
||||
write={
|
||||
"spark": "ENCODE(x, 'utf-8')",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"FROM_UTF8(x)",
|
||||
write={
|
||||
"spark": "DECODE(x, 'utf-8')",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"ENCODE(x, 'utf-8')",
|
||||
write={
|
||||
"presto": "TO_UTF8(x)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"DECODE(x, 'utf-8')",
|
||||
write={
|
||||
"presto": "FROM_UTF8(x)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"ENCODE(x, 'invalid')",
|
||||
write={
|
||||
"presto": UnsupportedError,
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"DECODE(x, 'invalid')",
|
||||
write={
|
||||
"presto": UnsupportedError,
|
||||
},
|
||||
)
|
||||
|
||||
def test_hex_unhex(self):
|
||||
self.validate_all(
|
||||
"TO_HEX(x)",
|
||||
write={
|
||||
"spark": "HEX(x)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"FROM_HEX(x)",
|
||||
write={
|
||||
"spark": "UNHEX(x)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"HEX(x)",
|
||||
write={
|
||||
"presto": "TO_HEX(x)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"UNHEX(x)",
|
||||
write={
|
||||
"presto": "FROM_HEX(x)",
|
||||
},
|
||||
)
|
||||
|
|
|
@ -169,6 +169,17 @@ class TestSnowflake(Validator):
|
|||
"snowflake": "SELECT a FROM test AS unpivot",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"trim(date_column, 'UTC')",
|
||||
write={
|
||||
"snowflake": "TRIM(date_column, 'UTC')",
|
||||
"postgres": "TRIM('UTC' FROM date_column)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"trim(date_column)",
|
||||
write={"snowflake": "TRIM(date_column)"},
|
||||
)
|
||||
|
||||
def test_null_treatment(self):
|
||||
self.validate_all(
|
||||
|
|
21
tests/fixtures/identity.sql
vendored
21
tests/fixtures/identity.sql
vendored
|
@ -122,13 +122,6 @@ x AT TIME ZONE 'UTC'
|
|||
CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo'
|
||||
SET x = 1
|
||||
SET -v
|
||||
ADD JAR s3://bucket
|
||||
ADD JARS s3://bucket, c
|
||||
ADD FILE s3://file
|
||||
ADD FILES s3://file, s3://a
|
||||
ADD ARCHIVE s3://file
|
||||
ADD ARCHIVES s3://file, s3://a
|
||||
BEGIN IMMEDIATE TRANSACTION
|
||||
COMMIT
|
||||
USE db
|
||||
NOT 1
|
||||
|
@ -278,6 +271,7 @@ SELECT CEIL(a, b) FROM test
|
|||
SELECT COUNT(a) FROM test
|
||||
SELECT COUNT(1) FROM test
|
||||
SELECT COUNT(*) FROM test
|
||||
SELECT COUNT() FROM test
|
||||
SELECT COUNT(DISTINCT a) FROM test
|
||||
SELECT EXP(a) FROM test
|
||||
SELECT FLOOR(a) FROM test
|
||||
|
@ -372,6 +366,8 @@ WITH a AS (SELECT 1) SELECT 1 UNION SELECT 2
|
|||
WITH a AS (SELECT 1) SELECT 1 INTERSECT SELECT 2
|
||||
WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2
|
||||
WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2
|
||||
WITH sub_query AS (SELECT a FROM table) (SELECT a FROM sub_query)
|
||||
WITH sub_query AS (SELECT a FROM table) ((((SELECT a FROM sub_query))))
|
||||
(SELECT 1) UNION (SELECT 2)
|
||||
(SELECT 1) UNION SELECT 2
|
||||
SELECT 1 UNION (SELECT 2)
|
||||
|
@ -463,6 +459,7 @@ CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECI
|
|||
CREATE TABLE z (a INT(11) DEFAULT UUID())
|
||||
CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id')
|
||||
CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1)
|
||||
CREATE TABLE z (a INT(11) NOT NULL DEFAULT -1)
|
||||
CREATE TABLE z (a INT(11) NOT NULL COLLATE utf8_bin AUTO_INCREMENT)
|
||||
CREATE TABLE z (a INT, PRIMARY KEY(a))
|
||||
CREATE TABLE z WITH (FORMAT='parquet') AS SELECT 1
|
||||
|
@ -476,6 +473,9 @@ CREATE TABLE z AS ((WITH cte AS (SELECT 1) SELECT * FROM cte))
|
|||
CREATE TABLE z (a INT UNIQUE)
|
||||
CREATE TABLE z (a INT AUTO_INCREMENT)
|
||||
CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT)
|
||||
CREATE TABLE z (a INT REFERENCES parent(b, c))
|
||||
CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id))
|
||||
CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c))
|
||||
CREATE TEMPORARY FUNCTION f
|
||||
CREATE TEMPORARY FUNCTION f AS 'g'
|
||||
CREATE FUNCTION f
|
||||
|
@ -514,17 +514,23 @@ DELETE FROM x WHERE y > 1
|
|||
DELETE FROM y
|
||||
DELETE FROM event USING sales WHERE event.eventid = sales.eventid
|
||||
DELETE FROM event USING sales, USING bla WHERE event.eventid = sales.eventid
|
||||
DELETE FROM event USING sales AS s WHERE event.eventid = s.eventid
|
||||
PREPARE statement
|
||||
EXECUTE statement
|
||||
DROP TABLE a
|
||||
DROP TABLE a.b
|
||||
DROP TABLE IF EXISTS a
|
||||
DROP TABLE IF EXISTS a.b
|
||||
DROP TABLE a CASCADE
|
||||
DROP VIEW a
|
||||
DROP VIEW a.b
|
||||
DROP VIEW IF EXISTS a
|
||||
DROP VIEW IF EXISTS a.b
|
||||
SHOW TABLES
|
||||
USE db
|
||||
BEGIN
|
||||
ROLLBACK
|
||||
ROLLBACK TO b
|
||||
EXPLAIN SELECT * FROM x
|
||||
INSERT INTO x SELECT * FROM y
|
||||
INSERT INTO x (SELECT * FROM y)
|
||||
|
@ -581,3 +587,4 @@ SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */
|
|||
SELECT x FROM a.b.c /* x */, e.f.g /* x */
|
||||
SELECT FOO(x /* c */) /* FOO */, b /* b */
|
||||
SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */
|
||||
SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b'
|
||||
|
|
5
tests/fixtures/optimizer/canonicalize.sql
vendored
Normal file
5
tests/fixtures/optimizer/canonicalize.sql
vendored
Normal file
|
@ -0,0 +1,5 @@
|
|||
SELECT w.d + w.e AS c FROM w AS w;
|
||||
SELECT CONCAT(w.d, w.e) AS c FROM w AS w;
|
||||
|
||||
SELECT CAST(w.d AS DATE) > w.e AS a FROM w AS w;
|
||||
SELECT CAST(w.d AS DATE) > CAST(w.e AS DATE) AS a FROM w AS w;
|
4
tests/fixtures/optimizer/optimizer.sql
vendored
4
tests/fixtures/optimizer/optimizer.sql
vendored
|
@ -119,7 +119,7 @@ GROUP BY
|
|||
LIMIT 1;
|
||||
|
||||
# title: Root subquery is union
|
||||
(SELECT b FROM x UNION SELECT b FROM y) LIMIT 1;
|
||||
(SELECT b FROM x UNION SELECT b FROM y ORDER BY b) LIMIT 1;
|
||||
(
|
||||
SELECT
|
||||
"x"."b" AS "b"
|
||||
|
@ -128,6 +128,8 @@ LIMIT 1;
|
|||
SELECT
|
||||
"y"."b" AS "b"
|
||||
FROM "y" AS "y"
|
||||
ORDER BY
|
||||
"b"
|
||||
)
|
||||
LIMIT 1;
|
||||
|
||||
|
|
50
tests/fixtures/optimizer/tpc-h/tpc-h.sql
vendored
50
tests/fixtures/optimizer/tpc-h/tpc-h.sql
vendored
|
@ -15,7 +15,7 @@ select
|
|||
from
|
||||
lineitem
|
||||
where
|
||||
CAST(l_shipdate AS DATE) <= date '1998-12-01' - interval '90' day
|
||||
l_shipdate <= date '1998-12-01' - interval '90' day
|
||||
group by
|
||||
l_returnflag,
|
||||
l_linestatus
|
||||
|
@ -250,8 +250,8 @@ FROM "orders" AS "orders"
|
|||
LEFT JOIN "_u_0" AS "_u_0"
|
||||
ON "_u_0"."l_orderkey" = "orders"."o_orderkey"
|
||||
WHERE
|
||||
"orders"."o_orderdate" < CAST('1993-10-01' AS DATE)
|
||||
AND "orders"."o_orderdate" >= CAST('1993-07-01' AS DATE)
|
||||
CAST("orders"."o_orderdate" AS DATE) < CAST('1993-10-01' AS DATE)
|
||||
AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1993-07-01' AS DATE)
|
||||
AND NOT "_u_0"."l_orderkey" IS NULL
|
||||
GROUP BY
|
||||
"orders"."o_orderpriority"
|
||||
|
@ -293,8 +293,8 @@ SELECT
|
|||
FROM "customer" AS "customer"
|
||||
JOIN "orders" AS "orders"
|
||||
ON "customer"."c_custkey" = "orders"."o_custkey"
|
||||
AND "orders"."o_orderdate" < CAST('1995-01-01' AS DATE)
|
||||
AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE)
|
||||
AND CAST("orders"."o_orderdate" AS DATE) < CAST('1995-01-01' AS DATE)
|
||||
AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1994-01-01' AS DATE)
|
||||
JOIN "region" AS "region"
|
||||
ON "region"."r_name" = 'ASIA'
|
||||
JOIN "nation" AS "nation"
|
||||
|
@ -328,8 +328,8 @@ FROM "lineitem" AS "lineitem"
|
|||
WHERE
|
||||
"lineitem"."l_discount" BETWEEN 0.05 AND 0.07
|
||||
AND "lineitem"."l_quantity" < 24
|
||||
AND "lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE)
|
||||
AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE);
|
||||
AND CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE)
|
||||
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE);
|
||||
|
||||
--------------------------------------
|
||||
-- TPC-H 7
|
||||
|
@ -384,13 +384,13 @@ WITH "n1" AS (
|
|||
SELECT
|
||||
"n1"."n_name" AS "supp_nation",
|
||||
"n2"."n_name" AS "cust_nation",
|
||||
EXTRACT(year FROM "lineitem"."l_shipdate") AS "l_year",
|
||||
EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME)) AS "l_year",
|
||||
SUM("lineitem"."l_extendedprice" * (
|
||||
1 - "lineitem"."l_discount"
|
||||
)) AS "revenue"
|
||||
FROM "supplier" AS "supplier"
|
||||
JOIN "lineitem" AS "lineitem"
|
||||
ON "lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
|
||||
ON CAST("lineitem"."l_shipdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
|
||||
AND "supplier"."s_suppkey" = "lineitem"."l_suppkey"
|
||||
JOIN "orders" AS "orders"
|
||||
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
|
||||
|
@ -409,7 +409,7 @@ JOIN "n1" AS "n2"
|
|||
GROUP BY
|
||||
"n1"."n_name",
|
||||
"n2"."n_name",
|
||||
EXTRACT(year FROM "lineitem"."l_shipdate")
|
||||
EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME))
|
||||
ORDER BY
|
||||
"supp_nation",
|
||||
"cust_nation",
|
||||
|
@ -456,7 +456,7 @@ group by
|
|||
order by
|
||||
o_year;
|
||||
SELECT
|
||||
EXTRACT(year FROM "orders"."o_orderdate") AS "o_year",
|
||||
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year",
|
||||
SUM(
|
||||
CASE
|
||||
WHEN "nation_2"."n_name" = 'BRAZIL'
|
||||
|
@ -477,7 +477,7 @@ JOIN "customer" AS "customer"
|
|||
ON "customer"."c_nationkey" = "nation"."n_nationkey"
|
||||
JOIN "orders" AS "orders"
|
||||
ON "orders"."o_custkey" = "customer"."c_custkey"
|
||||
AND "orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
|
||||
AND CAST("orders"."o_orderdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
|
||||
JOIN "lineitem" AS "lineitem"
|
||||
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
|
||||
AND "part"."p_partkey" = "lineitem"."l_partkey"
|
||||
|
@ -488,7 +488,7 @@ JOIN "nation" AS "nation_2"
|
|||
WHERE
|
||||
"part"."p_type" = 'ECONOMY ANODIZED STEEL'
|
||||
GROUP BY
|
||||
EXTRACT(year FROM "orders"."o_orderdate")
|
||||
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME))
|
||||
ORDER BY
|
||||
"o_year";
|
||||
|
||||
|
@ -529,7 +529,7 @@ order by
|
|||
o_year desc;
|
||||
SELECT
|
||||
"nation"."n_name" AS "nation",
|
||||
EXTRACT(year FROM "orders"."o_orderdate") AS "o_year",
|
||||
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year",
|
||||
SUM(
|
||||
"lineitem"."l_extendedprice" * (
|
||||
1 - "lineitem"."l_discount"
|
||||
|
@ -551,7 +551,7 @@ WHERE
|
|||
"part"."p_name" LIKE '%green%'
|
||||
GROUP BY
|
||||
"nation"."n_name",
|
||||
EXTRACT(year FROM "orders"."o_orderdate")
|
||||
EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME))
|
||||
ORDER BY
|
||||
"nation",
|
||||
"o_year" DESC;
|
||||
|
@ -606,8 +606,8 @@ SELECT
|
|||
FROM "customer" AS "customer"
|
||||
JOIN "orders" AS "orders"
|
||||
ON "customer"."c_custkey" = "orders"."o_custkey"
|
||||
AND "orders"."o_orderdate" < CAST('1994-01-01' AS DATE)
|
||||
AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE)
|
||||
AND CAST("orders"."o_orderdate" AS DATE) < CAST('1994-01-01' AS DATE)
|
||||
AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1993-10-01' AS DATE)
|
||||
JOIN "lineitem" AS "lineitem"
|
||||
ON "lineitem"."l_orderkey" = "orders"."o_orderkey" AND "lineitem"."l_returnflag" = 'R'
|
||||
JOIN "nation" AS "nation"
|
||||
|
@ -740,8 +740,8 @@ SELECT
|
|||
FROM "orders" AS "orders"
|
||||
JOIN "lineitem" AS "lineitem"
|
||||
ON "lineitem"."l_commitdate" < "lineitem"."l_receiptdate"
|
||||
AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE)
|
||||
AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE)
|
||||
AND CAST("lineitem"."l_receiptdate" AS DATE) < CAST('1995-01-01' AS DATE)
|
||||
AND CAST("lineitem"."l_receiptdate" AS DATE) >= CAST('1994-01-01' AS DATE)
|
||||
AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate"
|
||||
AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP')
|
||||
AND "orders"."o_orderkey" = "lineitem"."l_orderkey"
|
||||
|
@ -832,8 +832,8 @@ FROM "lineitem" AS "lineitem"
|
|||
JOIN "part" AS "part"
|
||||
ON "lineitem"."l_partkey" = "part"."p_partkey"
|
||||
WHERE
|
||||
"lineitem"."l_shipdate" < CAST('1995-10-01' AS DATE)
|
||||
AND "lineitem"."l_shipdate" >= CAST('1995-09-01' AS DATE);
|
||||
CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-10-01' AS DATE)
|
||||
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1995-09-01' AS DATE);
|
||||
|
||||
--------------------------------------
|
||||
-- TPC-H 15
|
||||
|
@ -876,8 +876,8 @@ WITH "revenue" AS (
|
|||
)) AS "total_revenue"
|
||||
FROM "lineitem" AS "lineitem"
|
||||
WHERE
|
||||
"lineitem"."l_shipdate" < CAST('1996-04-01' AS DATE)
|
||||
AND "lineitem"."l_shipdate" >= CAST('1996-01-01' AS DATE)
|
||||
CAST("lineitem"."l_shipdate" AS DATE) < CAST('1996-04-01' AS DATE)
|
||||
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1996-01-01' AS DATE)
|
||||
GROUP BY
|
||||
"lineitem"."l_suppkey"
|
||||
)
|
||||
|
@ -1220,8 +1220,8 @@ WITH "_u_0" AS (
|
|||
"lineitem"."l_suppkey" AS "_u_2"
|
||||
FROM "lineitem" AS "lineitem"
|
||||
WHERE
|
||||
"lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE)
|
||||
AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE)
|
||||
CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE)
|
||||
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE)
|
||||
GROUP BY
|
||||
"lineitem"."l_partkey",
|
||||
"lineitem"."l_suppkey"
|
||||
|
|
7
tests/fixtures/pretty.sql
vendored
7
tests/fixtures/pretty.sql
vendored
|
@ -315,3 +315,10 @@ FROM (
|
|||
WHERE
|
||||
id = 1
|
||||
) /* x */;
|
||||
SELECT * /* multi
|
||||
line
|
||||
comment */;
|
||||
SELECT
|
||||
* /* multi
|
||||
line
|
||||
comment */;
|
||||
|
|
|
@ -57,79 +57,79 @@ SKIP_INTEGRATION = string_to_bool(os.environ.get("SKIP_INTEGRATION", "0").lower(
|
|||
|
||||
TPCH_SCHEMA = {
|
||||
"lineitem": {
|
||||
"l_orderkey": "uint64",
|
||||
"l_partkey": "uint64",
|
||||
"l_suppkey": "uint64",
|
||||
"l_linenumber": "uint64",
|
||||
"l_quantity": "float64",
|
||||
"l_extendedprice": "float64",
|
||||
"l_discount": "float64",
|
||||
"l_tax": "float64",
|
||||
"l_orderkey": "bigint",
|
||||
"l_partkey": "bigint",
|
||||
"l_suppkey": "bigint",
|
||||
"l_linenumber": "bigint",
|
||||
"l_quantity": "double",
|
||||
"l_extendedprice": "double",
|
||||
"l_discount": "double",
|
||||
"l_tax": "double",
|
||||
"l_returnflag": "string",
|
||||
"l_linestatus": "string",
|
||||
"l_shipdate": "date32",
|
||||
"l_commitdate": "date32",
|
||||
"l_receiptdate": "date32",
|
||||
"l_shipdate": "string",
|
||||
"l_commitdate": "string",
|
||||
"l_receiptdate": "string",
|
||||
"l_shipinstruct": "string",
|
||||
"l_shipmode": "string",
|
||||
"l_comment": "string",
|
||||
},
|
||||
"orders": {
|
||||
"o_orderkey": "uint64",
|
||||
"o_custkey": "uint64",
|
||||
"o_orderkey": "bigint",
|
||||
"o_custkey": "bigint",
|
||||
"o_orderstatus": "string",
|
||||
"o_totalprice": "float64",
|
||||
"o_orderdate": "date32",
|
||||
"o_totalprice": "double",
|
||||
"o_orderdate": "string",
|
||||
"o_orderpriority": "string",
|
||||
"o_clerk": "string",
|
||||
"o_shippriority": "int32",
|
||||
"o_shippriority": "int",
|
||||
"o_comment": "string",
|
||||
},
|
||||
"customer": {
|
||||
"c_custkey": "uint64",
|
||||
"c_custkey": "bigint",
|
||||
"c_name": "string",
|
||||
"c_address": "string",
|
||||
"c_nationkey": "uint64",
|
||||
"c_nationkey": "bigint",
|
||||
"c_phone": "string",
|
||||
"c_acctbal": "float64",
|
||||
"c_acctbal": "double",
|
||||
"c_mktsegment": "string",
|
||||
"c_comment": "string",
|
||||
},
|
||||
"part": {
|
||||
"p_partkey": "uint64",
|
||||
"p_partkey": "bigint",
|
||||
"p_name": "string",
|
||||
"p_mfgr": "string",
|
||||
"p_brand": "string",
|
||||
"p_type": "string",
|
||||
"p_size": "int32",
|
||||
"p_size": "int",
|
||||
"p_container": "string",
|
||||
"p_retailprice": "float64",
|
||||
"p_retailprice": "double",
|
||||
"p_comment": "string",
|
||||
},
|
||||
"supplier": {
|
||||
"s_suppkey": "uint64",
|
||||
"s_suppkey": "bigint",
|
||||
"s_name": "string",
|
||||
"s_address": "string",
|
||||
"s_nationkey": "uint64",
|
||||
"s_nationkey": "bigint",
|
||||
"s_phone": "string",
|
||||
"s_acctbal": "float64",
|
||||
"s_acctbal": "double",
|
||||
"s_comment": "string",
|
||||
},
|
||||
"partsupp": {
|
||||
"ps_partkey": "uint64",
|
||||
"ps_suppkey": "uint64",
|
||||
"ps_availqty": "int32",
|
||||
"ps_supplycost": "float64",
|
||||
"ps_partkey": "bigint",
|
||||
"ps_suppkey": "bigint",
|
||||
"ps_availqty": "int",
|
||||
"ps_supplycost": "double",
|
||||
"ps_comment": "string",
|
||||
},
|
||||
"nation": {
|
||||
"n_nationkey": "uint64",
|
||||
"n_nationkey": "bigint",
|
||||
"n_name": "string",
|
||||
"n_regionkey": "uint64",
|
||||
"n_regionkey": "bigint",
|
||||
"n_comment": "string",
|
||||
},
|
||||
"region": {
|
||||
"r_regionkey": "uint64",
|
||||
"r_regionkey": "bigint",
|
||||
"r_name": "string",
|
||||
"r_comment": "string",
|
||||
},
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
import unittest
|
||||
from datetime import date
|
||||
|
||||
import duckdb
|
||||
import pandas as pd
|
||||
from pandas.testing import assert_frame_equal
|
||||
|
||||
from sqlglot import exp, parse_one
|
||||
from sqlglot.errors import ExecuteError
|
||||
from sqlglot.executor import execute
|
||||
from sqlglot.executor.python import Python
|
||||
from sqlglot.executor.table import Table, ensure_tables
|
||||
from tests.helpers import (
|
||||
FIXTURES_DIR,
|
||||
SKIP_INTEGRATION,
|
||||
|
@ -67,13 +70,399 @@ class TestExecutor(unittest.TestCase):
|
|||
def to_csv(expression):
|
||||
if isinstance(expression, exp.Table):
|
||||
return parse_one(
|
||||
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}"
|
||||
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}"
|
||||
)
|
||||
return expression
|
||||
|
||||
for sql, _ in self.sqls[0:3]:
|
||||
a = self.cached_execute(sql)
|
||||
sql = parse_one(sql).transform(to_csv).sql(pretty=True)
|
||||
table = execute(sql, TPCH_SCHEMA)
|
||||
b = pd.DataFrame(table.rows, columns=table.columns)
|
||||
assert_frame_equal(a, b, check_dtype=False)
|
||||
for i, (sql, _) in enumerate(self.sqls[0:7]):
|
||||
with self.subTest(f"tpch-h {i + 1}"):
|
||||
a = self.cached_execute(sql)
|
||||
sql = parse_one(sql).transform(to_csv).sql(pretty=True)
|
||||
table = execute(sql, TPCH_SCHEMA)
|
||||
b = pd.DataFrame(table.rows, columns=table.columns)
|
||||
assert_frame_equal(a, b, check_dtype=False)
|
||||
|
||||
def test_execute_callable(self):
|
||||
tables = {
|
||||
"x": [
|
||||
{"a": "a", "b": "d"},
|
||||
{"a": "b", "b": "e"},
|
||||
{"a": "c", "b": "f"},
|
||||
],
|
||||
"y": [
|
||||
{"b": "d", "c": "g"},
|
||||
{"b": "e", "c": "h"},
|
||||
{"b": "f", "c": "i"},
|
||||
],
|
||||
"z": [],
|
||||
}
|
||||
schema = {
|
||||
"x": {
|
||||
"a": "VARCHAR",
|
||||
"b": "VARCHAR",
|
||||
},
|
||||
"y": {
|
||||
"b": "VARCHAR",
|
||||
"c": "VARCHAR",
|
||||
},
|
||||
"z": {"d": "VARCHAR"},
|
||||
}
|
||||
|
||||
for sql, cols, rows in [
|
||||
("SELECT * FROM x", ["a", "b"], [("a", "d"), ("b", "e"), ("c", "f")]),
|
||||
(
|
||||
"SELECT * FROM x JOIN y ON x.b = y.b",
|
||||
["a", "b", "b", "c"],
|
||||
[("a", "d", "d", "g"), ("b", "e", "e", "h"), ("c", "f", "f", "i")],
|
||||
),
|
||||
(
|
||||
"SELECT j.c AS d FROM x AS i JOIN y AS j ON i.b = j.b",
|
||||
["d"],
|
||||
[("g",), ("h",), ("i",)],
|
||||
),
|
||||
(
|
||||
"SELECT CONCAT(x.a, y.c) FROM x JOIN y ON x.b = y.b WHERE y.b = 'e'",
|
||||
["_col_0"],
|
||||
[("bh",)],
|
||||
),
|
||||
(
|
||||
"SELECT * FROM x JOIN y ON x.b = y.b WHERE y.b = 'e'",
|
||||
["a", "b", "b", "c"],
|
||||
[("b", "e", "e", "h")],
|
||||
),
|
||||
(
|
||||
"SELECT * FROM z",
|
||||
["d"],
|
||||
[],
|
||||
),
|
||||
(
|
||||
"SELECT d FROM z ORDER BY d",
|
||||
["d"],
|
||||
[],
|
||||
),
|
||||
(
|
||||
"SELECT a FROM x WHERE x.a <> 'b'",
|
||||
["a"],
|
||||
[("a",), ("c",)],
|
||||
),
|
||||
(
|
||||
"SELECT a AS i FROM x ORDER BY a",
|
||||
["i"],
|
||||
[("a",), ("b",), ("c",)],
|
||||
),
|
||||
(
|
||||
"SELECT a AS i FROM x ORDER BY i",
|
||||
["i"],
|
||||
[("a",), ("b",), ("c",)],
|
||||
),
|
||||
(
|
||||
"SELECT 100 - ORD(a) AS a, a AS i FROM x ORDER BY a",
|
||||
["a", "i"],
|
||||
[(1, "c"), (2, "b"), (3, "a")],
|
||||
),
|
||||
(
|
||||
"SELECT a /* test */ FROM x LIMIT 1",
|
||||
["a"],
|
||||
[("a",)],
|
||||
),
|
||||
]:
|
||||
with self.subTest(sql):
|
||||
result = execute(sql, schema=schema, tables=tables)
|
||||
self.assertEqual(result.columns, tuple(cols))
|
||||
self.assertEqual(result.rows, rows)
|
||||
|
||||
def test_set_operations(self):
|
||||
tables = {
|
||||
"x": [
|
||||
{"a": "a"},
|
||||
{"a": "b"},
|
||||
{"a": "c"},
|
||||
],
|
||||
"y": [
|
||||
{"a": "b"},
|
||||
{"a": "c"},
|
||||
{"a": "d"},
|
||||
],
|
||||
}
|
||||
schema = {
|
||||
"x": {
|
||||
"a": "VARCHAR",
|
||||
},
|
||||
"y": {
|
||||
"a": "VARCHAR",
|
||||
},
|
||||
}
|
||||
|
||||
for sql, cols, rows in [
|
||||
(
|
||||
"SELECT a FROM x UNION ALL SELECT a FROM y",
|
||||
["a"],
|
||||
[("a",), ("b",), ("c",), ("b",), ("c",), ("d",)],
|
||||
),
|
||||
(
|
||||
"SELECT a FROM x UNION SELECT a FROM y",
|
||||
["a"],
|
||||
[("a",), ("b",), ("c",), ("d",)],
|
||||
),
|
||||
(
|
||||
"SELECT a FROM x EXCEPT SELECT a FROM y",
|
||||
["a"],
|
||||
[("a",)],
|
||||
),
|
||||
(
|
||||
"SELECT a FROM x INTERSECT SELECT a FROM y",
|
||||
["a"],
|
||||
[("b",), ("c",)],
|
||||
),
|
||||
(
|
||||
"""SELECT i.a
|
||||
FROM (
|
||||
SELECT a FROM x UNION SELECT a FROM y
|
||||
) AS i
|
||||
JOIN (
|
||||
SELECT a FROM x UNION SELECT a FROM y
|
||||
) AS j
|
||||
ON i.a = j.a""",
|
||||
["a"],
|
||||
[("a",), ("b",), ("c",), ("d",)],
|
||||
),
|
||||
(
|
||||
"SELECT 1 AS a UNION SELECT 2 AS a UNION SELECT 3 AS a",
|
||||
["a"],
|
||||
[(1,), (2,), (3,)],
|
||||
),
|
||||
]:
|
||||
with self.subTest(sql):
|
||||
result = execute(sql, schema=schema, tables=tables)
|
||||
self.assertEqual(result.columns, tuple(cols))
|
||||
self.assertEqual(set(result.rows), set(rows))
|
||||
|
||||
def test_execute_catalog_db_table(self):
|
||||
tables = {
|
||||
"catalog": {
|
||||
"db": {
|
||||
"x": [
|
||||
{"a": "a"},
|
||||
{"a": "b"},
|
||||
{"a": "c"},
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
schema = {
|
||||
"catalog": {
|
||||
"db": {
|
||||
"x": {
|
||||
"a": "VARCHAR",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result1 = execute("SELECT * FROM x", schema=schema, tables=tables)
|
||||
result2 = execute("SELECT * FROM catalog.db.x", schema=schema, tables=tables)
|
||||
assert result1.columns == result2.columns
|
||||
assert result1.rows == result2.rows
|
||||
|
||||
def test_execute_tables(self):
|
||||
tables = {
|
||||
"sushi": [
|
||||
{"id": 1, "price": 1.0},
|
||||
{"id": 2, "price": 2.0},
|
||||
{"id": 3, "price": 3.0},
|
||||
],
|
||||
"order_items": [
|
||||
{"sushi_id": 1, "order_id": 1},
|
||||
{"sushi_id": 1, "order_id": 1},
|
||||
{"sushi_id": 2, "order_id": 1},
|
||||
{"sushi_id": 3, "order_id": 2},
|
||||
],
|
||||
"orders": [
|
||||
{"id": 1, "user_id": 1},
|
||||
{"id": 2, "user_id": 2},
|
||||
],
|
||||
}
|
||||
|
||||
self.assertEqual(
|
||||
execute(
|
||||
"""
|
||||
SELECT
|
||||
o.user_id,
|
||||
SUM(s.price) AS price
|
||||
FROM orders o
|
||||
JOIN order_items i
|
||||
ON o.id = i.order_id
|
||||
JOIN sushi s
|
||||
ON i.sushi_id = s.id
|
||||
GROUP BY o.user_id
|
||||
""",
|
||||
tables=tables,
|
||||
).rows,
|
||||
[
|
||||
(1, 4.0),
|
||||
(2, 3.0),
|
||||
],
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
execute(
|
||||
"""
|
||||
SELECT
|
||||
o.id, x.*
|
||||
FROM orders o
|
||||
LEFT JOIN (
|
||||
SELECT
|
||||
1 AS id, 'b' AS x
|
||||
UNION ALL
|
||||
SELECT
|
||||
3 AS id, 'c' AS x
|
||||
) x
|
||||
ON o.id = x.id
|
||||
""",
|
||||
tables=tables,
|
||||
).rows,
|
||||
[(1, 1, "b"), (2, None, None)],
|
||||
)
|
||||
self.assertEqual(
|
||||
execute(
|
||||
"""
|
||||
SELECT
|
||||
o.id, x.*
|
||||
FROM orders o
|
||||
RIGHT JOIN (
|
||||
SELECT
|
||||
1 AS id,
|
||||
'b' AS x
|
||||
UNION ALL
|
||||
SELECT
|
||||
3 AS id, 'c' AS x
|
||||
) x
|
||||
ON o.id = x.id
|
||||
""",
|
||||
tables=tables,
|
||||
).rows,
|
||||
[
|
||||
(1, 1, "b"),
|
||||
(None, 3, "c"),
|
||||
],
|
||||
)
|
||||
|
||||
def test_table_depth_mismatch(self):
|
||||
tables = {"table": []}
|
||||
schema = {"db": {"table": {"col": "VARCHAR"}}}
|
||||
with self.assertRaises(ExecuteError):
|
||||
execute("SELECT * FROM table", schema=schema, tables=tables)
|
||||
|
||||
def test_tables(self):
|
||||
tables = ensure_tables(
|
||||
{
|
||||
"catalog1": {
|
||||
"db1": {
|
||||
"t1": [
|
||||
{"a": 1},
|
||||
],
|
||||
"t2": [
|
||||
{"a": 1},
|
||||
],
|
||||
},
|
||||
"db2": {
|
||||
"t3": [
|
||||
{"a": 1},
|
||||
],
|
||||
"t4": [
|
||||
{"a": 1},
|
||||
],
|
||||
},
|
||||
},
|
||||
"catalog2": {
|
||||
"db3": {
|
||||
"t5": Table(columns=("a",), rows=[(1,)]),
|
||||
"t6": Table(columns=("a",), rows=[(1,)]),
|
||||
},
|
||||
"db4": {
|
||||
"t7": Table(columns=("a",), rows=[(1,)]),
|
||||
"t8": Table(columns=("a",), rows=[(1,)]),
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
t1 = tables.find(exp.table_(table="t1", db="db1", catalog="catalog1"))
|
||||
self.assertEqual(t1.columns, ("a",))
|
||||
self.assertEqual(t1.rows, [(1,)])
|
||||
|
||||
t8 = tables.find(exp.table_(table="t8"))
|
||||
self.assertEqual(t1.columns, t8.columns)
|
||||
self.assertEqual(t1.rows, t8.rows)
|
||||
|
||||
def test_static_queries(self):
|
||||
for sql, cols, rows in [
|
||||
("SELECT 1", ["_col_0"], [(1,)]),
|
||||
("SELECT 1 + 2 AS x", ["x"], [(3,)]),
|
||||
("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
|
||||
("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
|
||||
("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]),
|
||||
]:
|
||||
result = execute(sql)
|
||||
self.assertEqual(result.columns, tuple(cols))
|
||||
self.assertEqual(result.rows, rows)
|
||||
|
||||
def test_aggregate_without_group_by(self):
|
||||
result = execute("SELECT SUM(x) FROM t", tables={"t": [{"x": 1}, {"x": 2}]})
|
||||
self.assertEqual(result.columns, ("_col_0",))
|
||||
self.assertEqual(result.rows, [(3,)])
|
||||
|
||||
def test_scalar_functions(self):
|
||||
for sql, expected in [
|
||||
("CONCAT('a', 'b')", "ab"),
|
||||
("CONCAT('a', NULL)", None),
|
||||
("CONCAT_WS('_', 'a', 'b')", "a_b"),
|
||||
("STR_POSITION('bar', 'foobarbar')", 4),
|
||||
("STR_POSITION('bar', 'foobarbar', 5)", 7),
|
||||
("STR_POSITION(NULL, 'foobarbar')", None),
|
||||
("STR_POSITION('bar', NULL)", None),
|
||||
("UPPER('foo')", "FOO"),
|
||||
("UPPER(NULL)", None),
|
||||
("LOWER('FOO')", "foo"),
|
||||
("LOWER(NULL)", None),
|
||||
("IFNULL('a', 'b')", "a"),
|
||||
("IFNULL(NULL, 'b')", "b"),
|
||||
("IFNULL(NULL, NULL)", None),
|
||||
("SUBSTRING('12345')", "12345"),
|
||||
("SUBSTRING('12345', 3)", "345"),
|
||||
("SUBSTRING('12345', 3, 0)", ""),
|
||||
("SUBSTRING('12345', 3, 1)", "3"),
|
||||
("SUBSTRING('12345', 3, 2)", "34"),
|
||||
("SUBSTRING('12345', 3, 3)", "345"),
|
||||
("SUBSTRING('12345', 3, 4)", "345"),
|
||||
("SUBSTRING('12345', -3)", "345"),
|
||||
("SUBSTRING('12345', -3, 0)", ""),
|
||||
("SUBSTRING('12345', -3, 1)", "3"),
|
||||
("SUBSTRING('12345', -3, 2)", "34"),
|
||||
("SUBSTRING('12345', 0)", ""),
|
||||
("SUBSTRING('12345', 0, 1)", ""),
|
||||
("SUBSTRING(NULL)", None),
|
||||
("SUBSTRING(NULL, 1)", None),
|
||||
("CAST(1 AS TEXT)", "1"),
|
||||
("CAST('1' AS LONG)", 1),
|
||||
("CAST('1.1' AS FLOAT)", 1.1),
|
||||
("COALESCE(NULL)", None),
|
||||
("COALESCE(NULL, NULL)", None),
|
||||
("COALESCE(NULL, 'b')", "b"),
|
||||
("COALESCE('a', 'b')", "a"),
|
||||
("1 << 1", 2),
|
||||
("1 >> 1", 0),
|
||||
("1 & 1", 1),
|
||||
("1 | 1", 1),
|
||||
("1 < 1", False),
|
||||
("1 <= 1", True),
|
||||
("1 > 1", False),
|
||||
("1 >= 1", True),
|
||||
("1 + NULL", None),
|
||||
("IF(true, 1, 0)", 1),
|
||||
("IF(false, 1, 0)", 0),
|
||||
("CASE WHEN 0 = 1 THEN 'foo' ELSE 'bar' END", "bar"),
|
||||
("CAST('2022-01-01' AS DATE) + INTERVAL '1' DAY", date(2022, 1, 2)),
|
||||
]:
|
||||
with self.subTest(sql):
|
||||
result = execute(f"SELECT {sql}")
|
||||
self.assertEqual(result.rows, [(expected,)])
|
||||
|
|
|
@ -441,6 +441,9 @@ class TestExpressions(unittest.TestCase):
|
|||
self.assertIsInstance(parse_one("VARIANCE(a)"), exp.Variance)
|
||||
self.assertIsInstance(parse_one("VARIANCE_POP(a)"), exp.VariancePop)
|
||||
self.assertIsInstance(parse_one("YEAR(a)"), exp.Year)
|
||||
self.assertIsInstance(parse_one("BEGIN DEFERRED TRANSACTION"), exp.Transaction)
|
||||
self.assertIsInstance(parse_one("COMMIT"), exp.Commit)
|
||||
self.assertIsInstance(parse_one("ROLLBACK"), exp.Rollback)
|
||||
|
||||
def test_column(self):
|
||||
dot = parse_one("a.b.c")
|
||||
|
@ -479,9 +482,9 @@ class TestExpressions(unittest.TestCase):
|
|||
self.assertEqual(column.text("expression"), "c")
|
||||
self.assertEqual(column.text("y"), "")
|
||||
self.assertEqual(parse_one("select * from x.y").find(exp.Table).text("db"), "x")
|
||||
self.assertEqual(parse_one("select *").text("this"), "")
|
||||
self.assertEqual(parse_one("1 + 1").text("this"), "1")
|
||||
self.assertEqual(parse_one("'a'").text("this"), "a")
|
||||
self.assertEqual(parse_one("select *").name, "")
|
||||
self.assertEqual(parse_one("1 + 1").name, "1")
|
||||
self.assertEqual(parse_one("'a'").name, "a")
|
||||
|
||||
def test_alias(self):
|
||||
self.assertEqual(alias("foo", "bar").sql(), "foo AS bar")
|
||||
|
@ -538,8 +541,8 @@ class TestExpressions(unittest.TestCase):
|
|||
this=exp.Literal.string("TABLE_FORMAT"),
|
||||
value=exp.to_identifier("test_format"),
|
||||
),
|
||||
exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.NULL),
|
||||
exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.TRUE),
|
||||
exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.null()),
|
||||
exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.true()),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
|
|
@ -29,6 +29,7 @@ class TestOptimizer(unittest.TestCase):
|
|||
CREATE TABLE x (a INT, b INT);
|
||||
CREATE TABLE y (b INT, c INT);
|
||||
CREATE TABLE z (b INT, c INT);
|
||||
CREATE TABLE w (d TEXT, e TEXT);
|
||||
|
||||
INSERT INTO x VALUES (1, 1);
|
||||
INSERT INTO x VALUES (2, 2);
|
||||
|
@ -47,6 +48,8 @@ class TestOptimizer(unittest.TestCase):
|
|||
INSERT INTO y VALUES (4, 4);
|
||||
INSERT INTO y VALUES (5, 5);
|
||||
INSERT INTO y VALUES (null, null);
|
||||
|
||||
INSERT INTO w VALUES ('a', 'b');
|
||||
"""
|
||||
)
|
||||
|
||||
|
@ -64,6 +67,10 @@ class TestOptimizer(unittest.TestCase):
|
|||
"b": "INT",
|
||||
"c": "INT",
|
||||
},
|
||||
"w": {
|
||||
"d": "TEXT",
|
||||
"e": "TEXT",
|
||||
},
|
||||
}
|
||||
|
||||
def check_file(self, file, func, pretty=False, execute=False, **kwargs):
|
||||
|
@ -224,6 +231,18 @@ class TestOptimizer(unittest.TestCase):
|
|||
def test_eliminate_subqueries(self):
|
||||
self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries)
|
||||
|
||||
def test_canonicalize(self):
|
||||
optimize = partial(
|
||||
optimizer.optimize,
|
||||
rules=[
|
||||
optimizer.qualify_tables.qualify_tables,
|
||||
optimizer.qualify_columns.qualify_columns,
|
||||
annotate_types,
|
||||
optimizer.canonicalize.canonicalize,
|
||||
],
|
||||
)
|
||||
self.check_file("canonicalize", optimize, schema=self.schema)
|
||||
|
||||
def test_tpch(self):
|
||||
self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True)
|
||||
|
||||
|
|
|
@ -41,12 +41,41 @@ class TestParser(unittest.TestCase):
|
|||
)
|
||||
|
||||
def test_command(self):
|
||||
expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1")
|
||||
expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1", read="hive")
|
||||
self.assertEqual(len(expressions), 3)
|
||||
self.assertEqual(expressions[0].sql(), "SET x = 1")
|
||||
self.assertEqual(expressions[1].sql(), "ADD JAR s3://a")
|
||||
self.assertEqual(expressions[2].sql(), "SELECT 1")
|
||||
|
||||
def test_transactions(self):
|
||||
expression = parse_one("BEGIN TRANSACTION")
|
||||
self.assertIsNone(expression.this)
|
||||
self.assertEqual(expression.args["modes"], [])
|
||||
self.assertEqual(expression.sql(), "BEGIN")
|
||||
|
||||
expression = parse_one("START TRANSACTION", read="mysql")
|
||||
self.assertIsNone(expression.this)
|
||||
self.assertEqual(expression.args["modes"], [])
|
||||
self.assertEqual(expression.sql(), "BEGIN")
|
||||
|
||||
expression = parse_one("BEGIN DEFERRED TRANSACTION")
|
||||
self.assertEqual(expression.this, "DEFERRED")
|
||||
self.assertEqual(expression.args["modes"], [])
|
||||
self.assertEqual(expression.sql(), "BEGIN")
|
||||
|
||||
expression = parse_one(
|
||||
"START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE", read="presto"
|
||||
)
|
||||
self.assertIsNone(expression.this)
|
||||
self.assertEqual(expression.args["modes"][0], "READ WRITE")
|
||||
self.assertEqual(expression.args["modes"][1], "ISOLATION LEVEL SERIALIZABLE")
|
||||
self.assertEqual(expression.sql(), "BEGIN")
|
||||
|
||||
expression = parse_one("BEGIN", read="bigquery")
|
||||
self.assertNotIsInstance(expression, exp.Transaction)
|
||||
self.assertIsNone(expression.expression)
|
||||
self.assertEqual(expression.sql(), "BEGIN")
|
||||
|
||||
def test_identify(self):
|
||||
expression = parse_one(
|
||||
"""
|
||||
|
@ -55,14 +84,14 @@ class TestParser(unittest.TestCase):
|
|||
"""
|
||||
)
|
||||
|
||||
assert expression.expressions[0].text("this") == "a"
|
||||
assert expression.expressions[1].text("this") == "b"
|
||||
assert expression.expressions[2].text("alias") == "c"
|
||||
assert expression.expressions[3].text("alias") == "D"
|
||||
assert expression.expressions[4].text("alias") == "y|z'"
|
||||
assert expression.expressions[0].name == "a"
|
||||
assert expression.expressions[1].name == "b"
|
||||
assert expression.expressions[2].alias == "c"
|
||||
assert expression.expressions[3].alias == "D"
|
||||
assert expression.expressions[4].alias == "y|z'"
|
||||
table = expression.args["from"].expressions[0]
|
||||
assert table.args["this"].args["this"] == "z"
|
||||
assert table.args["db"].args["this"] == "y"
|
||||
assert table.this.name == "z"
|
||||
assert table.args["db"].name == "y"
|
||||
|
||||
def test_multi(self):
|
||||
expressions = parse(
|
||||
|
@ -72,8 +101,8 @@ class TestParser(unittest.TestCase):
|
|||
)
|
||||
|
||||
assert len(expressions) == 2
|
||||
assert expressions[0].args["from"].expressions[0].args["this"].args["this"] == "a"
|
||||
assert expressions[1].args["from"].expressions[0].args["this"].args["this"] == "b"
|
||||
assert expressions[0].args["from"].expressions[0].this.name == "a"
|
||||
assert expressions[1].args["from"].expressions[0].this.name == "b"
|
||||
|
||||
def test_expression(self):
|
||||
ignore = Parser(error_level=ErrorLevel.IGNORE)
|
||||
|
@ -200,7 +229,7 @@ class TestParser(unittest.TestCase):
|
|||
@patch("sqlglot.parser.logger")
|
||||
def test_comment_error_n(self, logger):
|
||||
parse_one(
|
||||
"""CREATE TABLE x
|
||||
"""SUM
|
||||
(
|
||||
-- test
|
||||
)""",
|
||||
|
@ -208,19 +237,19 @@ class TestParser(unittest.TestCase):
|
|||
)
|
||||
|
||||
assert_logger_contains(
|
||||
"Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Schema'>. Line 4, Col: 1.",
|
||||
"Required keyword: 'this' missing for <class 'sqlglot.expressions.Sum'>. Line 4, Col: 1.",
|
||||
logger,
|
||||
)
|
||||
|
||||
@patch("sqlglot.parser.logger")
|
||||
def test_comment_error_r(self, logger):
|
||||
parse_one(
|
||||
"""CREATE TABLE x (-- test\r)""",
|
||||
"""SUM(-- test\r)""",
|
||||
error_level=ErrorLevel.WARN,
|
||||
)
|
||||
|
||||
assert_logger_contains(
|
||||
"Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Schema'>. Line 2, Col: 1.",
|
||||
"Required keyword: 'this' missing for <class 'sqlglot.expressions.Sum'>. Line 2, Col: 1.",
|
||||
logger,
|
||||
)
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ class TestTokens(unittest.TestCase):
|
|||
("--comment\nfoo --test", "comment"),
|
||||
("foo --comment", "comment"),
|
||||
("foo", None),
|
||||
("foo /*comment 1*/ /*comment 2*/", "comment 1"),
|
||||
]
|
||||
|
||||
for sql, comment in sql_comment:
|
||||
|
|
|
@ -20,6 +20,13 @@ class TestTranspile(unittest.TestCase):
|
|||
self.assertEqual(transpile(sql, **kwargs)[0], target)
|
||||
|
||||
def test_alias(self):
|
||||
self.assertEqual(transpile("SELECT 1 current_time")[0], "SELECT 1 AS current_time")
|
||||
self.assertEqual(
|
||||
transpile("SELECT 1 current_timestamp")[0], "SELECT 1 AS current_timestamp"
|
||||
)
|
||||
self.assertEqual(transpile("SELECT 1 current_date")[0], "SELECT 1 AS current_date")
|
||||
self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime")
|
||||
|
||||
for key in ("union", "filter", "over", "from", "join"):
|
||||
with self.subTest(f"alias {key}"):
|
||||
self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}")
|
||||
|
@ -69,6 +76,10 @@ class TestTranspile(unittest.TestCase):
|
|||
self.validate("SELECT 3>=3", "SELECT 3 >= 3")
|
||||
|
||||
def test_comments(self):
|
||||
self.validate("SELECT */*comment*/", "SELECT * /* comment */")
|
||||
self.validate(
|
||||
"SELECT * FROM table /*comment 1*/ /*comment 2*/", "SELECT * FROM table /* comment 1 */"
|
||||
)
|
||||
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */")
|
||||
self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo")
|
||||
self.validate("SELECT --!5\nx FROM foo", "/* !5 */ SELECT x FROM foo")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue