1
0
Fork 0

Merging upstream version 10.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:53:05 +01:00
parent 528822bfd4
commit b7d21c45b7
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
98 changed files with 4080 additions and 1666 deletions

View file

@ -1,9 +1,9 @@
import sys
import typing as t
import unittest
import warnings
import sqlglot
from sqlglot.helper import PYTHON_VERSION
from tests.helpers import SKIP_INTEGRATION
if t.TYPE_CHECKING:
@ -11,7 +11,8 @@ if t.TYPE_CHECKING:
@unittest.skipIf(
SKIP_INTEGRATION or sys.version_info[:2] > (3, 10), "Skipping Integration Tests since `SKIP_INTEGRATION` is set"
SKIP_INTEGRATION or PYTHON_VERSION > (3, 10),
"Skipping Integration Tests since `SKIP_INTEGRATION` is set",
)
class DataFrameValidator(unittest.TestCase):
spark = None
@ -36,7 +37,12 @@ class DataFrameValidator(unittest.TestCase):
# This is for test `test_branching_root_dataframes`
config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")])
cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate()
cls.spark = (
SparkSession.builder.master("local[*]")
.appName("Unit-tests")
.config(conf=config)
.getOrCreate()
)
cls.spark.sparkContext.setLogLevel("ERROR")
cls.sqlglot = SqlglotSparkSession()
cls.spark_employee_schema = types.StructType(
@ -50,7 +56,9 @@ class DataFrameValidator(unittest.TestCase):
)
cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType(
[
sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField(
"employee_id", sqlglotSparkTypes.IntegerType(), False
),
sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False),
@ -64,8 +72,12 @@ class DataFrameValidator(unittest.TestCase):
(4, "Claire", "Littleton", 27, 2),
(5, "Hugo", "Reyes", 29, 100),
]
cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema)
cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema)
cls.df_employee = cls.spark.createDataFrame(
data=employee_data, schema=cls.spark_employee_schema
)
cls.dfs_employee = cls.sqlglot.createDataFrame(
data=employee_data, schema=cls.sqlglot_employee_schema
)
cls.df_employee.createOrReplaceTempView("employee")
cls.spark_store_schema = types.StructType(
@ -80,7 +92,9 @@ class DataFrameValidator(unittest.TestCase):
[
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField(
"district_id", sqlglotSparkTypes.IntegerType(), False
),
sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False),
]
)
@ -89,7 +103,9 @@ class DataFrameValidator(unittest.TestCase):
(2, "Arrow", 2, 2000),
]
cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema)
cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema)
cls.dfs_store = cls.sqlglot.createDataFrame(
data=store_data, schema=cls.sqlglot_store_schema
)
cls.df_store.createOrReplaceTempView("store")
cls.spark_district_schema = types.StructType(
@ -101,17 +117,27 @@ class DataFrameValidator(unittest.TestCase):
)
cls.sqlglot_district_schema = sqlglotSparkTypes.StructType(
[
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField(
"district_id", sqlglotSparkTypes.IntegerType(), False
),
sqlglotSparkTypes.StructField(
"district_name", sqlglotSparkTypes.StringType(), False
),
sqlglotSparkTypes.StructField(
"manager_name", sqlglotSparkTypes.StringType(), False
),
]
)
district_data = [
(1, "Temple", "Dogen"),
(2, "Lighthouse", "Jacob"),
]
cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema)
cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema)
cls.df_district = cls.spark.createDataFrame(
data=district_data, schema=cls.spark_district_schema
)
cls.dfs_district = cls.sqlglot.createDataFrame(
data=district_data, schema=cls.sqlglot_district_schema
)
cls.df_district.createOrReplaceTempView("district")
sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema)
sqlglot.schema.add_table("store", cls.sqlglot_store_schema)

View file

