1
0
Fork 0

Merging upstream version 9.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:48:46 +01:00
parent ebb36a5fc5
commit 4483b8ff47
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
87 changed files with 7994 additions and 421 deletions

View file

View file

@ -0,0 +1,35 @@
import typing as t
import unittest
from sqlglot.dataframe.sql import types
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.session import SparkSession
class DataFrameSQLValidator(unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession()
self.employee_schema = types.StructType(
[
types.StructField("employee_id", types.IntegerType(), False),
types.StructField("fname", types.StringType(), False),
types.StructField("lname", types.StringType(), False),
types.StructField("age", types.IntegerType(), False),
types.StructField("store_id", types.IntegerType(), False),
]
)
employee_data = [
(1, "Jack", "Shephard", 37, 1),
(2, "John", "Locke", 65, 1),
(3, "Kate", "Austen", 37, 2),
(4, "Claire", "Littleton", 27, 2),
(5, "Hugo", "Reyes", 29, 100),
]
self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema)
def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False):
actual_sqls = df.sql(pretty=pretty)
expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements
self.assertEqual(len(expected_statements), len(actual_sqls))
for expected, actual in zip(expected_statements, actual_sqls):
self.assertEqual(expected, actual)

View file

@ -0,0 +1,167 @@
import datetime
import unittest
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.window import Window
class TestDataframeColumn(unittest.TestCase):
def test_eq(self):
self.assertEqual("cola = 1", (F.col("cola") == 1).sql())
def test_neq(self):
self.assertEqual("cola <> 1", (F.col("cola") != 1).sql())
def test_gt(self):
self.assertEqual("cola > 1", (F.col("cola") > 1).sql())
def test_lt(self):
self.assertEqual("cola < 1", (F.col("cola") < 1).sql())
def test_le(self):
self.assertEqual("cola <= 1", (F.col("cola") <= 1).sql())
def test_ge(self):
self.assertEqual("cola >= 1", (F.col("cola") >= 1).sql())
def test_and(self):
self.assertEqual(
"cola = colb AND colc = cold", ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql()
)
def test_or(self):
self.assertEqual(
"cola = colb OR colc = cold", ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql()
)
def test_mod(self):
self.assertEqual("cola % 2", (F.col("cola") % 2).sql())
def test_add(self):
self.assertEqual("cola + 1", (F.col("cola") + 1).sql())
def test_sub(self):
self.assertEqual("cola - 1", (F.col("cola") - 1).sql())
def test_mul(self):
self.assertEqual("cola * 2", (F.col("cola") * 2).sql())
def test_div(self):
self.assertEqual("cola / 2", (F.col("cola") / 2).sql())
def test_radd(self):
self.assertEqual("1 + cola", (1 + F.col("cola")).sql())
def test_rsub(self):
self.assertEqual("1 - cola", (1 - F.col("cola")).sql())
def test_rmul(self):
self.assertEqual("1 * cola", (1 * F.col("cola")).sql())
def test_rdiv(self):
self.assertEqual("1 / cola", (1 / F.col("cola")).sql())
def test_pow(self):
self.assertEqual("POWER(cola, 2)", (F.col("cola") ** 2).sql())
def test_rpow(self):
self.assertEqual("POWER(2, cola)", (2 ** F.col("cola")).sql())
def test_invert(self):
self.assertEqual("NOT cola", (~F.col("cola")).sql())
def test_startswith(self):
self.assertEqual("STARTSWITH(cola, 'test')", F.col("cola").startswith("test").sql())
def test_endswith(self):
self.assertEqual("ENDSWITH(cola, 'test')", F.col("cola").endswith("test").sql())
def test_rlike(self):
self.assertEqual("cola RLIKE 'foo'", F.col("cola").rlike("foo").sql())
def test_like(self):
self.assertEqual("cola LIKE 'foo%'", F.col("cola").like("foo%").sql())
def test_ilike(self):
self.assertEqual("cola ILIKE 'foo%'", F.col("cola").ilike("foo%").sql())
def test_substring(self):
self.assertEqual("SUBSTRING(cola, 2, 3)", F.col("cola").substr(2, 3).sql())
def test_isin(self):
self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin([1, 2, 3]).sql())
self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin(1, 2, 3).sql())
def test_asc(self):
self.assertEqual("cola", F.col("cola").asc().sql())
def test_desc(self):
self.assertEqual("cola DESC", F.col("cola").desc().sql())
def test_asc_nulls_first(self):
self.assertEqual("cola", F.col("cola").asc_nulls_first().sql())
def test_asc_nulls_last(self):
self.assertEqual("cola NULLS LAST", F.col("cola").asc_nulls_last().sql())
def test_desc_nulls_first(self):
self.assertEqual("cola DESC NULLS FIRST", F.col("cola").desc_nulls_first().sql())
def test_desc_nulls_last(self):
self.assertEqual("cola DESC", F.col("cola").desc_nulls_last().sql())
def test_when_otherwise(self):
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql())
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql())
self.assertEqual(
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
(F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(),
)
self.assertEqual(
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
F.col("cola").when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3).sql(),
)
self.assertEqual(
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 ELSE 4 END",
F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3).otherwise(4).sql(),
)
def test_is_null(self):
self.assertEqual("cola IS NULL", F.col("cola").isNull().sql())
def test_is_not_null(self):
self.assertEqual("NOT cola IS NULL", F.col("cola").isNotNull().sql())
def test_cast(self):
self.assertEqual("CAST(cola AS INT)", F.col("cola").cast("INT").sql())
def test_alias(self):
self.assertEqual("cola AS new_name", F.col("cola").alias("new_name").sql())
def test_between(self):
self.assertEqual("cola BETWEEN 1 AND 3", F.col("cola").between(1, 3).sql())
self.assertEqual("cola BETWEEN 10.1 AND 12.1", F.col("cola").between(10.1, 12.1).sql())
self.assertEqual(
"cola BETWEEN TO_DATE('2022-01-01') AND TO_DATE('2022-03-01')",
F.col("cola").between(datetime.date(2022, 1, 1), datetime.date(2022, 3, 1)).sql(),
)
self.assertEqual(
"cola BETWEEN CAST('2022-01-01 01:01:01' AS TIMESTAMP) " "AND CAST('2022-03-01 01:01:01' AS TIMESTAMP)",
F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(),
)
def test_over(self):
over_rows = F.sum("cola").over(
Window.partitionBy("colb").orderBy("colc").rowsBetween(1, Window.unboundedFollowing)
)
self.assertEqual(
"SUM(cola) OVER (PARTITION BY colb ORDER BY colc ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
over_rows.sql(),
)
over_range = F.sum("cola").over(
Window.partitionBy("colb").orderBy("colc").rangeBetween(1, Window.unboundedFollowing)
)
self.assertEqual(
"SUM(cola) OVER (PARTITION BY colb ORDER BY colc RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
over_range.sql(),
)

