Merging upstream version 10.0.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
528822bfd4
commit
b7d21c45b7
98 changed files with 4080 additions and 1666 deletions
|
@ -1,9 +1,9 @@
|
|||
import sys
|
||||
import typing as t
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import sqlglot
|
||||
from sqlglot.helper import PYTHON_VERSION
|
||||
from tests.helpers import SKIP_INTEGRATION
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
@ -11,7 +11,8 @@ if t.TYPE_CHECKING:
|
|||
|
||||
|
||||
@unittest.skipIf(
|
||||
SKIP_INTEGRATION or sys.version_info[:2] > (3, 10), "Skipping Integration Tests since `SKIP_INTEGRATION` is set"
|
||||
SKIP_INTEGRATION or PYTHON_VERSION > (3, 10),
|
||||
"Skipping Integration Tests since `SKIP_INTEGRATION` is set",
|
||||
)
|
||||
class DataFrameValidator(unittest.TestCase):
|
||||
spark = None
|
||||
|
@ -36,7 +37,12 @@ class DataFrameValidator(unittest.TestCase):
|
|||
|
||||
# This is for test `test_branching_root_dataframes`
|
||||
config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")])
|
||||
cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate()
|
||||
cls.spark = (
|
||||
SparkSession.builder.master("local[*]")
|
||||
.appName("Unit-tests")
|
||||
.config(conf=config)
|
||||
.getOrCreate()
|
||||
)
|
||||
cls.spark.sparkContext.setLogLevel("ERROR")
|
||||
cls.sqlglot = SqlglotSparkSession()
|
||||
cls.spark_employee_schema = types.StructType(
|
||||
|
@ -50,7 +56,9 @@ class DataFrameValidator(unittest.TestCase):
|
|||
)
|
||||
cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType(
|
||||
[
|
||||
sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField(
|
||||
"employee_id", sqlglotSparkTypes.IntegerType(), False
|
||||
),
|
||||
sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False),
|
||||
|
@ -64,8 +72,12 @@ class DataFrameValidator(unittest.TestCase):
|
|||
(4, "Claire", "Littleton", 27, 2),
|
||||
(5, "Hugo", "Reyes", 29, 100),
|
||||
]
|
||||
cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema)
|
||||
cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema)
|
||||
cls.df_employee = cls.spark.createDataFrame(
|
||||
data=employee_data, schema=cls.spark_employee_schema
|
||||
)
|
||||
cls.dfs_employee = cls.sqlglot.createDataFrame(
|
||||
data=employee_data, schema=cls.sqlglot_employee_schema
|
||||
)
|
||||
cls.df_employee.createOrReplaceTempView("employee")
|
||||
|
||||
cls.spark_store_schema = types.StructType(
|
||||
|
@ -80,7 +92,9 @@ class DataFrameValidator(unittest.TestCase):
|
|||
[
|
||||
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField(
|
||||
"district_id", sqlglotSparkTypes.IntegerType(), False
|
||||
),
|
||||
sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False),
|
||||
]
|
||||
)
|
||||
|
@ -89,7 +103,9 @@ class DataFrameValidator(unittest.TestCase):
|
|||
(2, "Arrow", 2, 2000),
|
||||
]
|
||||
cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema)
|
||||
cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema)
|
||||
cls.dfs_store = cls.sqlglot.createDataFrame(
|
||||
data=store_data, schema=cls.sqlglot_store_schema
|
||||
)
|
||||
cls.df_store.createOrReplaceTempView("store")
|
||||
|
||||
cls.spark_district_schema = types.StructType(
|
||||
|
@ -101,17 +117,27 @@ class DataFrameValidator(unittest.TestCase):
|
|||
)
|
||||
cls.sqlglot_district_schema = sqlglotSparkTypes.StructType(
|
||||
[
|
||||
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField(
|
||||
"district_id", sqlglotSparkTypes.IntegerType(), False
|
||||
),
|
||||
sqlglotSparkTypes.StructField(
|
||||
"district_name", sqlglotSparkTypes.StringType(), False
|
||||
),
|
||||
sqlglotSparkTypes.StructField(
|
||||
"manager_name", sqlglotSparkTypes.StringType(), False
|
||||
),
|
||||
]
|
||||
)
|
||||
district_data = [
|
||||
(1, "Temple", "Dogen"),
|
||||
(2, "Lighthouse", "Jacob"),
|
||||
]
|
||||
cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema)
|
||||
cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema)
|
||||
cls.df_district = cls.spark.createDataFrame(
|
||||
data=district_data, schema=cls.spark_district_schema
|
||||
)
|
||||
cls.dfs_district = cls.sqlglot.createDataFrame(
|
||||
data=district_data, schema=cls.sqlglot_district_schema
|
||||
)
|
||||
cls.df_district.createOrReplaceTempView("district")
|
||||
sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema)
|
||||
sqlglot.schema.add_table("store", cls.sqlglot_store_schema)
|
||||
|
|
|
@ -41,22 +41,32 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
|
||||
def test_alias_with_select(self):
|
||||
df_employee = self.df_spark_employee.alias("df_employee").select(
|
||||
self.df_spark_employee["employee_id"], F.col("df_employee.fname"), self.df_spark_employee.lname
|
||||
self.df_spark_employee["employee_id"],
|
||||
F.col("df_employee.fname"),
|
||||
self.df_spark_employee.lname,
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select(
|
||||
self.df_sqlglot_employee["employee_id"], SF.col("dfs_employee.fname"), self.df_sqlglot_employee.lname
|
||||
self.df_sqlglot_employee["employee_id"],
|
||||
SF.col("dfs_employee.fname"),
|
||||
self.df_sqlglot_employee.lname,
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
def test_case_when_otherwise(self):
|
||||
df = self.df_spark_employee.select(
|
||||
F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60"))
|
||||
F.when(
|
||||
(F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
|
||||
F.lit("between 40 and 60"),
|
||||
)
|
||||
.when(F.col("age") < F.lit(40), "less than 40")
|
||||
.otherwise("greater than 60")
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60"))
|
||||
SF.when(
|
||||
(SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
|
||||
SF.lit("between 40 and 60"),
|
||||
)
|
||||
.when(SF.col("age") < SF.lit(40), "less than 40")
|
||||
.otherwise("greater than 60")
|
||||
)
|
||||
|
@ -65,15 +75,17 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
|
||||
def test_case_when_no_otherwise(self):
|
||||
df = self.df_spark_employee.select(
|
||||
F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")).when(
|
||||
F.col("age") < F.lit(40), "less than 40"
|
||||
)
|
||||
F.when(
|
||||
(F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
|
||||
F.lit("between 40 and 60"),
|
||||
).when(F.col("age") < F.lit(40), "less than 40")
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")).when(
|
||||
SF.col("age") < SF.lit(40), "less than 40"
|
||||
)
|
||||
SF.when(
|
||||
(SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
|
||||
SF.lit("between 40 and 60"),
|
||||
).when(SF.col("age") < SF.lit(40), "less than 40")
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
@ -84,7 +96,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
def test_where_clause_multiple_and(self):
|
||||
df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack")))
|
||||
df_employee = self.df_spark_employee.where(
|
||||
(F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.where(
|
||||
(SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack"))
|
||||
)
|
||||
|
@ -106,7 +120,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
def test_where_clause_multiple_or(self):
|
||||
df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate")))
|
||||
df_employee = self.df_spark_employee.where(
|
||||
(F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.where(
|
||||
(SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate"))
|
||||
)
|
||||
|
@ -172,28 +188,43 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37))
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] % F.lit(5) == F.lit(0))
|
||||
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0))
|
||||
df_employee = self.df_spark_employee.where(
|
||||
self.df_spark_employee["age"] % F.lit(5) == F.lit(0)
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.where(
|
||||
self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] + F.lit(5) > F.lit(28))
|
||||
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28))
|
||||
df_employee = self.df_spark_employee.where(
|
||||
self.df_spark_employee["age"] + F.lit(5) > F.lit(28)
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.where(
|
||||
self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] - F.lit(5) > F.lit(28))
|
||||
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28))
|
||||
df_employee = self.df_spark_employee.where(
|
||||
self.df_spark_employee["age"] - F.lit(5) > F.lit(28)
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.where(
|
||||
self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
df_employee = self.df_spark_employee.where(
|
||||
self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2)
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.where(
|
||||
self.df_sqlglot_employee["age"] * SF.lit(0.5) == self.df_sqlglot_employee["age"] / SF.lit(2)
|
||||
self.df_sqlglot_employee["age"] * SF.lit(0.5)
|
||||
== self.df_sqlglot_employee["age"] / SF.lit(2)
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
def test_join_inner(self):
|
||||
df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="inner").select(
|
||||
df_joined = self.df_spark_employee.join(
|
||||
self.df_spark_store, on=["store_id"], how="inner"
|
||||
).select(
|
||||
self.df_spark_employee.employee_id,
|
||||
self.df_spark_employee["fname"],
|
||||
F.col("lname"),
|
||||
|
@ -202,7 +233,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.df_spark_store.store_name,
|
||||
self.df_spark_store["num_sales"],
|
||||
)
|
||||
dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="inner").select(
|
||||
dfs_joined = self.df_sqlglot_employee.join(
|
||||
self.df_sqlglot_store, on=["store_id"], how="inner"
|
||||
).select(
|
||||
self.df_sqlglot_employee.employee_id,
|
||||
self.df_sqlglot_employee["fname"],
|
||||
SF.col("lname"),
|
||||
|
@ -214,17 +247,27 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
|
||||
|
||||
def test_join_inner_no_select(self):
|
||||
df_joined = self.df_spark_employee.select(F.col("store_id"), F.col("fname"), F.col("lname")).join(
|
||||
self.df_spark_store.select(F.col("store_id"), F.col("store_name")), on=["store_id"], how="inner"
|
||||
df_joined = self.df_spark_employee.select(
|
||||
F.col("store_id"), F.col("fname"), F.col("lname")
|
||||
).join(
|
||||
self.df_spark_store.select(F.col("store_id"), F.col("store_name")),
|
||||
on=["store_id"],
|
||||
how="inner",
|
||||
)
|
||||
dfs_joined = self.df_sqlglot_employee.select(SF.col("store_id"), SF.col("fname"), SF.col("lname")).join(
|
||||
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), on=["store_id"], how="inner"
|
||||
dfs_joined = self.df_sqlglot_employee.select(
|
||||
SF.col("store_id"), SF.col("fname"), SF.col("lname")
|
||||
).join(
|
||||
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")),
|
||||
on=["store_id"],
|
||||
how="inner",
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
|
||||
|
||||
def test_join_inner_equality_single(self):
|
||||
df_joined = self.df_spark_employee.join(
|
||||
self.df_spark_store, on=self.df_spark_employee.store_id == self.df_spark_store.store_id, how="inner"
|
||||
self.df_spark_store,
|
||||
on=self.df_spark_employee.store_id == self.df_spark_store.store_id,
|
||||
how="inner",
|
||||
).select(
|
||||
self.df_spark_employee.employee_id,
|
||||
self.df_spark_employee["fname"],
|
||||
|
@ -235,7 +278,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.df_spark_store["num_sales"],
|
||||
)
|
||||
dfs_joined = self.df_sqlglot_employee.join(
|
||||
self.df_sqlglot_store, on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, how="inner"
|
||||
self.df_sqlglot_store,
|
||||
on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
|
||||
how="inner",
|
||||
).select(
|
||||
self.df_sqlglot_employee.employee_id,
|
||||
self.df_sqlglot_employee["fname"],
|
||||
|
@ -343,7 +388,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
|
||||
|
||||
def test_join_full_outer(self):
|
||||
df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="full_outer").select(
|
||||
df_joined = self.df_spark_employee.join(
|
||||
self.df_spark_store, on=["store_id"], how="full_outer"
|
||||
).select(
|
||||
self.df_spark_employee.employee_id,
|
||||
self.df_spark_employee["fname"],
|
||||
F.col("lname"),
|
||||
|
@ -352,7 +399,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.df_spark_store.store_name,
|
||||
self.df_spark_store["num_sales"],
|
||||
)
|
||||
dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="full_outer").select(
|
||||
dfs_joined = self.df_sqlglot_employee.join(
|
||||
self.df_sqlglot_store, on=["store_id"], how="full_outer"
|
||||
).select(
|
||||
self.df_sqlglot_employee.employee_id,
|
||||
self.df_sqlglot_employee["fname"],
|
||||
SF.col("lname"),
|
||||
|
@ -365,7 +414,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
|
||||
def test_triple_join(self):
|
||||
df = (
|
||||
self.df_employee.join(self.df_store, on=self.df_employee.employee_id == self.df_store.store_id)
|
||||
self.df_employee.join(
|
||||
self.df_store, on=self.df_employee.employee_id == self.df_store.store_id
|
||||
)
|
||||
.join(self.df_district, on=self.df_store.store_id == self.df_district.district_id)
|
||||
.select(
|
||||
self.df_employee.employee_id,
|
||||
|
@ -377,7 +428,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
)
|
||||
)
|
||||
dfs = (
|
||||
self.dfs_employee.join(self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id)
|
||||
self.dfs_employee.join(
|
||||
self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id
|
||||
)
|
||||
.join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id)
|
||||
.select(
|
||||
self.dfs_employee.employee_id,
|
||||
|
@ -391,13 +444,13 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_join_select_and_select_start(self):
|
||||
df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")).join(
|
||||
self.df_spark_store, "store_id", "inner"
|
||||
)
|
||||
df = self.df_spark_employee.select(
|
||||
F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")
|
||||
).join(self.df_spark_store, "store_id", "inner")
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")).join(
|
||||
self.df_sqlglot_store, "store_id", "inner"
|
||||
)
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
|
||||
).join(self.df_sqlglot_store, "store_id", "inner")
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
|
@ -485,13 +538,17 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
dfs_unioned = (
|
||||
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"))
|
||||
.unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")))
|
||||
.unionAll(self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name")))
|
||||
.unionAll(
|
||||
self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))
|
||||
)
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df_unioned, dfs_unioned)
|
||||
|
||||
def test_union_by_name(self):
|
||||
df = self.df_spark_employee.select(F.col("employee_id"), F.col("fname"), F.col("lname")).unionByName(
|
||||
df = self.df_spark_employee.select(
|
||||
F.col("employee_id"), F.col("fname"), F.col("lname")
|
||||
).unionByName(
|
||||
self.df_spark_store.select(
|
||||
F.col("store_name").alias("lname"),
|
||||
F.col("store_id").alias("employee_id"),
|
||||
|
@ -499,7 +556,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
)
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"), SF.col("lname")).unionByName(
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.col("employee_id"), SF.col("fname"), SF.col("lname")
|
||||
).unionByName(
|
||||
self.df_sqlglot_store.select(
|
||||
SF.col("store_name").alias("lname"),
|
||||
SF.col("store_id").alias("employee_id"),
|
||||
|
@ -537,10 +596,16 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_order_by_default(self):
|
||||
df = self.df_spark_store.groupBy(F.col("district_id")).agg(F.min("num_sales")).orderBy(F.col("district_id"))
|
||||
df = (
|
||||
self.df_spark_store.groupBy(F.col("district_id"))
|
||||
.agg(F.min("num_sales"))
|
||||
.orderBy(F.col("district_id"))
|
||||
)
|
||||
|
||||
dfs = (
|
||||
self.df_sqlglot_store.groupBy(SF.col("district_id")).agg(SF.min("num_sales")).orderBy(SF.col("district_id"))
|
||||
self.df_sqlglot_store.groupBy(SF.col("district_id"))
|
||||
.agg(SF.min("num_sales"))
|
||||
.orderBy(SF.col("district_id"))
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
@ -594,13 +659,17 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
df = (
|
||||
self.df_spark_store.groupBy(F.col("district_id"))
|
||||
.agg(F.min("num_sales").alias("total_sales"))
|
||||
.orderBy(F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last())
|
||||
.orderBy(
|
||||
F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()
|
||||
)
|
||||
)
|
||||
|
||||
dfs = (
|
||||
self.df_sqlglot_store.groupBy(SF.col("district_id"))
|
||||
.agg(SF.min("num_sales").alias("total_sales"))
|
||||
.orderBy(SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last())
|
||||
.orderBy(
|
||||
SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()
|
||||
)
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
@ -609,81 +678,87 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
df = (
|
||||
self.df_spark_store.groupBy(F.col("district_id"))
|
||||
.agg(F.min("num_sales").alias("total_sales"))
|
||||
.orderBy(F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first())
|
||||
.orderBy(
|
||||
F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()
|
||||
)
|
||||
)
|
||||
|
||||
dfs = (
|
||||
self.df_sqlglot_store.groupBy(SF.col("district_id"))
|
||||
.agg(SF.min("num_sales").alias("total_sales"))
|
||||
.orderBy(SF.when(SF.col("district_id") == SF.lit(1), SF.col("district_id")).desc_nulls_first())
|
||||
.orderBy(
|
||||
SF.when(
|
||||
SF.col("district_id") == SF.lit(1), SF.col("district_id")
|
||||
).desc_nulls_first()
|
||||
)
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_intersect(self):
|
||||
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
|
||||
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
|
||||
)
|
||||
df_employee_duplicate = self.df_spark_employee.select(
|
||||
F.col("employee_id"), F.col("store_id")
|
||||
).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
|
||||
|
||||
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
|
||||
self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
|
||||
)
|
||||
df_store_duplicate = self.df_spark_store.select(
|
||||
F.col("store_id"), F.col("district_id")
|
||||
).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
|
||||
|
||||
df = df_employee_duplicate.intersect(df_store_duplicate)
|
||||
|
||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
|
||||
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
|
||||
)
|
||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(
|
||||
SF.col("employee_id"), SF.col("store_id")
|
||||
).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
|
||||
|
||||
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
|
||||
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
|
||||
)
|
||||
dfs_store_duplicate = self.df_sqlglot_store.select(
|
||||
SF.col("store_id"), SF.col("district_id")
|
||||
).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
|
||||
|
||||
dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_intersect_all(self):
|
||||
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
|
||||
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
|
||||
)
|
||||
df_employee_duplicate = self.df_spark_employee.select(
|
||||
F.col("employee_id"), F.col("store_id")
|
||||
).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
|
||||
|
||||
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
|
||||
self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
|
||||
)
|
||||
df_store_duplicate = self.df_spark_store.select(
|
||||
F.col("store_id"), F.col("district_id")
|
||||
).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
|
||||
|
||||
df = df_employee_duplicate.intersectAll(df_store_duplicate)
|
||||
|
||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
|
||||
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
|
||||
)
|
||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(
|
||||
SF.col("employee_id"), SF.col("store_id")
|
||||
).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
|
||||
|
||||
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
|
||||
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
|
||||
)
|
||||
dfs_store_duplicate = self.df_sqlglot_store.select(
|
||||
SF.col("store_id"), SF.col("district_id")
|
||||
).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
|
||||
|
||||
dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_except_all(self):
|
||||
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
|
||||
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
|
||||
)
|
||||
df_employee_duplicate = self.df_spark_employee.select(
|
||||
F.col("employee_id"), F.col("store_id")
|
||||
).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
|
||||
|
||||
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
|
||||
self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
|
||||
)
|
||||
df_store_duplicate = self.df_spark_store.select(
|
||||
F.col("store_id"), F.col("district_id")
|
||||
).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
|
||||
|
||||
df = df_employee_duplicate.exceptAll(df_store_duplicate)
|
||||
|
||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
|
||||
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
|
||||
)
|
||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(
|
||||
SF.col("employee_id"), SF.col("store_id")
|
||||
).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
|
||||
|
||||
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
|
||||
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
|
||||
)
|
||||
dfs_store_duplicate = self.df_sqlglot_store.select(
|
||||
SF.col("store_id"), SF.col("district_id")
|
||||
).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
|
||||
|
||||
dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate)
|
||||
|
||||
|
@ -721,7 +796,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_drop_na_default(self):
|
||||
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).dropna()
|
||||
df = self.df_spark_employee.select(
|
||||
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
|
||||
).dropna()
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||
|
@ -746,7 +823,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
).dropna(how="any", thresh=2)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||
SF.lit(None),
|
||||
SF.lit(1),
|
||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
|
||||
).dropna(how="any", thresh=2)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
@ -757,13 +836,17 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
).dropna(thresh=1, subset="the_age")
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||
SF.lit(None),
|
||||
SF.lit(1),
|
||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
|
||||
).dropna(thresh=1, subset="the_age")
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
||||
def test_dropna_na_function(self):
|
||||
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.drop()
|
||||
df = self.df_spark_employee.select(
|
||||
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
|
||||
).na.drop()
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||
|
@ -772,7 +855,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_fillna_default(self):
|
||||
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).fillna(100)
|
||||
df = self.df_spark_employee.select(
|
||||
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
|
||||
).fillna(100)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||
|
@ -798,7 +883,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
||||
def test_fillna_na_func(self):
|
||||
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.fill(100)
|
||||
df = self.df_spark_employee.select(
|
||||
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
|
||||
).na.fill(100)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||
|
@ -807,7 +894,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_replace_basic(self):
|
||||
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(to_replace=37, value=100)
|
||||
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
|
||||
to_replace=37, value=100
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
|
||||
to_replace=37, value=100
|
||||
|
@ -827,9 +916,13 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
||||
def test_replace_mapping(self):
|
||||
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace({37: 100})
|
||||
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
|
||||
{37: 100}
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace({37: 100})
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
|
||||
{37: 100}
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
||||
|
@ -849,9 +942,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
to_replace=37, value=100
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).na.replace(
|
||||
to_replace=37, value=100
|
||||
)
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.col("age"), SF.lit(37).alias("test_col")
|
||||
).na.replace(to_replace=37, value=100)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
||||
|
@ -881,16 +974,18 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
"first_name", "first_name_again"
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("fname").alias("first_name")).withColumnRenamed(
|
||||
"first_name", "first_name_again"
|
||||
)
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.col("fname").alias("first_name")
|
||||
).withColumnRenamed("first_name", "first_name_again")
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_drop_column_single(self):
|
||||
df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age")
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop("age")
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop(
|
||||
"age"
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
|
@ -906,7 +1001,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
df_sqlglot_employee_cols = self.df_sqlglot_employee.select(
|
||||
SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
|
||||
)
|
||||
df_sqlglot_store_cols = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name"))
|
||||
df_sqlglot_store_cols = self.df_sqlglot_store.select(
|
||||
SF.col("store_id"), SF.col("store_name")
|
||||
)
|
||||
dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop(
|
||||
df_sqlglot_employee_cols.age,
|
||||
)
|
||||
|
|
|
@ -23,6 +23,14 @@ class TestSessionFunc(DataFrameValidator):
|
|||
ON
|
||||
e.store_id = s.store_id
|
||||
"""
|
||||
df = self.spark.sql(query).groupBy(F.col("store_id")).agg(F.countDistinct(F.col("employee_id")))
|
||||
dfs = self.sqlglot.sql(query).groupBy(SF.col("store_id")).agg(SF.countDistinct(SF.col("employee_id")))
|
||||
df = (
|
||||
self.spark.sql(query)
|
||||
.groupBy(F.col("store_id"))
|
||||
.agg(F.countDistinct(F.col("employee_id")))
|
||||
)
|
||||
dfs = (
|
||||
self.sqlglot.sql(query)
|
||||
.groupBy(SF.col("store_id"))
|
||||
.agg(SF.countDistinct(SF.col("employee_id")))
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
|
|
@ -25,11 +25,17 @@ class DataFrameSQLValidator(unittest.TestCase):
|
|||
(4, "Claire", "Littleton", 27, 2),
|
||||
(5, "Hugo", "Reyes", 29, 100),
|
||||
]
|
||||
self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema)
|
||||
self.df_employee = self.spark.createDataFrame(
|
||||
data=employee_data, schema=self.employee_schema
|
||||
)
|
||||
|
||||
def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False):
|
||||
def compare_sql(
|
||||
self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False
|
||||
):
|
||||
actual_sqls = df.sql(pretty=pretty)
|
||||
expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements
|
||||
expected_statements = (
|
||||
[expected_statements] if isinstance(expected_statements, str) else expected_statements
|
||||
)
|
||||
self.assertEqual(len(expected_statements), len(actual_sqls))
|
||||
for expected, actual in zip(expected_statements, actual_sqls):
|
||||
self.assertEqual(expected, actual)
|
||||
|
|
|
@ -26,12 +26,14 @@ class TestDataframeColumn(unittest.TestCase):
|
|||
|
||||
def test_and(self):
|
||||
self.assertEqual(
|
||||
"cola = colb AND colc = cold", ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql()
|
||||
"cola = colb AND colc = cold",
|
||||
((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql(),
|
||||
)
|
||||
|
||||
def test_or(self):
|
||||
self.assertEqual(
|
||||
"cola = colb OR colc = cold", ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql()
|
||||
"cola = colb OR colc = cold",
|
||||
((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql(),
|
||||
)
|
||||
|
||||
def test_mod(self):
|
||||
|
@ -112,7 +114,9 @@ class TestDataframeColumn(unittest.TestCase):
|
|||
|
||||
def test_when_otherwise(self):
|
||||
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql())
|
||||
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql())
|
||||
self.assertEqual(
|
||||
"CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql()
|
||||
)
|
||||
self.assertEqual(
|
||||
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
|
||||
(F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(),
|
||||
|
@ -148,7 +152,9 @@ class TestDataframeColumn(unittest.TestCase):
|
|||
self.assertEqual(
|
||||
"cola BETWEEN CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP) "
|
||||
"AND CAST('2022-03-01 01:01:01.000000' AS TIMESTAMP)",
|
||||
F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(),
|
||||
F.col("cola")
|
||||
.between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1))
|
||||
.sql(),
|
||||
)
|
||||
|
||||
def test_over(self):
|
||||
|
|
|
@ -9,7 +9,9 @@ class TestDataframe(DataFrameSQLValidator):
|
|||
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
|
||||
|
||||
def test_columns(self):
|
||||
self.assertEqual(["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns)
|
||||
self.assertEqual(
|
||||
["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns
|
||||
)
|
||||
|
||||
def test_cache(self):
|
||||
df = self.df_employee.select("fname").cache()
|
||||
|
|
|
@ -925,12 +925,17 @@ class TestFunctions(unittest.TestCase):
|
|||
col = SF.window(SF.col("cola"), "10 minutes")
|
||||
self.assertEqual("WINDOW(cola, '10 minutes')", col.sql())
|
||||
col_all_values = SF.window("cola", "2 minutes 30 seconds", "30 seconds", "15 seconds")
|
||||
self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql())
|
||||
self.assertEqual(
|
||||
"WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql()
|
||||
)
|
||||
col_no_start_time = SF.window("cola", "2 minutes 30 seconds", "30 seconds")
|
||||
self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql())
|
||||
self.assertEqual(
|
||||
"WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql()
|
||||
)
|
||||
col_no_slide = SF.window("cola", "2 minutes 30 seconds", startTime="15 seconds")
|
||||
self.assertEqual(
|
||||
"WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')", col_no_slide.sql()
|
||||
"WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')",
|
||||
col_no_slide.sql(),
|
||||
)
|
||||
|
||||
def test_session_window(self):
|
||||
|
@ -1359,9 +1364,13 @@ class TestFunctions(unittest.TestCase):
|
|||
|
||||
def test_from_json(self):
|
||||
col_str = SF.from_json("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
|
||||
self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
|
||||
self.assertEqual(
|
||||
"FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
|
||||
)
|
||||
col = SF.from_json(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
|
||||
self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
|
||||
self.assertEqual(
|
||||
"FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()
|
||||
)
|
||||
col_no_option = SF.from_json("cola", "cola INT")
|
||||
self.assertEqual("FROM_JSON(cola, 'cola INT')", col_no_option.sql())
|
||||
|
||||
|
@ -1375,7 +1384,9 @@ class TestFunctions(unittest.TestCase):
|
|||
|
||||
def test_schema_of_json(self):
|
||||
col_str = SF.schema_of_json("cola", dict(timestampFormat="dd/MM/yyyy"))
|
||||
self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
|
||||
self.assertEqual(
|
||||
"SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
|
||||
)
|
||||
col = SF.schema_of_json(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy"))
|
||||
self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
|
||||
col_no_option = SF.schema_of_json("cola")
|
||||
|
@ -1429,7 +1440,10 @@ class TestFunctions(unittest.TestCase):
|
|||
col = SF.array_sort(SF.col("cola"))
|
||||
self.assertEqual("ARRAY_SORT(cola)", col.sql())
|
||||
col_comparator = SF.array_sort(
|
||||
"cola", lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(SF.length(y) - SF.length(x))
|
||||
"cola",
|
||||
lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(
|
||||
SF.length(y) - SF.length(x)
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
"ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)",
|
||||
|
@ -1504,9 +1518,13 @@ class TestFunctions(unittest.TestCase):
|
|||
|
||||
def test_from_csv(self):
|
||||
col_str = SF.from_csv("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
|
||||
self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
|
||||
self.assertEqual(
|
||||
"FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
|
||||
)
|
||||
col = SF.from_csv(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
|
||||
self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
|
||||
self.assertEqual(
|
||||
"FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()
|
||||
)
|
||||
col_no_option = SF.from_csv("cola", "cola INT")
|
||||
self.assertEqual("FROM_CSV(cola, 'cola INT')", col_no_option.sql())
|
||||
|
||||
|
@ -1535,7 +1553,9 @@ class TestFunctions(unittest.TestCase):
|
|||
self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql())
|
||||
col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count)
|
||||
|
||||
self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql())
|
||||
self.assertEqual(
|
||||
"TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql()
|
||||
)
|
||||
|
||||
def test_exists(self):
|
||||
col_str = SF.exists("cola", lambda x: x % 2 == 0)
|
||||
|
@ -1558,10 +1578,13 @@ class TestFunctions(unittest.TestCase):
|
|||
self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql())
|
||||
col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i))
|
||||
self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql())
|
||||
col_custom_names = SF.filter("cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count))
|
||||
col_custom_names = SF.filter(
|
||||
"cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count)
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
"FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql()
|
||||
"FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)",
|
||||
col_custom_names.sql(),
|
||||
)
|
||||
|
||||
def test_zip_with(self):
|
||||
|
@ -1570,7 +1593,9 @@ class TestFunctions(unittest.TestCase):
|
|||
col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y))
|
||||
self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql())
|
||||
col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r))
|
||||
self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql())
|
||||
self.assertEqual(
|
||||
"ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql()
|
||||
)
|
||||
|
||||
def test_transform_keys(self):
|
||||
col_str = SF.transform_keys("cola", lambda k, v: SF.upper(k))
|
||||
|
@ -1586,7 +1611,9 @@ class TestFunctions(unittest.TestCase):
|
|||
col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v))
|
||||
self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql())
|
||||
col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value))
|
||||
self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql())
|
||||
self.assertEqual(
|
||||
"TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql()
|
||||
)
|
||||
|
||||
def test_map_filter(self):
|
||||
col_str = SF.map_filter("cola", lambda k, v: k > v)
|
||||
|
|
|
@ -21,9 +21,7 @@ class TestDataframeSession(DataFrameSQLValidator):
|
|||
|
||||
def test_cdf_no_schema(self):
|
||||
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
|
||||
expected = (
|
||||
"SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
|
||||
)
|
||||
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_cdf_row_mixed_primitives(self):
|
||||
|
@ -77,7 +75,8 @@ class TestDataframeSession(DataFrameSQLValidator):
|
|||
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
|
||||
df = self.spark.sql(query)
|
||||
self.assertIn(
|
||||
"SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", df.sql(pretty=False)
|
||||
"SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`",
|
||||
df.sql(pretty=False),
|
||||
)
|
||||
|
||||
@mock.patch("sqlglot.schema", MappingSchema())
|
||||
|
@ -104,9 +103,7 @@ class TestDataframeSession(DataFrameSQLValidator):
|
|||
query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1"
|
||||
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
|
||||
df = self.spark.sql(query)
|
||||
expected = (
|
||||
"INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
|
||||
)
|
||||
expected = "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_session_create_builder_patterns(self):
|
||||
|
|
|
@ -53,7 +53,10 @@ class TestDataframeTypes(unittest.TestCase):
|
|||
self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString())
|
||||
|
||||
def test_map(self):
|
||||
self.assertEqual("map<int, string>", types.MapType(types.IntegerType(), types.StringType()).simpleString())
|
||||
self.assertEqual(
|
||||
"map<int, string>",
|
||||
types.MapType(types.IntegerType(), types.StringType()).simpleString(),
|
||||
)
|
||||
|
||||
def test_struct_field(self):
|
||||
self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString())
|
||||
|
|
|
@ -39,22 +39,38 @@ class TestDataframeWindow(unittest.TestCase):
|
|||
|
||||
def test_window_rows_unbounded(self):
|
||||
rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
|
||||
self.assertEqual("OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", rows_between_unbounded_start.sql())
|
||||
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
|
||||
self.assertEqual("OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_end.sql())
|
||||
rows_between_unbounded_both = Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
|
||||
self.assertEqual(
|
||||
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_both.sql()
|
||||
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
|
||||
rows_between_unbounded_start.sql(),
|
||||
)
|
||||
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
|
||||
self.assertEqual(
|
||||
"OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||
rows_between_unbounded_end.sql(),
|
||||
)
|
||||
rows_between_unbounded_both = Window.rowsBetween(
|
||||
Window.unboundedPreceding, Window.unboundedFollowing
|
||||
)
|
||||
self.assertEqual(
|
||||
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||
rows_between_unbounded_both.sql(),
|
||||
)
|
||||
|
||||
def test_window_range_unbounded(self):
|
||||
range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
|
||||
self.assertEqual(
|
||||
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", range_between_unbounded_start.sql()
|
||||
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
|
||||
range_between_unbounded_start.sql(),
|
||||
)
|
||||
range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
|
||||
self.assertEqual("OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_end.sql())
|
||||
range_between_unbounded_both = Window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)
|
||||
self.assertEqual(
|
||||
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_both.sql()
|
||||
"OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||
range_between_unbounded_end.sql(),
|
||||
)
|
||||
range_between_unbounded_both = Window.rangeBetween(
|
||||
Window.unboundedPreceding, Window.unboundedFollowing
|
||||
)
|
||||
self.assertEqual(
|
||||
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||
range_between_unbounded_both.sql(),
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue