1
0
Fork 0

Adding upstream version 24.0.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:35:53 +01:00
parent b6ae88ec81
commit 8b1190270c
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
127 changed files with 40727 additions and 46460 deletions

View file

@ -1,174 +0,0 @@
import typing as t
import unittest
import warnings
import sqlglot
from tests.helpers import SKIP_INTEGRATION
if t.TYPE_CHECKING:
from pyspark.sql import DataFrame as SparkDataFrame
@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set")
class DataFrameValidator(unittest.TestCase):
spark = None
sqlglot = None
df_employee = None
df_store = None
df_district = None
spark_employee_schema = None
sqlglot_employee_schema = None
spark_store_schema = None
sqlglot_store_schema = None
spark_district_schema = None
sqlglot_district_schema = None
@classmethod
def setUpClass(cls):
from pyspark import SparkConf
from pyspark.sql import SparkSession, types
from sqlglot.dataframe.sql import types as sqlglotSparkTypes
from sqlglot.dataframe.sql.session import SparkSession as SqlglotSparkSession
# This is for test `test_branching_root_dataframes`
config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")])
cls.spark = (
SparkSession.builder.master("local[*]")
.appName("Unit-tests")
.config(conf=config)
.getOrCreate()
)
cls.spark.sparkContext.setLogLevel("ERROR")
cls.sqlglot = SqlglotSparkSession()
cls.spark_employee_schema = types.StructType(
[
types.StructField("employee_id", types.IntegerType(), False),
types.StructField("fname", types.StringType(), False),
types.StructField("lname", types.StringType(), False),
types.StructField("age", types.IntegerType(), False),
types.StructField("store_id", types.IntegerType(), False),
]
)
cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType(
[
sqlglotSparkTypes.StructField(
"employee_id", sqlglotSparkTypes.IntegerType(), False
),
sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
]
)
employee_data = [
(1, "Jack", "Shephard", 37, 1),
(2, "John", "Locke", 65, 1),
(3, "Kate", "Austen", 37, 2),
(4, "Claire", "Littleton", 27, 2),
(5, "Hugo", "Reyes", 29, 100),
]
cls.df_employee = cls.spark.createDataFrame(
data=employee_data, schema=cls.spark_employee_schema
)
cls.dfs_employee = cls.sqlglot.createDataFrame(
data=employee_data, schema=cls.sqlglot_employee_schema
)
cls.df_employee.createOrReplaceTempView("employee")
cls.spark_store_schema = types.StructType(
[
types.StructField("store_id", types.IntegerType(), False),
types.StructField("store_name", types.StringType(), False),
types.StructField("district_id", types.IntegerType(), False),
types.StructField("num_sales", types.IntegerType(), False),
]
)
cls.sqlglot_store_schema = sqlglotSparkTypes.StructType(
[
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField(
"district_id", sqlglotSparkTypes.IntegerType(), False
),
sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False),
]
)
store_data = [
(1, "Hydra", 1, 37),
(2, "Arrow", 2, 2000),
]
cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema)
cls.dfs_store = cls.sqlglot.createDataFrame(
data=store_data, schema=cls.sqlglot_store_schema
)
cls.df_store.createOrReplaceTempView("store")
cls.spark_district_schema = types.StructType(
[
types.StructField("district_id", types.IntegerType(), False),
types.StructField("district_name", types.StringType(), False),
types.StructField("manager_name", types.StringType(), False),
]
)
cls.sqlglot_district_schema = sqlglotSparkTypes.StructType(
[
sqlglotSparkTypes.StructField(
"district_id", sqlglotSparkTypes.IntegerType(), False
),
sqlglotSparkTypes.StructField(
"district_name", sqlglotSparkTypes.StringType(), False
),
sqlglotSparkTypes.StructField(
"manager_name", sqlglotSparkTypes.StringType(), False
),
]
)
district_data = [
(1, "Temple", "Dogen"),
(2, "Lighthouse", "Jacob"),
]
cls.df_district = cls.spark.createDataFrame(
data=district_data, schema=cls.spark_district_schema
)
cls.dfs_district = cls.sqlglot.createDataFrame(
data=district_data, schema=cls.sqlglot_district_schema
)
cls.df_district.createOrReplaceTempView("district")
sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema, dialect="spark")
sqlglot.schema.add_table("store", cls.sqlglot_store_schema, dialect="spark")
sqlglot.schema.add_table("district", cls.sqlglot_district_schema, dialect="spark")
def setUp(self) -> None:
warnings.filterwarnings("ignore", category=ResourceWarning)
self.df_spark_store = self.df_store.alias("df_store") # type: ignore
self.df_spark_employee = self.df_employee.alias("df_employee") # type: ignore
self.df_spark_district = self.df_district.alias("df_district") # type: ignore
self.df_sqlglot_store = self.dfs_store.alias("store") # type: ignore
self.df_sqlglot_employee = self.dfs_employee.alias("employee") # type: ignore
self.df_sqlglot_district = self.dfs_district.alias("district") # type: ignore
def compare_spark_with_sqlglot(
self, df_spark, df_sqlglot, no_empty=True, skip_schema_compare=False
) -> t.Tuple["SparkDataFrame", "SparkDataFrame"]:
def compare_schemas(schema_1, schema_2):
for schema in [schema_1, schema_2]:
for struct_field in schema.fields:
struct_field.metadata = {}
self.assertEqual(schema_1, schema_2)
for statement in df_sqlglot.sql():
actual_df_sqlglot = self.spark.sql(statement) # type: ignore
df_sqlglot_results = actual_df_sqlglot.collect()
df_spark_results = df_spark.collect()
if not skip_schema_compare:
compare_schemas(df_spark.schema, actual_df_sqlglot.schema)
self.assertEqual(df_spark_results, df_sqlglot_results)
if no_empty:
self.assertNotEqual(len(df_spark_results), 0)
self.assertNotEqual(len(df_sqlglot_results), 0)
return df_spark, actual_df_sqlglot
@classmethod
def get_explain_plan(cls, df: "SparkDataFrame", mode: str = "extended") -> str:
return df._sc._jvm.PythonSQLUtils.explainString(df._jdf.queryExecution(), mode) # type: ignore

