1
0
Fork 0

Adding upstream version 9.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:47:39 +01:00
parent 768d386bf5
commit fca0265317
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
87 changed files with 7994 additions and 421 deletions

View file

View file

View file

@ -0,0 +1,149 @@
import typing as t
import unittest
import warnings
import sqlglot
from tests.helpers import SKIP_INTEGRATION
if t.TYPE_CHECKING:
from pyspark.sql import DataFrame as SparkDataFrame
@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set")
class DataFrameValidator(unittest.TestCase):
spark = None
sqlglot = None
df_employee = None
df_store = None
df_district = None
spark_employee_schema = None
sqlglot_employee_schema = None
spark_store_schema = None
sqlglot_store_schema = None
spark_district_schema = None
sqlglot_district_schema = None
@classmethod
def setUpClass(cls):
from pyspark import SparkConf
from pyspark.sql import SparkSession, types
from sqlglot.dataframe.sql import types as sqlglotSparkTypes
from sqlglot.dataframe.sql.session import SparkSession as SqlglotSparkSession
# This is for test `test_branching_root_dataframes`
config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")])
cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate()
cls.spark.sparkContext.setLogLevel("ERROR")
cls.sqlglot = SqlglotSparkSession()
cls.spark_employee_schema = types.StructType(
[
types.StructField("employee_id", types.IntegerType(), False),
types.StructField("fname", types.StringType(), False),
types.StructField("lname", types.StringType(), False),
types.StructField("age", types.IntegerType(), False),
types.StructField("store_id", types.IntegerType(), False),
]
)
cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType(
[
sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
]
)
employee_data = [
(1, "Jack", "Shephard", 37, 1),
(2, "John", "Locke", 65, 1),
(3, "Kate", "Austen", 37, 2),
(4, "Claire", "Littleton", 27, 2),
(5, "Hugo", "Reyes", 29, 100),
]
cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema)
cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema)
cls.df_employee.createOrReplaceTempView("employee")
cls.spark_store_schema = types.StructType(
[
types.StructField("store_id", types.IntegerType(), False),
types.StructField("store_name", types.StringType(), False),
types.StructField("district_id", types.IntegerType(), False),
types.StructField("num_sales", types.IntegerType(), False),
]
)
cls.sqlglot_store_schema = sqlglotSparkTypes.StructType(
[
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False),
]
)
store_data = [
(1, "Hydra", 1, 37),
(2, "Arrow", 2, 2000),
]
cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema)
cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema)
cls.df_store.createOrReplaceTempView("store")
cls.spark_district_schema = types.StructType(
[
types.StructField("district_id", types.IntegerType(), False),
types.StructField("district_name", types.StringType(), False),
types.StructField("manager_name", types.StringType(), False),
]
)
cls.sqlglot_district_schema = sqlglotSparkTypes.StructType(
[
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False),
]
)
district_data = [
(1, "Temple", "Dogen"),
(2, "Lighthouse", "Jacob"),
]
cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema)
cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema)
cls.df_district.createOrReplaceTempView("district")
sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema)
sqlglot.schema.add_table("store", cls.sqlglot_store_schema)
sqlglot.schema.add_table("district", cls.sqlglot_district_schema)
def setUp(self) -> None:
warnings.filterwarnings("ignore", category=ResourceWarning)
self.df_spark_store = self.df_store.alias("df_store") # type: ignore
self.df_spark_employee = self.df_employee.alias("df_employee") # type: ignore
self.df_spark_district = self.df_district.alias("df_district") # type: ignore
self.df_sqlglot_store = self.dfs_store.alias("store") # type: ignore
self.df_sqlglot_employee = self.dfs_employee.alias("employee") # type: ignore
self.df_sqlglot_district = self.dfs_district.alias("district") # type: ignore
def compare_spark_with_sqlglot(
self, df_spark, df_sqlglot, no_empty=True, skip_schema_compare=False
) -> t.Tuple["SparkDataFrame", "SparkDataFrame"]:
def compare_schemas(schema_1, schema_2):
for schema in [schema_1, schema_2]:
for struct_field in schema.fields:
struct_field.metadata = {}
self.assertEqual(schema_1, schema_2)
for statement in df_sqlglot.sql():
actual_df_sqlglot = self.spark.sql(statement) # type: ignore
df_sqlglot_results = actual_df_sqlglot.collect()
df_spark_results = df_spark.collect()
if not skip_schema_compare:
compare_schemas(df_spark.schema, actual_df_sqlglot.schema)
self.assertEqual(df_spark_results, df_sqlglot_results)
if no_empty:
self.assertNotEqual(len(df_spark_results), 0)
self.assertNotEqual(len(df_sqlglot_results), 0)
return df_spark, actual_df_sqlglot
@classmethod
def get_explain_plan(cls, df: "SparkDataFrame", mode: str = "extended") -> str:
return df._sc._jvm.PythonSQLUtils.explainString(df._jdf.queryExecution(), mode) # type: ignore

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,71 @@
from pyspark.sql import functions as F
from sqlglot.dataframe.sql import functions as SF
from tests.dataframe.integration.dataframe_validator import DataFrameValidator
class TestDataframeFunc(DataFrameValidator):
def test_group_by(self):
df_employee = self.df_spark_employee.groupBy(self.df_spark_employee.age).agg(
F.min(self.df_spark_employee.employee_id)
)
dfs_employee = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).agg(
SF.min(self.df_sqlglot_employee.employee_id)
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee, skip_schema_compare=True)
def test_group_by_where_non_aggregate(self):
df_employee = (
self.df_spark_employee.groupBy(self.df_spark_employee.age)
.agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
.where(F.col("age") > F.lit(50))
)
dfs_employee = (
self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
.agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
.where(SF.col("age") > SF.lit(50))
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_group_by_where_aggregate_like_having(self):
df_employee = (
self.df_spark_employee.groupBy(self.df_spark_employee.age)
.agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
.where(F.col("min_employee_id") > F.lit(1))
)
dfs_employee = (
self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
.agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
.where(SF.col("min_employee_id") > SF.lit(1))
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_count(self):
df = self.df_spark_employee.groupBy(self.df_spark_employee.age).count()
dfs = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).count()
self.compare_spark_with_sqlglot(df, dfs)
def test_mean(self):
df = self.df_spark_employee.groupBy().mean("age", "store_id")
dfs = self.df_sqlglot_employee.groupBy().mean("age", "store_id")
self.compare_spark_with_sqlglot(df, dfs)
def test_avg(self):
df = self.df_spark_employee.groupBy("age").avg("store_id")
dfs = self.df_sqlglot_employee.groupBy("age").avg("store_id")
self.compare_spark_with_sqlglot(df, dfs)
def test_max(self):
df = self.df_spark_employee.groupBy("age").max("store_id")
dfs = self.df_sqlglot_employee.groupBy("age").max("store_id")
self.compare_spark_with_sqlglot(df, dfs)
def test_min(self):
df = self.df_spark_employee.groupBy("age").min("store_id")
dfs = self.df_sqlglot_employee.groupBy("age").min("store_id")
self.compare_spark_with_sqlglot(df, dfs)
def test_sum(self):
df = self.df_spark_employee.groupBy("age").sum("store_id")
dfs = self.df_sqlglot_employee.groupBy("age").sum("store_id")
self.compare_spark_with_sqlglot(df, dfs)

View file

@ -0,0 +1,28 @@
from pyspark.sql import functions as F
from sqlglot.dataframe.sql import functions as SF
from tests.dataframe.integration.dataframe_validator import DataFrameValidator
class TestSessionFunc(DataFrameValidator):
def test_sql_simple_select(self):
query = "SELECT fname, lname FROM employee"
df = self.spark.sql(query)
dfs = self.sqlglot.sql(query)
self.compare_spark_with_sqlglot(df, dfs)
def test_sql_with_join(self):
query = """
SELECT
e.employee_id
, s.store_id
FROM
employee e
INNER JOIN
store s
ON
e.store_id = s.store_id
"""
df = self.spark.sql(query).groupBy(F.col("store_id")).agg(F.countDistinct(F.col("employee_id")))
dfs = self.sqlglot.sql(query).groupBy(SF.col("store_id")).agg(SF.countDistinct(SF.col("employee_id")))
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)

View file

View file

@ -0,0 +1,35 @@
import typing as t
import unittest
from sqlglot.dataframe.sql import types
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.session import SparkSession
class DataFrameSQLValidator(unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession()
self.employee_schema = types.StructType(
[
types.StructField("employee_id", types.IntegerType(), False),
types.StructField("fname", types.StringType(), False),
types.StructField("lname", types.StringType(), False),
types.StructField("age", types.IntegerType(), False),
types.StructField("store_id", types.IntegerType(), False),
]
)
employee_data = [
(1, "Jack", "Shephard", 37, 1),
(2, "John", "Locke", 65, 1),
(3, "Kate", "Austen", 37, 2),
(4, "Claire", "Littleton", 27, 2),
(5, "Hugo", "Reyes", 29, 100),
]
self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema)
def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False):
actual_sqls = df.sql(pretty=pretty)
expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements
self.assertEqual(len(expected_statements), len(actual_sqls))
for expected, actual in zip(expected_statements, actual_sqls):
self.assertEqual(expected, actual)

View file

@ -0,0 +1,167 @@
import datetime
import unittest
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.window import Window
class TestDataframeColumn(unittest.TestCase):
def test_eq(self):
self.assertEqual("cola = 1", (F.col("cola") == 1).sql())
def test_neq(self):
self.assertEqual("cola <> 1", (F.col("cola") != 1).sql())
def test_gt(self):
self.assertEqual("cola > 1", (F.col("cola") > 1).sql())
def test_lt(self):
self.assertEqual("cola < 1", (F.col("cola") < 1).sql())
def test_le(self):
self.assertEqual("cola <= 1", (F.col("cola") <= 1).sql())
def test_ge(self):
self.assertEqual("cola >= 1", (F.col("cola") >= 1).sql())
def test_and(self):
self.assertEqual(
"cola = colb AND colc = cold", ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql()
)
def test_or(self):
self.assertEqual(
"cola = colb OR colc = cold", ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql()
)
def test_mod(self):
self.assertEqual("cola % 2", (F.col("cola") % 2).sql())
def test_add(self):
self.assertEqual("cola + 1", (F.col("cola") + 1).sql())
def test_sub(self):
self.assertEqual("cola - 1", (F.col("cola") - 1).sql())
def test_mul(self):
self.assertEqual("cola * 2", (F.col("cola") * 2).sql())
def test_div(self):
self.assertEqual("cola / 2", (F.col("cola") / 2).sql())
def test_radd(self):
self.assertEqual("1 + cola", (1 + F.col("cola")).sql())
def test_rsub(self):
self.assertEqual("1 - cola", (1 - F.col("cola")).sql())
def test_rmul(self):
self.assertEqual("1 * cola", (1 * F.col("cola")).sql())
def test_rdiv(self):
self.assertEqual("1 / cola", (1 / F.col("cola")).sql())
def test_pow(self):
self.assertEqual("POWER(cola, 2)", (F.col("cola") ** 2).sql())
def test_rpow(self):
self.assertEqual("POWER(2, cola)", (2 ** F.col("cola")).sql())
def test_invert(self):
self.assertEqual("NOT cola", (~F.col("cola")).sql())
def test_startswith(self):
self.assertEqual("STARTSWITH(cola, 'test')", F.col("cola").startswith("test").sql())
def test_endswith(self):
self.assertEqual("ENDSWITH(cola, 'test')", F.col("cola").endswith("test").sql())
def test_rlike(self):
self.assertEqual("cola RLIKE 'foo'", F.col("cola").rlike("foo").sql())
def test_like(self):
self.assertEqual("cola LIKE 'foo%'", F.col("cola").like("foo%").sql())
def test_ilike(self):
self.assertEqual("cola ILIKE 'foo%'", F.col("cola").ilike("foo%").sql())
def test_substring(self):
self.assertEqual("SUBSTRING(cola, 2, 3)", F.col("cola").substr(2, 3).sql())
def test_isin(self):
self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin([1, 2, 3]).sql())
self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin(1, 2, 3).sql())
def test_asc(self):
self.assertEqual("cola", F.col("cola").asc().sql())
def test_desc(self):
self.assertEqual("cola DESC", F.col("cola").desc().sql())
def test_asc_nulls_first(self):
self.assertEqual("cola", F.col("cola").asc_nulls_first().sql())
def test_asc_nulls_last(self):
self.assertEqual("cola NULLS LAST", F.col("cola").asc_nulls_last().sql())
def test_desc_nulls_first(self):
self.assertEqual("cola DESC NULLS FIRST", F.col("cola").desc_nulls_first().sql())
def test_desc_nulls_last(self):
self.assertEqual("cola DESC", F.col("cola").desc_nulls_last().sql())
def test_when_otherwise(self):
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql())
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql())
self.assertEqual(
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
(F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(),
)
self.assertEqual(
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
F.col("cola").when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3).sql(),
)
self.assertEqual(
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 ELSE 4 END",
F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3).otherwise(4).sql(),
)
def test_is_null(self):
self.assertEqual("cola IS NULL", F.col("cola").isNull().sql())
def test_is_not_null(self):
self.assertEqual("NOT cola IS NULL", F.col("cola").isNotNull().sql())
def test_cast(self):
self.assertEqual("CAST(cola AS INT)", F.col("cola").cast("INT").sql())
def test_alias(self):
self.assertEqual("cola AS new_name", F.col("cola").alias("new_name").sql())
def test_between(self):
self.assertEqual("cola BETWEEN 1 AND 3", F.col("cola").between(1, 3).sql())
self.assertEqual("cola BETWEEN 10.1 AND 12.1", F.col("cola").between(10.1, 12.1).sql())
self.assertEqual(
"cola BETWEEN TO_DATE('2022-01-01') AND TO_DATE('2022-03-01')",
F.col("cola").between(datetime.date(2022, 1, 1), datetime.date(2022, 3, 1)).sql(),
)
self.assertEqual(
"cola BETWEEN CAST('2022-01-01 01:01:01' AS TIMESTAMP) " "AND CAST('2022-03-01 01:01:01' AS TIMESTAMP)",
F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(),
)
def test_over(self):
over_rows = F.sum("cola").over(
Window.partitionBy("colb").orderBy("colc").rowsBetween(1, Window.unboundedFollowing)
)
self.assertEqual(
"SUM(cola) OVER (PARTITION BY colb ORDER BY colc ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
over_rows.sql(),
)
over_range = F.sum("cola").over(
Window.partitionBy("colb").orderBy("colc").rangeBetween(1, Window.unboundedFollowing)
)
self.assertEqual(
"SUM(cola) OVER (PARTITION BY colb ORDER BY colc RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
over_range.sql(),
)