View file

@ -0,0 +1,39 @@
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.dataframe import DataFrame
from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataframe(DataFrameSQLValidator):
def test_hash_select_expression(self):
expression = exp.select("cola").from_("table")
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
def test_columns(self):
self.assertEqual(["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns)
def test_cache(self):
df = self.df_employee.select("fname").cache()
expected_statements = [
"DROP VIEW IF EXISTS t11623",
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
]
self.compare_sql(df, expected_statements)
def test_persist_default(self):
df = self.df_employee.select("fname").persist()
expected_statements = [
"DROP VIEW IF EXISTS t11623",
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
]
self.compare_sql(df, expected_statements)
def test_persist_storagelevel(self):
df = self.df_employee.select("fname").persist("DISK_ONLY_2")
expected_statements = [
"DROP VIEW IF EXISTS t11623",
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
]
self.compare_sql(df, expected_statements)

View file

@ -0,0 +1,86 @@
from unittest import mock
import sqlglot
from sqlglot.schema import MappingSchema
from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataFrameWriter(DataFrameSQLValidator):
def test_insertInto_full_path(self):
df = self.df_employee.write.insertInto("catalog.db.table_name")
expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_db_table(self):
df = self.df_employee.write.insertInto("db.table_name")
expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_table(self):
df = self.df_employee.write.insertInto("table_name")
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_overwrite(self):
df = self.df_employee.write.insertInto("table_name", overwrite=True)
expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema())
def test_insertInto_byName(self):
sqlglot.schema.add_table("table_name", {"employee_id": "INT"})
df = self.df_employee.write.byName.insertInto("table_name")
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_cache(self):
df = self.df_employee.cache().write.insertInto("table_name")
expected_statements = [
"DROP VIEW IF EXISTS t35612",
"CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"INSERT INTO table_name SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`",
]
self.compare_sql(df, expected_statements)
def test_saveAsTable_format(self):
with self.assertRaises(NotImplementedError):
self.df_employee.write.saveAsTable("table_name", format="parquet").sql(pretty=False)[0]
def test_saveAsTable_append(self):
df = self.df_employee.write.saveAsTable("table_name", mode="append")
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_overwrite(self):
df = self.df_employee.write.saveAsTable("table_name", mode="overwrite")
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_error(self):
df = self.df_employee.write.saveAsTable("table_name", mode="error")
expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_ignore(self):
df = self.df_employee.write.saveAsTable("table_name", mode="ignore")
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_mode_standalone(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name")
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_mode_override(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite")
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_cache(self):
df = self.df_employee.cache().write.saveAsTable("table_name")
expected_statements = [
"DROP VIEW IF EXISTS t35612",
"CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"CREATE TABLE table_name AS SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`",
]
self.compare_sql(df, expected_statements)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,114 @@
from unittest import mock
import sqlglot
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql import types
from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.schema import MappingSchema
from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataframeSession(DataFrameSQLValidator):
def test_cdf_one_row(self):
df = self.spark.createDataFrame([[1, 2]], ["cola", "colb"])
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2)) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_multiple_rows(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]], ["cola", "colb"])
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_no_schema(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
expected = (
"SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
)
self.compare_sql(df, expected)
def test_cdf_row_mixed_primitives(self):
df = self.spark.createDataFrame([[1, 10.1, "test", False, None]])
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM (VALUES (1, 10.1, 'test', FALSE, NULL)) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)"
self.compare_sql(df, expected)
def test_cdf_dict_rows(self):
df = self.spark.createDataFrame([{"cola": 1, "colb": "test"}, {"cola": 2, "colb": "test2"}])
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 'test'), (2, 'test2')) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_str_schema(self):
df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING")
expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_typed_schema_basic(self):
schema = types.StructType(
[
types.StructField("cola", types.IntegerType()),
types.StructField("colb", types.StringType()),
]
)
df = self.spark.createDataFrame([[1, "test"]], schema)
expected = "SELECT CAST(`a2`.`cola` AS int) AS `cola`, CAST(`a2`.`colb` AS string) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_typed_schema_nested(self):
schema = types.StructType(
[
types.StructField(
"cola",
types.StructType(
[
types.StructField("sub_cola", types.IntegerType()),
types.StructField("sub_colb", types.StringType()),
]
),
)
]
)
df = self.spark.createDataFrame([[{"sub_cola": 1, "sub_colb": "test"}]], schema)
expected = "SELECT CAST(`a2`.`cola` AS struct<sub_cola:int, sub_colb:string>) AS `cola` FROM (VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`))) AS `a2`(`cola`)"
self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema())
def test_sql_select_only(self):
# TODO: Do exact matches once CTE names are deterministic
query = "SELECT cola, colb FROM table"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query)
self.assertIn(
"SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", df.sql(pretty=False)
)
@mock.patch("sqlglot.schema", MappingSchema())
def test_sql_with_aggs(self):
# TODO: Do exact matches once CTE names are deterministic
query = "SELECT cola, colb FROM table"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query).groupBy(F.col("cola")).agg(F.sum("colb"))
result = df.sql(pretty=False, optimize=False)[0]
self.assertIn("SELECT cola, colb FROM table", result)
self.assertIn("SUM(colb)", result)
self.assertIn("GROUP BY cola", result)
@mock.patch("sqlglot.schema", MappingSchema())
def test_sql_create(self):
query = "CREATE TABLE new_table AS WITH t1 AS (SELECT cola, colb FROM table) SELECT cola, colb, FROM t1"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query)
expected = "CREATE TABLE new_table AS SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema())
def test_sql_insert(self):
query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query)
expected = (
"INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
)
self.compare_sql(df, expected)
def test_session_create_builder_patterns(self):
spark = SparkSession()
self.assertEqual(spark.builder.appName("abc").getOrCreate(), spark)