File diff suppressed because it is too large Load diff

View file

@ -1,71 +0,0 @@
from pyspark.sql import functions as F
from sqlglot.dataframe.sql import functions as SF
from tests.dataframe.integration.dataframe_validator import DataFrameValidator
class TestDataframeFunc(DataFrameValidator):
def test_group_by(self):
df_employee = self.df_spark_employee.groupBy(self.df_spark_employee.age).agg(
F.min(self.df_spark_employee.employee_id)
)
dfs_employee = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).agg(
SF.min(self.df_sqlglot_employee.employee_id)
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee, skip_schema_compare=True)
def test_group_by_where_non_aggregate(self):
df_employee = (
self.df_spark_employee.groupBy(self.df_spark_employee.age)
.agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
.where(F.col("age") > F.lit(50))
)
dfs_employee = (
self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
.agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
.where(SF.col("age") > SF.lit(50))
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_group_by_where_aggregate_like_having(self):
df_employee = (
self.df_spark_employee.groupBy(self.df_spark_employee.age)
.agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
.where(F.col("min_employee_id") > F.lit(1))
)
dfs_employee = (
self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
.agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
.where(SF.col("min_employee_id") > SF.lit(1))
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_count(self):
df = self.df_spark_employee.groupBy(self.df_spark_employee.age).count()
dfs = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).count()
self.compare_spark_with_sqlglot(df, dfs)
def test_mean(self):
df = self.df_spark_employee.groupBy().mean("age", "store_id")
dfs = self.df_sqlglot_employee.groupBy().mean("age", "store_id")
self.compare_spark_with_sqlglot(df, dfs)
def test_avg(self):
df = self.df_spark_employee.groupBy("age").avg("store_id")
dfs = self.df_sqlglot_employee.groupBy("age").avg("store_id")
self.compare_spark_with_sqlglot(df, dfs)
def test_max(self):
df = self.df_spark_employee.groupBy("age").max("store_id")
dfs = self.df_sqlglot_employee.groupBy("age").max("store_id")
self.compare_spark_with_sqlglot(df, dfs)
def test_min(self):
df = self.df_spark_employee.groupBy("age").min("store_id")
dfs = self.df_sqlglot_employee.groupBy("age").min("store_id")
self.compare_spark_with_sqlglot(df, dfs)
def test_sum(self):
df = self.df_spark_employee.groupBy("age").sum("store_id")
dfs = self.df_sqlglot_employee.groupBy("age").sum("store_id")
self.compare_spark_with_sqlglot(df, dfs)

View file

@ -1,43 +0,0 @@
from pyspark.sql import functions as F
from sqlglot.dataframe.sql import functions as SF
from tests.dataframe.integration.dataframe_validator import DataFrameValidator
class TestSessionFunc(DataFrameValidator):
def test_sql_simple_select(self):
query = "SELECT fname, lname FROM employee"
df = self.spark.sql(query)
dfs = self.sqlglot.sql(query)
self.compare_spark_with_sqlglot(df, dfs)
def test_sql_with_join(self):
query = """
SELECT
e.employee_id
, s.store_id
FROM
employee e
INNER JOIN
store s
ON
e.store_id = s.store_id
"""
df = (
self.spark.sql(query)
.groupBy(F.col("store_id"))
.agg(F.countDistinct(F.col("employee_id")))
)
dfs = (
self.sqlglot.sql(query)
.groupBy(SF.col("store_id"))
.agg(SF.countDistinct(SF.col("employee_id")))
)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_nameless_column(self):
query = "SELECT MAX(age) FROM employee"
df = self.spark.sql(query)
dfs = self.sqlglot.sql(query)
# Spark will alias the column to `max(age)` while sqlglot will alias to `_col_0` so their schemas will differ
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)

View file

@ -1,28 +0,0 @@
from sqlglot.dataframe.sql import types
from sqlglot.dataframe.sql.session import SparkSession
from tests.dataframe.unit.dataframe_test_base import DataFrameTestBase
class DataFrameSQLValidator(DataFrameTestBase):
def setUp(self) -> None:
super().setUp()
self.spark = SparkSession()
self.employee_schema = types.StructType(
[
types.StructField("employee_id", types.IntegerType(), False),
types.StructField("fname", types.StringType(), False),
types.StructField("lname", types.StringType(), False),
types.StructField("age", types.IntegerType(), False),
types.StructField("store_id", types.IntegerType(), False),
]
)
employee_data = [
(1, "Jack", "Shephard", 37, 1),
(2, "John", "Locke", 65, 1),
(3, "Kate", "Austen", 37, 2),
(4, "Claire", "Littleton", 27, 2),
(5, "Hugo", "Reyes", 29, 100),
]
self.df_employee = self.spark.createDataFrame(
data=employee_data, schema=self.employee_schema
)

View file

@ -1,23 +0,0 @@
import typing as t
import unittest
import sqlglot
from sqlglot import MappingSchema
from sqlglot.dataframe.sql import SparkSession
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.helper import ensure_list
class DataFrameTestBase(unittest.TestCase):
def setUp(self) -> None:
sqlglot.schema = MappingSchema()
SparkSession._instance = None
def compare_sql(
self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False
):
actual_sqls = df.sql(pretty=pretty)
expected_statements = ensure_list(expected_statements)
self.assertEqual(len(expected_statements), len(actual_sqls))
for expected, actual in zip(expected_statements, actual_sqls):
self.assertEqual(expected, actual)