View file

@ -0,0 +1,39 @@
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.dataframe import DataFrame
from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataframe(DataFrameSQLValidator):
def test_hash_select_expression(self):
expression = exp.select("cola").from_("table")
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
def test_columns(self):
self.assertEqual(["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns)
def test_cache(self):
df = self.df_employee.select("fname").cache()
expected_statements = [
"DROP VIEW IF EXISTS t11623",
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
]
self.compare_sql(df, expected_statements)
def test_persist_default(self):
df = self.df_employee.select("fname").persist()
expected_statements = [
"DROP VIEW IF EXISTS t11623",
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
]
self.compare_sql(df, expected_statements)
def test_persist_storagelevel(self):
df = self.df_employee.select("fname").persist("DISK_ONLY_2")
expected_statements = [
"DROP VIEW IF EXISTS t11623",
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
]
self.compare_sql(df, expected_statements)

View file

@ -0,0 +1,86 @@
from unittest import mock
import sqlglot
from sqlglot.schema import MappingSchema
from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataFrameWriter(DataFrameSQLValidator):
def test_insertInto_full_path(self):
df = self.df_employee.write.insertInto("catalog.db.table_name")
expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_db_table(self):
df = self.df_employee.write.insertInto("db.table_name")
expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_table(self):
df = self.df_employee.write.insertInto("table_name")
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_overwrite(self):
df = self.df_employee.write.insertInto("table_name", overwrite=True)
expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema())
def test_insertInto_byName(self):
sqlglot.schema.add_table("table_name", {"employee_id": "INT"})
df = self.df_employee.write.byName.insertInto("table_name")
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_cache(self):
df = self.df_employee.cache().write.insertInto("table_name")
expected_statements = [
"DROP VIEW IF EXISTS t35612",
"CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"INSERT INTO table_name SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`",
]
self.compare_sql(df, expected_statements)
def test_saveAsTable_format(self):
with self.assertRaises(NotImplementedError):
self.df_employee.write.saveAsTable("table_name", format="parquet").sql(pretty=False)[0]
def test_saveAsTable_append(self):
df = self.df_employee.write.saveAsTable("table_name", mode="append")
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_overwrite(self):
df = self.df_employee.write.saveAsTable("table_name", mode="overwrite")
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_error(self):
df = self.df_employee.write.saveAsTable("table_name", mode="error")
expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_ignore(self):
df = self.df_employee.write.saveAsTable("table_name", mode="ignore")
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_mode_standalone(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name")
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_mode_override(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite")
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_cache(self):
df = self.df_employee.cache().write.saveAsTable("table_name")
expected_statements = [
"DROP VIEW IF EXISTS t35612",
"CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"CREATE TABLE table_name AS SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`",
]
self.compare_sql(df, expected_statements)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,114 @@
from unittest import mock
import sqlglot
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql import types
from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.schema import MappingSchema
from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataframeSession(DataFrameSQLValidator):
def test_cdf_one_row(self):
df = self.spark.createDataFrame([[1, 2]], ["cola", "colb"])
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2)) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_multiple_rows(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]], ["cola", "colb"])
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_no_schema(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
expected = (
"SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
)
self.compare_sql(df, expected)
def test_cdf_row_mixed_primitives(self):
df = self.spark.createDataFrame([[1, 10.1, "test", False, None]])
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM (VALUES (1, 10.1, 'test', FALSE, NULL)) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)"
self.compare_sql(df, expected)
def test_cdf_dict_rows(self):
df = self.spark.createDataFrame([{"cola": 1, "colb": "test"}, {"cola": 2, "colb": "test2"}])
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 'test'), (2, 'test2')) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_str_schema(self):
df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING")
expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_typed_schema_basic(self):
schema = types.StructType(
[
types.StructField("cola", types.IntegerType()),
types.StructField("colb", types.StringType()),
]
)
df = self.spark.createDataFrame([[1, "test"]], schema)
expected = "SELECT CAST(`a2`.`cola` AS int) AS `cola`, CAST(`a2`.`colb` AS string) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_typed_schema_nested(self):
schema = types.StructType(
[
types.StructField(
"cola",
types.StructType(
[
types.StructField("sub_cola", types.IntegerType()),
types.StructField("sub_colb", types.StringType()),
]
),
)
]
)
df = self.spark.createDataFrame([[{"sub_cola": 1, "sub_colb": "test"}]], schema)
expected = "SELECT CAST(`a2`.`cola` AS struct<sub_cola:int, sub_colb:string>) AS `cola` FROM (VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`))) AS `a2`(`cola`)"
self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema())
def test_sql_select_only(self):
# TODO: Do exact matches once CTE names are deterministic
query = "SELECT cola, colb FROM table"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query)
self.assertIn(
"SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", df.sql(pretty=False)
)
@mock.patch("sqlglot.schema", MappingSchema())
def test_sql_with_aggs(self):
# TODO: Do exact matches once CTE names are deterministic
query = "SELECT cola, colb FROM table"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query).groupBy(F.col("cola")).agg(F.sum("colb"))
result = df.sql(pretty=False, optimize=False)[0]
self.assertIn("SELECT cola, colb FROM table", result)
self.assertIn("SUM(colb)", result)
self.assertIn("GROUP BY cola", result)
@mock.patch("sqlglot.schema", MappingSchema())
def test_sql_create(self):
query = "CREATE TABLE new_table AS WITH t1 AS (SELECT cola, colb FROM table) SELECT cola, colb, FROM t1"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query)
expected = "CREATE TABLE new_table AS SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema())
def test_sql_insert(self):
query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query)
expected = (
"INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
)
self.compare_sql(df, expected)
def test_session_create_builder_patterns(self):
spark = SparkSession()
self.assertEqual(spark.builder.appName("abc").getOrCreate(), spark)

View file

@ -0,0 +1,70 @@
import unittest
from sqlglot.dataframe.sql import types
class TestDataframeTypes(unittest.TestCase):
def test_string(self):
self.assertEqual("string", types.StringType().simpleString())
def test_char(self):
self.assertEqual("char(100)", types.CharType(100).simpleString())
def test_varchar(self):
self.assertEqual("varchar(65)", types.VarcharType(65).simpleString())
def test_binary(self):
self.assertEqual("binary", types.BinaryType().simpleString())
def test_boolean(self):
self.assertEqual("boolean", types.BooleanType().simpleString())
def test_date(self):
self.assertEqual("date", types.DateType().simpleString())
def test_timestamp(self):
self.assertEqual("timestamp", types.TimestampType().simpleString())
def test_timestamp_ntz(self):
self.assertEqual("timestamp_ntz", types.TimestampNTZType().simpleString())
def test_decimal(self):
self.assertEqual("decimal(10, 3)", types.DecimalType(10, 3).simpleString())
def test_double(self):
self.assertEqual("double", types.DoubleType().simpleString())
def test_float(self):
self.assertEqual("float", types.FloatType().simpleString())
def test_byte(self):
self.assertEqual("tinyint", types.ByteType().simpleString())
def test_integer(self):
self.assertEqual("int", types.IntegerType().simpleString())
def test_long(self):
self.assertEqual("bigint", types.LongType().simpleString())
def test_short(self):
self.assertEqual("smallint", types.ShortType().simpleString())
def test_array(self):
self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString())
def test_map(self):
self.assertEqual("map<int, string>", types.MapType(types.IntegerType(), types.StringType()).simpleString())
def test_struct_field(self):
self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString())
def test_struct_type(self):
self.assertEqual(
"struct<cola:int, colb:string>",
types.StructType(
[
types.StructField("cola", types.IntegerType()),
types.StructField("colb", types.StringType()),
]
).simpleString(),
)

View file