@ -41,22 +41,32 @@ class TestDataframeFunc(DataFrameValidator):
def test_alias_with_select(self):
df_employee = self.df_spark_employee.alias("df_employee").select(
self.df_spark_employee["employee_id"], F.col("df_employee.fname"), self.df_spark_employee.lname
self.df_spark_employee["employee_id"],
F.col("df_employee.fname"),
self.df_spark_employee.lname,
)
dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select(
self.df_sqlglot_employee["employee_id"], SF.col("dfs_employee.fname"), self.df_sqlglot_employee.lname
self.df_sqlglot_employee["employee_id"],
SF.col("dfs_employee.fname"),
self.df_sqlglot_employee.lname,
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_case_when_otherwise(self):
df = self.df_spark_employee.select(
F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60"))
F.when(
(F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
F.lit("between 40 and 60"),
)
.when(F.col("age") < F.lit(40), "less than 40")
.otherwise("greater than 60")
)
dfs = self.df_sqlglot_employee.select(
SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60"))
SF.when(
(SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
SF.lit("between 40 and 60"),
)
.when(SF.col("age") < SF.lit(40), "less than 40")
.otherwise("greater than 60")
)
@ -65,15 +75,17 @@ class TestDataframeFunc(DataFrameValidator):
def test_case_when_no_otherwise(self):
df = self.df_spark_employee.select(
F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")).when(
F.col("age") < F.lit(40), "less than 40"
)
F.when(
(F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
F.lit("between 40 and 60"),
).when(F.col("age") < F.lit(40), "less than 40")
)
dfs = self.df_sqlglot_employee.select(
SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")).when(
SF.col("age") < SF.lit(40), "less than 40"
)
SF.when(
(SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
SF.lit("between 40 and 60"),
).when(SF.col("age") < SF.lit(40), "less than 40")
)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@ -84,7 +96,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_where_clause_multiple_and(self):
df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack")))
df_employee = self.df_spark_employee.where(
(F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))
)
dfs_employee = self.df_sqlglot_employee.where(
(SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack"))
)
@ -106,7 +120,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_where_clause_multiple_or(self):
df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate")))
df_employee = self.df_spark_employee.where(
(F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))
)
dfs_employee = self.df_sqlglot_employee.where(
(SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate"))
)
@ -172,28 +188,43 @@ class TestDataframeFunc(DataFrameValidator):
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37))
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] % F.lit(5) == F.lit(0))
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0))
df_employee = self.df_spark_employee.where(
self.df_spark_employee["age"] % F.lit(5) == F.lit(0)
)
dfs_employee = self.df_sqlglot_employee.where(
self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] + F.lit(5) > F.lit(28))
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28))
df_employee = self.df_spark_employee.where(
self.df_spark_employee["age"] + F.lit(5) > F.lit(28)
)
dfs_employee = self.df_sqlglot_employee.where(
self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] - F.lit(5) > F.lit(28))
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28))
df_employee = self.df_spark_employee.where(
self.df_spark_employee["age"] - F.lit(5) > F.lit(28)
)
dfs_employee = self.df_sqlglot_employee.where(
self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
df_employee = self.df_spark_employee.where(
self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2)
)
dfs_employee = self.df_sqlglot_employee.where(
self.df_sqlglot_employee["age"] * SF.lit(0.5) == self.df_sqlglot_employee["age"] / SF.lit(2)
self.df_sqlglot_employee["age"] * SF.lit(0.5)
== self.df_sqlglot_employee["age"] / SF.lit(2)
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_join_inner(self):
df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="inner").select(
df_joined = self.df_spark_employee.join(
self.df_spark_store, on=["store_id"], how="inner"
).select(
self.df_spark_employee.employee_id,
self.df_spark_employee["fname"],
F.col("lname"),
@ -202,7 +233,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store.store_name,
self.df_spark_store["num_sales"],
)
dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="inner").select(
dfs_joined = self.df_sqlglot_employee.join(
self.df_sqlglot_store, on=["store_id"], how="inner"
).select(
self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"],
SF.col("lname"),
@ -214,17 +247,27 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_inner_no_select(self):
df_joined = self.df_spark_employee.select(F.col("store_id"), F.col("fname"), F.col("lname")).join(
self.df_spark_store.select(F.col("store_id"), F.col("store_name")), on=["store_id"], how="inner"
df_joined = self.df_spark_employee.select(
F.col("store_id"), F.col("fname"), F.col("lname")
).join(
self.df_spark_store.select(F.col("store_id"), F.col("store_name")),
on=["store_id"],
how="inner",
)
dfs_joined = self.df_sqlglot_employee.select(SF.col("store_id"), SF.col("fname"), SF.col("lname")).join(
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), on=["store_id"], how="inner"
dfs_joined = self.df_sqlglot_employee.select(
SF.col("store_id"), SF.col("fname"), SF.col("lname")
).join(
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")),
on=["store_id"],
how="inner",
)
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_inner_equality_single(self):
df_joined = self.df_spark_employee.join(
self.df_spark_store, on=self.df_spark_employee.store_id == self.df_spark_store.store_id, how="inner"
self.df_spark_store,
on=self.df_spark_employee.store_id == self.df_spark_store.store_id,
how="inner",
).select(
self.df_spark_employee.employee_id,
self.df_spark_employee["fname"],
@ -235,7 +278,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store["num_sales"],
)
dfs_joined = self.df_sqlglot_employee.join(
self.df_sqlglot_store, on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, how="inner"
self.df_sqlglot_store,
on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
how="inner",
).select(
self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"],
@ -343,7 +388,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_full_outer(self):
df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="full_outer").select(
df_joined = self.df_spark_employee.join(
self.df_spark_store, on=["store_id"], how="full_outer"
).select(
self.df_spark_employee.employee_id,
self.df_spark_employee["fname"],
F.col("lname"),
@ -352,7 +399,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store.store_name,
self.df_spark_store["num_sales"],
)
dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="full_outer").select(
dfs_joined = self.df_sqlglot_employee.join(
self.df_sqlglot_store, on=["store_id"], how="full_outer"
).select(
self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"],
SF.col("lname"),
@ -365,7 +414,9 @@ class TestDataframeFunc(DataFrameValidator):
def test_triple_join(self):
df = (
self.df_employee.join(self.df_store, on=self.df_employee.employee_id == self.df_store.store_id)
self.df_employee.join(
self.df_store, on=self.df_employee.employee_id == self.df_store.store_id
)
.join(self.df_district, on=self.df_store.store_id == self.df_district.district_id)
.select(
self.df_employee.employee_id,
@ -377,7 +428,9 @@ class TestDataframeFunc(DataFrameValidator):
)
)
dfs = (
self.dfs_employee.join(self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id)
self.dfs_employee.join(
self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id
)
.join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id)
.select(
self.dfs_employee.employee_id,
@ -391,13 +444,13 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_join_select_and_select_start(self):
df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")).join(
self.df_spark_store, "store_id", "inner"
)
df = self.df_spark_employee.select(
F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")
).join(self.df_spark_store, "store_id", "inner")
dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")).join(
self.df_sqlglot_store, "store_id", "inner"
)
dfs = self.df_sqlglot_employee.select(
SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
).join(self.df_sqlglot_store, "store_id", "inner")
self.compare_spark_with_sqlglot(df, dfs)
@ -485,13 +538,17 @@ class TestDataframeFunc(DataFrameValidator):
dfs_unioned = (
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"))
.unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")))
.unionAll(self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name")))
.unionAll(
self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))
)
)
self.compare_spark_with_sqlglot(df_unioned, dfs_unioned)
def test_union_by_name(self):
df = self.df_spark_employee.select(F.col("employee_id"), F.col("fname"), F.col("lname")).unionByName(
df = self.df_spark_employee.select(
F.col("employee_id"), F.col("fname"), F.col("lname")
).unionByName(
self.df_spark_store.select(
F.col("store_name").alias("lname"),
F.col("store_id").alias("employee_id"),
@ -499,7 +556,9 @@ class TestDataframeFunc(DataFrameValidator):
)
)
dfs = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"), SF.col("lname")).unionByName(
dfs = self.df_sqlglot_employee.select(
SF.col("employee_id"), SF.col("fname"), SF.col("lname")
).unionByName(
self.df_sqlglot_store.select(
SF.col("store_name").alias("lname"),
SF.col("store_id").alias("employee_id"),
@ -537,10 +596,16 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_order_by_default(self):
df = self.df_spark_store.groupBy(F.col("district_id")).agg(F.min("num_sales")).orderBy(F.col("district_id"))
df = (
self.df_spark_store.groupBy(F.col("district_id"))
.agg(F.min("num_sales"))
.orderBy(F.col("district_id"))
)
dfs = (
self.df_sqlglot_store.groupBy(SF.col("district_id")).agg(SF.min("num_sales")).orderBy(SF.col("district_id"))
self.df_sqlglot_store.groupBy(SF.col("district_id"))
.agg(SF.min("num_sales"))
.orderBy(SF.col("district_id"))
)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@ -594,13 +659,17 @@ class TestDataframeFunc(DataFrameValidator):
df = (
self.df_spark_store.groupBy(F.col("district_id"))
.agg(F.min("num_sales").alias("total_sales"))
.orderBy(F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last())
.orderBy(
F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()
)
)
dfs = (
self.df_sqlglot_store.groupBy(SF.col("district_id"))
.agg(SF.min("num_sales").alias("total_sales"))
.orderBy(SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last())
.orderBy(
SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()
)
)
self.compare_spark_with_sqlglot(df, dfs)
@ -609,81 +678,87 @@ class TestDataframeFunc(DataFrameValidator):
df = (
self.df_spark_store.groupBy(F.col("district_id"))
.agg(F.min("num_sales").alias("total_sales"))
.orderBy(F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first())
.orderBy(
F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()
)
)
dfs = (
self.df_sqlglot_store.groupBy(SF.col("district_id"))
.agg(SF.min("num_sales").alias("total_sales"))
.orderBy(SF.when(SF.col("district_id") == SF.lit(1), SF.col("district_id")).desc_nulls_first())
.orderBy(
SF.when(
SF.col("district_id") == SF.lit(1), SF.col("district_id")
).desc_nulls_first()
)
)
self.compare_spark_with_sqlglot(df, dfs)
def test_intersect(self):
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
)
df_employee_duplicate = self.df_spark_employee.select(
F.col("employee_id"), F.col("store_id")
).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
)
df_store_duplicate = self.df_spark_store.select(
F.col("store_id"), F.col("district_id")
).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.intersect(df_store_duplicate)
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
)
dfs_employee_duplicate = self.df_sqlglot_employee.select(
SF.col("employee_id"), SF.col("store_id")
).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
)
dfs_store_duplicate = self.df_sqlglot_store.select(
SF.col("store_id"), SF.col("district_id")
).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate)
self.compare_spark_with_sqlglot(df, dfs)
def test_intersect_all(self):
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
)
df_employee_duplicate = self.df_spark_employee.select(
F.col("employee_id"), F.col("store_id")
).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
)
df_store_duplicate = self.df_spark_store.select(
F.col("store_id"), F.col("district_id")
).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.intersectAll(df_store_duplicate)
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
)
dfs_employee_duplicate = self.df_sqlglot_employee.select(
SF.col("employee_id"), SF.col("store_id")
).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
)
dfs_store_duplicate = self.df_sqlglot_store.select(
SF.col("store_id"), SF.col("district_id")
).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate)
self.compare_spark_with_sqlglot(df, dfs)
def test_except_all(self):
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
)
df_employee_duplicate = self.df_spark_employee.select(
F.col("employee_id"), F.col("store_id")
).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
)
df_store_duplicate = self.df_spark_store.select(
F.col("store_id"), F.col("district_id")
).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.exceptAll(df_store_duplicate)
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
)
dfs_employee_duplicate = self.df_sqlglot_employee.select(
SF.col("employee_id"), SF.col("store_id")
).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
)
dfs_store_duplicate = self.df_sqlglot_store.select(
SF.col("store_id"), SF.col("district_id")
).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate)
@ -721,7 +796,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_drop_na_default(self):
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).dropna()
df = self.df_spark_employee.select(
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
).dropna()
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@ -746,7 +823,9 @@ class TestDataframeFunc(DataFrameValidator):
).dropna(how="any", thresh=2)
dfs = self.df_sqlglot_employee.select(
SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
SF.lit(None),
SF.lit(1),
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
).dropna(how="any", thresh=2)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@ -757,13 +836,17 @@ class TestDataframeFunc(DataFrameValidator):
).dropna(thresh=1, subset="the_age")
dfs = self.df_sqlglot_employee.select(
SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
SF.lit(None),
SF.lit(1),
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
).dropna(thresh=1, subset="the_age")
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_dropna_na_function(self):
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.drop()
df = self.df_spark_employee.select(
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
).na.drop()
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@ -772,7 +855,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_fillna_default(self):
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).fillna(100)
df = self.df_spark_employee.select(
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
).fillna(100)
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@ -798,7 +883,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_fillna_na_func(self):
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.fill(100)
df = self.df_spark_employee.select(
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
).na.fill(100)
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@ -807,7 +894,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_replace_basic(self):
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(to_replace=37, value=100)
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
to_replace=37, value=100
)
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
to_replace=37, value=100
@ -827,9 +916,13 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_replace_mapping(self):
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace({37: 100})
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
{37: 100}
)
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace({37: 100})
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
{37: 100}
)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@ -849,9 +942,9 @@ class TestDataframeFunc(DataFrameValidator):
to_replace=37, value=100
)
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).na.replace(
to_replace=37, value=100
)
dfs = self.df_sqlglot_employee.select(
SF.col("age"), SF.lit(37).alias("test_col")
).na.replace(to_replace=37, value=100)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@ -881,16 +974,18 @@ class TestDataframeFunc(DataFrameValidator):
"first_name", "first_name_again"
)
dfs = self.df_sqlglot_employee.select(SF.col("fname").alias("first_name")).withColumnRenamed(
"first_name", "first_name_again"
)
dfs = self.df_sqlglot_employee.select(
SF.col("fname").alias("first_name")
).withColumnRenamed("first_name", "first_name_again")
self.compare_spark_with_sqlglot(df, dfs)
def test_drop_column_single(self):
df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age")
dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop("age")
dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop(
"age"
)
self.compare_spark_with_sqlglot(df, dfs)
@ -906,7 +1001,9 @@ class TestDataframeFunc(DataFrameValidator):
df_sqlglot_employee_cols = self.df_sqlglot_employee.select(
SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
)
df_sqlglot_store_cols = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name"))
df_sqlglot_store_cols = self.df_sqlglot_store.select(
SF.col("store_id"), SF.col("store_name")
)
dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop(
df_sqlglot_employee_cols.age,
)

View file