View file

@ -1,174 +0,0 @@
import datetime
import unittest
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.window import Window
class TestDataframeColumn(unittest.TestCase):
def test_eq(self):
self.assertEqual("cola = 1", (F.col("cola") == 1).sql())
def test_neq(self):
self.assertEqual("cola <> 1", (F.col("cola") != 1).sql())
def test_gt(self):
self.assertEqual("cola > 1", (F.col("cola") > 1).sql())
def test_lt(self):
self.assertEqual("cola < 1", (F.col("cola") < 1).sql())
def test_le(self):
self.assertEqual("cola <= 1", (F.col("cola") <= 1).sql())
def test_ge(self):
self.assertEqual("cola >= 1", (F.col("cola") >= 1).sql())
def test_and(self):
self.assertEqual(
"cola = colb AND colc = cold",
((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql(),
)
def test_or(self):
self.assertEqual(
"cola = colb OR colc = cold",
((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql(),
)
def test_mod(self):
self.assertEqual("cola % 2", (F.col("cola") % 2).sql())
def test_add(self):
self.assertEqual("cola + 1", (F.col("cola") + 1).sql())
def test_sub(self):
self.assertEqual("cola - 1", (F.col("cola") - 1).sql())
def test_mul(self):
self.assertEqual("cola * 2", (F.col("cola") * 2).sql())
def test_div(self):
self.assertEqual("cola / 2", (F.col("cola") / 2).sql())
def test_radd(self):
self.assertEqual("1 + cola", (1 + F.col("cola")).sql())
def test_rsub(self):
self.assertEqual("1 - cola", (1 - F.col("cola")).sql())
def test_rmul(self):
self.assertEqual("1 * cola", (1 * F.col("cola")).sql())
def test_rdiv(self):
self.assertEqual("1 / cola", (1 / F.col("cola")).sql())
def test_pow(self):
self.assertEqual("POWER(cola, 2)", (F.col("cola") ** 2).sql())
def test_rpow(self):
self.assertEqual("POWER(2, cola)", (2 ** F.col("cola")).sql())
def test_invert(self):
self.assertEqual("NOT cola", (~F.col("cola")).sql())
def test_startswith(self):
self.assertEqual("STARTSWITH(cola, 'test')", F.col("cola").startswith("test").sql())
def test_endswith(self):
self.assertEqual("ENDSWITH(cola, 'test')", F.col("cola").endswith("test").sql())
def test_rlike(self):
self.assertEqual("cola RLIKE 'foo'", F.col("cola").rlike("foo").sql())
def test_like(self):
self.assertEqual("cola LIKE 'foo%'", F.col("cola").like("foo%").sql())
def test_ilike(self):
self.assertEqual("cola ILIKE 'foo%'", F.col("cola").ilike("foo%").sql())
def test_substring(self):
self.assertEqual("SUBSTRING(cola, 2, 3)", F.col("cola").substr(2, 3).sql())
def test_isin(self):
self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin([1, 2, 3]).sql())
self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin(1, 2, 3).sql())
def test_asc(self):
self.assertEqual("cola ASC", F.col("cola").asc().sql())
def test_desc(self):
self.assertEqual("cola DESC", F.col("cola").desc().sql())
def test_asc_nulls_first(self):
self.assertEqual("cola ASC", F.col("cola").asc_nulls_first().sql())
def test_asc_nulls_last(self):
self.assertEqual("cola ASC NULLS LAST", F.col("cola").asc_nulls_last().sql())
def test_desc_nulls_first(self):
self.assertEqual("cola DESC NULLS FIRST", F.col("cola").desc_nulls_first().sql())
def test_desc_nulls_last(self):
self.assertEqual("cola DESC", F.col("cola").desc_nulls_last().sql())
def test_when_otherwise(self):
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql())
self.assertEqual(
"CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql()
)
self.assertEqual(
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
(F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(),
)
self.assertEqual(
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
F.col("cola").when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3).sql(),
)
self.assertEqual(
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 ELSE 4 END",
F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3).otherwise(4).sql(),
)
def test_is_null(self):
self.assertEqual("cola IS NULL", F.col("cola").isNull().sql())
def test_is_not_null(self):
self.assertEqual("NOT cola IS NULL", F.col("cola").isNotNull().sql())
def test_cast(self):
self.assertEqual("CAST(cola AS INT)", F.col("cola").cast("INT").sql())
def test_alias(self):
self.assertEqual("cola AS new_name", F.col("cola").alias("new_name").sql())
def test_between(self):
self.assertEqual("cola BETWEEN 1 AND 3", F.col("cola").between(1, 3).sql())
self.assertEqual("cola BETWEEN 10.1 AND 12.1", F.col("cola").between(10.1, 12.1).sql())
self.assertEqual(
"cola BETWEEN CAST('2022-01-01' AS DATE) AND CAST('2022-03-01' AS DATE)",
F.col("cola").between(datetime.date(2022, 1, 1), datetime.date(2022, 3, 1)).sql(),
)
self.assertEqual(
"cola BETWEEN CAST('2022-01-01 01:01:01+00:00' AS TIMESTAMP) "
"AND CAST('2022-03-01 01:01:01+00:00' AS TIMESTAMP)",
F.col("cola")
.between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1))
.sql(),
)
def test_over(self):
over_rows = F.sum("cola").over(
Window.partitionBy("colb").orderBy("colc").rowsBetween(1, Window.unboundedFollowing)
)
self.assertEqual(
"SUM(cola) OVER (PARTITION BY colb ORDER BY colc ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
over_rows.sql(),
)
over_range = F.sum("cola").over(
Window.partitionBy("colb").orderBy("colc").rangeBetween(1, Window.unboundedFollowing)
)
self.assertEqual(
"SUM(cola) OVER (PARTITION BY colb ORDER BY colc RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
over_range.sql(),
)

