71 lines
3.1 KiB
Python
71 lines
3.1 KiB
Python
from pyspark.sql import functions as F
|
|
|
|
from sqlglot.dataframe.sql import functions as SF
|
|
from tests.dataframe.integration.dataframe_validator import DataFrameValidator
|
|
|
|
|
|
class TestDataframeFunc(DataFrameValidator):
|
|
def test_group_by(self):
|
|
df_employee = self.df_spark_employee.groupBy(self.df_spark_employee.age).agg(
|
|
F.min(self.df_spark_employee.employee_id)
|
|
)
|
|
dfs_employee = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).agg(
|
|
SF.min(self.df_sqlglot_employee.employee_id)
|
|
)
|
|
self.compare_spark_with_sqlglot(df_employee, dfs_employee, skip_schema_compare=True)
|
|
|
|
def test_group_by_where_non_aggregate(self):
|
|
df_employee = (
|
|
self.df_spark_employee.groupBy(self.df_spark_employee.age)
|
|
.agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
|
|
.where(F.col("age") > F.lit(50))
|
|
)
|
|
dfs_employee = (
|
|
self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
|
|
.agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
|
|
.where(SF.col("age") > SF.lit(50))
|
|
)
|
|
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
|
|
|
def test_group_by_where_aggregate_like_having(self):
|
|
df_employee = (
|
|
self.df_spark_employee.groupBy(self.df_spark_employee.age)
|
|
.agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
|
|
.where(F.col("min_employee_id") > F.lit(1))
|
|
)
|
|
dfs_employee = (
|
|
self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
|
|
.agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
|
|
.where(SF.col("min_employee_id") > SF.lit(1))
|
|
)
|
|
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
|
|
|
def test_count(self):
|
|
df = self.df_spark_employee.groupBy(self.df_spark_employee.age).count()
|
|
dfs = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).count()
|
|
self.compare_spark_with_sqlglot(df, dfs)
|
|
|
|
def test_mean(self):
|
|
df = self.df_spark_employee.groupBy().mean("age", "store_id")
|
|
dfs = self.df_sqlglot_employee.groupBy().mean("age", "store_id")
|
|
self.compare_spark_with_sqlglot(df, dfs)
|
|
|
|
def test_avg(self):
|
|
df = self.df_spark_employee.groupBy("age").avg("store_id")
|
|
dfs = self.df_sqlglot_employee.groupBy("age").avg("store_id")
|
|
self.compare_spark_with_sqlglot(df, dfs)
|
|
|
|
def test_max(self):
|
|
df = self.df_spark_employee.groupBy("age").max("store_id")
|
|
dfs = self.df_sqlglot_employee.groupBy("age").max("store_id")
|
|
self.compare_spark_with_sqlglot(df, dfs)
|
|
|
|
def test_min(self):
|
|
df = self.df_spark_employee.groupBy("age").min("store_id")
|
|
dfs = self.df_sqlglot_employee.groupBy("age").min("store_id")
|
|
self.compare_spark_with_sqlglot(df, dfs)
|
|
|
|
def test_sum(self):
|
|
df = self.df_spark_employee.groupBy("age").sum("store_id")
|
|
dfs = self.df_sqlglot_employee.groupBy("age").sum("store_id")
|
|
self.compare_spark_with_sqlglot(df, dfs)
|