View file

@ -0,0 +1,70 @@
import unittest
from sqlglot.dataframe.sql import types
class TestDataframeTypes(unittest.TestCase):
def test_string(self):
self.assertEqual("string", types.StringType().simpleString())
def test_char(self):
self.assertEqual("char(100)", types.CharType(100).simpleString())
def test_varchar(self):
self.assertEqual("varchar(65)", types.VarcharType(65).simpleString())
def test_binary(self):
self.assertEqual("binary", types.BinaryType().simpleString())
def test_boolean(self):
self.assertEqual("boolean", types.BooleanType().simpleString())
def test_date(self):
self.assertEqual("date", types.DateType().simpleString())
def test_timestamp(self):
self.assertEqual("timestamp", types.TimestampType().simpleString())
def test_timestamp_ntz(self):
self.assertEqual("timestamp_ntz", types.TimestampNTZType().simpleString())
def test_decimal(self):
self.assertEqual("decimal(10, 3)", types.DecimalType(10, 3).simpleString())
def test_double(self):
self.assertEqual("double", types.DoubleType().simpleString())
def test_float(self):
self.assertEqual("float", types.FloatType().simpleString())
def test_byte(self):
self.assertEqual("tinyint", types.ByteType().simpleString())
def test_integer(self):
self.assertEqual("int", types.IntegerType().simpleString())
def test_long(self):
self.assertEqual("bigint", types.LongType().simpleString())
def test_short(self):
self.assertEqual("smallint", types.ShortType().simpleString())
def test_array(self):
self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString())
def test_map(self):
self.assertEqual("map<int, string>", types.MapType(types.IntegerType(), types.StringType()).simpleString())
def test_struct_field(self):
self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString())
def test_struct_type(self):
self.assertEqual(
"struct<cola:int, colb:string>",
types.StructType(
[
types.StructField("cola", types.IntegerType()),
types.StructField("colb", types.StringType()),
]
).simpleString(),
)