View file

@ -1,43 +0,0 @@
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.dataframe import DataFrame
from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataframe(DataFrameSQLValidator):
maxDiff = None
def test_hash_select_expression(self):
expression = exp.select("cola").from_("table")
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
def test_columns(self):
self.assertEqual(
["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns
)
def test_cache(self):
df = self.df_employee.select("fname").cache()
expected_statements = [
"DROP VIEW IF EXISTS t31563",
"CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
]
self.compare_sql(df, expected_statements)
def test_persist_default(self):
df = self.df_employee.select("fname").persist()
expected_statements = [
"DROP VIEW IF EXISTS t31563",
"CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
]
self.compare_sql(df, expected_statements)
def test_persist_storagelevel(self):
df = self.df_employee.select("fname").persist("DISK_ONLY_2")
expected_statements = [
"DROP VIEW IF EXISTS t31563",
"CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
]
self.compare_sql(df, expected_statements)

View file

@ -1,95 +0,0 @@
from unittest import mock
import sqlglot
from sqlglot.schema import MappingSchema
from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataFrameWriter(DataFrameSQLValidator):
maxDiff = None
def test_insertInto_full_path(self):
df = self.df_employee.write.insertInto("catalog.db.table_name")
expected = "INSERT INTO catalog.db.table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_db_table(self):
df = self.df_employee.write.insertInto("db.table_name")
expected = "INSERT INTO db.table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_table(self):
df = self.df_employee.write.insertInto("table_name")
expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_overwrite(self):
df = self.df_employee.write.insertInto("table_name", overwrite=True)
expected = "INSERT OVERWRITE TABLE table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema())
def test_insertInto_byName(self):
sqlglot.schema.add_table("table_name", {"employee_id": "INT"}, dialect="spark")
df = self.df_employee.write.byName.insertInto("table_name")
expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_cache(self):
df = self.df_employee.cache().write.insertInto("table_name")
expected_statements = [
"DROP VIEW IF EXISTS t12441",
"CACHE LAZY TABLE t12441 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"INSERT INTO table_name SELECT `t12441`.`employee_id` AS `employee_id`, `t12441`.`fname` AS `fname`, `t12441`.`lname` AS `lname`, `t12441`.`age` AS `age`, `t12441`.`store_id` AS `store_id` FROM `t12441` AS `t12441`",
]
self.compare_sql(df, expected_statements)
def test_saveAsTable_format(self):
with self.assertRaises(NotImplementedError):
self.df_employee.write.saveAsTable("table_name", format="parquet").sql(pretty=False)[0]
def test_saveAsTable_append(self):
df = self.df_employee.write.saveAsTable("table_name", mode="append")
expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_overwrite(self):
df = self.df_employee.write.saveAsTable("table_name", mode="overwrite")
expected = "CREATE OR REPLACE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_error(self):
df = self.df_employee.write.saveAsTable("table_name", mode="error")
expected = "CREATE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_ignore(self):
df = self.df_employee.write.saveAsTable("table_name", mode="ignore")
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_mode_standalone(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name")
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_mode_override(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite")
expected = "CREATE OR REPLACE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_cache(self):
df = self.df_employee.cache().write.saveAsTable("table_name")
expected_statements = [
"DROP VIEW IF EXISTS t12441",
"CACHE LAZY TABLE t12441 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"CREATE TABLE table_name AS SELECT `t12441`.`employee_id` AS `employee_id`, `t12441`.`fname` AS `fname`, `t12441`.`lname` AS `lname`, `t12441`.`age` AS `age`, `t12441`.`store_id` AS `store_id` FROM `t12441` AS `t12441`",
]
self.compare_sql(df, expected_statements)
def test_quotes(self):
sqlglot.schema.add_table("`Test`", {"`ID`": "STRING"}, dialect="spark")
df = self.spark.table("`Test`")
self.compare_sql(
df.select(df["`ID`"]), ["SELECT `test`.`id` AS `id` FROM `test` AS `test`"]
)

File diff suppressed because it is too large Load diff

View file

@ -1,101 +0,0 @@
import sqlglot
from sqlglot.dataframe.sql import functions as F, types
from sqlglot.dataframe.sql.session import SparkSession
from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataframeSession(DataFrameSQLValidator):
def test_cdf_one_row(self):
df = self.spark.createDataFrame([[1, 2]], ["cola", "colb"])
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_multiple_rows(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]], ["cola", "colb"])
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_no_schema(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`_1`, `_2`)"
self.compare_sql(df, expected)
def test_cdf_row_mixed_primitives(self):
df = self.spark.createDataFrame([[1, 10.1, "test", False, None]])
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM VALUES (1, 10.1, 'test', FALSE, NULL) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)"
self.compare_sql(df, expected)
def test_cdf_dict_rows(self):
df = self.spark.createDataFrame([{"cola": 1, "colb": "test"}, {"cola": 2, "colb": "test2"}])
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 'test'), (2, 'test2') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_str_schema(self):
df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING")
expected = "SELECT `a2`.`cola` AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_typed_schema_basic(self):
schema = types.StructType(
[
types.StructField("cola", types.IntegerType()),
types.StructField("colb", types.StringType()),
]
)
df = self.spark.createDataFrame([[1, "test"]], schema)
expected = "SELECT `a2`.`cola` AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_typed_schema_nested(self):
schema = types.StructType(
[
types.StructField(
"cola",
types.StructType(
[
types.StructField("sub_cola", types.IntegerType()),
types.StructField("sub_colb", types.StringType()),
]
),
)
]
)
df = self.spark.createDataFrame([[{"sub_cola": 1, "sub_colb": "test"}]], schema)
expected = "SELECT `a2`.`cola` AS `cola` FROM VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`)) AS `a2`(`cola`)"
self.compare_sql(df, expected)
def test_sql_select_only(self):
query = "SELECT cola, colb FROM table"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark")
df = self.spark.sql(query)
self.assertEqual(
"SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`",
df.sql(pretty=False)[0],
)
def test_sql_with_aggs(self):
query = "SELECT cola, colb FROM table"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark")
df = self.spark.sql(query).groupBy(F.col("cola")).agg(F.sum("colb"))
self.assertEqual(
"WITH t26614 AS (SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`), t23454 AS (SELECT cola, colb FROM t26614) SELECT cola, SUM(colb) FROM t23454 GROUP BY cola",
df.sql(pretty=False, optimize=False)[0],
)
def test_sql_create(self):
query = "CREATE TABLE new_table AS WITH t1 AS (SELECT cola, colb FROM table) SELECT cola, colb, FROM t1"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark")
df = self.spark.sql(query)
expected = "CREATE TABLE `new_table` AS SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
self.compare_sql(df, expected)
def test_sql_insert(self):
query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark")
df = self.spark.sql(query)
expected = "INSERT INTO `new_table` SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
self.compare_sql(df, expected)
def test_session_create_builder_patterns(self):
self.assertEqual(SparkSession.builder.appName("abc").getOrCreate(), SparkSession())