@ -23,6 +23,14 @@ class TestSessionFunc(DataFrameValidator):
ON
e.store_id = s.store_id
"""
df = self.spark.sql(query).groupBy(F.col("store_id")).agg(F.countDistinct(F.col("employee_id")))
dfs = self.sqlglot.sql(query).groupBy(SF.col("store_id")).agg(SF.countDistinct(SF.col("employee_id")))
df = (
self.spark.sql(query)
.groupBy(F.col("store_id"))
.agg(F.countDistinct(F.col("employee_id")))
)
dfs = (
self.sqlglot.sql(query)
.groupBy(SF.col("store_id"))
.agg(SF.countDistinct(SF.col("employee_id")))
)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)

View file

@ -25,11 +25,17 @@ class DataFrameSQLValidator(unittest.TestCase):
(4, "Claire", "Littleton", 27, 2),
(5, "Hugo", "Reyes", 29, 100),
]
self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema)
self.df_employee = self.spark.createDataFrame(
data=employee_data, schema=self.employee_schema
)
def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False):
def compare_sql(
self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False
):
actual_sqls = df.sql(pretty=pretty)
expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements
expected_statements = (
[expected_statements] if isinstance(expected_statements, str) else expected_statements
)
self.assertEqual(len(expected_statements), len(actual_sqls))
for expected, actual in zip(expected_statements, actual_sqls):
self.assertEqual(expected, actual)

View file

@ -26,12 +26,14 @@ class TestDataframeColumn(unittest.TestCase):
def test_and(self):
self.assertEqual(
"cola = colb AND colc = cold", ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql()
"cola = colb AND colc = cold",
((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql(),
)
def test_or(self):
self.assertEqual(
"cola = colb OR colc = cold", ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql()
"cola = colb OR colc = cold",
((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql(),
)
def test_mod(self):
@ -112,7 +114,9 @@ class TestDataframeColumn(unittest.TestCase):
def test_when_otherwise(self):
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql())
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql())
self.assertEqual(
"CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql()
)
self.assertEqual(
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
(F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(),
@ -148,7 +152,9 @@ class TestDataframeColumn(unittest.TestCase):
self.assertEqual(
"cola BETWEEN CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP) "
"AND CAST('2022-03-01 01:01:01.000000' AS TIMESTAMP)",
F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(),
F.col("cola")
.between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1))
.sql(),
)
def test_over(self):

View file

@ -9,7 +9,9 @@ class TestDataframe(DataFrameSQLValidator):
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
def test_columns(self):
self.assertEqual(["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns)
self.assertEqual(
["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns
)
def test_cache(self):
df = self.df_employee.select("fname").cache()

View file

@ -925,12 +925,17 @@ class TestFunctions(unittest.TestCase):
col = SF.window(SF.col("cola"), "10 minutes")
self.assertEqual("WINDOW(cola, '10 minutes')", col.sql())
col_all_values = SF.window("cola", "2 minutes 30 seconds", "30 seconds", "15 seconds")
self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql())
self.assertEqual(
"WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql()
)
col_no_start_time = SF.window("cola", "2 minutes 30 seconds", "30 seconds")
self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql())
self.assertEqual(
"WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql()
)
col_no_slide = SF.window("cola", "2 minutes 30 seconds", startTime="15 seconds")
self.assertEqual(
"WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')", col_no_slide.sql()
"WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')",
col_no_slide.sql(),
)
def test_session_window(self):
@ -1359,9 +1364,13 @@ class TestFunctions(unittest.TestCase):
def test_from_json(self):
col_str = SF.from_json("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
self.assertEqual(
"FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
)
col = SF.from_json(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
self.assertEqual(
"FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()
)
col_no_option = SF.from_json("cola", "cola INT")
self.assertEqual("FROM_JSON(cola, 'cola INT')", col_no_option.sql())
@ -1375,7 +1384,9 @@ class TestFunctions(unittest.TestCase):
def test_schema_of_json(self):
col_str = SF.schema_of_json("cola", dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
self.assertEqual(
"SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
)
col = SF.schema_of_json(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
col_no_option = SF.schema_of_json("cola")
@ -1429,7 +1440,10 @@ class TestFunctions(unittest.TestCase):
col = SF.array_sort(SF.col("cola"))
self.assertEqual("ARRAY_SORT(cola)", col.sql())
col_comparator = SF.array_sort(
"cola", lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(SF.length(y) - SF.length(x))
"cola",
lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(
SF.length(y) - SF.length(x)
),
)
self.assertEqual(
"ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)",
@ -1504,9 +1518,13 @@ class TestFunctions(unittest.TestCase):
def test_from_csv(self):
col_str = SF.from_csv("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
self.assertEqual(
"FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
)
col = SF.from_csv(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
self.assertEqual(
"FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()
)
col_no_option = SF.from_csv("cola", "cola INT")
self.assertEqual("FROM_CSV(cola, 'cola INT')", col_no_option.sql())
@ -1535,7 +1553,9 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql())
col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count)
self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql())
self.assertEqual(
"TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql()
)
def test_exists(self):
col_str = SF.exists("cola", lambda x: x % 2 == 0)
@ -1558,10 +1578,13 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql())
col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i))
self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql())
col_custom_names = SF.filter("cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count))
col_custom_names = SF.filter(
"cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count)
)
self.assertEqual(
"FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql()
"FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)",
col_custom_names.sql(),
)
def test_zip_with(self):
@ -1570,7 +1593,9 @@ class TestFunctions(unittest.TestCase):
col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y))
self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql())
col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r))
self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql())
self.assertEqual(
"ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql()
)
def test_transform_keys(self):
col_str = SF.transform_keys("cola", lambda k, v: SF.upper(k))
@ -1586,7 +1611,9 @@ class TestFunctions(unittest.TestCase):
col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v))
self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql())
col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value))
self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql())
self.assertEqual(
"TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql()
)
def test_map_filter(self):
col_str = SF.map_filter("cola", lambda k, v: k > v)

View file

@ -21,9 +21,7 @@ class TestDataframeSession(DataFrameSQLValidator):
def test_cdf_no_schema(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
expected = (
"SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
)
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
self.compare_sql(df, expected)
def test_cdf_row_mixed_primitives(self):
@ -77,7 +75,8 @@ class TestDataframeSession(DataFrameSQLValidator):
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query)
self.assertIn(
"SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", df.sql(pretty=False)
"SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`",
df.sql(pretty=False),
)
@mock.patch("sqlglot.schema", MappingSchema())
@ -104,9 +103,7 @@ class TestDataframeSession(DataFrameSQLValidator):
query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query)
expected = (
"INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
)
expected = "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
self.compare_sql(df, expected)
def test_session_create_builder_patterns(self):

View file

@ -53,7 +53,10 @@ class TestDataframeTypes(unittest.TestCase):
self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString())
def test_map(self):
self.assertEqual("map<int, string>", types.MapType(types.IntegerType(), types.StringType()).simpleString())
self.assertEqual(
"map<int, string>",
types.MapType(types.IntegerType(), types.StringType()).simpleString(),
)
def test_struct_field(self):
self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString())

View file

@ -39,22 +39,38 @@ class TestDataframeWindow(unittest.TestCase):
def test_window_rows_unbounded(self):
rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
self.assertEqual("OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", rows_between_unbounded_start.sql())
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
self.assertEqual("OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_end.sql())
rows_between_unbounded_both = Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
self.assertEqual(
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_both.sql()
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
rows_between_unbounded_start.sql(),
)
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
self.assertEqual(
"OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
rows_between_unbounded_end.sql(),
)
rows_between_unbounded_both = Window.rowsBetween(
Window.unboundedPreceding, Window.unboundedFollowing
)
self.assertEqual(
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
rows_between_unbounded_both.sql(),
)
def test_window_range_unbounded(self):
range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
self.assertEqual(
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", range_between_unbounded_start.sql()
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
range_between_unbounded_start.sql(),
)
range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
self.assertEqual("OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_end.sql())
range_between_unbounded_both = Window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)
self.assertEqual(
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_both.sql()
"OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
range_between_unbounded_end.sql(),
)
range_between_unbounded_both = Window.rangeBetween(
Window.unboundedPreceding, Window.unboundedFollowing
)
self.assertEqual(
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
range_between_unbounded_both.sql(),
)

View file

@ -157,6 +157,14 @@ class TestBigQuery(Validator):
},
)
self.validate_all(
"DIV(x, y)",
write={
"bigquery": "DIV(x, y)",
"duckdb": "CAST(x / y AS INT)",
},
)
self.validate_identity(
"SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)"
)
@ -284,4 +292,6 @@ class TestBigQuery(Validator):
"CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'"
)
self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)")
self.validate_identity("CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t")
self.validate_identity(
"CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t"
)

View file

@ -18,7 +18,6 @@ class TestClickhouse(Validator):
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
},
)
self.validate_all(
"CAST(1 AS NULLABLE(Int64))",
write={
@ -31,3 +30,7 @@ class TestClickhouse(Validator):
"clickhouse": "CAST(1 AS Nullable(DateTime64(6, 'UTC')))",
},
)
self.validate_all(
"SELECT x #! comment",
write={"": "SELECT x /* comment */"},
)

View file

@ -22,7 +22,8 @@ class TestDatabricks(Validator):
},
)
self.validate_all(
"SELECT DATEDIFF('end', 'start')", write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"}
"SELECT DATEDIFF('end', 'start')",
write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"},
)
self.validate_all(
"SELECT DATE_ADD('2020-01-01', 1)",

View file

@ -1,20 +1,18 @@
import unittest
from sqlglot import (
Dialect,
Dialects,
ErrorLevel,
UnsupportedError,
parse_one,
transpile,
)
from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one
class Validator(unittest.TestCase):
dialect = None
def validate_identity(self, sql):
self.assertEqual(transpile(sql, read=self.dialect, write=self.dialect)[0], sql)
def parse_one(self, sql):
return parse_one(sql, read=self.dialect)
def validate_identity(self, sql, write_sql=None):
expression = self.parse_one(sql)
self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect))
return expression
def validate_all(self, sql, read=None, write=None, pretty=False):
"""
@ -28,12 +26,14 @@ class Validator(unittest.TestCase):
read (dict): Mapping of dialect -> SQL
write (dict): Mapping of dialect -> SQL
"""
expression = parse_one(sql, read=self.dialect)
expression = self.parse_one(sql)
for read_dialect, read_sql in (read or {}).items():
with self.subTest(f"{read_dialect} -> {sql}"):
self.assertEqual(
parse_one(read_sql, read_dialect).sql(self.dialect, unsupported_level=ErrorLevel.IGNORE),
parse_one(read_sql, read_dialect).sql(
self.dialect, unsupported_level=ErrorLevel.IGNORE, pretty=pretty
),
sql,
)
@ -83,10 +83,6 @@ class TestDialect(Validator):
)
self.validate_all(
"CAST(a AS BINARY(4))",
read={
"presto": "CAST(a AS VARBINARY(4))",
"sqlite": "CAST(a AS VARBINARY(4))",
},
write={
"bigquery": "CAST(a AS BINARY(4))",
"clickhouse": "CAST(a AS BINARY(4))",
@ -103,6 +99,24 @@ class TestDialect(Validator):
"starrocks": "CAST(a AS BINARY(4))",
},
)
self.validate_all(
"CAST(a AS VARBINARY(4))",
write={
"bigquery": "CAST(a AS VARBINARY(4))",
"clickhouse": "CAST(a AS VARBINARY(4))",
"duckdb": "CAST(a AS VARBINARY(4))",
"mysql": "CAST(a AS VARBINARY(4))",
"hive": "CAST(a AS BINARY(4))",
"oracle": "CAST(a AS BLOB(4))",
"postgres": "CAST(a AS BYTEA(4))",
"presto": "CAST(a AS VARBINARY(4))",
"redshift": "CAST(a AS VARBYTE(4))",
"snowflake": "CAST(a AS VARBINARY(4))",
"sqlite": "CAST(a AS BLOB(4))",
"spark": "CAST(a AS BINARY(4))",
"starrocks": "CAST(a AS VARBINARY(4))",
},
)
self.validate_all(
"CAST(MAP('a', '1') AS MAP(TEXT, TEXT))",
write={
@ -472,45 +486,57 @@ class TestDialect(Validator):
},
)
self.validate_all(
"DATE_TRUNC(x, 'day')",
"DATE_TRUNC('day', x)",
write={
"mysql": "DATE(x)",
"starrocks": "DATE(x)",
},
)
self.validate_all(
"DATE_TRUNC(x, 'week')",
"DATE_TRUNC('week', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')",
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')",
},
)
self.validate_all(
"DATE_TRUNC(x, 'month')",
"DATE_TRUNC('month', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')",
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')",
},
)
self.validate_all(
"DATE_TRUNC(x, 'quarter')",
"DATE_TRUNC('quarter', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')",
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')",
},
)
self.validate_all(
"DATE_TRUNC(x, 'year')",
"DATE_TRUNC('year', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
},
)
self.validate_all(
"DATE_TRUNC(x, 'millenium')",
"DATE_TRUNC('millenium', x)",
write={
"mysql": UnsupportedError,
"starrocks": UnsupportedError,
},
)
self.validate_all(
"DATE_TRUNC('year', x)",
read={
"starrocks": "DATE_TRUNC('year', x)",
},
write={
"starrocks": "DATE_TRUNC('year', x)",
},
)
self.validate_all(
"DATE_TRUNC(x, year)",
read={
"bigquery": "DATE_TRUNC(x, year)",
},
write={
"bigquery": "DATE_TRUNC(x, year)",
},
)
self.validate_all(
@ -564,6 +590,22 @@ class TestDialect(Validator):
"spark": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
},
)
self.validate_all(
"TIMESTAMP '2022-01-01'",
write={
"mysql": "CAST('2022-01-01' AS TIMESTAMP)",
"starrocks": "CAST('2022-01-01' AS DATETIME)",
"hive": "CAST('2022-01-01' AS TIMESTAMP)",
},
)
self.validate_all(
"TIMESTAMP('2022-01-01')",
write={
"mysql": "TIMESTAMP('2022-01-01')",
"starrocks": "TIMESTAMP('2022-01-01')",
"hive": "TIMESTAMP('2022-01-01')",
},
)
for unit in ("DAY", "MONTH", "YEAR"):
self.validate_all(
@ -1002,7 +1044,10 @@ class TestDialect(Validator):
)
def test_limit(self):
self.validate_all("SELECT * FROM data LIMIT 10, 20", write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"})
self.validate_all(
"SELECT * FROM data LIMIT 10, 20",
write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"},
)
self.validate_all(
"SELECT x FROM y LIMIT 10",
write={
@ -1132,3 +1177,56 @@ class TestDialect(Validator):
"sqlite": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
},
)
def test_nullsafe_eq(self):
self.validate_all(
"SELECT a IS NOT DISTINCT FROM b",
read={
"mysql": "SELECT a <=> b",
"postgres": "SELECT a IS NOT DISTINCT FROM b",
},
write={
"mysql": "SELECT a <=> b",
"postgres": "SELECT a IS NOT DISTINCT FROM b",
},
)
def test_nullsafe_neq(self):
self.validate_all(
"SELECT a IS DISTINCT FROM b",
read={
"postgres": "SELECT a IS DISTINCT FROM b",
},
write={
"mysql": "SELECT NOT a <=> b",
"postgres": "SELECT a IS DISTINCT FROM b",
},
)
def test_hash_comments(self):
self.validate_all(
"SELECT 1 /* arbitrary content,,, until end-of-line */",
read={
"mysql": "SELECT 1 # arbitrary content,,, until end-of-line",
"bigquery": "SELECT 1 # arbitrary content,,, until end-of-line",
"clickhouse": "SELECT 1 #! arbitrary content,,, until end-of-line",
},
)
self.validate_all(
"""/* comment1 */
SELECT
x, -- comment2
y -- comment3""",
read={
"mysql": """SELECT # comment1
x, # comment2
y # comment3""",
"bigquery": """SELECT # comment1
x, # comment2
y # comment3""",
"clickhouse": """SELECT # comment1
x, # comment2
y # comment3""",
},
pretty=True,
)

