Merging upstream version 10.0.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
528822bfd4
commit
b7d21c45b7
98 changed files with 4080 additions and 1666 deletions
|
@ -1,9 +1,9 @@
|
|||
import sys
|
||||
import typing as t
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import sqlglot
|
||||
from sqlglot.helper import PYTHON_VERSION
|
||||
from tests.helpers import SKIP_INTEGRATION
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
@ -11,7 +11,8 @@ if t.TYPE_CHECKING:
|
|||
|
||||
|
||||
@unittest.skipIf(
|
||||
SKIP_INTEGRATION or sys.version_info[:2] > (3, 10), "Skipping Integration Tests since `SKIP_INTEGRATION` is set"
|
||||
SKIP_INTEGRATION or PYTHON_VERSION > (3, 10),
|
||||
"Skipping Integration Tests since `SKIP_INTEGRATION` is set",
|
||||
)
|
||||
class DataFrameValidator(unittest.TestCase):
|
||||
spark = None
|
||||
|
@ -36,7 +37,12 @@ class DataFrameValidator(unittest.TestCase):
|
|||
|
||||
# This is for test `test_branching_root_dataframes`
|
||||
config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")])
|
||||
cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate()
|
||||
cls.spark = (
|
||||
SparkSession.builder.master("local[*]")
|
||||
.appName("Unit-tests")
|
||||
.config(conf=config)
|
||||
.getOrCreate()
|
||||
)
|
||||
cls.spark.sparkContext.setLogLevel("ERROR")
|
||||
cls.sqlglot = SqlglotSparkSession()
|
||||
cls.spark_employee_schema = types.StructType(
|
||||
|
@ -50,7 +56,9 @@ class DataFrameValidator(unittest.TestCase):
|
|||
)
|
||||
cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType(
|
||||
[
|
||||
sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField(
|
||||
"employee_id", sqlglotSparkTypes.IntegerType(), False
|
||||
),
|
||||
sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False),
|
||||
|
@ -64,8 +72,12 @@ class DataFrameValidator(unittest.TestCase):
|
|||
(4, "Claire", "Littleton", 27, 2),
|
||||
(5, "Hugo", "Reyes", 29, 100),
|
||||
]
|
||||
cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema)
|
||||
cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema)
|
||||
cls.df_employee = cls.spark.createDataFrame(
|
||||
data=employee_data, schema=cls.spark_employee_schema
|
||||
)
|
||||
cls.dfs_employee = cls.sqlglot.createDataFrame(
|
||||
data=employee_data, schema=cls.sqlglot_employee_schema
|
||||
)
|
||||
cls.df_employee.createOrReplaceTempView("employee")
|
||||
|
||||
cls.spark_store_schema = types.StructType(
|
||||
|
@ -80,7 +92,9 @@ class DataFrameValidator(unittest.TestCase):
|
|||
[
|
||||
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField(
|
||||
"district_id", sqlglotSparkTypes.IntegerType(), False
|
||||
),
|
||||
sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False),
|
||||
]
|
||||
)
|
||||
|
@ -89,7 +103,9 @@ class DataFrameValidator(unittest.TestCase):
|
|||
(2, "Arrow", 2, 2000),
|
||||
]
|
||||
cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema)
|
||||
cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema)
|
||||
cls.dfs_store = cls.sqlglot.createDataFrame(
|
||||
data=store_data, schema=cls.sqlglot_store_schema
|
||||
)
|
||||
cls.df_store.createOrReplaceTempView("store")
|
||||
|
||||
cls.spark_district_schema = types.StructType(
|
||||
|
@ -101,17 +117,27 @@ class DataFrameValidator(unittest.TestCase):
|
|||
)
|
||||
cls.sqlglot_district_schema = sqlglotSparkTypes.StructType(
|
||||
[
|
||||
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField(
|
||||
"district_id", sqlglotSparkTypes.IntegerType(), False
|
||||
),
|
||||
sqlglotSparkTypes.StructField(
|
||||
"district_name", sqlglotSparkTypes.StringType(), False
|
||||
),
|
||||
sqlglotSparkTypes.StructField(
|
||||
"manager_name", sqlglotSparkTypes.StringType(), False
|
||||
),
|
||||
]
|
||||
)
|
||||
district_data = [
|
||||
(1, "Temple", "Dogen"),
|
||||
(2, "Lighthouse", "Jacob"),
|
||||
]
|
||||
cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema)
|
||||
cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema)
|
||||
cls.df_district = cls.spark.createDataFrame(
|
||||
data=district_data, schema=cls.spark_district_schema
|
||||
)
|
||||
cls.dfs_district = cls.sqlglot.createDataFrame(
|
||||
data=district_data, schema=cls.sqlglot_district_schema
|
||||
)
|
||||
cls.df_district.createOrReplaceTempView("district")
|
||||
sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema)
|
||||
sqlglot.schema.add_table("store", cls.sqlglot_store_schema)
|
||||
|
|
|
@ -41,22 +41,32 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
|
||||
def test_alias_with_select(self):
|
||||
df_employee = self.df_spark_employee.alias("df_employee").select(
|
||||
self.df_spark_employee["employee_id"], F.col("df_employee.fname"), self.df_spark_employee.lname
|
||||
self.df_spark_employee["employee_id"],
|
||||
F.col("df_employee.fname"),
|
||||
self.df_spark_employee.lname,
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select(
|
||||
self.df_sqlglot_employee["employee_id"], SF.col("dfs_employee.fname"), self.df_sqlglot_employee.lname
|
||||
self.df_sqlglot_employee["employee_id"],
|
||||
SF.col("dfs_employee.fname"),
|
||||
self.df_sqlglot_employee.lname,
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
def test_case_when_otherwise(self):
|
||||
df = self.df_spark_employee.select(
|
||||
F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60"))
|
||||
F.when(
|
||||
(F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
|
||||
F.lit("between 40 and 60"),
|
||||
)
|
||||
.when(F.col("age") < F.lit(40), "less than 40")
|
||||
.otherwise("greater than 60")
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60"))
|
||||
SF.when(
|
||||
(SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
|
||||
SF.lit("between 40 and 60"),
|
||||
)
|
||||
.when(SF.col("age") < SF.lit(40), "less than 40")
|
||||
.otherwise("greater than 60")
|
||||
)
|
||||
|
@ -65,15 +75,17 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
|
||||
def test_case_when_no_otherwise(self):
|
||||
df = self.df_spark_employee.select(
|
||||
F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")).when(
|
||||
F.col("age") < F.lit(40), "less than 40"
|
||||
)
|
||||
F.when(
|
||||
(F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
|
||||
F.lit("between 40 and 60"),
|
||||
).when(F.col("age") < F.lit(40), "less than 40")
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")).when(
|
||||
SF.col("age") < SF.lit(40), "less than 40"
|
||||
)
|
||||
SF.when(
|
||||
(SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
|
||||
SF.lit("between 40 and 60"),
|
||||
).when(SF.col("age") < SF.lit(40), "less than 40")
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
@ -84,7 +96,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
def test_where_clause_multiple_and(self):
|
||||
df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack")))
|
||||
df_employee = self.df_spark_employee.where(
|
||||
(F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.where(
|
||||
(SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack"))
|
||||
)
|
||||
|
@ -106,7 +120,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
def test_where_clause_multiple_or(self):
|
||||
df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate")))
|
||||
df_employee = self.df_spark_employee.where(
|
||||
(F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.where(
|
||||
(SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate"))
|
||||
)
|
||||
|
@ -172,28 +188,43 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37))
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] % F.lit(5) == F.lit(0))
|
||||
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0))
|
||||
df_employee = self.df_spark_employee.where(
|
||||
self.df_spark_employee["age"] % F.lit(5) == F.lit(0)
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.where(
|
||||
self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] + F.lit(5) > F.lit(28))
|
||||
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28))
|
||||
df_employee = self.df_spark_employee.where(
|
||||
self.df_spark_employee["age"] + F.lit(5) > F.lit(28)
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.where(
|
||||
self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] - F.lit(5) > F.lit(28))
|
||||
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28))
|
||||
df_employee = self.df_spark_employee.where(
|
||||
self.df_spark_employee["age"] - F.lit(5) > F.lit(28)
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.where(
|
||||
self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
df_employee = self.df_spark_employee.where(
|
||||
self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2)
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.where(
|
||||
self.df_sqlglot_employee["age"] * SF.lit(0.5) == self.df_sqlglot_employee["age"] / SF.lit(2)
|
||||
self.df_sqlglot_employee["age"] * SF.lit(0.5)
|
||||
== self.df_sqlglot_employee["age"] / SF.lit(2)
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
def test_join_inner(self):
|
||||
df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="inner").select(
|
||||
df_joined = self.df_spark_employee.join(
|
||||
self.df_spark_store, on=["store_id"], how="inner"
|
||||
).select(
|
||||
self.df_spark_employee.employee_id,
|
||||
self.df_spark_employee["fname"],
|
||||
F.col("lname"),
|
||||
|
@ -202,7 +233,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.df_spark_store.store_name,
|
||||
self.df_spark_store["num_sales"],
|
||||
)
|
||||
dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="inner").select(
|
||||
dfs_joined = self.df_sqlglot_employee.join(
|
||||
self.df_sqlglot_store, on=["store_id"], how="inner"
|
||||
).select(
|
||||
self.df_sqlglot_employee.employee_id,
|
||||
self.df_sqlglot_employee["fname"],
|
||||
SF.col("lname"),
|
||||
|
@ -214,17 +247,27 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
|
||||
|
||||
def test_join_inner_no_select(self):
|
||||
df_joined = self.df_spark_employee.select(F.col("store_id"), F.col("fname"), F.col("lname")).join(
|
||||
self.df_spark_store.select(F.col("store_id"), F.col("store_name")), on=["store_id"], how="inner"
|
||||
df_joined = self.df_spark_employee.select(
|
||||
F.col("store_id"), F.col("fname"), F.col("lname")
|
||||
).join(
|
||||
self.df_spark_store.select(F.col("store_id"), F.col("store_name")),
|
||||
on=["store_id"],
|
||||
how="inner",
|
||||
)
|
||||
dfs_joined = self.df_sqlglot_employee.select(SF.col("store_id"), SF.col("fname"), SF.col("lname")).join(
|
||||
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), on=["store_id"], how="inner"
|
||||
dfs_joined = self.df_sqlglot_employee.select(
|
||||
SF.col("store_id"), SF.col("fname"), SF.col("lname")
|
||||
).join(
|
||||
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")),
|
||||
on=["store_id"],
|
||||
how="inner",
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
|
||||
|
||||
def test_join_inner_equality_single(self):
|
||||
df_joined = self.df_spark_employee.join(
|
||||
self.df_spark_store, on=self.df_spark_employee.store_id == self.df_spark_store.store_id, how="inner"
|
||||
self.df_spark_store,
|
||||
on=self.df_spark_employee.store_id == self.df_spark_store.store_id,
|
||||
how="inner",
|
||||
).select(
|
||||
self.df_spark_employee.employee_id,
|
||||
self.df_spark_employee["fname"],
|
||||
|
@ -235,7 +278,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.df_spark_store["num_sales"],
|
||||
)
|
||||
dfs_joined = self.df_sqlglot_employee.join(
|
||||
self.df_sqlglot_store, on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, how="inner"
|
||||
self.df_sqlglot_store,
|
||||
on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
|
||||
how="inner",
|
||||
).select(
|
||||
self.df_sqlglot_employee.employee_id,
|
||||
self.df_sqlglot_employee["fname"],
|
||||
|
@ -343,7 +388,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
|
||||
|
||||
def test_join_full_outer(self):
|
||||
df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="full_outer").select(
|
||||
df_joined = self.df_spark_employee.join(
|
||||
self.df_spark_store, on=["store_id"], how="full_outer"
|
||||
).select(
|
||||
self.df_spark_employee.employee_id,
|
||||
self.df_spark_employee["fname"],
|
||||
F.col("lname"),
|
||||
|
@ -352,7 +399,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.df_spark_store.store_name,
|
||||
self.df_spark_store["num_sales"],
|
||||
)
|
||||
dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="full_outer").select(
|
||||
dfs_joined = self.df_sqlglot_employee.join(
|
||||
self.df_sqlglot_store, on=["store_id"], how="full_outer"
|
||||
).select(
|
||||
self.df_sqlglot_employee.employee_id,
|
||||
self.df_sqlglot_employee["fname"],
|
||||
SF.col("lname"),
|
||||
|
@ -365,7 +414,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
|
||||
def test_triple_join(self):
|
||||
df = (
|
||||
self.df_employee.join(self.df_store, on=self.df_employee.employee_id == self.df_store.store_id)
|
||||
self.df_employee.join(
|
||||
self.df_store, on=self.df_employee.employee_id == self.df_store.store_id
|
||||
)
|
||||
.join(self.df_district, on=self.df_store.store_id == self.df_district.district_id)
|
||||
.select(
|
||||
self.df_employee.employee_id,
|
||||
|
@ -377,7 +428,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
)
|
||||
)
|
||||
dfs = (
|
||||
self.dfs_employee.join(self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id)
|
||||
self.dfs_employee.join(
|
||||
self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id
|
||||
)
|
||||
.join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id)
|
||||
.select(
|
||||
self.dfs_employee.employee_id,
|
||||
|
@ -391,13 +444,13 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_join_select_and_select_start(self):
|
||||
df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")).join(
|
||||
self.df_spark_store, "store_id", "inner"
|
||||
)
|
||||
df = self.df_spark_employee.select(
|
||||
F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")
|
||||
).join(self.df_spark_store, "store_id", "inner")
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")).join(
|
||||
self.df_sqlglot_store, "store_id", "inner"
|
||||
)
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
|
||||
).join(self.df_sqlglot_store, "store_id", "inner")
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
|
@ -485,13 +538,17 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
dfs_unioned = (
|
||||
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"))
|
||||
.unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")))
|
||||
.unionAll(self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name")))
|
||||
.unionAll(
|
||||
self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))
|
||||
)
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df_unioned, dfs_unioned)
|
||||
|
||||
def test_union_by_name(self):
|
||||
df = self.df_spark_employee.select(F.col("employee_id"), F.col("fname"), F.col("lname")).unionByName(
|
||||
df = self.df_spark_employee.select(
|
||||
F.col("employee_id"), F.col("fname"), F.col("lname")
|
||||
).unionByName(
|
||||
self.df_spark_store.select(
|
||||
F.col("store_name").alias("lname"),
|
||||
F.col("store_id").alias("employee_id"),
|
||||
|
@ -499,7 +556,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
)
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"), SF.col("lname")).unionByName(
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.col("employee_id"), SF.col("fname"), SF.col("lname")
|
||||
).unionByName(
|
||||
self.df_sqlglot_store.select(
|
||||
SF.col("store_name").alias("lname"),
|
||||
SF.col("store_id").alias("employee_id"),
|
||||
|
@ -537,10 +596,16 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_order_by_default(self):
|
||||
df = self.df_spark_store.groupBy(F.col("district_id")).agg(F.min("num_sales")).orderBy(F.col("district_id"))
|
||||
df = (
|
||||
self.df_spark_store.groupBy(F.col("district_id"))
|
||||
.agg(F.min("num_sales"))
|
||||
.orderBy(F.col("district_id"))
|
||||
)
|
||||
|
||||
dfs = (
|
||||
self.df_sqlglot_store.groupBy(SF.col("district_id")).agg(SF.min("num_sales")).orderBy(SF.col("district_id"))
|
||||
self.df_sqlglot_store.groupBy(SF.col("district_id"))
|
||||
.agg(SF.min("num_sales"))
|
||||
.orderBy(SF.col("district_id"))
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
@ -594,13 +659,17 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
df = (
|
||||
self.df_spark_store.groupBy(F.col("district_id"))
|
||||
.agg(F.min("num_sales").alias("total_sales"))
|
||||
.orderBy(F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last())
|
||||
.orderBy(
|
||||
F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()
|
||||
)
|
||||
)
|
||||
|
||||
dfs = (
|
||||
self.df_sqlglot_store.groupBy(SF.col("district_id"))
|
||||
.agg(SF.min("num_sales").alias("total_sales"))
|
||||
.orderBy(SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last())
|
||||
.orderBy(
|
||||
SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()
|
||||
)
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
@ -609,81 +678,87 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
df = (
|
||||
self.df_spark_store.groupBy(F.col("district_id"))
|
||||
.agg(F.min("num_sales").alias("total_sales"))
|
||||
.orderBy(F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first())
|
||||
.orderBy(
|
||||
F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()
|
||||
)
|
||||
)
|
||||
|
||||
dfs = (
|
||||
self.df_sqlglot_store.groupBy(SF.col("district_id"))
|
||||
.agg(SF.min("num_sales").alias("total_sales"))
|
||||
.orderBy(SF.when(SF.col("district_id") == SF.lit(1), SF.col("district_id")).desc_nulls_first())
|
||||
.orderBy(
|
||||
SF.when(
|
||||
SF.col("district_id") == SF.lit(1), SF.col("district_id")
|
||||
).desc_nulls_first()
|
||||
)
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_intersect(self):
|
||||
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
|
||||
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
|
||||
)
|
||||
df_employee_duplicate = self.df_spark_employee.select(
|
||||
F.col("employee_id"), F.col("store_id")
|
||||
).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
|
||||
|
||||
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
|
||||
self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
|
||||
)
|
||||
df_store_duplicate = self.df_spark_store.select(
|
||||
F.col("store_id"), F.col("district_id")
|
||||
).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
|
||||
|
||||
df = df_employee_duplicate.intersect(df_store_duplicate)
|
||||
|
||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
|
||||
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
|
||||
)
|
||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(
|
||||
SF.col("employee_id"), SF.col("store_id")
|
||||
).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
|
||||
|
||||
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
|
||||
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
|
||||
)
|
||||
dfs_store_duplicate = self.df_sqlglot_store.select(
|
||||
SF.col("store_id"), SF.col("district_id")
|
||||
).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
|
||||
|
||||
dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_intersect_all(self):
|
||||
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
|
||||
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
|
||||
)
|
||||
df_employee_duplicate = self.df_spark_employee.select(
|
||||
F.col("employee_id"), F.col("store_id")
|
||||
).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
|
||||
|
||||
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
|
||||
self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
|
||||
)
|
||||
df_store_duplicate = self.df_spark_store.select(
|
||||
F.col("store_id"), F.col("district_id")
|
||||
).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
|
||||
|
||||
df = df_employee_duplicate.intersectAll(df_store_duplicate)
|
||||
|
||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
|
||||
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
|
||||
)
|
||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(
|
||||
SF.col("employee_id"), SF.col("store_id")
|
||||
).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
|
||||
|
||||
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
|
||||
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
|
||||
)
|
||||
dfs_store_duplicate = self.df_sqlglot_store.select(
|
||||
SF.col("store_id"), SF.col("district_id")
|
||||
).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
|
||||
|
||||
dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_except_all(self):
|
||||
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
|
||||
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
|
||||
)
|
||||
df_employee_duplicate = self.df_spark_employee.select(
|
||||
F.col("employee_id"), F.col("store_id")
|
||||
).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
|
||||
|
||||
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
|
||||
self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
|
||||
)
|
||||
df_store_duplicate = self.df_spark_store.select(
|
||||
F.col("store_id"), F.col("district_id")
|
||||
).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
|
||||
|
||||
df = df_employee_duplicate.exceptAll(df_store_duplicate)
|
||||
|
||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
|
||||
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
|
||||
)
|
||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(
|
||||
SF.col("employee_id"), SF.col("store_id")
|
||||
).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
|
||||
|
||||
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
|
||||
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
|
||||
)
|
||||
dfs_store_duplicate = self.df_sqlglot_store.select(
|
||||
SF.col("store_id"), SF.col("district_id")
|
||||
).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
|
||||
|
||||
dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate)
|
||||
|
||||
|
@ -721,7 +796,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_drop_na_default(self):
|
||||
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).dropna()
|
||||
df = self.df_spark_employee.select(
|
||||
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
|
||||
).dropna()
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||
|
@ -746,7 +823,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
).dropna(how="any", thresh=2)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||
SF.lit(None),
|
||||
SF.lit(1),
|
||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
|
||||
).dropna(how="any", thresh=2)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
@ -757,13 +836,17 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
).dropna(thresh=1, subset="the_age")
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||
SF.lit(None),
|
||||
SF.lit(1),
|
||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
|
||||
).dropna(thresh=1, subset="the_age")
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
||||
def test_dropna_na_function(self):
|
||||
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.drop()
|
||||
df = self.df_spark_employee.select(
|
||||
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
|
||||
).na.drop()
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||
|
@ -772,7 +855,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_fillna_default(self):
|
||||
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).fillna(100)
|
||||
df = self.df_spark_employee.select(
|
||||
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
|
||||
).fillna(100)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||
|
@ -798,7 +883,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
||||
def test_fillna_na_func(self):
|
||||
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.fill(100)
|
||||
df = self.df_spark_employee.select(
|
||||
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
|
||||
).na.fill(100)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||
|
@ -807,7 +894,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_replace_basic(self):
|
||||
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(to_replace=37, value=100)
|
||||
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
|
||||
to_replace=37, value=100
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
|
||||
to_replace=37, value=100
|
||||
|
@ -827,9 +916,13 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
||||
def test_replace_mapping(self):
|
||||
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace({37: 100})
|
||||
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
|
||||
{37: 100}
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace({37: 100})
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
|
||||
{37: 100}
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
||||
|
@ -849,9 +942,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
to_replace=37, value=100
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).na.replace(
|
||||
to_replace=37, value=100
|
||||
)
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.col("age"), SF.lit(37).alias("test_col")
|
||||
).na.replace(to_replace=37, value=100)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
||||
|
@ -881,16 +974,18 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
"first_name", "first_name_again"
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("fname").alias("first_name")).withColumnRenamed(
|
||||
"first_name", "first_name_again"
|
||||
)
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.col("fname").alias("first_name")
|
||||
).withColumnRenamed("first_name", "first_name_again")
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_drop_column_single(self):
|
||||
df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age")
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop("age")
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop(
|
||||
"age"
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
|
@ -906,7 +1001,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
df_sqlglot_employee_cols = self.df_sqlglot_employee.select(
|
||||
SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
|
||||
)
|
||||
df_sqlglot_store_cols = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name"))
|
||||
df_sqlglot_store_cols = self.df_sqlglot_store.select(
|
||||
SF.col("store_id"), SF.col("store_name")
|
||||
)
|
||||
dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop(
|
||||
df_sqlglot_employee_cols.age,
|
||||
)
|
||||
|
|
|
@ -23,6 +23,14 @@ class TestSessionFunc(DataFrameValidator):
|
|||
ON
|
||||
e.store_id = s.store_id
|
||||
"""
|
||||
df = self.spark.sql(query).groupBy(F.col("store_id")).agg(F.countDistinct(F.col("employee_id")))
|
||||
dfs = self.sqlglot.sql(query).groupBy(SF.col("store_id")).agg(SF.countDistinct(SF.col("employee_id")))
|
||||
df = (
|
||||
self.spark.sql(query)
|
||||
.groupBy(F.col("store_id"))
|
||||
.agg(F.countDistinct(F.col("employee_id")))
|
||||
)
|
||||
dfs = (
|
||||
self.sqlglot.sql(query)
|
||||
.groupBy(SF.col("store_id"))
|
||||
.agg(SF.countDistinct(SF.col("employee_id")))
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue