2025-02-13 14:48:46 +01:00
import sqlglot
2025-02-13 15:40:23 +01:00
from sqlglot . dataframe . sql import functions as F , types
2025-02-13 14:48:46 +01:00
from sqlglot . dataframe . sql . session import SparkSession
from tests . dataframe . unit . dataframe_sql_validator import DataFrameSQLValidator
class TestDataframeSession ( DataFrameSQLValidator ) :
def test_cdf_one_row ( self ) :
df = self . spark . createDataFrame ( [ [ 1 , 2 ] ] , [ " cola " , " colb " ] )
2025-02-13 14:54:32 +01:00
expected = " SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2) AS `a2`(`cola`, `colb`) "
2025-02-13 14:48:46 +01:00
self . compare_sql ( df , expected )
def test_cdf_multiple_rows ( self ) :
df = self . spark . createDataFrame ( [ [ 1 , 2 ] , [ 3 , 4 ] , [ None , 6 ] ] , [ " cola " , " colb " ] )
2025-02-13 14:54:32 +01:00
expected = " SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`cola`, `colb`) "
2025-02-13 14:48:46 +01:00
self . compare_sql ( df , expected )
def test_cdf_no_schema ( self ) :
df = self . spark . createDataFrame ( [ [ 1 , 2 ] , [ 3 , 4 ] , [ None , 6 ] ] )
2025-02-13 14:54:32 +01:00
expected = " SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`_1`, `_2`) "
2025-02-13 14:48:46 +01:00
self . compare_sql ( df , expected )
def test_cdf_row_mixed_primitives ( self ) :
df = self . spark . createDataFrame ( [ [ 1 , 10.1 , " test " , False , None ] ] )
2025-02-13 14:54:32 +01:00
expected = " SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM VALUES (1, 10.1, ' test ' , FALSE, NULL) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`) "
2025-02-13 14:48:46 +01:00
self . compare_sql ( df , expected )
def test_cdf_dict_rows ( self ) :
df = self . spark . createDataFrame ( [ { " cola " : 1 , " colb " : " test " } , { " cola " : 2 , " colb " : " test2 " } ] )
2025-02-13 14:54:32 +01:00
expected = " SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, ' test ' ), (2, ' test2 ' ) AS `a2`(`cola`, `colb`) "
2025-02-13 14:48:46 +01:00
self . compare_sql ( df , expected )
def test_cdf_str_schema ( self ) :
df = self . spark . createDataFrame ( [ [ 1 , " test " ] ] , " cola: INT, colb: STRING " )
2025-02-13 14:58:37 +01:00
expected = " SELECT `a2`.`cola` AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, ' test ' ) AS `a2`(`cola`, `colb`) "
2025-02-13 14:48:46 +01:00
self . compare_sql ( df , expected )
def test_typed_schema_basic ( self ) :
schema = types . StructType (
[
types . StructField ( " cola " , types . IntegerType ( ) ) ,
types . StructField ( " colb " , types . StringType ( ) ) ,
]
)
df = self . spark . createDataFrame ( [ [ 1 , " test " ] ] , schema )
2025-02-13 14:58:37 +01:00
expected = " SELECT `a2`.`cola` AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, ' test ' ) AS `a2`(`cola`, `colb`) "
2025-02-13 14:48:46 +01:00
self . compare_sql ( df , expected )
def test_typed_schema_nested ( self ) :
schema = types . StructType (
[
types . StructField (
" cola " ,
types . StructType (
[
types . StructField ( " sub_cola " , types . IntegerType ( ) ) ,
types . StructField ( " sub_colb " , types . StringType ( ) ) ,
]
) ,
)
]
)
df = self . spark . createDataFrame ( [ [ { " sub_cola " : 1 , " sub_colb " : " test " } ] ] , schema )
2025-02-13 21:19:14 +01:00
expected = " SELECT `a2`.`cola` AS `cola` FROM VALUES (STRUCT(1 AS `sub_cola`, ' test ' AS `sub_colb`)) AS `a2`(`cola`) "
2025-02-13 14:54:32 +01:00
2025-02-13 14:48:46 +01:00
self . compare_sql ( df , expected )
def test_sql_select_only ( self ) :
query = " SELECT cola, colb FROM table "
2025-02-13 20:04:59 +01:00
sqlglot . schema . add_table ( " table " , { " cola " : " string " , " colb " : " string " } , dialect = " spark " )
2025-02-13 14:48:46 +01:00
df = self . spark . sql ( query )
2025-02-13 15:57:23 +01:00
self . assertEqual (
2025-02-13 14:53:05 +01:00
" SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table` " ,
2025-02-13 15:57:23 +01:00
df . sql ( pretty = False ) [ 0 ] ,
2025-02-13 14:48:46 +01:00
)
def test_sql_with_aggs ( self ) :
query = " SELECT cola, colb FROM table "
2025-02-13 20:04:59 +01:00
sqlglot . schema . add_table ( " table " , { " cola " : " string " , " colb " : " string " } , dialect = " spark " )
2025-02-13 14:48:46 +01:00
df = self . spark . sql ( query ) . groupBy ( F . col ( " cola " ) ) . agg ( F . sum ( " colb " ) )
2025-02-13 15:57:23 +01:00
self . assertEqual (
" WITH t38189 AS (SELECT cola, colb FROM table), t42330 AS (SELECT cola, colb FROM t38189) SELECT cola, SUM(colb) FROM t42330 GROUP BY cola " ,
df . sql ( pretty = False , optimize = False ) [ 0 ] ,
)
2025-02-13 14:48:46 +01:00
def test_sql_create ( self ) :
query = " CREATE TABLE new_table AS WITH t1 AS (SELECT cola, colb FROM table) SELECT cola, colb, FROM t1 "
2025-02-13 20:04:59 +01:00
sqlglot . schema . add_table ( " table " , { " cola " : " string " , " colb " : " string " } , dialect = " spark " )
2025-02-13 14:48:46 +01:00
df = self . spark . sql ( query )
expected = " CREATE TABLE new_table AS SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table` "
self . compare_sql ( df , expected )
def test_sql_insert ( self ) :
query = " WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1 "
2025-02-13 20:04:59 +01:00
sqlglot . schema . add_table ( " table " , { " cola " : " string " , " colb " : " string " } , dialect = " spark " )
2025-02-13 14:48:46 +01:00
df = self . spark . sql ( query )
2025-02-13 14:53:05 +01:00
expected = " INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table` "
2025-02-13 14:48:46 +01:00
self . compare_sql ( df , expected )
def test_session_create_builder_patterns ( self ) :
2025-02-13 20:58:22 +01:00
self . assertEqual ( SparkSession . builder . appName ( " abc " ) . getOrCreate ( ) , SparkSession ( ) )