View file

@ -1,3 +1,4 @@
from sqlglot import expressions as exp
from tests.dialects.test_dialect import Validator
@ -20,6 +21,52 @@ class TestMySQL(Validator):
self.validate_identity("SELECT TRIM(TRAILING 'bla' FROM ' XXX ')")
self.validate_identity("SELECT TRIM(BOTH 'bla' FROM ' XXX ')")
self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')")
self.validate_identity("@@GLOBAL.max_connections")
# SET Commands
self.validate_identity("SET @var_name = expr")
self.validate_identity("SET @name = 43")
self.validate_identity("SET @total_tax = (SELECT SUM(tax) FROM taxable_transactions)")
self.validate_identity("SET GLOBAL max_connections = 1000")
self.validate_identity("SET @@GLOBAL.max_connections = 1000")
self.validate_identity("SET SESSION sql_mode = 'TRADITIONAL'")
self.validate_identity("SET LOCAL sql_mode = 'TRADITIONAL'")
self.validate_identity("SET @@SESSION.sql_mode = 'TRADITIONAL'")
self.validate_identity("SET @@LOCAL.sql_mode = 'TRADITIONAL'")
self.validate_identity("SET @@sql_mode = 'TRADITIONAL'")
self.validate_identity("SET sql_mode = 'TRADITIONAL'")
self.validate_identity("SET PERSIST max_connections = 1000")
self.validate_identity("SET @@PERSIST.max_connections = 1000")
self.validate_identity("SET PERSIST_ONLY back_log = 100")
self.validate_identity("SET @@PERSIST_ONLY.back_log = 100")
self.validate_identity("SET @@SESSION.max_join_size = DEFAULT")
self.validate_identity("SET @@SESSION.max_join_size = @@GLOBAL.max_join_size")
self.validate_identity("SET @x = 1, SESSION sql_mode = ''")
self.validate_identity(
"SET GLOBAL sort_buffer_size = 1000000, SESSION sort_buffer_size = 1000000"
)
self.validate_identity(
"SET @@GLOBAL.sort_buffer_size = 1000000, @@LOCAL.sort_buffer_size = 1000000"
)
self.validate_identity("SET GLOBAL max_connections = 1000, sort_buffer_size = 1000000")
self.validate_identity("SET @@GLOBAL.sort_buffer_size = 50000, sort_buffer_size = 1000000")
self.validate_identity("SET CHARACTER SET 'utf8'")
self.validate_identity("SET CHARACTER SET utf8")
self.validate_identity("SET CHARACTER SET DEFAULT")
self.validate_identity("SET NAMES 'utf8'")
self.validate_identity("SET NAMES DEFAULT")
self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'")
self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci")
self.validate_identity("SET autocommit = ON")
def test_escape(self):
self.validate_all(
r"'a \' b '' '",
write={
"mysql": r"'a '' b '' '",
"spark": r"'a \' b \' '",
},
)
def test_introducers(self):
self.validate_all(
@ -115,14 +162,6 @@ class TestMySQL(Validator):
},
)
def test_hash_comments(self):
self.validate_all(
"SELECT 1 # arbitrary content,,, until end-of-line",
write={
"mysql": "SELECT 1",
},
)
def test_mysql(self):
self.validate_all(
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
@ -174,3 +213,242 @@ COMMENT='客户账户表'"""
},
pretty=True,
)
def test_show_simple(self):
for key, write_key in [
("BINARY LOGS", "BINARY LOGS"),
("MASTER LOGS", "BINARY LOGS"),
("STORAGE ENGINES", "ENGINES"),
("ENGINES", "ENGINES"),
("EVENTS", "EVENTS"),
("MASTER STATUS", "MASTER STATUS"),
("PLUGINS", "PLUGINS"),
("PRIVILEGES", "PRIVILEGES"),
("PROFILES", "PROFILES"),
("REPLICAS", "REPLICAS"),
("SLAVE HOSTS", "REPLICAS"),
]:
show = self.validate_identity(f"SHOW {key}", f"SHOW {write_key}")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, write_key)
def test_show_events(self):
for key in ["BINLOG", "RELAYLOG"]:
show = self.validate_identity(f"SHOW {key} EVENTS")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, f"{key} EVENTS")
show = self.validate_identity(f"SHOW {key} EVENTS IN 'log' FROM 1 LIMIT 2, 3")
self.assertEqual(show.text("log"), "log")
self.assertEqual(show.text("position"), "1")
self.assertEqual(show.text("limit"), "3")
self.assertEqual(show.text("offset"), "2")
show = self.validate_identity(f"SHOW {key} EVENTS LIMIT 1")
self.assertEqual(show.text("limit"), "1")
self.assertIsNone(show.args.get("offset"))
def test_show_like_or_where(self):
for key, write_key in [
("CHARSET", "CHARACTER SET"),
("CHARACTER SET", "CHARACTER SET"),
("COLLATION", "COLLATION"),
("DATABASES", "DATABASES"),
("FUNCTION STATUS", "FUNCTION STATUS"),
("PROCEDURE STATUS", "PROCEDURE STATUS"),
("GLOBAL STATUS", "GLOBAL STATUS"),
("SESSION STATUS", "STATUS"),
("STATUS", "STATUS"),
("GLOBAL VARIABLES", "GLOBAL VARIABLES"),
("SESSION VARIABLES", "VARIABLES"),
("VARIABLES", "VARIABLES"),
]:
expected_name = write_key.strip("GLOBAL").strip()
template = "SHOW {}"
show = self.validate_identity(template.format(key), template.format(write_key))
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, expected_name)
template = "SHOW {} LIKE '%foo%'"
show = self.validate_identity(template.format(key), template.format(write_key))
self.assertIsInstance(show, exp.Show)
self.assertIsInstance(show.args["like"], exp.Literal)
self.assertEqual(show.text("like"), "%foo%")
template = "SHOW {} WHERE Column_name LIKE '%foo%'"
show = self.validate_identity(template.format(key), template.format(write_key))
self.assertIsInstance(show, exp.Show)
self.assertIsInstance(show.args["where"], exp.Where)
self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'")
def test_show_columns(self):
show = self.validate_identity("SHOW COLUMNS FROM tbl_name")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "COLUMNS")
self.assertEqual(show.text("target"), "tbl_name")
self.assertFalse(show.args["full"])
show = self.validate_identity("SHOW FULL COLUMNS FROM tbl_name FROM db_name LIKE '%foo%'")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.text("target"), "tbl_name")
self.assertTrue(show.args["full"])
self.assertEqual(show.text("db"), "db_name")
self.assertIsInstance(show.args["like"], exp.Literal)
self.assertEqual(show.text("like"), "%foo%")
def test_show_name(self):
for key in [
"CREATE DATABASE",
"CREATE EVENT",
"CREATE FUNCTION",
"CREATE PROCEDURE",
"CREATE TABLE",
"CREATE TRIGGER",
"CREATE VIEW",
"FUNCTION CODE",
"PROCEDURE CODE",
]:
show = self.validate_identity(f"SHOW {key} foo")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, key)
self.assertEqual(show.text("target"), "foo")
def test_show_grants(self):
show = self.validate_identity(f"SHOW GRANTS FOR foo")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "GRANTS")
self.assertEqual(show.text("target"), "foo")
def test_show_engine(self):
show = self.validate_identity("SHOW ENGINE foo STATUS")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "ENGINE")
self.assertEqual(show.text("target"), "foo")
self.assertFalse(show.args["mutex"])
show = self.validate_identity("SHOW ENGINE foo MUTEX")
self.assertEqual(show.name, "ENGINE")
self.assertEqual(show.text("target"), "foo")
self.assertTrue(show.args["mutex"])
def test_show_errors(self):
for key in ["ERRORS", "WARNINGS"]:
show = self.validate_identity(f"SHOW {key}")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, key)
show = self.validate_identity(f"SHOW {key} LIMIT 2, 3")
self.assertEqual(show.text("limit"), "3")
self.assertEqual(show.text("offset"), "2")
def test_show_index(self):
show = self.validate_identity("SHOW INDEX FROM foo")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "INDEX")
self.assertEqual(show.text("target"), "foo")
show = self.validate_identity("SHOW INDEX FROM foo FROM bar")
self.assertEqual(show.text("db"), "bar")
def test_show_db_like_or_where_sql(self):
for key in [
"OPEN TABLES",
"TABLE STATUS",
"TRIGGERS",
]:
show = self.validate_identity(f"SHOW {key}")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, key)
show = self.validate_identity(f"SHOW {key} FROM db_name")
self.assertEqual(show.name, key)
self.assertEqual(show.text("db"), "db_name")
show = self.validate_identity(f"SHOW {key} LIKE '%foo%'")
self.assertEqual(show.name, key)
self.assertIsInstance(show.args["like"], exp.Literal)
self.assertEqual(show.text("like"), "%foo%")
show = self.validate_identity(f"SHOW {key} WHERE Column_name LIKE '%foo%'")
self.assertEqual(show.name, key)
self.assertIsInstance(show.args["where"], exp.Where)
self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'")
def test_show_processlist(self):
show = self.validate_identity("SHOW PROCESSLIST")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "PROCESSLIST")
self.assertFalse(show.args["full"])
show = self.validate_identity("SHOW FULL PROCESSLIST")
self.assertEqual(show.name, "PROCESSLIST")
self.assertTrue(show.args["full"])
def test_show_profile(self):
show = self.validate_identity("SHOW PROFILE")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "PROFILE")
show = self.validate_identity("SHOW PROFILE BLOCK IO")
self.assertEqual(show.args["types"][0].name, "BLOCK IO")
show = self.validate_identity(
"SHOW PROFILE BLOCK IO, PAGE FAULTS FOR QUERY 1 OFFSET 2 LIMIT 3"
)
self.assertEqual(show.args["types"][0].name, "BLOCK IO")
self.assertEqual(show.args["types"][1].name, "PAGE FAULTS")
self.assertEqual(show.text("query"), "1")
self.assertEqual(show.text("offset"), "2")
self.assertEqual(show.text("limit"), "3")
def test_show_replica_status(self):
show = self.validate_identity("SHOW REPLICA STATUS")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "REPLICA STATUS")
show = self.validate_identity("SHOW SLAVE STATUS", "SHOW REPLICA STATUS")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "REPLICA STATUS")
show = self.validate_identity("SHOW REPLICA STATUS FOR CHANNEL channel_name")
self.assertEqual(show.text("channel"), "channel_name")
def test_show_tables(self):
show = self.validate_identity("SHOW TABLES")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "TABLES")
show = self.validate_identity("SHOW FULL TABLES FROM db_name LIKE '%foo%'")
self.assertTrue(show.args["full"])
self.assertEqual(show.text("db"), "db_name")
self.assertIsInstance(show.args["like"], exp.Literal)
self.assertEqual(show.text("like"), "%foo%")
def test_set_variable(self):
cmd = self.parse_one("SET SESSION x = 1")
item = cmd.expressions[0]
self.assertEqual(item.text("kind"), "SESSION")
self.assertIsInstance(item.this, exp.EQ)
self.assertEqual(item.this.left.name, "x")
self.assertEqual(item.this.right.name, "1")
cmd = self.parse_one("SET @@GLOBAL.x = @@GLOBAL.y")
item = cmd.expressions[0]
self.assertEqual(item.text("kind"), "")
self.assertIsInstance(item.this, exp.EQ)
self.assertIsInstance(item.this.left, exp.SessionParameter)
self.assertIsInstance(item.this.right, exp.SessionParameter)
cmd = self.parse_one("SET NAMES 'charset_name' COLLATE 'collation_name'")
item = cmd.expressions[0]
self.assertEqual(item.text("kind"), "NAMES")
self.assertEqual(item.name, "charset_name")
self.assertEqual(item.text("collate"), "collation_name")
cmd = self.parse_one("SET CHARSET DEFAULT")
item = cmd.expressions[0]
self.assertEqual(item.text("kind"), "CHARACTER SET")
self.assertEqual(item.this.name, "DEFAULT")
cmd = self.parse_one("SET x = 1, y = 2")
self.assertEqual(len(cmd.expressions), 2)

View file

@ -8,7 +8,9 @@ class TestPostgres(Validator):
def test_ddl(self):
self.validate_all(
"CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)",
write={"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"},
write={
"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"
},
)
self.validate_all(
"CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)",
@ -59,15 +61,27 @@ class TestPostgres(Validator):
def test_postgres(self):
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END")
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END")
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END")
self.validate_identity('SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')')
self.validate_identity(
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END"
)
self.validate_identity(
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END"
)
self.validate_identity(
'SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')'
)
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')")
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')")
self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))")
self.validate_identity(
"SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')"
)
self.validate_identity(
"SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))"
)
self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')")
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
self.validate_identity("SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')")
self.validate_identity(
"SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')"
)
self.validate_identity("COMMENT ON TABLE mytable IS 'this'")
self.validate_identity("SELECT e'\\xDEADBEEF'")
self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
@ -75,7 +89,7 @@ class TestPostgres(Validator):
self.validate_all(
"CREATE TABLE x (a UUID, b BYTEA)",
write={
"duckdb": "CREATE TABLE x (a UUID, b BINARY)",
"duckdb": "CREATE TABLE x (a UUID, b VARBINARY)",
"presto": "CREATE TABLE x (a UUID, b VARBINARY)",
"hive": "CREATE TABLE x (a UUID, b BINARY)",
"spark": "CREATE TABLE x (a UUID, b BINARY)",
@ -153,7 +167,9 @@ class TestPostgres(Validator):
)
self.validate_all(
"SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss",
read={"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"},
read={
"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"
},
)
self.validate_all(
"SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL",
@ -169,11 +185,15 @@ class TestPostgres(Validator):
)
self.validate_all(
"SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL",
read={"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL"},
read={
"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL"
},
)
self.validate_all(
"SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL",
read={"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL"},
read={
"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL"
},
)
self.validate_all(
"'[1,2,3]'::json->2",
@ -184,7 +204,8 @@ class TestPostgres(Validator):
write={"postgres": """CAST('{"a":1,"b":2}' AS JSON)->'b'"""},
)
self.validate_all(
"""'{"x": {"y": 1}}'::json->'x'->'y'""", write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""}
"""'{"x": {"y": 1}}'::json->'x'->'y'""",
write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""},
)
self.validate_all(
"""'{"x": {"y": 1}}'::json->'x'::json->'y'""",

View file

@ -61,4 +61,6 @@ class TestRedshift(Validator):
"SELECT caldate + INTERVAL '1 second' AS dateplus FROM date WHERE caldate = '12-31-2008'"
)
self.validate_identity("CREATE TABLE datetable (start_date DATE, end_date DATE)")
self.validate_identity("SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'")
self.validate_identity(
"SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'"
)

View file

@ -336,7 +336,8 @@ class TestSnowflake(Validator):
def test_table_literal(self):
# All examples from https://docs.snowflake.com/en/sql-reference/literals-table.html
self.validate_all(
r"""SELECT * FROM TABLE('MYTABLE')""", write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""}
r"""SELECT * FROM TABLE('MYTABLE')""",
write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""},
)
self.validate_all(
@ -352,15 +353,123 @@ class TestSnowflake(Validator):
write={"snowflake": r"""SELECT * FROM TABLE('MYDB. "MYSCHEMA"."MYTABLE"')"""},
)
self.validate_all(r"""SELECT * FROM TABLE($MYVAR)""", write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""})
self.validate_all(r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""})
self.validate_all(
r"""SELECT * FROM TABLE($MYVAR)""",
write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""},
)
self.validate_all(
r"""SELECT * FROM TABLE(:BINDING)""", write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""}
r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""}
)
self.validate_all(
r"""SELECT * FROM TABLE(:BINDING)""",
write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""},
)
self.validate_all(
r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10""",
write={"snowflake": r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10"""},
)
def test_flatten(self):
self.validate_all(
"""
select
dag_report.acct_id,
dag_report.report_date,
dag_report.report_uuid,
dag_report.airflow_name,
dag_report.dag_id,
f.value::varchar as operator
from cs.telescope.dag_report,
table(flatten(input=>split(operators, ','))) f
""",
write={
"snowflake": """SELECT
dag_report.acct_id,
dag_report.report_date,
dag_report.report_uuid,
dag_report.airflow_name,
dag_report.dag_id,
CAST(f.value AS VARCHAR) AS operator
FROM cs.telescope.dag_report, TABLE(FLATTEN(input => SPLIT(operators, ','))) AS f"""
},
pretty=True,
)
# All examples from https://docs.snowflake.com/en/sql-reference/functions/flatten.html#syntax
self.validate_all(
"SELECT * FROM TABLE(FLATTEN(input => parse_json('[1, ,77]'))) f",
write={
"snowflake": "SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[1, ,77]'))) AS f"
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88]}'), outer => true)) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88]}'), outer => TRUE)) AS f"""
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88]}'), path => 'b')) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88]}'), path => 'b')) AS f"""
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('[]'))) f""",
write={"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[]'))) AS f"""},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('[]'), outer => true)) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[]'), outer => TRUE)) AS f"""
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'))) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'))) AS f"""
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => true)) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => TRUE)) AS f"""
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => true, mode => 'object')) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => TRUE, mode => 'object')) AS f"""
},
)
self.validate_all(
"""
SELECT id as "ID",
f.value AS "Contact",
f1.value:type AS "Type",
f1.value:content AS "Details"
FROM persons p,
lateral flatten(input => p.c, path => 'contact') f,
lateral flatten(input => f.value:business) f1
""",
write={
"snowflake": """SELECT
id AS "ID",
f.value AS "Contact",
f1.value['type'] AS "Type",
f1.value['content'] AS "Details"
FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') f, LATERAL FLATTEN(input => f.value['business']) f1""",
},
pretty=True,
)