View file

@ -0,0 +1,60 @@
import unittest
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.window import Window, WindowSpec
class TestDataframeWindow(unittest.TestCase):
def test_window_spec_partition_by(self):
partition_by = WindowSpec().partitionBy(F.col("cola"), F.col("colb"))
self.assertEqual("OVER (PARTITION BY cola, colb)", partition_by.sql())
def test_window_spec_order_by(self):
order_by = WindowSpec().orderBy("cola", "colb")
self.assertEqual("OVER (ORDER BY cola, colb)", order_by.sql())
def test_window_spec_rows_between(self):
rows_between = WindowSpec().rowsBetween(3, 5)
self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
def test_window_spec_range_between(self):
range_between = WindowSpec().rangeBetween(3, 5)
self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
def test_window_partition_by(self):
partition_by = Window.partitionBy(F.col("cola"), F.col("colb"))
self.assertEqual("OVER (PARTITION BY cola, colb)", partition_by.sql())
def test_window_order_by(self):
order_by = Window.orderBy("cola", "colb")
self.assertEqual("OVER (ORDER BY cola, colb)", order_by.sql())
def test_window_rows_between(self):
rows_between = Window.rowsBetween(3, 5)
self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
def test_window_range_between(self):
range_between = Window.rangeBetween(3, 5)
self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
def test_window_rows_unbounded(self):
rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
self.assertEqual("OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", rows_between_unbounded_start.sql())
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
self.assertEqual("OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_end.sql())
rows_between_unbounded_both = Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
self.assertEqual(
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_both.sql()
)
def test_window_range_unbounded(self):
range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
self.assertEqual(
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", range_between_unbounded_start.sql()
)
range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
self.assertEqual("OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_end.sql())
range_between_unbounded_both = Window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)
self.assertEqual(
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_both.sql()
)