@ -0,0 +1,60 @@
import unittest
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.window import Window, WindowSpec
class TestDataframeWindow(unittest.TestCase):
def test_window_spec_partition_by(self):
partition_by = WindowSpec().partitionBy(F.col("cola"), F.col("colb"))
self.assertEqual("OVER (PARTITION BY cola, colb)", partition_by.sql())
def test_window_spec_order_by(self):
order_by = WindowSpec().orderBy("cola", "colb")
self.assertEqual("OVER (ORDER BY cola, colb)", order_by.sql())
def test_window_spec_rows_between(self):
rows_between = WindowSpec().rowsBetween(3, 5)
self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
def test_window_spec_range_between(self):
range_between = WindowSpec().rangeBetween(3, 5)
self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
def test_window_partition_by(self):
partition_by = Window.partitionBy(F.col("cola"), F.col("colb"))
self.assertEqual("OVER (PARTITION BY cola, colb)", partition_by.sql())
def test_window_order_by(self):
order_by = Window.orderBy("cola", "colb")
self.assertEqual("OVER (ORDER BY cola, colb)", order_by.sql())
def test_window_rows_between(self):
rows_between = Window.rowsBetween(3, 5)
self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
def test_window_range_between(self):
range_between = Window.rangeBetween(3, 5)
self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
def test_window_rows_unbounded(self):
rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
self.assertEqual("OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", rows_between_unbounded_start.sql())
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
self.assertEqual("OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_end.sql())
rows_between_unbounded_both = Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
self.assertEqual(
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_both.sql()
)
def test_window_range_unbounded(self):
range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
self.assertEqual(
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", range_between_unbounded_start.sql()
)
range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
self.assertEqual("OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_end.sql())
range_between_unbounded_both = Window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)
self.assertEqual(
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_both.sql()
)

View file

@ -694,29 +694,6 @@ class TestDialect(Validator):
},
)
# https://dev.mysql.com/doc/refman/8.0/en/join.html
# https://www.postgresql.org/docs/current/queries-table-expressions.html
def test_joined_tables(self):
self.validate_identity("SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)")
self.validate_identity("SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3)")
self.validate_identity("SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)")
self.validate_identity("SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)")
self.validate_all(
"SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)",
write={
"postgres": "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)",
"mysql": "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)",
},
)
self.validate_all(
"SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)",
write={
"postgres": "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)",
"mysql": "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)",
},
)
def test_lateral_subquery(self):
self.validate_identity(
"SELECT art FROM tbl1 INNER JOIN LATERAL (SELECT art FROM tbl2) AS tbl2 ON tbl1.art = tbl2.art"
@ -856,7 +833,7 @@ class TestDialect(Validator):
"postgres": "x ILIKE '%y'",
"presto": "LOWER(x) LIKE '%y'",
"snowflake": "x ILIKE '%y'",
"spark": "LOWER(x) LIKE '%y'",
"spark": "x ILIKE '%y'",
"sqlite": "LOWER(x) LIKE '%y'",
"starrocks": "LOWER(x) LIKE '%y'",
"trino": "LOWER(x) LIKE '%y'",

View file

@ -48,7 +48,7 @@ class TestDuckDB(Validator):
self.validate_all(
"STRPTIME(x, '%y-%-m')",
write={
"bigquery": "STR_TO_TIME(x, '%y-%-m')",
"bigquery": "PARSE_TIMESTAMP('%y-%m', x)",
"duckdb": "STRPTIME(x, '%y-%-m')",
"presto": "DATE_PARSE(x, '%y-%c')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy-M')) AS TIMESTAMP)",
@ -63,6 +63,16 @@ class TestDuckDB(Validator):
"hive": "CAST(x AS TIMESTAMP)",
},
)
self.validate_all(
"STRPTIME(x, '%-m/%-d/%y %-I:%M %p')",
write={
"bigquery": "PARSE_TIMESTAMP('%m/%d/%y %I:%M %p', x)",
"duckdb": "STRPTIME(x, '%-m/%-d/%y %-I:%M %p')",
"presto": "DATE_PARSE(x, '%c/%e/%y %l:%i %p')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'M/d/yy h:mm a')) AS TIMESTAMP)",
"spark": "TO_TIMESTAMP(x, 'M/d/yy h:mm a')",
},
)
def test_duckdb(self):
self.validate_all(
@ -268,6 +278,17 @@ class TestDuckDB(Validator):
"spark": "MONTH('2021-03-01')",
},
)
self.validate_all(
"ARRAY_CONCAT(LIST_VALUE(1, 2), LIST_VALUE(3, 4))",
write={
"duckdb": "ARRAY_CONCAT(LIST_VALUE(1, 2), LIST_VALUE(3, 4))",
"presto": "CONCAT(ARRAY[1, 2], ARRAY[3, 4])",
"hive": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))",
"spark": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))",
"snowflake": "ARRAY_CAT([1, 2], [3, 4])",
"bigquery": "ARRAY_CONCAT([1, 2], [3, 4])",
},
)
with self.assertRaises(UnsupportedError):
transpile(

View file

@ -31,6 +31,24 @@ class TestMySQL(Validator):
"mysql": "_utf8mb4 'hola'",
},
)
self.validate_all(
"N 'some text'",
read={
"mysql": "N'some text'",
},
write={
"mysql": "N 'some text'",
},
)
self.validate_all(
"_latin1 x'4D7953514C'",
read={
"mysql": "_latin1 X'4D7953514C'",
},
write={
"mysql": "_latin1 x'4D7953514C'",
},
)
def test_hexadecimal_literal(self):
self.validate_all(

View file

@ -69,6 +69,8 @@ class TestPostgres(Validator):
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
self.validate_identity("SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')")
self.validate_identity("COMMENT ON TABLE mytable IS 'this'")
self.validate_identity("SELECT e'\\xDEADBEEF'")
self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
self.validate_all(
"CREATE TABLE x (a UUID, b BYTEA)",
@ -204,3 +206,11 @@ class TestPostgres(Validator):
"""'{"a":[1,2,3],"b":[4,5,6]}'::json#>>'{a,2}'""",
write={"postgres": """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON)#>>'{a,2}'"""},
)
self.validate_all(
"SELECT $$a$$",
write={"postgres": "SELECT 'a'"},
)
self.validate_all(
"SELECT $$Dianne's horse$$",
write={"postgres": "SELECT 'Dianne''s horse'"},
)

View file

@ -321,7 +321,7 @@ class TestPresto(Validator):
"duckdb": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
"presto": "SELECT APPROX_DISTINCT(a, 0.1) FROM foo",
"hive": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
"spark": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
"spark": "SELECT APPROX_COUNT_DISTINCT(a, 0.1) FROM foo",
},
)
self.validate_all(
@ -329,7 +329,7 @@ class TestPresto(Validator):
write={
"presto": "SELECT APPROX_DISTINCT(a, 0.1) FROM foo",
"hive": UnsupportedError,
"spark": UnsupportedError,
"spark": "SELECT APPROX_COUNT_DISTINCT(a, 0.1) FROM foo",
},
)
self.validate_all(

View file

@ -65,7 +65,7 @@ class TestSnowflake(Validator):
self.validate_all(
"SELECT TO_TIMESTAMP('2013-04-05 01:02:03')",
write={
"bigquery": "SELECT STR_TO_TIME('2013-04-05 01:02:03', '%Y-%m-%d %H:%M:%S')",
"bigquery": "SELECT PARSE_TIMESTAMP('%Y-%m-%d %H:%M:%S', '2013-04-05 01:02:03')",
"snowflake": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-mm-dd hh24:mi:ss')",
"spark": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-MM-dd HH:mm:ss')",
},
@ -73,16 +73,17 @@ class TestSnowflake(Validator):
self.validate_all(
"SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')",
read={
"bigquery": "SELECT STR_TO_TIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')",
"bigquery": "SELECT PARSE_TIMESTAMP('%m/%d/%Y %H:%M:%S', '04/05/2013 01:02:03')",
"duckdb": "SELECT STRPTIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')",
"snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')",
},
write={
"bigquery": "SELECT STR_TO_TIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')",
"bigquery": "SELECT PARSE_TIMESTAMP('%m/%d/%Y %H:%M:%S', '04/05/2013 01:02:03')",
"snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')",
"spark": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'MM/dd/yyyy HH:mm:ss')",
},
)
self.validate_all(
"SELECT IFF(TRUE, 'true', 'false')",
write={
@ -240,11 +241,25 @@ class TestSnowflake(Validator):
},
)
self.validate_all(
"SELECT DATE_PART(month FROM a::DATETIME)",
"SELECT DATE_PART(month, a::DATETIME)",
write={
"snowflake": "SELECT EXTRACT(month FROM CAST(a AS DATETIME))",
},
)
self.validate_all(
"SELECT DATE_PART(epoch_second, foo) as ddate from table_name",
write={
"snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMPNTZ)) AS ddate FROM table_name",
"presto": "SELECT TO_UNIXTIME(CAST(foo AS TIMESTAMP)) AS ddate FROM table_name",
},
)
self.validate_all(
"SELECT DATE_PART(epoch_milliseconds, foo) as ddate from table_name",
write={
"snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMPNTZ)) * 1000 AS ddate FROM table_name",
"presto": "SELECT TO_UNIXTIME(CAST(foo AS TIMESTAMP)) * 1000 AS ddate FROM table_name",
},
)
def test_semi_structured_types(self):
self.validate_identity("SELECT CAST(a AS VARIANT)")

View file

@ -45,3 +45,29 @@ class TestTSQL(Validator):
"tsql": "CAST(x AS DATETIME2)",
},
)
def test_charindex(self):
self.validate_all(
"CHARINDEX(x, y, 9)",
write={
"spark": "LOCATE(x, y, 9)",
},
)
self.validate_all(
"CHARINDEX(x, y)",
write={
"spark": "LOCATE(x, y)",
},
)
self.validate_all(
"CHARINDEX('sub', 'testsubstring', 3)",
write={
"spark": "LOCATE('sub', 'testsubstring', 3)",
},
)
self.validate_all(
"CHARINDEX('sub', 'testsubstring')",
write={
"spark": "LOCATE('sub', 'testsubstring')",
},
)

