Adding upstream version 9.0.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
768d386bf5
commit
fca0265317
87 changed files with 7994 additions and 421 deletions
0
tests/dataframe/integration/__init__.py
Normal file
0
tests/dataframe/integration/__init__.py
Normal file
149
tests/dataframe/integration/dataframe_validator.py
Normal file
149
tests/dataframe/integration/dataframe_validator.py
Normal file
|
@ -0,0 +1,149 @@
|
|||
import typing as t
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import sqlglot
|
||||
from tests.helpers import SKIP_INTEGRATION
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from pyspark.sql import DataFrame as SparkDataFrame
|
||||
|
||||
|
||||
@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set")
|
||||
class DataFrameValidator(unittest.TestCase):
|
||||
spark = None
|
||||
sqlglot = None
|
||||
df_employee = None
|
||||
df_store = None
|
||||
df_district = None
|
||||
spark_employee_schema = None
|
||||
sqlglot_employee_schema = None
|
||||
spark_store_schema = None
|
||||
sqlglot_store_schema = None
|
||||
spark_district_schema = None
|
||||
sqlglot_district_schema = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
from pyspark import SparkConf
|
||||
from pyspark.sql import SparkSession, types
|
||||
|
||||
from sqlglot.dataframe.sql import types as sqlglotSparkTypes
|
||||
from sqlglot.dataframe.sql.session import SparkSession as SqlglotSparkSession
|
||||
|
||||
# This is for test `test_branching_root_dataframes`
|
||||
config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")])
|
||||
cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate()
|
||||
cls.spark.sparkContext.setLogLevel("ERROR")
|
||||
cls.sqlglot = SqlglotSparkSession()
|
||||
cls.spark_employee_schema = types.StructType(
|
||||
[
|
||||
types.StructField("employee_id", types.IntegerType(), False),
|
||||
types.StructField("fname", types.StringType(), False),
|
||||
types.StructField("lname", types.StringType(), False),
|
||||
types.StructField("age", types.IntegerType(), False),
|
||||
types.StructField("store_id", types.IntegerType(), False),
|
||||
]
|
||||
)
|
||||
cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType(
|
||||
[
|
||||
sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
]
|
||||
)
|
||||
employee_data = [
|
||||
(1, "Jack", "Shephard", 37, 1),
|
||||
(2, "John", "Locke", 65, 1),
|
||||
(3, "Kate", "Austen", 37, 2),
|
||||
(4, "Claire", "Littleton", 27, 2),
|
||||
(5, "Hugo", "Reyes", 29, 100),
|
||||
]
|
||||
cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema)
|
||||
cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema)
|
||||
cls.df_employee.createOrReplaceTempView("employee")
|
||||
|
||||
cls.spark_store_schema = types.StructType(
|
||||
[
|
||||
types.StructField("store_id", types.IntegerType(), False),
|
||||
types.StructField("store_name", types.StringType(), False),
|
||||
types.StructField("district_id", types.IntegerType(), False),
|
||||
types.StructField("num_sales", types.IntegerType(), False),
|
||||
]
|
||||
)
|
||||
cls.sqlglot_store_schema = sqlglotSparkTypes.StructType(
|
||||
[
|
||||
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False),
|
||||
]
|
||||
)
|
||||
store_data = [
|
||||
(1, "Hydra", 1, 37),
|
||||
(2, "Arrow", 2, 2000),
|
||||
]
|
||||
cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema)
|
||||
cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema)
|
||||
cls.df_store.createOrReplaceTempView("store")
|
||||
|
||||
cls.spark_district_schema = types.StructType(
|
||||
[
|
||||
types.StructField("district_id", types.IntegerType(), False),
|
||||
types.StructField("district_name", types.StringType(), False),
|
||||
types.StructField("manager_name", types.StringType(), False),
|
||||
]
|
||||
)
|
||||
cls.sqlglot_district_schema = sqlglotSparkTypes.StructType(
|
||||
[
|
||||
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False),
|
||||
]
|
||||
)
|
||||
district_data = [
|
||||
(1, "Temple", "Dogen"),
|
||||
(2, "Lighthouse", "Jacob"),
|
||||
]
|
||||
cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema)
|
||||
cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema)
|
||||
cls.df_district.createOrReplaceTempView("district")
|
||||
sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema)
|
||||
sqlglot.schema.add_table("store", cls.sqlglot_store_schema)
|
||||
sqlglot.schema.add_table("district", cls.sqlglot_district_schema)
|
||||
|
||||
def setUp(self) -> None:
|
||||
warnings.filterwarnings("ignore", category=ResourceWarning)
|
||||
self.df_spark_store = self.df_store.alias("df_store") # type: ignore
|
||||
self.df_spark_employee = self.df_employee.alias("df_employee") # type: ignore
|
||||
self.df_spark_district = self.df_district.alias("df_district") # type: ignore
|
||||
self.df_sqlglot_store = self.dfs_store.alias("store") # type: ignore
|
||||
self.df_sqlglot_employee = self.dfs_employee.alias("employee") # type: ignore
|
||||
self.df_sqlglot_district = self.dfs_district.alias("district") # type: ignore
|
||||
|
||||
def compare_spark_with_sqlglot(
|
||||
self, df_spark, df_sqlglot, no_empty=True, skip_schema_compare=False
|
||||
) -> t.Tuple["SparkDataFrame", "SparkDataFrame"]:
|
||||
def compare_schemas(schema_1, schema_2):
|
||||
for schema in [schema_1, schema_2]:
|
||||
for struct_field in schema.fields:
|
||||
struct_field.metadata = {}
|
||||
self.assertEqual(schema_1, schema_2)
|
||||
|
||||
for statement in df_sqlglot.sql():
|
||||
actual_df_sqlglot = self.spark.sql(statement) # type: ignore
|
||||
df_sqlglot_results = actual_df_sqlglot.collect()
|
||||
df_spark_results = df_spark.collect()
|
||||
if not skip_schema_compare:
|
||||
compare_schemas(df_spark.schema, actual_df_sqlglot.schema)
|
||||
self.assertEqual(df_spark_results, df_sqlglot_results)
|
||||
if no_empty:
|
||||
self.assertNotEqual(len(df_spark_results), 0)
|
||||
self.assertNotEqual(len(df_sqlglot_results), 0)
|
||||
return df_spark, actual_df_sqlglot
|
||||
|
||||
@classmethod
|
||||
def get_explain_plan(cls, df: "SparkDataFrame", mode: str = "extended") -> str:
|
||||
return df._sc._jvm.PythonSQLUtils.explainString(df._jdf.queryExecution(), mode) # type: ignore
|
1103
tests/dataframe/integration/test_dataframe.py
Normal file
1103
tests/dataframe/integration/test_dataframe.py
Normal file
File diff suppressed because it is too large
Load diff
71
tests/dataframe/integration/test_grouped_data.py
Normal file
71
tests/dataframe/integration/test_grouped_data.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
from pyspark.sql import functions as F
|
||||
|
||||
from sqlglot.dataframe.sql import functions as SF
|
||||
from tests.dataframe.integration.dataframe_validator import DataFrameValidator
|
||||
|
||||
|
||||
class TestDataframeFunc(DataFrameValidator):
|
||||
def test_group_by(self):
|
||||
df_employee = self.df_spark_employee.groupBy(self.df_spark_employee.age).agg(
|
||||
F.min(self.df_spark_employee.employee_id)
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).agg(
|
||||
SF.min(self.df_sqlglot_employee.employee_id)
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee, skip_schema_compare=True)
|
||||
|
||||
def test_group_by_where_non_aggregate(self):
|
||||
df_employee = (
|
||||
self.df_spark_employee.groupBy(self.df_spark_employee.age)
|
||||
.agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
|
||||
.where(F.col("age") > F.lit(50))
|
||||
)
|
||||
dfs_employee = (
|
||||
self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
|
||||
.agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
|
||||
.where(SF.col("age") > SF.lit(50))
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
def test_group_by_where_aggregate_like_having(self):
|
||||
df_employee = (
|
||||
self.df_spark_employee.groupBy(self.df_spark_employee.age)
|
||||
.agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
|
||||
.where(F.col("min_employee_id") > F.lit(1))
|
||||
)
|
||||
dfs_employee = (
|
||||
self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
|
||||
.agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
|
||||
.where(SF.col("min_employee_id") > SF.lit(1))
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
def test_count(self):
|
||||
df = self.df_spark_employee.groupBy(self.df_spark_employee.age).count()
|
||||
dfs = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).count()
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_mean(self):
|
||||
df = self.df_spark_employee.groupBy().mean("age", "store_id")
|
||||
dfs = self.df_sqlglot_employee.groupBy().mean("age", "store_id")
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_avg(self):
|
||||
df = self.df_spark_employee.groupBy("age").avg("store_id")
|
||||
dfs = self.df_sqlglot_employee.groupBy("age").avg("store_id")
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_max(self):
|
||||
df = self.df_spark_employee.groupBy("age").max("store_id")
|
||||
dfs = self.df_sqlglot_employee.groupBy("age").max("store_id")
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_min(self):
|
||||
df = self.df_spark_employee.groupBy("age").min("store_id")
|
||||
dfs = self.df_sqlglot_employee.groupBy("age").min("store_id")
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_sum(self):
|
||||
df = self.df_spark_employee.groupBy("age").sum("store_id")
|
||||
dfs = self.df_sqlglot_employee.groupBy("age").sum("store_id")
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
28
tests/dataframe/integration/test_session.py
Normal file
28
tests/dataframe/integration/test_session.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
from pyspark.sql import functions as F
|
||||
|
||||
from sqlglot.dataframe.sql import functions as SF
|
||||
from tests.dataframe.integration.dataframe_validator import DataFrameValidator
|
||||
|
||||
|
||||
class TestSessionFunc(DataFrameValidator):
|
||||
def test_sql_simple_select(self):
|
||||
query = "SELECT fname, lname FROM employee"
|
||||
df = self.spark.sql(query)
|
||||
dfs = self.sqlglot.sql(query)
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_sql_with_join(self):
|
||||
query = """
|
||||
SELECT
|
||||
e.employee_id
|
||||
, s.store_id
|
||||
FROM
|
||||
employee e
|
||||
INNER JOIN
|
||||
store s
|
||||
ON
|
||||
e.store_id = s.store_id
|
||||
"""
|
||||
df = self.spark.sql(query).groupBy(F.col("store_id")).agg(F.countDistinct(F.col("employee_id")))
|
||||
dfs = self.sqlglot.sql(query).groupBy(SF.col("store_id")).agg(SF.countDistinct(SF.col("employee_id")))
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
Loading…
Add table
Add a link
Reference in a new issue