View file

@ -284,4 +284,6 @@ TBLPROPERTIES (
)
def test_iif(self):
self.validate_all("SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"})
self.validate_all(
"SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"}
)

View file

@ -6,3 +6,6 @@ class TestMySQL(Validator):
def test_identity(self):
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")
def test_time(self):
self.validate_identity("TIMESTAMP('2022-01-01')")

View file

@ -278,12 +278,19 @@ class TestTSQL(Validator):
def test_add_date(self):
self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')")
self.validate_all(
"SELECT DATEADD(year, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"}
"SELECT DATEADD(year, 1, '2017/08/25')",
write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"},
)
self.validate_all(
"SELECT DATEADD(qq, 1, '2017/08/25')",
write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"},
)
self.validate_all("SELECT DATEADD(qq, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"})
self.validate_all(
"SELECT DATEADD(wk, 1, '2017/08/25')",
write={"spark": "SELECT DATE_ADD('2017/08/25', 7)", "databricks": "SELECT DATEADD(week, 1, '2017/08/25')"},
write={
"spark": "SELECT DATE_ADD('2017/08/25', 7)",
"databricks": "SELECT DATEADD(week, 1, '2017/08/25')",
},
)
def test_date_diff(self):
@ -370,13 +377,21 @@ class TestTSQL(Validator):
"SELECT FORMAT(1000000.01,'###,###.###')",
write={"spark": "SELECT FORMAT_NUMBER(1000000.01, '###,###.###')"},
)
self.validate_all("SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"})
self.validate_all(
"SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"}
)
self.validate_all(
"SELECT FORMAT('01-01-1991', 'dd.mm.yyyy')",
write={"spark": "SELECT DATE_FORMAT('01-01-1991', 'dd.mm.yyyy')"},
)
self.validate_all(
"SELECT FORMAT(date_col, 'dd.mm.yyyy')", write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"}
"SELECT FORMAT(date_col, 'dd.mm.yyyy')",
write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"},
)
self.validate_all(
"SELECT FORMAT(date_col, 'm')",
write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"},
)
self.validate_all(
"SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"}
)
self.validate_all("SELECT FORMAT(date_col, 'm')", write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"})
self.validate_all("SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"})