View file

@ -513,6 +513,8 @@ ALTER TYPE electronic_mail RENAME TO email
ANALYZE a.y
DELETE FROM x WHERE y > 1
DELETE FROM y
DELETE FROM event USING sales WHERE event.eventid = sales.eventid
DELETE FROM event USING sales, USING bla WHERE event.eventid = sales.eventid
DROP TABLE a
DROP TABLE a.b
DROP TABLE IF EXISTS a
@ -563,3 +565,8 @@ WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a
SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z
SELECT ((SELECT 1) + 1)
SELECT * FROM project.dataset.INFORMATION_SCHEMA.TABLES
SELECT * FROM (table1 AS t1 LEFT JOIN table2 AS t2 ON 1 = 1)
SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)
SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3)
SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)
SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)

View file

@ -287,3 +287,27 @@ SELECT
FROM
t1;
SELECT x.a AS a, x.b AS b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) AS row_num FROM x AS x;
# title: Values Test
# dialect: spark
WITH t1 AS (
SELECT
a1.cola
FROM
VALUES (1) AS a1(cola)
), t2 AS (
SELECT
a2.cola
FROM
VALUES (1) AS a2(cola)
)
SELECT /*+ BROADCAST(t2) */
t1.cola,
t2.cola,
FROM
t1
JOIN
t2
ON
t1.cola = t2.cola;
SELECT /*+ BROADCAST(a2) */ a1.cola AS cola, a2.cola AS cola FROM VALUES (1) AS a1(cola) JOIN VALUES (1) AS a2(cola) ON a1.cola = a2.cola;

View file

@ -33,3 +33,6 @@ SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON y.
with t1 as (SELECT x.a, x.b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) as row_num FROM x) SELECT t1.a, t1.b FROM t1 WHERE row_num = 1;
WITH t1 AS (SELECT x.a, x.b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) AS row_num FROM x) SELECT t1.a, t1.b FROM t1 WHERE row_num = 1;
WITH m AS (SELECT a, b FROM (VALUES (1, 2)) AS a1(a, b)), n AS (SELECT a, b FROM m WHERE m.a = 1), o AS (SELECT a, b FROM m WHERE m.a = 2) SELECT n.a, n.b, n.a, o.b FROM n FULL OUTER JOIN o ON n.a = o.a;
WITH m AS (SELECT a, b FROM (VALUES (1, 2)) AS a1(a, b)), n AS (SELECT a, b FROM m WHERE m.a = 1), o AS (SELECT a, b FROM m WHERE m.a = 2) SELECT n.a, n.b, n.a, o.b FROM n FULL OUTER JOIN o ON n.a = o.a;

View file

@ -22,6 +22,9 @@ SELECT "_q_0".a AS a FROM (SELECT DISTINCT x.a AS a, x.b AS b FROM x AS x) AS "_
SELECT a FROM (SELECT a, b FROM x UNION ALL SELECT a, b FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION ALL SELECT x.a AS a FROM x AS x) AS "_q_0";
WITH t1 AS (SELECT x.a AS a, x.b AS b FROM x UNION ALL SELECT z.b AS b, z.c AS c FROM z) SELECT a, b FROM t1;
WITH t1 AS (SELECT x.a AS a, x.b AS b FROM x AS x UNION ALL SELECT z.b AS b, z.c AS c FROM z AS z) SELECT t1.a AS a, t1.b AS b FROM t1;
SELECT a FROM (SELECT a, b FROM x UNION SELECT a, b FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x UNION SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0";

View file

@ -72,6 +72,9 @@ SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a;
SELECT a FROM x ORDER BY b;
SELECT x.a AS a FROM x AS x ORDER BY x.b;
SELECT SUM(a) AS a FROM x ORDER BY SUM(a);
SELECT SUM(x.a) AS a FROM x AS x ORDER BY SUM(x.a);
# dialect: bigquery
SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) AS row_num FROM x QUALIFY row_num = 1;
SELECT ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.b) AS row_num FROM x AS x QUALIFY row_num = 1;

View file