View file

@ -1,87 +0,0 @@
import sqlglot
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.errors import OptimizeError
from tests.dataframe.unit.dataframe_test_base import DataFrameTestBase
class TestSessionCaseSensitivity(DataFrameTestBase):
def setUp(self) -> None:
super().setUp()
self.spark = SparkSession.builder.config("sqlframe.dialect", "snowflake").getOrCreate()
tests = [
(
"All lower no intention of CS",
"test",
"test",
{"name": "VARCHAR"},
"name",
'''SELECT "TEST"."NAME" AS "NAME" FROM "TEST" AS "TEST"''',
),
(
"Table has CS while column does not",
'"Test"',
'"Test"',
{"name": "VARCHAR"},
"name",
'''SELECT "Test"."NAME" AS "NAME" FROM "Test" AS "Test"''',
),
(
"Column has CS while table does not",
"test",
"test",
{'"Name"': "VARCHAR"},
'"Name"',
'''SELECT "TEST"."Name" AS "Name" FROM "TEST" AS "TEST"''',
),
(
"Both Table and column have CS",
'"Test"',
'"Test"',
{'"Name"': "VARCHAR"},
'"Name"',
'''SELECT "Test"."Name" AS "Name" FROM "Test" AS "Test"''',
),
(
"Lowercase CS table and column",
'"test"',
'"test"',
{'"name"': "VARCHAR"},
'"name"',
'''SELECT "test"."name" AS "name" FROM "test" AS "test"''',
),
(
"CS table and column and query table but no CS in query column",
'"test"',
'"test"',
{'"name"': "VARCHAR"},
"name",
OptimizeError(),
),
(
"CS table and column and query column but no CS in query table",
'"test"',
"test",
{'"name"': "VARCHAR"},
'"name"',
OptimizeError(),
),
]
def test_basic_case_sensitivity(self):
for test_name, table_name, spark_table, schema, spark_column, expected in self.tests:
with self.subTest(test_name):
sqlglot.schema.add_table(table_name, schema, dialect=self.spark.dialect)
df = self.spark.table(spark_table).select(F.col(spark_column))
if isinstance(expected, OptimizeError):
with self.assertRaises(OptimizeError):
df.sql()
else:
self.compare_sql(df, expected)
def test_alias(self):
col = F.col('"Name"')
self.assertEqual(col.sql(dialect=self.spark.dialect), '"Name"')
self.assertEqual(col.alias("nAME").sql(dialect=self.spark.dialect), '"Name" AS NAME')
self.assertEqual(col.alias('"nAME"').sql(dialect=self.spark.dialect), '"Name" AS "nAME"')

View file

@ -1,73 +0,0 @@
import unittest
from sqlglot.dataframe.sql import types
class TestDataframeTypes(unittest.TestCase):
def test_string(self):
self.assertEqual("string", types.StringType().simpleString())
def test_char(self):
self.assertEqual("char(100)", types.CharType(100).simpleString())
def test_varchar(self):
self.assertEqual("varchar(65)", types.VarcharType(65).simpleString())
def test_binary(self):
self.assertEqual("binary", types.BinaryType().simpleString())
def test_boolean(self):
self.assertEqual("boolean", types.BooleanType().simpleString())
def test_date(self):
self.assertEqual("date", types.DateType().simpleString())
def test_timestamp(self):
self.assertEqual("timestamp", types.TimestampType().simpleString())
def test_timestamp_ntz(self):
self.assertEqual("timestamp_ntz", types.TimestampNTZType().simpleString())
def test_decimal(self):
self.assertEqual("decimal(10, 3)", types.DecimalType(10, 3).simpleString())
def test_double(self):
self.assertEqual("double", types.DoubleType().simpleString())
def test_float(self):
self.assertEqual("float", types.FloatType().simpleString())
def test_byte(self):
self.assertEqual("tinyint", types.ByteType().simpleString())
def test_integer(self):
self.assertEqual("int", types.IntegerType().simpleString())
def test_long(self):
self.assertEqual("bigint", types.LongType().simpleString())
def test_short(self):
self.assertEqual("smallint", types.ShortType().simpleString())
def test_array(self):
self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString())
def test_map(self):
self.assertEqual(
"map<int, string>",
types.MapType(types.IntegerType(), types.StringType()).simpleString(),
)
def test_struct_field(self):
self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString())
def test_struct_type(self):
self.assertEqual(
"struct<cola:int, colb:string>",
types.StructType(
[
types.StructField("cola", types.IntegerType()),
types.StructField("colb", types.StringType()),
]
).simpleString(),
)