View file

@ -523,6 +523,8 @@ DROP VIEW a.b
DROP VIEW IF EXISTS a
DROP VIEW IF EXISTS a.b
SHOW TABLES
USE db
ROLLBACK
EXPLAIN SELECT * FROM x
INSERT INTO x SELECT * FROM y
INSERT INTO x (SELECT * FROM y)
@ -569,3 +571,13 @@ SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)
SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3)
SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)
SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)
SELECT CAST(x AS INT) /* comment */ FROM foo
SELECT a /* x */, b /* x */
SELECT * FROM foo /* x */, bla /* x */
SELECT 1 /* comment */ + 1
SELECT 1 /* c1 */ + 2 /* c2 */
SELECT 1 /* c1 */ + 2 /* c2 */ + 3 /* c3 */
SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */
SELECT x FROM a.b.c /* x */, e.f.g /* x */
SELECT FOO(x /* c */) /* FOO */, b /* b */
SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */

View file

@ -104,6 +104,16 @@ SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_
SELECT AGGREGATE(ARRAY(a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + x.a) AS sum_agg FROM x AS x;
# dialect: starrocks
# execute: false
SELECT DATE_TRUNC('week', a) AS a FROM x;
SELECT DATE_TRUNC('week', x.a) AS a FROM x AS x;
# dialect: bigquery
# execute: false
SELECT DATE_TRUNC(a, MONTH) AS a FROM x;
SELECT DATE_TRUNC(x.a, MONTH) AS a FROM x AS x;
--------------------------------------
-- Derived tables
--------------------------------------

View file

@ -79,6 +79,15 @@ NULL;
NULL = NULL;
NULL;
NULL <=> NULL;
TRUE;
a IS NOT DISTINCT FROM a;
TRUE;
NULL IS DISTINCT FROM NULL;
FALSE;
NOT (NOT TRUE);
TRUE;

View file

@ -287,3 +287,31 @@ SELECT
"fffffff"
)
);
/*
multi
line
comment
*/
SELECT * FROM foo;
/*
multi
line
comment
*/
SELECT
*
FROM foo;
SELECT x FROM a.b.c /*x*/, e.f.g /*x*/;
SELECT
x
FROM a.b.c /* x */, e.f.g /* x */;
SELECT x FROM (SELECT * FROM bla /*x*/WHERE id = 1) /*x*/;
SELECT
x
FROM (
SELECT
*
FROM bla /* x */
WHERE
id = 1
) /* x */;

View file