@ -53,6 +53,8 @@ def string_to_bool(string):
return string and string.lower() in ("true", "1")
SKIP_INTEGRATION = string_to_bool(os.environ.get("SKIP_INTEGRATION", "0").lower())
TPCH_SCHEMA = {
"lineitem": {
"l_orderkey": "uint64",

View file

@ -7,11 +7,17 @@ from pandas.testing import assert_frame_equal
from sqlglot import exp, parse_one
from sqlglot.executor import execute
from sqlglot.executor.python import Python
from tests.helpers import FIXTURES_DIR, TPCH_SCHEMA, load_sql_fixture_pairs
from tests.helpers import (
FIXTURES_DIR,
SKIP_INTEGRATION,
TPCH_SCHEMA,
load_sql_fixture_pairs,
)
DIR = FIXTURES_DIR + "/optimizer/tpc-h/"
@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set")
class TestExecutor(unittest.TestCase):
@classmethod
def setUpClass(cls):

View file

@ -123,13 +123,16 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(exp.table_name(parse_one("a.b.c", into=exp.Table)), "a.b.c")
self.assertEqual(exp.table_name("a.b.c"), "a.b.c")
def test_table(self):
self.assertEqual(exp.table_("a", alias="b"), parse_one("select * from a b").find(exp.Table))
def test_replace_tables(self):
self.assertEqual(
exp.replace_tables(
parse_one("select * from a join b join c.a join d.a join e.a"),
parse_one("select * from a AS a join b join c.a join d.a join e.a"),
{"a": "a1", "b": "b.a", "c.a": "c.a2", "d.a": "d2"},
).sql(),
'SELECT * FROM "a1" JOIN "b"."a" JOIN "c"."a2" JOIN "d2" JOIN e.a',
"SELECT * FROM a1 AS a JOIN b.a JOIN c.a2 JOIN d2 JOIN e.a",
)
def test_named_selects(self):
@ -495,11 +498,15 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(exp.convert(value).sql(), expected)
def test_annotation_alias(self):
expression = parse_one("SELECT a, b AS B, c #comment, d AS D #another_comment FROM foo")
sql = "SELECT a, b AS B, c # comment, d AS D # another_comment FROM foo"
expression = parse_one(sql)
self.assertEqual(
[e.alias_or_name for e in expression.expressions],
["a", "B", "c", "D"],
)
self.assertEqual(expression.sql(), sql)
self.assertEqual(expression.expressions[2].name, "comment")
self.assertEqual(expression.sql(annotations=False), "SELECT a, b AS B, c, d AS D")
def test_to_table(self):
table_only = exp.to_table("table_name")
@ -514,6 +521,18 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(catalog_db_and_table.name, "table_name")
self.assertEqual(catalog_db_and_table.args.get("db"), exp.to_identifier("db"))
self.assertEqual(catalog_db_and_table.args.get("catalog"), exp.to_identifier("catalog"))
with self.assertRaises(ValueError):
exp.to_table(1)
def test_to_column(self):
column_only = exp.to_column("column_name")
self.assertEqual(column_only.name, "column_name")
self.assertIsNone(column_only.args.get("table"))
table_and_column = exp.to_column("table_name.column_name")
self.assertEqual(table_and_column.name, "column_name")
self.assertEqual(table_and_column.args.get("table"), exp.to_identifier("table_name"))
with self.assertRaises(ValueError):
exp.to_column(1)
def test_union(self):
expression = parse_one("SELECT cola, colb UNION SELECT colx, coly")

View file

@ -5,11 +5,11 @@ import duckdb
from pandas.testing import assert_frame_equal
import sqlglot
from sqlglot import exp, optimizer, parse_one, table
from sqlglot import exp, optimizer, parse_one
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.schema import MappingSchema, ensure_schema
from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope
from sqlglot.schema import MappingSchema
from tests.helpers import (
TPCH_SCHEMA,
load_sql_fixture_pairs,
@ -29,19 +29,19 @@ class TestOptimizer(unittest.TestCase):
CREATE TABLE x (a INT, b INT);
CREATE TABLE y (b INT, c INT);
CREATE TABLE z (b INT, c INT);
INSERT INTO x VALUES (1, 1);
INSERT INTO x VALUES (2, 2);
INSERT INTO x VALUES (2, 2);
INSERT INTO x VALUES (3, 3);
INSERT INTO x VALUES (null, null);
INSERT INTO y VALUES (2, 2);
INSERT INTO y VALUES (2, 2);
INSERT INTO y VALUES (3, 3);
INSERT INTO y VALUES (4, 4);
INSERT INTO y VALUES (null, null);
INSERT INTO y VALUES (3, 3);
INSERT INTO y VALUES (3, 3);
INSERT INTO y VALUES (4, 4);
@ -80,8 +80,8 @@ class TestOptimizer(unittest.TestCase):
with self.subTest(title):
self.assertEqual(
optimized.sql(pretty=pretty, dialect=dialect),
expected,
optimized.sql(pretty=pretty, dialect=dialect),
)
should_execute = meta.get("execute")
@ -223,85 +223,6 @@ class TestOptimizer(unittest.TestCase):
def test_tpch(self):
self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True)
def test_schema(self):
schema = ensure_schema(
{
"x": {
"a": "uint64",
}
}
)
self.assertEqual(
schema.column_names(
table(
"x",
)
),
["a"],
)
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db"))
with self.assertRaises(ValueError):
schema.column_names(table("x2"))
schema = ensure_schema(
{
"db": {
"x": {
"a": "uint64",
}
}
}
)
self.assertEqual(schema.column_names(table("x", db="db")), ["a"])
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c"))
with self.assertRaises(ValueError):
schema.column_names(table("x"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db2"))
with self.assertRaises(ValueError):
schema.column_names(table("x2", db="db"))
schema = ensure_schema(
{
"c": {
"db": {
"x": {
"a": "uint64",
}
}
}
}
)
self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"])
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db"))
with self.assertRaises(ValueError):
schema.column_names(table("x"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c2"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db2"))
with self.assertRaises(ValueError):
schema.column_names(table("x2", db="db"))
schema = ensure_schema(
MappingSchema(
{
"x": {
"a": "uint64",
}
}
)
)
self.assertEqual(schema.column_names(table("x")), ["a"])
with self.assertRaises(OptimizeError):
ensure_schema({})
def test_file_schema(self):
expression = parse_one(
"""
@ -327,6 +248,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
SELECT x.b FROM x
), r AS (
SELECT y.b FROM y
), z as (
SELECT cola, colb FROM (VALUES(1, 'test')) AS tab(cola, colb)
)
SELECT
r.b,
@ -340,19 +263,23 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
"""
expression = parse_one(sql)
for scopes in traverse_scope(expression), list(build_scope(expression).traverse()):
self.assertEqual(len(scopes), 5)
self.assertEqual(len(scopes), 7)
self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x")
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
self.assertEqual(scopes[2].expression.sql(), "SELECT y.c AS b FROM y")
self.assertEqual(scopes[3].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql())
self.assertEqual(scopes[2].expression.sql(), "(VALUES (1, 'test')) AS tab(cola, colb)")
self.assertEqual(
scopes[3].expression.sql(), "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)"
)
self.assertEqual(scopes[4].expression.sql(), "SELECT y.c AS b FROM y")
self.assertEqual(scopes[5].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
self.assertEqual(scopes[6].expression.sql(), parse_one(sql).sql())
self.assertEqual(set(scopes[4].sources), {"q", "r", "s"})
self.assertEqual(len(scopes[4].columns), 6)
self.assertEqual(set(c.table for c in scopes[4].columns), {"r", "s"})
self.assertEqual(scopes[4].source_columns("q"), [])
self.assertEqual(len(scopes[4].source_columns("r")), 2)
self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"})
self.assertEqual(set(scopes[6].sources), {"q", "z", "r", "s"})
self.assertEqual(len(scopes[6].columns), 6)
self.assertEqual(set(c.table for c in scopes[6].columns), {"r", "s"})
self.assertEqual(scopes[6].source_columns("q"), [])
self.assertEqual(len(scopes[6].source_columns("r")), 2)
self.assertEqual(set(c.table for c in scopes[6].source_columns("r")), {"r"})
self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"})
self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b")

View file

@ -81,7 +81,7 @@ class TestParser(unittest.TestCase):
self.assertIsInstance(ignore.expression(exp.Hint, y=""), exp.Hint)
self.assertIsInstance(ignore.expression(exp.Hint), exp.Hint)
default = Parser()
default = Parser(error_level=ErrorLevel.RAISE)
self.assertIsInstance(default.expression(exp.Hint, expressions=[""]), exp.Hint)
default.expression(exp.Hint, y="")
default.expression(exp.Hint)
@ -139,12 +139,12 @@ class TestParser(unittest.TestCase):
)
assert expression.expressions[0].name == "annotation1"
assert expression.expressions[1].name == "annotation2:testing "
assert expression.expressions[1].name == "annotation2:testing"
assert expression.expressions[2].name == "test#annotation"
assert expression.expressions[3].name == "annotation3"
assert expression.expressions[4].name == "annotation4"
assert expression.expressions[5].name == ""
assert expression.expressions[6].name == " space"
assert expression.expressions[6].name == "space"
def test_pretty_config_override(self):
self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x")

290
tests/test_schema.py Normal file
View file

@ -0,0 +1,290 @@
import unittest
from sqlglot import table
from sqlglot.dataframe.sql import types as df_types
from sqlglot.schema import MappingSchema, ensure_schema
class TestSchema(unittest.TestCase):
def test_schema(self):
schema = ensure_schema(
{
"x": {
"a": "uint64",
}
}
)
self.assertEqual(
schema.column_names(
table(
"x",
)
),
["a"],
)
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db"))
with self.assertRaises(ValueError):
schema.column_names(table("x2"))
with self.assertRaises(ValueError):
schema.add_table(table("y", db="db"), {"b": "string"})
with self.assertRaises(ValueError):
schema.add_table(table("y", db="db", catalog="c"), {"b": "string"})
schema.add_table(table("y"), {"b": "string"})
schema_with_y = {
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
}
self.assertEqual(schema.schema, schema_with_y)
new_schema = schema.copy()
new_schema.add_table(table("z"), {"c": "string"})
self.assertEqual(schema.schema, schema_with_y)
self.assertEqual(
new_schema.schema,
{
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
"z": {
"c": "string",
},
},
)
schema.add_table(table("m"), {"d": "string"})
schema.add_table(table("n"), {"e": "string"})
schema_with_m_n = {
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
"m": {
"d": "string",
},
"n": {
"e": "string",
},
}
self.assertEqual(schema.schema, schema_with_m_n)
new_schema = schema.copy()
new_schema.add_table(table("o"), {"f": "string"})
new_schema.add_table(table("p"), {"g": "string"})
self.assertEqual(schema.schema, schema_with_m_n)
self.assertEqual(
new_schema.schema,
{
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
"m": {
"d": "string",
},
"n": {
"e": "string",
},
"o": {
"f": "string",
},
"p": {
"g": "string",
},
},
)
schema = ensure_schema(
{
"db": {
"x": {
"a": "uint64",
}
}
}
)
self.assertEqual(schema.column_names(table("x", db="db")), ["a"])
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c"))
with self.assertRaises(ValueError):
schema.column_names(table("x"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db2"))
with self.assertRaises(ValueError):
schema.column_names(table("x2", db="db"))
with self.assertRaises(ValueError):
schema.add_table(table("y"), {"b": "string"})
with self.assertRaises(ValueError):
schema.add_table(table("y", db="db", catalog="c"), {"b": "string"})
schema.add_table(table("y", db="db"), {"b": "string"})
self.assertEqual(
schema.schema,
{
"db": {
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
}
},
)
schema = ensure_schema(
{
"c": {
"db": {
"x": {
"a": "uint64",
}
}
}
}
)
self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"])
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db"))
with self.assertRaises(ValueError):
schema.column_names(table("x"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c2"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db2"))
with self.assertRaises(ValueError):
schema.column_names(table("x2", db="db"))
with self.assertRaises(ValueError):
schema.add_table(table("x"), {"b": "string"})
with self.assertRaises(ValueError):
schema.add_table(table("x", db="db"), {"b": "string"})
schema.add_table(table("y", db="db", catalog="c"), {"a": "string", "b": "int"})
self.assertEqual(
schema.schema,
{
"c": {
"db": {
"x": {
"a": "uint64",
},
"y": {
"a": "string",
"b": "int",
},
}
}
},
)
schema.add_table(table("z", db="db2", catalog="c"), {"c": "string", "d": "int"})
self.assertEqual(
schema.schema,
{
"c": {
"db": {
"x": {
"a": "uint64",
},
"y": {
"a": "string",
"b": "int",
},
},
"db2": {
"z": {
"c": "string",
"d": "int",
}
},
}
},
)
schema.add_table(table("m", db="db2", catalog="c2"), {"e": "string", "f": "int"})
self.assertEqual(
schema.schema,
{
"c": {
"db": {
"x": {
"a": "uint64",
},
"y": {
"a": "string",
"b": "int",
},
},
"db2": {
"z": {
"c": "string",
"d": "int",
}
},
},
"c2": {
"db2": {
"m": {
"e": "string",
"f": "int",
}
}
},
},
)
schema = ensure_schema(
{
"x": {
"a": "uint64",
}
}
)
self.assertEqual(schema.column_names(table("x")), ["a"])
schema = MappingSchema()
schema.add_table(table("x"), {"a": "string"})
self.assertEqual(
schema.schema,
{
"x": {
"a": "string",
}
},
)
schema.add_table(table("y"), df_types.StructType([df_types.StructField("b", df_types.StringType())]))
self.assertEqual(
schema.schema,
{
"x": {
"a": "string",
},
"y": {
"b": "string",
},
},
)
def test_schema_add_table_with_and_without_mapping(self):
schema = MappingSchema()
schema.add_table("test")
self.assertEqual(schema.column_names("test"), [])
schema.add_table("test", {"x": "string"})
self.assertEqual(schema.column_names("test"), ["x"])
schema.add_table("test", {"x": "string", "y": "int"})
self.assertEqual(schema.column_names("test"), ["x", "y"])
schema.add_table("test")
self.assertEqual(schema.column_names("test"), ["x", "y"])