View file

@ -1,75 +0,0 @@
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.window import Window, WindowSpec
from tests.dataframe.unit.dataframe_test_base import DataFrameTestBase
class TestDataframeWindow(DataFrameTestBase):
def test_window_spec_partition_by(self):
partition_by = WindowSpec().partitionBy(F.col("cola"), F.col("colb"))
self.assertEqual("OVER (PARTITION BY cola, colb)", partition_by.sql())
def test_window_spec_order_by(self):
order_by = WindowSpec().orderBy("cola", "colb")
self.assertEqual("OVER (ORDER BY cola, colb)", order_by.sql())
def test_window_spec_rows_between(self):
rows_between = WindowSpec().rowsBetween(3, 5)
self.assertEqual("OVER (ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
def test_window_spec_range_between(self):
range_between = WindowSpec().rangeBetween(3, 5)
self.assertEqual("OVER (RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
def test_window_partition_by(self):
partition_by = Window.partitionBy(F.col("cola"), F.col("colb"))
self.assertEqual("OVER (PARTITION BY cola, colb)", partition_by.sql())
def test_window_order_by(self):
order_by = Window.orderBy("cola", "colb")
self.assertEqual("OVER (ORDER BY cola, colb)", order_by.sql())
def test_window_rows_between(self):
rows_between = Window.rowsBetween(3, 5)
self.assertEqual("OVER (ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
def test_window_range_between(self):
range_between = Window.rangeBetween(3, 5)
self.assertEqual("OVER (RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
def test_window_rows_unbounded(self):
rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
self.assertEqual(
"OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
rows_between_unbounded_start.sql(),
)
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
self.assertEqual(
"OVER (ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
rows_between_unbounded_end.sql(),
)
rows_between_unbounded_both = Window.rowsBetween(
Window.unboundedPreceding, Window.unboundedFollowing
)
self.assertEqual(
"OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
rows_between_unbounded_both.sql(),
)
def test_window_range_unbounded(self):
range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
self.assertEqual(
"OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
range_between_unbounded_start.sql(),
)
range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
self.assertEqual(
"OVER (RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
range_between_unbounded_end.sql(),
)
range_between_unbounded_both = Window.rangeBetween(
Window.unboundedPreceding, Window.unboundedFollowing
)
self.assertEqual(
"OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
range_between_unbounded_both.sql(),
)

View file

@ -20,7 +20,6 @@ class TestDatabricks(Validator):
self.validate_identity("SELECT CAST('23:00:00' AS INTERVAL MINUTE TO SECOND)")
self.validate_identity("CREATE TABLE target SHALLOW CLONE source")
self.validate_identity("INSERT INTO a REPLACE WHERE cond VALUES (1), (2)")
self.validate_identity("SELECT c1 : price")
self.validate_identity("CREATE FUNCTION a.b(x INT) RETURNS INT RETURN x + 1")
self.validate_identity("CREATE FUNCTION a AS b")
self.validate_identity("SELECT ${x} FROM ${y} WHERE ${z} > 1")
@ -68,6 +67,20 @@ class TestDatabricks(Validator):
},
)
self.validate_all(
"SELECT X'1A2B'",
read={
"spark2": "SELECT X'1A2B'",
"spark": "SELECT X'1A2B'",
"databricks": "SELECT x'1A2B'",
},
write={
"spark2": "SELECT X'1A2B'",
"spark": "SELECT X'1A2B'",
"databricks": "SELECT X'1A2B'",
},
)
with self.assertRaises(ParseError):
transpile(
"CREATE FUNCTION add_one(x INT) RETURNS INT LANGUAGE PYTHON AS $foo$def add_one(x):\n return x+1$$",
@ -82,37 +95,33 @@ class TestDatabricks(Validator):
# https://docs.databricks.com/sql/language-manual/functions/colonsign.html
def test_json(self):
self.validate_identity("""SELECT c1 : price FROM VALUES ('{ "price": 5 }') AS T(c1)""")
self.validate_all(
self.validate_identity(
"""SELECT c1 : price FROM VALUES ('{ "price": 5 }') AS T(c1)""",
"""SELECT GET_JSON_OBJECT(c1, '$.price') FROM VALUES ('{ "price": 5 }') AS T(c1)""",
)
self.validate_identity(
"""SELECT c1:['price'] FROM VALUES('{ "price": 5 }') AS T(c1)""",
write={
"databricks": """SELECT c1 : ARRAY('price') FROM VALUES ('{ "price": 5 }') AS T(c1)""",
},
"""SELECT GET_JSON_OBJECT(c1, '$.price') FROM VALUES ('{ "price": 5 }') AS T(c1)""",
)
self.validate_all(
self.validate_identity(
"""SELECT c1:item[1].price FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
write={
"databricks": """SELECT c1 : item[1].price FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
},
"""SELECT GET_JSON_OBJECT(c1, '$.item[1].price') FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
)
self.validate_all(
self.validate_identity(
"""SELECT c1:item[*].price FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
write={
"databricks": """SELECT c1 : item[*].price FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
},
"""SELECT GET_JSON_OBJECT(c1, '$.item[*].price') FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
)
self.validate_all(
self.validate_identity(
"""SELECT from_json(c1:item[*].price, 'ARRAY<DOUBLE>')[0] FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
write={
"databricks": """SELECT FROM_JSON(c1 : item[*].price, 'ARRAY<DOUBLE>')[0] FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
},
"""SELECT FROM_JSON(GET_JSON_OBJECT(c1, '$.item[*].price'), 'ARRAY<DOUBLE>')[0] FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
)
self.validate_all(
self.validate_identity(
"""SELECT inline(from_json(c1:item[*], 'ARRAY<STRUCT<model STRING, price DOUBLE>>')) FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
write={
"databricks": """SELECT INLINE(FROM_JSON(c1 : item[*], 'ARRAY<STRUCT<model STRING, price DOUBLE>>')) FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
},
"""SELECT INLINE(FROM_JSON(GET_JSON_OBJECT(c1, '$.item[*]'), 'ARRAY<STRUCT<model STRING, price DOUBLE>>')) FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
)
self.validate_identity(
"SELECT c1 : price",
"SELECT GET_JSON_OBJECT(c1, '$.price')",
)
def test_datediff(self):

View file

@ -1163,7 +1163,7 @@ class TestDialect(Validator):
read={
"bigquery": "JSON_EXTRACT(x, '$.y')",
"duckdb": "x -> 'y'",
"doris": "x -> '$.y'",
"doris": "JSON_EXTRACT(x, '$.y')",
"mysql": "JSON_EXTRACT(x, '$.y')",
"postgres": "x->'y'",
"presto": "JSON_EXTRACT(x, '$.y')",
@ -1174,7 +1174,7 @@ class TestDialect(Validator):
write={
"bigquery": "JSON_EXTRACT(x, '$.y')",
"clickhouse": "JSONExtractString(x, 'y')",
"doris": "x -> '$.y'",
"doris": "JSON_EXTRACT(x, '$.y')",
"duckdb": "x -> '$.y'",
"mysql": "JSON_EXTRACT(x, '$.y')",
"oracle": "JSON_EXTRACT(x, '$.y')",
@ -1218,7 +1218,7 @@ class TestDialect(Validator):
read={
"bigquery": "JSON_EXTRACT(x, '$.y[0].z')",
"duckdb": "x -> '$.y[0].z'",
"doris": "x -> '$.y[0].z'",
"doris": "JSON_EXTRACT(x, '$.y[0].z')",
"mysql": "JSON_EXTRACT(x, '$.y[0].z')",
"presto": "JSON_EXTRACT(x, '$.y[0].z')",
"snowflake": "GET_PATH(x, 'y[0].z')",
@ -1228,7 +1228,7 @@ class TestDialect(Validator):
write={
"bigquery": "JSON_EXTRACT(x, '$.y[0].z')",
"clickhouse": "JSONExtractString(x, 'y', 1, 'z')",
"doris": "x -> '$.y[0].z'",
"doris": "JSON_EXTRACT(x, '$.y[0].z')",
"duckdb": "x -> '$.y[0].z'",
"mysql": "JSON_EXTRACT(x, '$.y[0].z')",
"oracle": "JSON_EXTRACT(x, '$.y[0].z')",

View file

@ -14,7 +14,9 @@ class TestDoris(Validator):
)
self.validate_all(
"SELECT MAX_BY(a, b), MIN_BY(c, d)",
read={"clickhouse": "SELECT argMax(a, b), argMin(c, d)"},
read={
"clickhouse": "SELECT argMax(a, b), argMin(c, d)",
},
)
self.validate_all(
"SELECT ARRAY_SUM(x -> x * x, ARRAY(2, 3))",
@ -36,6 +38,16 @@ class TestDoris(Validator):
"oracle": "ADD_MONTHS(d, n)",
},
)
self.validate_all(
"""SELECT JSON_EXTRACT(CAST('{"key": 1}' AS JSONB), '$.key')""",
read={
"postgres": """SELECT '{"key": 1}'::jsonb ->> 'key'""",
},
write={
"doris": """SELECT JSON_EXTRACT(CAST('{"key": 1}' AS JSONB), '$.key')""",
"postgres": """SELECT JSON_EXTRACT_PATH(CAST('{"key": 1}' AS JSONB), 'key')""",
},
)
def test_identity(self):
self.validate_identity("COALECSE(a, b, c, d)")

View file

@ -155,6 +155,10 @@ class TestMySQL(Validator):
"""SELECT * FROM foo WHERE 3 MEMBER OF(info->'$.value')""",
"""SELECT * FROM foo WHERE 3 MEMBER OF(JSON_EXTRACT(info, '$.value'))""",
)
self.validate_identity(
"SELECT 1 AS row",
"SELECT 1 AS `row`",
)
# Index hints
self.validate_identity(
@ -334,7 +338,7 @@ class TestMySQL(Validator):
write_CC = {
"bigquery": "SELECT 0xCC",
"clickhouse": "SELECT 0xCC",
"databricks": "SELECT 204",
"databricks": "SELECT X'CC'",
"drill": "SELECT 204",
"duckdb": "SELECT 204",
"hive": "SELECT 204",
@ -355,7 +359,7 @@ class TestMySQL(Validator):
write_CC_with_leading_zeros = {
"bigquery": "SELECT 0x0000CC",
"clickhouse": "SELECT 0x0000CC",
"databricks": "SELECT 204",
"databricks": "SELECT X'0000CC'",
"drill": "SELECT 204",
"duckdb": "SELECT 204",
"hive": "SELECT 204",

View file

@ -38,8 +38,6 @@ class TestPostgres(Validator):
self.validate_identity("CAST(x AS TSTZMULTIRANGE)")
self.validate_identity("CAST(x AS DATERANGE)")
self.validate_identity("CAST(x AS DATEMULTIRANGE)")
self.validate_identity("SELECT ARRAY[1, 2, 3] @> ARRAY[1, 2]")
self.validate_identity("SELECT ARRAY[1, 2, 3] <@ ARRAY[1, 2]")
self.validate_identity("x$")
self.validate_identity("SELECT ARRAY[1, 2, 3]")
self.validate_identity("SELECT ARRAY(SELECT 1)")
@ -64,6 +62,10 @@ class TestPostgres(Validator):
self.validate_identity("EXEC AS myfunc @id = 123", check_command_warning=True)
self.validate_identity("SELECT CURRENT_USER")
self.validate_identity("SELECT * FROM ONLY t1")
self.validate_identity(
"SELECT ARRAY[1, 2, 3] <@ ARRAY[1, 2]",
"SELECT ARRAY[1, 2] @> ARRAY[1, 2, 3]",
)
self.validate_identity(
"""UPDATE "x" SET "y" = CAST('0 days 60.000000 seconds' AS INTERVAL) WHERE "x"."id" IN (2, 3)"""
)
@ -325,6 +327,17 @@ class TestPostgres(Validator):
"CAST(x AS BIGINT)",
)
self.validate_all(
"SELECT ARRAY[1, 2, 3] @> ARRAY[1, 2]",
read={
"duckdb": "SELECT ARRAY_HAS_ALL([1, 2, 3], [1, 2])",
},
write={
"duckdb": "SELECT ARRAY_HAS_ALL([1, 2, 3], [1, 2])",
"mysql": UnsupportedError,
"postgres": "SELECT ARRAY[1, 2, 3] @> ARRAY[1, 2]",
},
)
self.validate_all(
"SELECT REGEXP_REPLACE('mr .', '[^a-zA-Z]', '', 'g')",
write={
@ -740,6 +753,9 @@ class TestPostgres(Validator):
self.validate_identity("ALTER TABLE t1 SET ACCESS METHOD method")
self.validate_identity("ALTER TABLE t1 SET TABLESPACE tablespace")
self.validate_identity("ALTER TABLE t1 SET (fillfactor = 5, autovacuum_enabled = TRUE)")
self.validate_identity(
"CREATE FUNCTION pymax(a INT, b INT) RETURNS INT LANGUAGE plpython3u AS $$\n if a > b:\n return a\n return b\n$$",
)
self.validate_identity(
"CREATE TABLE t (vid INT NOT NULL, CONSTRAINT ht_vid_nid_fid_idx EXCLUDE (INT4RANGE(vid, nid) WITH &&, INT4RANGE(fid, fid, '[]') WITH &&))"
)

View file

@ -10,6 +10,11 @@ class TestSnowflake(Validator):
dialect = "snowflake"
def test_snowflake(self):
self.validate_identity(
"transform(x, a int -> a + a + 1)",
"TRANSFORM(x, a -> CAST(a AS INT) + CAST(a AS INT) + 1)",
)
self.validate_all(
"ARRAY_CONSTRUCT_COMPACT(1, null, 2)",
write={
@ -321,10 +326,12 @@ WHERE
"""SELECT PARSE_JSON('{"fruit":"banana"}'):fruit""",
write={
"bigquery": """SELECT JSON_EXTRACT(PARSE_JSON('{"fruit":"banana"}'), '$.fruit')""",
"databricks": """SELECT GET_JSON_OBJECT('{"fruit":"banana"}', '$.fruit')""",
"duckdb": """SELECT JSON('{"fruit":"banana"}') -> '$.fruit'""",
"mysql": """SELECT JSON_EXTRACT('{"fruit":"banana"}', '$.fruit')""",
"presto": """SELECT JSON_EXTRACT(JSON_PARSE('{"fruit":"banana"}'), '$.fruit')""",
"snowflake": """SELECT GET_PATH(PARSE_JSON('{"fruit":"banana"}'), 'fruit')""",
"spark": """SELECT GET_JSON_OBJECT('{"fruit":"banana"}', '$.fruit')""",
"tsql": """SELECT ISNULL(JSON_QUERY('{"fruit":"banana"}', '$.fruit'), JSON_VALUE('{"fruit":"banana"}', '$.fruit'))""",
},
)
@ -1198,6 +1205,8 @@ WHERE
self.validate_identity("CREATE TABLE IDENTIFIER('foo') (COLUMN1 VARCHAR, COLUMN2 VARCHAR)")
self.validate_identity("CREATE TABLE IDENTIFIER($foo) (col1 VARCHAR, col2 VARCHAR)")
self.validate_identity("CREATE TAG cost_center ALLOWED_VALUES 'a', 'b'")
self.validate_identity("CREATE WAREHOUSE x").this.assert_is(exp.Identifier)
self.validate_identity("CREATE STREAMLIT x").this.assert_is(exp.Identifier)
self.validate_identity(
"CREATE OR REPLACE TAG IF NOT EXISTS cost_center COMMENT='cost_center tag'"
).this.assert_is(exp.Identifier)
@ -1825,7 +1834,7 @@ STORAGE_AWS_ROLE_ARN='arn:aws:iam::001234567890:role/myrole'
ENABLED=TRUE
STORAGE_ALLOWED_LOCATIONS=('s3://mybucket1/path1/', 's3://mybucket2/path2/')""",
pretty=True,
)
).this.assert_is(exp.Identifier)
def test_swap(self):
ast = parse_one("ALTER TABLE a SWAP WITH b", read="snowflake")

View file

@ -38,7 +38,7 @@ class TestTeradata(Validator):
"UPDATE A FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B SET col2 = '' WHERE A.col1 = B.col1",
write={
"teradata": "UPDATE A FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B SET col2 = '' WHERE A.col1 = B.col1",
"mysql": "UPDATE A SET col2 = '' FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B WHERE A.col1 = B.col1",
"mysql": "UPDATE A SET col2 = '' FROM `schema`.tableA AS A, (SELECT col1 FROM `schema`.tableA GROUP BY col1) AS B WHERE A.col1 = B.col1",
},
)

View file

@ -870,4 +870,5 @@ SELECT enum
SELECT unlogged
SELECT name
SELECT copy
SELECT rollup
SELECT rollup
SELECT unnest