@ -100,15 +100,21 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl LEFT OUTER JOIN tbl2",
),
(
lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer"),
lambda: select("x")
.from_("tbl")
.join(exp.Table(this="tbl2"), join_type="left outer"),
"SELECT x FROM tbl LEFT OUTER JOIN tbl2",
),
(
lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"),
lambda: select("x")
.from_("tbl")
.join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"),
"SELECT x FROM tbl LEFT OUTER JOIN tbl2 AS foo",
),
(
lambda: select("x").from_("tbl").join(select("y").from_("tbl2"), join_type="left outer"),
lambda: select("x")
.from_("tbl")
.join(select("y").from_("tbl2"), join_type="left outer"),
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2)",
),
(
@ -131,7 +137,9 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2) AS aliased",
),
(
lambda: select("x").from_("tbl").join(parse_one("left join x", into=exp.Join), on="a=b"),
lambda: select("x")
.from_("tbl")
.join(parse_one("left join x", into=exp.Join), on="a=b"),
"SELECT x FROM tbl LEFT JOIN x ON a = b",
),
(
@ -139,7 +147,9 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl LEFT JOIN x ON a = b",
),
(
lambda: select("x").from_("tbl").join("select b from tbl2", on="a=b", join_type="left"),
lambda: select("x")
.from_("tbl")
.join("select b from tbl2", on="a=b", join_type="left"),
"SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) ON a = b",
),
(
@ -162,7 +172,10 @@ class TestBuild(unittest.TestCase):
(
lambda: select("x", "y", "z")
.from_("merged_df")
.join("vte_diagnosis_df", using=[exp.to_identifier("patient_id"), exp.to_identifier("encounter_id")]),
.join(
"vte_diagnosis_df",
using=[exp.to_identifier("patient_id"), exp.to_identifier("encounter_id")],
),
"SELECT x, y, z FROM merged_df JOIN vte_diagnosis_df USING (patient_id, encounter_id)",
),
(
@ -222,7 +235,10 @@ class TestBuild(unittest.TestCase):
"SELECT x, y, z, a FROM tbl ORDER BY x, y, z, a",
),
(
lambda: select("x", "y", "z", "a").from_("tbl").cluster_by("x, y", "z").cluster_by("a"),
lambda: select("x", "y", "z", "a")
.from_("tbl")
.cluster_by("x, y", "z")
.cluster_by("a"),
"SELECT x, y, z, a FROM tbl CLUSTER BY x, y, z, a",
),
(
@ -239,7 +255,9 @@ class TestBuild(unittest.TestCase):
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
),
(
lambda: select("x").from_("tbl").with_("tbl", as_="SELECT x FROM tbl2", recursive=True),
lambda: select("x")
.from_("tbl")
.with_("tbl", as_="SELECT x FROM tbl2", recursive=True),
"WITH RECURSIVE tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
),
(
@ -247,7 +265,9 @@ class TestBuild(unittest.TestCase):
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
),
(
lambda: select("x").from_("tbl").with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")),
lambda: select("x")
.from_("tbl")
.with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")),
"WITH tbl(x, y) AS (SELECT x, y FROM tbl2) SELECT x FROM tbl",
),
(
@ -258,7 +278,10 @@ class TestBuild(unittest.TestCase):
"WITH tbl AS (SELECT x FROM tbl2), tbl2 AS (SELECT x FROM tbl3) SELECT x FROM tbl",
),
(
lambda: select("x").from_("tbl").with_("tbl", as_=select("x", "y").from_("tbl2")).select("y"),
lambda: select("x")
.from_("tbl")
.with_("tbl", as_=select("x", "y").from_("tbl2"))
.select("y"),
"WITH tbl AS (SELECT x, y FROM tbl2) SELECT x, y FROM tbl",
),
(
@ -266,35 +289,59 @@ class TestBuild(unittest.TestCase):
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
),
(
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").group_by("x"),
lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.group_by("x"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl GROUP BY x",
),
(
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").order_by("x"),
lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.order_by("x"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl ORDER BY x",
),
(
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").limit(10),
lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.limit(10),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl LIMIT 10",
),
(
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").offset(10),
lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.offset(10),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl OFFSET 10",
),
(
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").join("tbl3"),
lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.join("tbl3"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl JOIN tbl3",
),
(
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").distinct(),
lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.distinct(),
"WITH tbl AS (SELECT x FROM tbl2) SELECT DISTINCT x FROM tbl",
),
(
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").where("x > 10"),
lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.where("x > 10"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl WHERE x > 10",
),
(
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").having("x > 20"),
lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.having("x > 20"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl HAVING x > 20",
),
(lambda: select("x").from_("tbl").subquery(), "(SELECT x FROM tbl)"),
@ -354,7 +401,9 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM (SELECT x FROM tbl) AS foo WHERE x > 0",
),
(
lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select("x"),
lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select(
"x"
),
"SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned",
),
(

View file

@ -33,7 +33,10 @@ class TestExecutor(unittest.TestCase):
)
cls.cache = {}
cls.sqls = [(sql, expected) for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")]
cls.sqls = [
(sql, expected)
for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")
]
@classmethod
def tearDownClass(cls):
@ -63,7 +66,9 @@ class TestExecutor(unittest.TestCase):
def test_execute_tpch(self):
def to_csv(expression):
if isinstance(expression, exp.Table):
return parse_one(f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}")
return parse_one(
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}"
)
return expression
for sql, _ in self.sqls[0:3]:

View file

@ -30,7 +30,9 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)"))
self.assertEqual(exp.Table(pivots=[]), exp.Table())
self.assertNotEqual(exp.Table(pivots=[None]), exp.Table())
self.assertEqual(exp.DataType.build("int"), exp.DataType(this=exp.DataType.Type.INT, nested=False))
self.assertEqual(
exp.DataType.build("int"), exp.DataType(this=exp.DataType.Type.INT, nested=False)
)
def test_find(self):
expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y")
@ -89,7 +91,9 @@ class TestExpressions(unittest.TestCase):
self.assertIsNone(column.find_ancestor(exp.Join))
def test_alias_or_name(self):
expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
expression = parse_one(
"SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
)
self.assertEqual(
[e.alias_or_name for e in expression.expressions],
["a", "B", "e", "*", "zz", "z"],
@ -166,7 +170,9 @@ class TestExpressions(unittest.TestCase):
"SELECT * FROM foo WHERE ? > 100",
)
self.assertEqual(
exp.replace_placeholders(parse_one("select * from :name WHERE ? > 100"), another_name="bla").sql(),
exp.replace_placeholders(
parse_one("select * from :name WHERE ? > 100"), another_name="bla"
).sql(),
"SELECT * FROM :name WHERE ? > 100",
)
self.assertEqual(
@ -183,7 +189,9 @@ class TestExpressions(unittest.TestCase):
)
def test_named_selects(self):
expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
expression = parse_one(
"SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
)
self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"])
expression = parse_one(
@ -367,7 +375,9 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(len(list(expression.walk())), 9)
self.assertEqual(len(list(expression.walk(bfs=False))), 9)
self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk()))
self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False)))
self.assertTrue(
all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False))
)
def test_functions(self):
self.assertIsInstance(parse_one("ABS(a)"), exp.Abs)
@ -512,14 +522,21 @@ class TestExpressions(unittest.TestCase):
),
exp.Properties(
expressions=[
exp.FileFormatProperty(this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet")),
exp.FileFormatProperty(
this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet")
),
exp.PartitionedByProperty(
this=exp.Literal.string("PARTITIONED_BY"),
value=exp.Tuple(expressions=[exp.to_identifier("a"), exp.to_identifier("b")]),
value=exp.Tuple(
expressions=[exp.to_identifier("a"), exp.to_identifier("b")]
),
),
exp.AnonymousProperty(
this=exp.Literal.string("custom"), value=exp.Literal.number(1)
),
exp.AnonymousProperty(this=exp.Literal.string("custom"), value=exp.Literal.number(1)),
exp.TableFormatProperty(
this=exp.Literal.string("TABLE_FORMAT"), value=exp.to_identifier("test_format")
this=exp.Literal.string("TABLE_FORMAT"),
value=exp.to_identifier("test_format"),
),
exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.NULL),
exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.TRUE),
@ -538,7 +555,10 @@ class TestExpressions(unittest.TestCase):
((1, "2", None), "(1, '2', NULL)"),
([1, "2", None], "ARRAY(1, '2', NULL)"),
({"x": None}, "MAP('x', NULL)"),
(datetime.datetime(2022, 10, 1, 1, 1, 1), "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')"),
(
datetime.datetime(2022, 10, 1, 1, 1, 1),
"TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')",
),
(
datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc),
"TIME_STR_TO_TIME('2022-10-01 01:01:01.000000+0000')",
@ -548,30 +568,48 @@ class TestExpressions(unittest.TestCase):
with self.subTest(value):
self.assertEqual(exp.convert(value).sql(), expected)
def test_annotation_alias(self):
sql = "SELECT a, b AS B, c # comment, d AS D # another_comment FROM foo"
def test_comment_alias(self):
sql = """
SELECT
a,
b AS B,
c, /*comment*/
d AS D, -- another comment
CAST(x AS INT) -- final comment
FROM foo
"""
expression = parse_one(sql)
self.assertEqual(
[e.alias_or_name for e in expression.expressions],
["a", "B", "c", "D"],
["a", "B", "c", "D", "x"],
)
self.assertEqual(expression.sql(), "SELECT a, b AS B, c, d AS D")
self.assertEqual(expression.expressions[2].name, "comment")
self.assertEqual(
expression.sql(pretty=True, annotations=False),
expression.sql(),
"SELECT a, b AS B, c /* comment */, d AS D /* another comment */, CAST(x AS INT) /* final comment */ FROM foo",
)
self.assertEqual(
expression.sql(comments=False),
"SELECT a, b AS B, c, d AS D, CAST(x AS INT) FROM foo",
)
self.assertEqual(
expression.sql(pretty=True, comments=False),
"""SELECT
a,
b AS B,
c,
d AS D""",
d AS D,
CAST(x AS INT)
FROM foo""",
)
self.assertEqual(
expression.sql(pretty=True),
"""SELECT
a,
b AS B,
c # comment,
d AS D # another_comment FROM foo""",
c, -- comment
d AS D, -- another comment
CAST(x AS INT) -- final comment
FROM foo""",
)
def test_to_table(self):
@ -605,5 +643,9 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(expression, exp.Union)
self.assertEqual(expression.named_selects, ["cola", "colb"])
self.assertEqual(
expression.selects, [exp.Column(this=exp.to_identifier("cola")), exp.Column(this=exp.to_identifier("colb"))]
expression.selects,
[
exp.Column(this=exp.to_identifier("cola")),
exp.Column(this=exp.to_identifier("colb")),
],
)

View file

@ -67,7 +67,9 @@ class TestOptimizer(unittest.TestCase):
}
def check_file(self, file, func, pretty=False, execute=False, **kwargs):
for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1):
for i, (meta, sql, expected) in enumerate(
load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1
):
title = meta.get("title") or f"{i}, {sql}"
dialect = meta.get("dialect")
leave_tables_isolated = meta.get("leave_tables_isolated")
@ -90,7 +92,9 @@ class TestOptimizer(unittest.TestCase):
if string_to_bool(should_execute):
with self.subTest(f"(execute) {title}"):
df1 = self.conn.execute(sqlglot.transpile(sql, read=dialect, write="duckdb")[0]).df()
df1 = self.conn.execute(
sqlglot.transpile(sql, read=dialect, write="duckdb")[0]
).df()
df2 = self.conn.execute(optimized.sql(pretty=pretty, dialect="duckdb")).df()
assert_frame_equal(df1, df2)
@ -268,7 +272,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
self.assertEqual(scopes[2].expression.sql(), "(VALUES (1, 'test')) AS tab(cola, colb)")
self.assertEqual(
scopes[3].expression.sql(), "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)"
scopes[3].expression.sql(),
"SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
)
self.assertEqual(scopes[4].expression.sql(), "SELECT y.c AS b FROM y")
self.assertEqual(scopes[5].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
@ -287,7 +292,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
# Check that we can walk in scope from an arbitrary node
self.assertEqual(
{node.sql() for node, *_ in walk_in_scope(expression.find(exp.Where)) if isinstance(node, exp.Column)},
{
node.sql()
for node, *_ in walk_in_scope(expression.find(exp.Where))
if isinstance(node, exp.Column)
},
{"s.b"},
)
@ -324,7 +333,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT)
def test_cache_annotation(self):
expression = annotate_types(parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1"))
expression = annotate_types(
parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")
)
self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT)
def test_binary_annotation(self):
@ -384,7 +395,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
"""
expression = annotate_types(parse_one(sql), schema=schema)
self.assertEqual(expression.expressions[0].type, exp.DataType.Type.TEXT) # tbl.cola + tbl.colb + 'foo' AS col
self.assertEqual(
expression.expressions[0].type, exp.DataType.Type.TEXT
) # tbl.cola + tbl.colb + 'foo' AS col
outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo'
self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT)
@ -396,7 +409,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT)
cte_select = expression.args["with"].expressions[0].this
self.assertEqual(cte_select.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola + 'bla' AS cola
self.assertEqual(
cte_select.expressions[0].type, exp.DataType.Type.VARCHAR
) # x.cola + 'bla' AS cola
self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb
cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla'
@ -405,7 +420,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR)
# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
for d, t in zip(cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]):
for d, t in zip(
cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]
):
self.assertEqual(d.this.expressions[0].this.type, t)
def test_function_annotation(self):
@ -421,6 +438,19 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb
sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x"
case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
self.assertEqual(case_expr_alias.type, exp.DataType.Type.VARCHAR)
case_expr = case_expr_alias.this
self.assertEqual(case_expr.type, exp.DataType.Type.VARCHAR)
self.assertEqual(case_expr.args["default"].type, exp.DataType.Type.CHAR)
case_ifs_expr = case_expr.args["ifs"][0]
self.assertEqual(case_ifs_expr.type, exp.DataType.Type.VARCHAR)
self.assertEqual(case_ifs_expr.args["true"].type, exp.DataType.Type.VARCHAR)
def test_unknown_annotation(self):
schema = {"x": {"cola": "VARCHAR"}}
sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x"
@ -431,8 +461,12 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
concat_expr = concat_expr_alias.this
self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN)
self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) # SOME_ANONYMOUS_FUNC(x.cola)
self.assertEqual(concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola (arg)
self.assertEqual(
concat_expr.right.type, exp.DataType.Type.UNKNOWN
) # SOME_ANONYMOUS_FUNC(x.cola)
self.assertEqual(
concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR
) # x.cola (arg)
def test_null_annotation(self):
expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this

View file

@ -23,8 +23,6 @@ class TestParser(unittest.TestCase):
def test_float(self):
self.assertEqual(parse_one(".2"), parse_one("0.2"))
self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)"))
self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)"))
def test_table(self):
tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)]
@ -33,7 +31,9 @@ class TestParser(unittest.TestCase):
def test_select(self):
self.assertIsNotNone(parse_one("select 1 natural"))
self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"])
self.assertIsNotNone(parse_one("select * from x where a = (select 1) order by x.y").args["order"])
self.assertIsNotNone(
parse_one("select * from x where a = (select 1) order by x.y").args["order"]
)
self.assertEqual(len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1)
self.assertEqual(
parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(),
@ -125,26 +125,70 @@ class TestParser(unittest.TestCase):
def test_var(self):
self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'")
def test_annotations(self):
def test_comments(self):
expression = parse_one(
"""
SELECT
a #annotation1,
b as B #annotation2:testing ,
"test#annotation",c#annotation3, d #annotation4,
e #,
f # space
--comment1
SELECT /* this won't be used */
a, --comment2
b as B, --comment3:testing
"test--annotation",
c, --comment4 --foo
e, --
f -- space
FROM foo
"""
)
assert expression.expressions[0].name == "annotation1"
assert expression.expressions[1].name == "annotation2:testing"
assert expression.expressions[2].name == "test#annotation"
assert expression.expressions[3].name == "annotation3"
assert expression.expressions[4].name == "annotation4"
assert expression.expressions[5].name == ""
assert expression.expressions[6].name == "space"
self.assertEqual(expression.comment, "comment1")
self.assertEqual(expression.expressions[0].comment, "comment2")
self.assertEqual(expression.expressions[1].comment, "comment3:testing")
self.assertEqual(expression.expressions[2].comment, None)
self.assertEqual(expression.expressions[3].comment, "comment4 --foo")
self.assertEqual(expression.expressions[4].comment, "")
self.assertEqual(expression.expressions[5].comment, " space")
def test_type_literals(self):
self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)"))
self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)"))
self.assertEqual(
parse_one("TIMESTAMP '2022-01-01'").sql(), "CAST('2022-01-01' AS TIMESTAMP)"
)
self.assertEqual(
parse_one("TIMESTAMP(1) '2022-01-01'").sql(), "CAST('2022-01-01' AS TIMESTAMP(1))"
)
self.assertEqual(
parse_one("TIMESTAMP WITH TIME ZONE '2022-01-01'").sql(),
"CAST('2022-01-01' AS TIMESTAMPTZ)",
)
self.assertEqual(
parse_one("TIMESTAMP WITH LOCAL TIME ZONE '2022-01-01'").sql(),
"CAST('2022-01-01' AS TIMESTAMPLTZ)",
)
self.assertEqual(
parse_one("TIMESTAMP WITHOUT TIME ZONE '2022-01-01'").sql(),
"CAST('2022-01-01' AS TIMESTAMP)",
)
self.assertEqual(
parse_one("TIMESTAMP(1) WITH TIME ZONE '2022-01-01'").sql(),
"CAST('2022-01-01' AS TIMESTAMPTZ(1))",
)
self.assertEqual(
parse_one("TIMESTAMP(1) WITH LOCAL TIME ZONE '2022-01-01'").sql(),
"CAST('2022-01-01' AS TIMESTAMPLTZ(1))",
)
self.assertEqual(
parse_one("TIMESTAMP(1) WITHOUT TIME ZONE '2022-01-01'").sql(),
"CAST('2022-01-01' AS TIMESTAMP(1))",
)
self.assertEqual(parse_one("TIMESTAMP(1) WITH TIME ZONE").sql(), "TIMESTAMPTZ(1)")
self.assertEqual(parse_one("TIMESTAMP(1) WITH LOCAL TIME ZONE").sql(), "TIMESTAMPLTZ(1)")
self.assertEqual(parse_one("TIMESTAMP(1) WITHOUT TIME ZONE").sql(), "TIMESTAMP(1)")
self.assertEqual(parse_one("""JSON '{"x":"y"}'""").sql(), """CAST('{"x":"y"}' AS JSON)""")
self.assertIsInstance(parse_one("TIMESTAMP(1)"), exp.Func)
self.assertIsInstance(parse_one("TIMESTAMP('2022-01-01')"), exp.Func)
self.assertIsInstance(parse_one("TIMESTAMP()"), exp.Func)
self.assertIsInstance(parse_one("map.x"), exp.Column)
def test_pretty_config_override(self):
self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x")

View file

@ -1,281 +1,141 @@
import unittest
from sqlglot import table
from sqlglot.dataframe.sql import types as df_types
from sqlglot import exp, to_table
from sqlglot.errors import SchemaError
from sqlglot.schema import MappingSchema, ensure_schema
class TestSchema(unittest.TestCase):
def assert_column_names(self, schema, *table_results):
for table, result in table_results:
with self.subTest(f"{table} -> {result}"):
self.assertEqual(schema.column_names(to_table(table)), result)
def assert_column_names_raises(self, schema, *tables):
for table in tables:
with self.subTest(table):
with self.assertRaises(SchemaError):
schema.column_names(to_table(table))
def test_schema(self):
schema = ensure_schema(
{
"x": {
"a": "uint64",
}
}
)
self.assertEqual(
schema.column_names(
table(
"x",
)
),
["a"],
)
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db"))
with self.assertRaises(ValueError):
schema.column_names(table("x2"))
with self.assertRaises(ValueError):
schema.add_table(table("y", db="db"), {"b": "string"})
with self.assertRaises(ValueError):
schema.add_table(table("y", db="db", catalog="c"), {"b": "string"})
schema.add_table(table("y"), {"b": "string"})
schema_with_y = {
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
}
self.assertEqual(schema.schema, schema_with_y)
new_schema = schema.copy()
new_schema.add_table(table("z"), {"c": "string"})
self.assertEqual(schema.schema, schema_with_y)
self.assertEqual(
new_schema.schema,
{
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
"z": {
"c": "string",
},
},
)
schema.add_table(table("m"), {"d": "string"})
schema.add_table(table("n"), {"e": "string"})
schema_with_m_n = {
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
"m": {
"d": "string",
},
"n": {
"e": "string",
},
}
self.assertEqual(schema.schema, schema_with_m_n)
new_schema = schema.copy()
new_schema.add_table(table("o"), {"f": "string"})
new_schema.add_table(table("p"), {"g": "string"})
self.assertEqual(schema.schema, schema_with_m_n)
self.assertEqual(
new_schema.schema,
{
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
"m": {
"d": "string",
},
"n": {
"e": "string",
},
"o": {
"f": "string",
},
"p": {
"g": "string",
"b": "uint64",
"c": "uint64",
},
},
)
self.assert_column_names(
schema,
("x", ["a"]),
("y", ["b", "c"]),
("z.x", ["a"]),
("z.x.y", ["b", "c"]),
)
self.assert_column_names_raises(
schema,
"z",
"z.z",
"z.z.z",
)
def test_schema_db(self):
schema = ensure_schema(
{
"db": {
"x": {
"a": "uint64",
}
}
}
)
self.assertEqual(schema.column_names(table("x", db="db")), ["a"])
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c"))
with self.assertRaises(ValueError):
schema.column_names(table("x"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db2"))
with self.assertRaises(ValueError):
schema.column_names(table("x2", db="db"))
with self.assertRaises(ValueError):
schema.add_table(table("y"), {"b": "string"})
with self.assertRaises(ValueError):
schema.add_table(table("y", db="db", catalog="c"), {"b": "string"})
schema.add_table(table("y", db="db"), {"b": "string"})
self.assertEqual(
schema.schema,
{
"db": {
"d1": {
"x": {
"a": "uint64",
},
"y": {
"b": "string",
"b": "uint64",
},
}
},
"d2": {
"x": {
"c": "uint64",
},
},
},
)
self.assert_column_names(
schema,
("d1.x", ["a"]),
("d2.x", ["c"]),
("y", ["b"]),
("d1.y", ["b"]),
("z.d1.y", ["b"]),
)
self.assert_column_names_raises(
schema,
"x",
"z.x",
"z.y",
)
def test_schema_catalog(self):
schema = ensure_schema(
{
"c": {
"db": {
"x": {
"a": "uint64",
}
}
}
}
)
self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"])
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db"))
with self.assertRaises(ValueError):
schema.column_names(table("x"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c2"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db2"))
with self.assertRaises(ValueError):
schema.column_names(table("x2", db="db"))
with self.assertRaises(ValueError):
schema.add_table(table("x"), {"b": "string"})
with self.assertRaises(ValueError):
schema.add_table(table("x", db="db"), {"b": "string"})
schema.add_table(table("y", db="db", catalog="c"), {"a": "string", "b": "int"})
self.assertEqual(
schema.schema,
{
"c": {
"db": {
"c1": {
"d1": {
"x": {
"a": "uint64",
},
"y": {
"a": "string",
"b": "int",
"b": "uint64",
},
}
}
},
)
schema.add_table(table("z", db="db2", catalog="c"), {"c": "string", "d": "int"})
self.assertEqual(
schema.schema,
{
"c": {
"db": {
"x": {
"a": "uint64",
},
"y": {
"a": "string",
"b": "int",
},
},
"db2": {
"z": {
"c": "string",
"d": "int",
}
},
}
},
)
schema.add_table(table("m", db="db2", catalog="c2"), {"e": "string", "f": "int"})
self.assertEqual(
schema.schema,
{
"c": {
"db": {
"x": {
"a": "uint64",
"c": "uint64",
},
"y": {
"a": "string",
"b": "int",
},
},
"db2": {
"z": {
"c": "string",
"d": "int",
}
},
},
"c2": {
"db2": {
"m": {
"e": "string",
"f": "int",
}
}
"d1": {
"y": {
"d": "uint64",
},
"z": {
"e": "uint64",
},
},
"d2": {
"z": {
"f": "uint64",
},
},
},
},
)
schema = ensure_schema(
{
"x": {
"a": "uint64",
}
}
)
self.assertEqual(schema.column_names(table("x")), ["a"])
schema = MappingSchema()
schema.add_table(table("x"), {"a": "string"})
self.assertEqual(
schema.schema,
{
"x": {
"a": "string",
}
},
self.assert_column_names(
schema,
("x", ["a"]),
("d1.x", ["a"]),
("c1.d1.x", ["a"]),
("c1.d1.y", ["b"]),
("c1.d1.z", ["c"]),
("c2.d1.y", ["d"]),
("c2.d1.z", ["e"]),
("d2.z", ["f"]),
("c2.d2.z", ["f"]),
)
schema.add_table(table("y"), df_types.StructType([df_types.StructField("b", df_types.StringType())]))
self.assertEqual(
schema.schema,
{
"x": {
"a": "string",
},
"y": {
"b": "string",
},
},
self.assert_column_names_raises(
schema,
"q",
"d2.x",
"y",
"z",
"d1.y",
"d1.z",
"a.b.c",
)
def test_schema_add_table_with_and_without_mapping(self):
@ -288,3 +148,34 @@ class TestSchema(unittest.TestCase):
self.assertEqual(schema.column_names("test"), ["x", "y"])
schema.add_table("test")
self.assertEqual(schema.column_names("test"), ["x", "y"])
def test_schema_get_column_type(self):
schema = MappingSchema({"a": {"b": "varchar"}})
self.assertEqual(schema.get_column_type("a", "b"), exp.DataType.Type.VARCHAR)
self.assertEqual(
schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")),
exp.DataType.Type.VARCHAR,
)
self.assertEqual(
schema.get_column_type("a", exp.Column(this="b")), exp.DataType.Type.VARCHAR
)
self.assertEqual(
schema.get_column_type(exp.Table(this="a"), "b"), exp.DataType.Type.VARCHAR
)
schema = MappingSchema({"a": {"b": {"c": "varchar"}}})
self.assertEqual(
schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")),
exp.DataType.Type.VARCHAR,
)
self.assertEqual(
schema.get_column_type(exp.Table(this="b", db="a"), "c"), exp.DataType.Type.VARCHAR
)
schema = MappingSchema({"a": {"b": {"c": {"d": "varchar"}}}})
self.assertEqual(
schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d")),
exp.DataType.Type.VARCHAR,
)
self.assertEqual(
schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d"),
exp.DataType.Type.VARCHAR,
)

18
tests/test_tokens.py Normal file
View file

@ -0,0 +1,18 @@
import unittest
from sqlglot.tokens import Tokenizer
class TestTokens(unittest.TestCase):
def test_comment_attachment(self):
tokenizer = Tokenizer()
sql_comment = [
("/*comment*/ foo", "comment"),
("/*comment*/ foo --test", "comment"),
("--comment\nfoo --test", "comment"),
("foo --comment", "comment"),
("foo", None),
]
for sql, comment in sql_comment:
self.assertEqual(tokenizer.tokenize(sql)[0].comment, comment)

View file

@ -49,6 +49,12 @@ class TestTranspile(unittest.TestCase):
leading_comma=True,
pretty=True,
)
self.validate(
"SELECT FOO, /*x*/\nBAR, /*y*/\nBAZ",
"SELECT\n FOO -- x\n , BAR -- y\n , BAZ",
leading_comma=True,
pretty=True,
)
# without pretty, this should be a no-op
self.validate(
"SELECT FOO, BAR, BAZ",
@ -63,24 +69,61 @@ class TestTranspile(unittest.TestCase):
self.validate("SELECT 3>=3", "SELECT 3 >= 3")
def test_comments(self):
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo")
self.validate("SELECT 1 /* inline */ FROM foo -- comment", "SELECT 1 FROM foo")
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */")
self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo")
self.validate("SELECT --!5\nx FROM foo", "/* !5 */ SELECT x FROM foo")
self.validate(
"SELECT 1 /* inline */ FROM foo -- comment",
"SELECT 1 /* inline */ FROM foo /* comment */",
)
self.validate(
"SELECT FUN(x) /*x*/, [1,2,3] /*y*/", "SELECT FUN(x) /* x */, ARRAY(1, 2, 3) /* y */"
)
self.validate(
"""
SELECT 1 -- comment
FROM foo -- comment
""",
"SELECT 1 FROM foo",
"SELECT 1 /* comment */ FROM foo /* comment */",
)
self.validate(
"""
SELECT 1 /* big comment
like this */
FROM foo -- comment
""",
"SELECT 1 FROM foo",
"""SELECT 1 /* big comment
like this */ FROM foo /* comment */""",
)
self.validate(
"select x from foo -- x",
"SELECT x FROM foo /* x */",
)
self.validate(
"""
/* multi
line
comment
*/
SELECT
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
CAST(x AS INT), # comment 3
y -- comment 4
FROM
bar /* comment 5 */,
tbl # comment 6
""",
"""/* multi
line
comment
*/
SELECT
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
CAST(x AS INT), -- comment 3
y -- comment 4
FROM bar /* comment 5 */, tbl /* comment 6 */""",
read="mysql",
pretty=True,
)
def test_types(self):
@ -146,6 +189,16 @@ class TestTranspile(unittest.TestCase):
def test_ignore_nulls(self):
self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)")
def test_with(self):
self.validate(
"WITH a AS (SELECT 1) WITH b AS (SELECT 2) SELECT *",
"WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *",
)
self.validate(
"WITH a AS (SELECT 1), WITH b AS (SELECT 2) SELECT *",
"WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *",
)
def test_time(self):
self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)")
self.validate("TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)")