Adding upstream version 10.0.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
24752785d9
commit
1e860cc299
98 changed files with 4080 additions and 1666 deletions
|
@ -1,9 +1,9 @@
|
|||
import sys
|
||||
import typing as t
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import sqlglot
|
||||
from sqlglot.helper import PYTHON_VERSION
|
||||
from tests.helpers import SKIP_INTEGRATION
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
@ -11,7 +11,8 @@ if t.TYPE_CHECKING:
|
|||
|
||||
|
||||
@unittest.skipIf(
|
||||
SKIP_INTEGRATION or sys.version_info[:2] > (3, 10), "Skipping Integration Tests since `SKIP_INTEGRATION` is set"
|
||||
SKIP_INTEGRATION or PYTHON_VERSION > (3, 10),
|
||||
"Skipping Integration Tests since `SKIP_INTEGRATION` is set",
|
||||
)
|
||||
class DataFrameValidator(unittest.TestCase):
|
||||
spark = None
|
||||
|
@ -36,7 +37,12 @@ class DataFrameValidator(unittest.TestCase):
|
|||
|
||||
# This is for test `test_branching_root_dataframes`
|
||||
config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")])
|
||||
cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate()
|
||||
cls.spark = (
|
||||
SparkSession.builder.master("local[*]")
|
||||
.appName("Unit-tests")
|
||||
.config(conf=config)
|
||||
.getOrCreate()
|
||||
)
|
||||
cls.spark.sparkContext.setLogLevel("ERROR")
|
||||
cls.sqlglot = SqlglotSparkSession()
|
||||
cls.spark_employee_schema = types.StructType(
|
||||
|
@ -50,7 +56,9 @@ class DataFrameValidator(unittest.TestCase):
|
|||
)
|
||||
cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType(
|
||||
[
|
||||
sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField(
|
||||
"employee_id", sqlglotSparkTypes.IntegerType(), False
|
||||
),
|
||||
sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False),
|
||||
|
@ -64,8 +72,12 @@ class DataFrameValidator(unittest.TestCase):
|
|||
(4, "Claire", "Littleton", 27, 2),
|
||||
(5, "Hugo", "Reyes", 29, 100),
|
||||
]
|
||||
cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema)
|
||||
cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema)
|
||||
cls.df_employee = cls.spark.createDataFrame(
|
||||
data=employee_data, schema=cls.spark_employee_schema
|
||||
)
|
||||
cls.dfs_employee = cls.sqlglot.createDataFrame(
|
||||
data=employee_data, schema=cls.sqlglot_employee_schema
|
||||
)
|
||||
cls.df_employee.createOrReplaceTempView("employee")
|
||||
|
||||
cls.spark_store_schema = types.StructType(
|
||||
|
@ -80,7 +92,9 @@ class DataFrameValidator(unittest.TestCase):
|
|||
[
|
||||
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField(
|
||||
"district_id", sqlglotSparkTypes.IntegerType(), False
|
||||
),
|
||||
sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False),
|
||||
]
|
||||
)
|
||||
|
@ -89,7 +103,9 @@ class DataFrameValidator(unittest.TestCase):
|
|||
(2, "Arrow", 2, 2000),
|
||||
]
|
||||
cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema)
|
||||
cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema)
|
||||
cls.dfs_store = cls.sqlglot.createDataFrame(
|
||||
data=store_data, schema=cls.sqlglot_store_schema
|
||||
)
|
||||
cls.df_store.createOrReplaceTempView("store")
|
||||
|
||||
cls.spark_district_schema = types.StructType(
|
||||
|
@ -101,17 +117,27 @@ class DataFrameValidator(unittest.TestCase):
|
|||
)
|
||||
cls.sqlglot_district_schema = sqlglotSparkTypes.StructType(
|
||||
[
|
||||
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField(
|
||||
"district_id", sqlglotSparkTypes.IntegerType(), False
|
||||
),
|
||||
sqlglotSparkTypes.StructField(
|
||||
"district_name", sqlglotSparkTypes.StringType(), False
|
||||
),
|
||||
sqlglotSparkTypes.StructField(
|
||||
"manager_name", sqlglotSparkTypes.StringType(), False
|
||||
),
|
||||
]
|
||||
)
|
||||
district_data = [
|
||||
(1, "Temple", "Dogen"),
|
||||
(2, "Lighthouse", "Jacob"),
|
||||
]
|
||||
cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema)
|
||||
cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema)
|
||||
cls.df_district = cls.spark.createDataFrame(
|
||||
data=district_data, schema=cls.spark_district_schema
|
||||
)
|
||||
cls.dfs_district = cls.sqlglot.createDataFrame(
|
||||
data=district_data, schema=cls.sqlglot_district_schema
|
||||
)
|
||||
cls.df_district.createOrReplaceTempView("district")
|
||||
sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema)
|
||||
sqlglot.schema.add_table("store", cls.sqlglot_store_schema)
|
||||
|
|
|
@ -41,22 +41,32 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
|
||||
def test_alias_with_select(self):
|
||||
df_employee = self.df_spark_employee.alias("df_employee").select(
|
||||
self.df_spark_employee["employee_id"], F.col("df_employee.fname"), self.df_spark_employee.lname
|
||||
self.df_spark_employee["employee_id"],
|
||||
F.col("df_employee.fname"),
|
||||
self.df_spark_employee.lname,
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select(
|
||||
self.df_sqlglot_employee["employee_id"], SF.col("dfs_employee.fname"), self.df_sqlglot_employee.lname
|
||||
self.df_sqlglot_employee["employee_id"],
|
||||
SF.col("dfs_employee.fname"),
|
||||
self.df_sqlglot_employee.lname,
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
def test_case_when_otherwise(self):
|
||||
df = self.df_spark_employee.select(
|
||||
F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60"))
|
||||
F.when(
|
||||
(F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
|
||||
F.lit("between 40 and 60"),
|
||||
)
|
||||
.when(F.col("age") < F.lit(40), "less than 40")
|
||||
.otherwise("greater than 60")
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60"))
|
||||
SF.when(
|
||||
(SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
|
||||
SF.lit("between 40 and 60"),
|
||||
)
|
||||
.when(SF.col("age") < SF.lit(40), "less than 40")
|
||||
.otherwise("greater than 60")
|
||||
)
|
||||
|
@ -65,15 +75,17 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
|
||||
def test_case_when_no_otherwise(self):
|
||||
df = self.df_spark_employee.select(
|
||||
F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")).when(
|
||||
F.col("age") < F.lit(40), "less than 40"
|
||||
)
|
||||
F.when(
|
||||
(F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
|
||||
F.lit("between 40 and 60"),
|
||||
).when(F.col("age") < F.lit(40), "less than 40")
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")).when(
|
||||
SF.col("age") < SF.lit(40), "less than 40"
|
||||
)
|
||||
SF.when(
|
||||
(SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
|
||||
SF.lit("between 40 and 60"),
|
||||
).when(SF.col("age") < SF.lit(40), "less than 40")
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
@ -84,7 +96,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
def test_where_clause_multiple_and(self):
|
||||
df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack")))
|
||||
df_employee = self.df_spark_employee.where(
|
||||
(F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.where(
|
||||
(SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack"))
|
||||
)
|
||||
|
@ -106,7 +120,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
def test_where_clause_multiple_or(self):
|
||||
df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate")))
|
||||
df_employee = self.df_spark_employee.where(
|
||||
(F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.where(
|
||||
(SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate"))
|
||||
)
|
||||
|
@ -172,28 +188,43 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37))
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] % F.lit(5) == F.lit(0))
|
||||
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0))
|
||||
df_employee = self.df_spark_employee.where(
|
||||
self.df_spark_employee["age"] % F.lit(5) == F.lit(0)
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.where(
|
||||
self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] + F.lit(5) > F.lit(28))
|
||||
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28))
|
||||
df_employee = self.df_spark_employee.where(
|
||||
self.df_spark_employee["age"] + F.lit(5) > F.lit(28)
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.where(
|
||||
self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] - F.lit(5) > F.lit(28))
|
||||
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28))
|
||||
df_employee = self.df_spark_employee.where(
|
||||
self.df_spark_employee["age"] - F.lit(5) > F.lit(28)
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.where(
|
||||
self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
df_employee = self.df_spark_employee.where(
|
||||
self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2)
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.where(
|
||||
self.df_sqlglot_employee["age"] * SF.lit(0.5) == self.df_sqlglot_employee["age"] / SF.lit(2)
|
||||
self.df_sqlglot_employee["age"] * SF.lit(0.5)
|
||||
== self.df_sqlglot_employee["age"] / SF.lit(2)
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
def test_join_inner(self):
|
||||
df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="inner").select(
|
||||
df_joined = self.df_spark_employee.join(
|
||||
self.df_spark_store, on=["store_id"], how="inner"
|
||||
).select(
|
||||
self.df_spark_employee.employee_id,
|
||||
self.df_spark_employee["fname"],
|
||||
F.col("lname"),
|
||||
|
@ -202,7 +233,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.df_spark_store.store_name,
|
||||
self.df_spark_store["num_sales"],
|
||||
)
|
||||
dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="inner").select(
|
||||
dfs_joined = self.df_sqlglot_employee.join(
|
||||
self.df_sqlglot_store, on=["store_id"], how="inner"
|
||||
).select(
|
||||
self.df_sqlglot_employee.employee_id,
|
||||
self.df_sqlglot_employee["fname"],
|
||||
SF.col("lname"),
|
||||
|
@ -214,17 +247,27 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
|
||||
|
||||
def test_join_inner_no_select(self):
|
||||
df_joined = self.df_spark_employee.select(F.col("store_id"), F.col("fname"), F.col("lname")).join(
|
||||
self.df_spark_store.select(F.col("store_id"), F.col("store_name")), on=["store_id"], how="inner"
|
||||
df_joined = self.df_spark_employee.select(
|
||||
F.col("store_id"), F.col("fname"), F.col("lname")
|
||||
).join(
|
||||
self.df_spark_store.select(F.col("store_id"), F.col("store_name")),
|
||||
on=["store_id"],
|
||||
how="inner",
|
||||
)
|
||||
dfs_joined = self.df_sqlglot_employee.select(SF.col("store_id"), SF.col("fname"), SF.col("lname")).join(
|
||||
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), on=["store_id"], how="inner"
|
||||
dfs_joined = self.df_sqlglot_employee.select(
|
||||
SF.col("store_id"), SF.col("fname"), SF.col("lname")
|
||||
).join(
|
||||
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")),
|
||||
on=["store_id"],
|
||||
how="inner",
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
|
||||
|
||||
def test_join_inner_equality_single(self):
|
||||
df_joined = self.df_spark_employee.join(
|
||||
self.df_spark_store, on=self.df_spark_employee.store_id == self.df_spark_store.store_id, how="inner"
|
||||
self.df_spark_store,
|
||||
on=self.df_spark_employee.store_id == self.df_spark_store.store_id,
|
||||
how="inner",
|
||||
).select(
|
||||
self.df_spark_employee.employee_id,
|
||||
self.df_spark_employee["fname"],
|
||||
|
@ -235,7 +278,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.df_spark_store["num_sales"],
|
||||
)
|
||||
dfs_joined = self.df_sqlglot_employee.join(
|
||||
self.df_sqlglot_store, on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, how="inner"
|
||||
self.df_sqlglot_store,
|
||||
on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
|
||||
how="inner",
|
||||
).select(
|
||||
self.df_sqlglot_employee.employee_id,
|
||||
self.df_sqlglot_employee["fname"],
|
||||
|
@ -343,7 +388,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
|
||||
|
||||
def test_join_full_outer(self):
|
||||
df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="full_outer").select(
|
||||
df_joined = self.df_spark_employee.join(
|
||||
self.df_spark_store, on=["store_id"], how="full_outer"
|
||||
).select(
|
||||
self.df_spark_employee.employee_id,
|
||||
self.df_spark_employee["fname"],
|
||||
F.col("lname"),
|
||||
|
@ -352,7 +399,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.df_spark_store.store_name,
|
||||
self.df_spark_store["num_sales"],
|
||||
)
|
||||
dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="full_outer").select(
|
||||
dfs_joined = self.df_sqlglot_employee.join(
|
||||
self.df_sqlglot_store, on=["store_id"], how="full_outer"
|
||||
).select(
|
||||
self.df_sqlglot_employee.employee_id,
|
||||
self.df_sqlglot_employee["fname"],
|
||||
SF.col("lname"),
|
||||
|
@ -365,7 +414,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
|
||||
def test_triple_join(self):
|
||||
df = (
|
||||
self.df_employee.join(self.df_store, on=self.df_employee.employee_id == self.df_store.store_id)
|
||||
self.df_employee.join(
|
||||
self.df_store, on=self.df_employee.employee_id == self.df_store.store_id
|
||||
)
|
||||
.join(self.df_district, on=self.df_store.store_id == self.df_district.district_id)
|
||||
.select(
|
||||
self.df_employee.employee_id,
|
||||
|
@ -377,7 +428,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
)
|
||||
)
|
||||
dfs = (
|
||||
self.dfs_employee.join(self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id)
|
||||
self.dfs_employee.join(
|
||||
self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id
|
||||
)
|
||||
.join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id)
|
||||
.select(
|
||||
self.dfs_employee.employee_id,
|
||||
|
@ -391,13 +444,13 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_join_select_and_select_start(self):
|
||||
df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")).join(
|
||||
self.df_spark_store, "store_id", "inner"
|
||||
)
|
||||
df = self.df_spark_employee.select(
|
||||
F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")
|
||||
).join(self.df_spark_store, "store_id", "inner")
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")).join(
|
||||
self.df_sqlglot_store, "store_id", "inner"
|
||||
)
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
|
||||
).join(self.df_sqlglot_store, "store_id", "inner")
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
|
@ -485,13 +538,17 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
dfs_unioned = (
|
||||
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"))
|
||||
.unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")))
|
||||
.unionAll(self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name")))
|
||||
.unionAll(
|
||||
self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))
|
||||
)
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df_unioned, dfs_unioned)
|
||||
|
||||
def test_union_by_name(self):
|
||||
df = self.df_spark_employee.select(F.col("employee_id"), F.col("fname"), F.col("lname")).unionByName(
|
||||
df = self.df_spark_employee.select(
|
||||
F.col("employee_id"), F.col("fname"), F.col("lname")
|
||||
).unionByName(
|
||||
self.df_spark_store.select(
|
||||
F.col("store_name").alias("lname"),
|
||||
F.col("store_id").alias("employee_id"),
|
||||
|
@ -499,7 +556,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
)
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"), SF.col("lname")).unionByName(
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.col("employee_id"), SF.col("fname"), SF.col("lname")
|
||||
).unionByName(
|
||||
self.df_sqlglot_store.select(
|
||||
SF.col("store_name").alias("lname"),
|
||||
SF.col("store_id").alias("employee_id"),
|
||||
|
@ -537,10 +596,16 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_order_by_default(self):
|
||||
df = self.df_spark_store.groupBy(F.col("district_id")).agg(F.min("num_sales")).orderBy(F.col("district_id"))
|
||||
df = (
|
||||
self.df_spark_store.groupBy(F.col("district_id"))
|
||||
.agg(F.min("num_sales"))
|
||||
.orderBy(F.col("district_id"))
|
||||
)
|
||||
|
||||
dfs = (
|
||||
self.df_sqlglot_store.groupBy(SF.col("district_id")).agg(SF.min("num_sales")).orderBy(SF.col("district_id"))
|
||||
self.df_sqlglot_store.groupBy(SF.col("district_id"))
|
||||
.agg(SF.min("num_sales"))
|
||||
.orderBy(SF.col("district_id"))
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
@ -594,13 +659,17 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
df = (
|
||||
self.df_spark_store.groupBy(F.col("district_id"))
|
||||
.agg(F.min("num_sales").alias("total_sales"))
|
||||
.orderBy(F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last())
|
||||
.orderBy(
|
||||
F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()
|
||||
)
|
||||
)
|
||||
|
||||
dfs = (
|
||||
self.df_sqlglot_store.groupBy(SF.col("district_id"))
|
||||
.agg(SF.min("num_sales").alias("total_sales"))
|
||||
.orderBy(SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last())
|
||||
.orderBy(
|
||||
SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()
|
||||
)
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
@ -609,81 +678,87 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
df = (
|
||||
self.df_spark_store.groupBy(F.col("district_id"))
|
||||
.agg(F.min("num_sales").alias("total_sales"))
|
||||
.orderBy(F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first())
|
||||
.orderBy(
|
||||
F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()
|
||||
)
|
||||
)
|
||||
|
||||
dfs = (
|
||||
self.df_sqlglot_store.groupBy(SF.col("district_id"))
|
||||
.agg(SF.min("num_sales").alias("total_sales"))
|
||||
.orderBy(SF.when(SF.col("district_id") == SF.lit(1), SF.col("district_id")).desc_nulls_first())
|
||||
.orderBy(
|
||||
SF.when(
|
||||
SF.col("district_id") == SF.lit(1), SF.col("district_id")
|
||||
).desc_nulls_first()
|
||||
)
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_intersect(self):
|
||||
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
|
||||
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
|
||||
)
|
||||
df_employee_duplicate = self.df_spark_employee.select(
|
||||
F.col("employee_id"), F.col("store_id")
|
||||
).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
|
||||
|
||||
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
|
||||
self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
|
||||
)
|
||||
df_store_duplicate = self.df_spark_store.select(
|
||||
F.col("store_id"), F.col("district_id")
|
||||
).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
|
||||
|
||||
df = df_employee_duplicate.intersect(df_store_duplicate)
|
||||
|
||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
|
||||
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
|
||||
)
|
||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(
|
||||
SF.col("employee_id"), SF.col("store_id")
|
||||
).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
|
||||
|
||||
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
|
||||
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
|
||||
)
|
||||
dfs_store_duplicate = self.df_sqlglot_store.select(
|
||||
SF.col("store_id"), SF.col("district_id")
|
||||
).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
|
||||
|
||||
dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_intersect_all(self):
|
||||
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
|
||||
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
|
||||
)
|
||||
df_employee_duplicate = self.df_spark_employee.select(
|
||||
F.col("employee_id"), F.col("store_id")
|
||||
).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
|
||||
|
||||
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
|
||||
self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
|
||||
)
|
||||
df_store_duplicate = self.df_spark_store.select(
|
||||
F.col("store_id"), F.col("district_id")
|
||||
).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
|
||||
|
||||
df = df_employee_duplicate.intersectAll(df_store_duplicate)
|
||||
|
||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
|
||||
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
|
||||
)
|
||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(
|
||||
SF.col("employee_id"), SF.col("store_id")
|
||||
).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
|
||||
|
||||
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
|
||||
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
|
||||
)
|
||||
dfs_store_duplicate = self.df_sqlglot_store.select(
|
||||
SF.col("store_id"), SF.col("district_id")
|
||||
).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
|
||||
|
||||
dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_except_all(self):
|
||||
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
|
||||
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
|
||||
)
|
||||
df_employee_duplicate = self.df_spark_employee.select(
|
||||
F.col("employee_id"), F.col("store_id")
|
||||
).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
|
||||
|
||||
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
|
||||
self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
|
||||
)
|
||||
df_store_duplicate = self.df_spark_store.select(
|
||||
F.col("store_id"), F.col("district_id")
|
||||
).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
|
||||
|
||||
df = df_employee_duplicate.exceptAll(df_store_duplicate)
|
||||
|
||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
|
||||
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
|
||||
)
|
||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(
|
||||
SF.col("employee_id"), SF.col("store_id")
|
||||
).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
|
||||
|
||||
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
|
||||
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
|
||||
)
|
||||
dfs_store_duplicate = self.df_sqlglot_store.select(
|
||||
SF.col("store_id"), SF.col("district_id")
|
||||
).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
|
||||
|
||||
dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate)
|
||||
|
||||
|
@ -721,7 +796,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_drop_na_default(self):
|
||||
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).dropna()
|
||||
df = self.df_spark_employee.select(
|
||||
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
|
||||
).dropna()
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||
|
@ -746,7 +823,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
).dropna(how="any", thresh=2)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||
SF.lit(None),
|
||||
SF.lit(1),
|
||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
|
||||
).dropna(how="any", thresh=2)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
@ -757,13 +836,17 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
).dropna(thresh=1, subset="the_age")
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||
SF.lit(None),
|
||||
SF.lit(1),
|
||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
|
||||
).dropna(thresh=1, subset="the_age")
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
||||
def test_dropna_na_function(self):
|
||||
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.drop()
|
||||
df = self.df_spark_employee.select(
|
||||
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
|
||||
).na.drop()
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||
|
@ -772,7 +855,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_fillna_default(self):
|
||||
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).fillna(100)
|
||||
df = self.df_spark_employee.select(
|
||||
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
|
||||
).fillna(100)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||
|
@ -798,7 +883,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
||||
def test_fillna_na_func(self):
|
||||
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.fill(100)
|
||||
df = self.df_spark_employee.select(
|
||||
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
|
||||
).na.fill(100)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||
|
@ -807,7 +894,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_replace_basic(self):
|
||||
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(to_replace=37, value=100)
|
||||
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
|
||||
to_replace=37, value=100
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
|
||||
to_replace=37, value=100
|
||||
|
@ -827,9 +916,13 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
||||
def test_replace_mapping(self):
|
||||
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace({37: 100})
|
||||
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
|
||||
{37: 100}
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace({37: 100})
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
|
||||
{37: 100}
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
||||
|
@ -849,9 +942,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
to_replace=37, value=100
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).na.replace(
|
||||
to_replace=37, value=100
|
||||
)
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.col("age"), SF.lit(37).alias("test_col")
|
||||
).na.replace(to_replace=37, value=100)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
||||
|
@ -881,16 +974,18 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
"first_name", "first_name_again"
|
||||
)
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("fname").alias("first_name")).withColumnRenamed(
|
||||
"first_name", "first_name_again"
|
||||
)
|
||||
dfs = self.df_sqlglot_employee.select(
|
||||
SF.col("fname").alias("first_name")
|
||||
).withColumnRenamed("first_name", "first_name_again")
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_drop_column_single(self):
|
||||
df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age")
|
||||
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop("age")
|
||||
dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop(
|
||||
"age"
|
||||
)
|
||||
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
|
@ -906,7 +1001,9 @@ class TestDataframeFunc(DataFrameValidator):
|
|||
df_sqlglot_employee_cols = self.df_sqlglot_employee.select(
|
||||
SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
|
||||
)
|
||||
df_sqlglot_store_cols = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name"))
|
||||
df_sqlglot_store_cols = self.df_sqlglot_store.select(
|
||||
SF.col("store_id"), SF.col("store_name")
|
||||
)
|
||||
dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop(
|
||||
df_sqlglot_employee_cols.age,
|
||||
)
|
||||
|
|
|
@ -23,6 +23,14 @@ class TestSessionFunc(DataFrameValidator):
|
|||
ON
|
||||
e.store_id = s.store_id
|
||||
"""
|
||||
df = self.spark.sql(query).groupBy(F.col("store_id")).agg(F.countDistinct(F.col("employee_id")))
|
||||
dfs = self.sqlglot.sql(query).groupBy(SF.col("store_id")).agg(SF.countDistinct(SF.col("employee_id")))
|
||||
df = (
|
||||
self.spark.sql(query)
|
||||
.groupBy(F.col("store_id"))
|
||||
.agg(F.countDistinct(F.col("employee_id")))
|
||||
)
|
||||
dfs = (
|
||||
self.sqlglot.sql(query)
|
||||
.groupBy(SF.col("store_id"))
|
||||
.agg(SF.countDistinct(SF.col("employee_id")))
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||
|
|
|
@ -25,11 +25,17 @@ class DataFrameSQLValidator(unittest.TestCase):
|
|||
(4, "Claire", "Littleton", 27, 2),
|
||||
(5, "Hugo", "Reyes", 29, 100),
|
||||
]
|
||||
self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema)
|
||||
self.df_employee = self.spark.createDataFrame(
|
||||
data=employee_data, schema=self.employee_schema
|
||||
)
|
||||
|
||||
def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False):
|
||||
def compare_sql(
|
||||
self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False
|
||||
):
|
||||
actual_sqls = df.sql(pretty=pretty)
|
||||
expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements
|
||||
expected_statements = (
|
||||
[expected_statements] if isinstance(expected_statements, str) else expected_statements
|
||||
)
|
||||
self.assertEqual(len(expected_statements), len(actual_sqls))
|
||||
for expected, actual in zip(expected_statements, actual_sqls):
|
||||
self.assertEqual(expected, actual)
|
||||
|
|
|
@ -26,12 +26,14 @@ class TestDataframeColumn(unittest.TestCase):
|
|||
|
||||
def test_and(self):
|
||||
self.assertEqual(
|
||||
"cola = colb AND colc = cold", ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql()
|
||||
"cola = colb AND colc = cold",
|
||||
((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql(),
|
||||
)
|
||||
|
||||
def test_or(self):
|
||||
self.assertEqual(
|
||||
"cola = colb OR colc = cold", ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql()
|
||||
"cola = colb OR colc = cold",
|
||||
((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql(),
|
||||
)
|
||||
|
||||
def test_mod(self):
|
||||
|
@ -112,7 +114,9 @@ class TestDataframeColumn(unittest.TestCase):
|
|||
|
||||
def test_when_otherwise(self):
|
||||
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql())
|
||||
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql())
|
||||
self.assertEqual(
|
||||
"CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql()
|
||||
)
|
||||
self.assertEqual(
|
||||
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
|
||||
(F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(),
|
||||
|
@ -148,7 +152,9 @@ class TestDataframeColumn(unittest.TestCase):
|
|||
self.assertEqual(
|
||||
"cola BETWEEN CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP) "
|
||||
"AND CAST('2022-03-01 01:01:01.000000' AS TIMESTAMP)",
|
||||
F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(),
|
||||
F.col("cola")
|
||||
.between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1))
|
||||
.sql(),
|
||||
)
|
||||
|
||||
def test_over(self):
|
||||
|
|
|
@ -9,7 +9,9 @@ class TestDataframe(DataFrameSQLValidator):
|
|||
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
|
||||
|
||||
def test_columns(self):
|
||||
self.assertEqual(["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns)
|
||||
self.assertEqual(
|
||||
["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns
|
||||
)
|
||||
|
||||
def test_cache(self):
|
||||
df = self.df_employee.select("fname").cache()
|
||||
|
|
|
@ -925,12 +925,17 @@ class TestFunctions(unittest.TestCase):
|
|||
col = SF.window(SF.col("cola"), "10 minutes")
|
||||
self.assertEqual("WINDOW(cola, '10 minutes')", col.sql())
|
||||
col_all_values = SF.window("cola", "2 minutes 30 seconds", "30 seconds", "15 seconds")
|
||||
self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql())
|
||||
self.assertEqual(
|
||||
"WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql()
|
||||
)
|
||||
col_no_start_time = SF.window("cola", "2 minutes 30 seconds", "30 seconds")
|
||||
self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql())
|
||||
self.assertEqual(
|
||||
"WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql()
|
||||
)
|
||||
col_no_slide = SF.window("cola", "2 minutes 30 seconds", startTime="15 seconds")
|
||||
self.assertEqual(
|
||||
"WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')", col_no_slide.sql()
|
||||
"WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')",
|
||||
col_no_slide.sql(),
|
||||
)
|
||||
|
||||
def test_session_window(self):
|
||||
|
@ -1359,9 +1364,13 @@ class TestFunctions(unittest.TestCase):
|
|||
|
||||
def test_from_json(self):
|
||||
col_str = SF.from_json("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
|
||||
self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
|
||||
self.assertEqual(
|
||||
"FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
|
||||
)
|
||||
col = SF.from_json(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
|
||||
self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
|
||||
self.assertEqual(
|
||||
"FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()
|
||||
)
|
||||
col_no_option = SF.from_json("cola", "cola INT")
|
||||
self.assertEqual("FROM_JSON(cola, 'cola INT')", col_no_option.sql())
|
||||
|
||||
|
@ -1375,7 +1384,9 @@ class TestFunctions(unittest.TestCase):
|
|||
|
||||
def test_schema_of_json(self):
|
||||
col_str = SF.schema_of_json("cola", dict(timestampFormat="dd/MM/yyyy"))
|
||||
self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
|
||||
self.assertEqual(
|
||||
"SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
|
||||
)
|
||||
col = SF.schema_of_json(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy"))
|
||||
self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
|
||||
col_no_option = SF.schema_of_json("cola")
|
||||
|
@ -1429,7 +1440,10 @@ class TestFunctions(unittest.TestCase):
|
|||
col = SF.array_sort(SF.col("cola"))
|
||||
self.assertEqual("ARRAY_SORT(cola)", col.sql())
|
||||
col_comparator = SF.array_sort(
|
||||
"cola", lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(SF.length(y) - SF.length(x))
|
||||
"cola",
|
||||
lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(
|
||||
SF.length(y) - SF.length(x)
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
"ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)",
|
||||
|
@ -1504,9 +1518,13 @@ class TestFunctions(unittest.TestCase):
|
|||
|
||||
def test_from_csv(self):
|
||||
col_str = SF.from_csv("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
|
||||
self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
|
||||
self.assertEqual(
|
||||
"FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
|
||||
)
|
||||
col = SF.from_csv(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
|
||||
self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
|
||||
self.assertEqual(
|
||||
"FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()
|
||||
)
|
||||
col_no_option = SF.from_csv("cola", "cola INT")
|
||||
self.assertEqual("FROM_CSV(cola, 'cola INT')", col_no_option.sql())
|
||||
|
||||
|
@ -1535,7 +1553,9 @@ class TestFunctions(unittest.TestCase):
|
|||
self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql())
|
||||
col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count)
|
||||
|
||||
self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql())
|
||||
self.assertEqual(
|
||||
"TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql()
|
||||
)
|
||||
|
||||
def test_exists(self):
|
||||
col_str = SF.exists("cola", lambda x: x % 2 == 0)
|
||||
|
@ -1558,10 +1578,13 @@ class TestFunctions(unittest.TestCase):
|
|||
self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql())
|
||||
col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i))
|
||||
self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql())
|
||||
col_custom_names = SF.filter("cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count))
|
||||
col_custom_names = SF.filter(
|
||||
"cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count)
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
"FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql()
|
||||
"FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)",
|
||||
col_custom_names.sql(),
|
||||
)
|
||||
|
||||
def test_zip_with(self):
|
||||
|
@ -1570,7 +1593,9 @@ class TestFunctions(unittest.TestCase):
|
|||
col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y))
|
||||
self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql())
|
||||
col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r))
|
||||
self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql())
|
||||
self.assertEqual(
|
||||
"ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql()
|
||||
)
|
||||
|
||||
def test_transform_keys(self):
|
||||
col_str = SF.transform_keys("cola", lambda k, v: SF.upper(k))
|
||||
|
@ -1586,7 +1611,9 @@ class TestFunctions(unittest.TestCase):
|
|||
col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v))
|
||||
self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql())
|
||||
col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value))
|
||||
self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql())
|
||||
self.assertEqual(
|
||||
"TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql()
|
||||
)
|
||||
|
||||
def test_map_filter(self):
|
||||
col_str = SF.map_filter("cola", lambda k, v: k > v)
|
||||
|
|
|
@ -21,9 +21,7 @@ class TestDataframeSession(DataFrameSQLValidator):
|
|||
|
||||
def test_cdf_no_schema(self):
|
||||
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
|
||||
expected = (
|
||||
"SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
|
||||
)
|
||||
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_cdf_row_mixed_primitives(self):
|
||||
|
@ -77,7 +75,8 @@ class TestDataframeSession(DataFrameSQLValidator):
|
|||
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
|
||||
df = self.spark.sql(query)
|
||||
self.assertIn(
|
||||
"SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", df.sql(pretty=False)
|
||||
"SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`",
|
||||
df.sql(pretty=False),
|
||||
)
|
||||
|
||||
@mock.patch("sqlglot.schema", MappingSchema())
|
||||
|
@ -104,9 +103,7 @@ class TestDataframeSession(DataFrameSQLValidator):
|
|||
query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1"
|
||||
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
|
||||
df = self.spark.sql(query)
|
||||
expected = (
|
||||
"INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
|
||||
)
|
||||
expected = "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_session_create_builder_patterns(self):
|
||||
|
|
|
@ -53,7 +53,10 @@ class TestDataframeTypes(unittest.TestCase):
|
|||
self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString())
|
||||
|
||||
def test_map(self):
|
||||
self.assertEqual("map<int, string>", types.MapType(types.IntegerType(), types.StringType()).simpleString())
|
||||
self.assertEqual(
|
||||
"map<int, string>",
|
||||
types.MapType(types.IntegerType(), types.StringType()).simpleString(),
|
||||
)
|
||||
|
||||
def test_struct_field(self):
|
||||
self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString())
|
||||
|
|
|
@ -39,22 +39,38 @@ class TestDataframeWindow(unittest.TestCase):
|
|||
|
||||
def test_window_rows_unbounded(self):
|
||||
rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
|
||||
self.assertEqual("OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", rows_between_unbounded_start.sql())
|
||||
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
|
||||
self.assertEqual("OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_end.sql())
|
||||
rows_between_unbounded_both = Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
|
||||
self.assertEqual(
|
||||
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_both.sql()
|
||||
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
|
||||
rows_between_unbounded_start.sql(),
|
||||
)
|
||||
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
|
||||
self.assertEqual(
|
||||
"OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||
rows_between_unbounded_end.sql(),
|
||||
)
|
||||
rows_between_unbounded_both = Window.rowsBetween(
|
||||
Window.unboundedPreceding, Window.unboundedFollowing
|
||||
)
|
||||
self.assertEqual(
|
||||
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||
rows_between_unbounded_both.sql(),
|
||||
)
|
||||
|
||||
def test_window_range_unbounded(self):
|
||||
range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
|
||||
self.assertEqual(
|
||||
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", range_between_unbounded_start.sql()
|
||||
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
|
||||
range_between_unbounded_start.sql(),
|
||||
)
|
||||
range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
|
||||
self.assertEqual("OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_end.sql())
|
||||
range_between_unbounded_both = Window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)
|
||||
self.assertEqual(
|
||||
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_both.sql()
|
||||
"OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||
range_between_unbounded_end.sql(),
|
||||
)
|
||||
range_between_unbounded_both = Window.rangeBetween(
|
||||
Window.unboundedPreceding, Window.unboundedFollowing
|
||||
)
|
||||
self.assertEqual(
|
||||
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||
range_between_unbounded_both.sql(),
|
||||
)
|
||||
|
|
|
@ -157,6 +157,14 @@ class TestBigQuery(Validator):
|
|||
},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"DIV(x, y)",
|
||||
write={
|
||||
"bigquery": "DIV(x, y)",
|
||||
"duckdb": "CAST(x / y AS INT)",
|
||||
},
|
||||
)
|
||||
|
||||
self.validate_identity(
|
||||
"SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)"
|
||||
)
|
||||
|
@ -284,4 +292,6 @@ class TestBigQuery(Validator):
|
|||
"CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'"
|
||||
)
|
||||
self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)")
|
||||
self.validate_identity("CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t")
|
||||
self.validate_identity(
|
||||
"CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t"
|
||||
)
|
||||
|
|
|
@ -18,7 +18,6 @@ class TestClickhouse(Validator):
|
|||
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
|
||||
},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"CAST(1 AS NULLABLE(Int64))",
|
||||
write={
|
||||
|
@ -31,3 +30,7 @@ class TestClickhouse(Validator):
|
|||
"clickhouse": "CAST(1 AS Nullable(DateTime64(6, 'UTC')))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT x #! comment",
|
||||
write={"": "SELECT x /* comment */"},
|
||||
)
|
||||
|
|
|
@ -22,7 +22,8 @@ class TestDatabricks(Validator):
|
|||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT DATEDIFF('end', 'start')", write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"}
|
||||
"SELECT DATEDIFF('end', 'start')",
|
||||
write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT DATE_ADD('2020-01-01', 1)",
|
||||
|
|
|
@ -1,20 +1,18 @@
|
|||
import unittest
|
||||
|
||||
from sqlglot import (
|
||||
Dialect,
|
||||
Dialects,
|
||||
ErrorLevel,
|
||||
UnsupportedError,
|
||||
parse_one,
|
||||
transpile,
|
||||
)
|
||||
from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one
|
||||
|
||||
|
||||
class Validator(unittest.TestCase):
|
||||
dialect = None
|
||||
|
||||
def validate_identity(self, sql):
|
||||
self.assertEqual(transpile(sql, read=self.dialect, write=self.dialect)[0], sql)
|
||||
def parse_one(self, sql):
|
||||
return parse_one(sql, read=self.dialect)
|
||||
|
||||
def validate_identity(self, sql, write_sql=None):
|
||||
expression = self.parse_one(sql)
|
||||
self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect))
|
||||
return expression
|
||||
|
||||
def validate_all(self, sql, read=None, write=None, pretty=False):
|
||||
"""
|
||||
|
@ -28,12 +26,14 @@ class Validator(unittest.TestCase):
|
|||
read (dict): Mapping of dialect -> SQL
|
||||
write (dict): Mapping of dialect -> SQL
|
||||
"""
|
||||
expression = parse_one(sql, read=self.dialect)
|
||||
expression = self.parse_one(sql)
|
||||
|
||||
for read_dialect, read_sql in (read or {}).items():
|
||||
with self.subTest(f"{read_dialect} -> {sql}"):
|
||||
self.assertEqual(
|
||||
parse_one(read_sql, read_dialect).sql(self.dialect, unsupported_level=ErrorLevel.IGNORE),
|
||||
parse_one(read_sql, read_dialect).sql(
|
||||
self.dialect, unsupported_level=ErrorLevel.IGNORE, pretty=pretty
|
||||
),
|
||||
sql,
|
||||
)
|
||||
|
||||
|
@ -83,10 +83,6 @@ class TestDialect(Validator):
|
|||
)
|
||||
self.validate_all(
|
||||
"CAST(a AS BINARY(4))",
|
||||
read={
|
||||
"presto": "CAST(a AS VARBINARY(4))",
|
||||
"sqlite": "CAST(a AS VARBINARY(4))",
|
||||
},
|
||||
write={
|
||||
"bigquery": "CAST(a AS BINARY(4))",
|
||||
"clickhouse": "CAST(a AS BINARY(4))",
|
||||
|
@ -103,6 +99,24 @@ class TestDialect(Validator):
|
|||
"starrocks": "CAST(a AS BINARY(4))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CAST(a AS VARBINARY(4))",
|
||||
write={
|
||||
"bigquery": "CAST(a AS VARBINARY(4))",
|
||||
"clickhouse": "CAST(a AS VARBINARY(4))",
|
||||
"duckdb": "CAST(a AS VARBINARY(4))",
|
||||
"mysql": "CAST(a AS VARBINARY(4))",
|
||||
"hive": "CAST(a AS BINARY(4))",
|
||||
"oracle": "CAST(a AS BLOB(4))",
|
||||
"postgres": "CAST(a AS BYTEA(4))",
|
||||
"presto": "CAST(a AS VARBINARY(4))",
|
||||
"redshift": "CAST(a AS VARBYTE(4))",
|
||||
"snowflake": "CAST(a AS VARBINARY(4))",
|
||||
"sqlite": "CAST(a AS BLOB(4))",
|
||||
"spark": "CAST(a AS BINARY(4))",
|
||||
"starrocks": "CAST(a AS VARBINARY(4))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CAST(MAP('a', '1') AS MAP(TEXT, TEXT))",
|
||||
write={
|
||||
|
@ -472,45 +486,57 @@ class TestDialect(Validator):
|
|||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"DATE_TRUNC(x, 'day')",
|
||||
"DATE_TRUNC('day', x)",
|
||||
write={
|
||||
"mysql": "DATE(x)",
|
||||
"starrocks": "DATE(x)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"DATE_TRUNC(x, 'week')",
|
||||
"DATE_TRUNC('week', x)",
|
||||
write={
|
||||
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')",
|
||||
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"DATE_TRUNC(x, 'month')",
|
||||
"DATE_TRUNC('month', x)",
|
||||
write={
|
||||
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')",
|
||||
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"DATE_TRUNC(x, 'quarter')",
|
||||
"DATE_TRUNC('quarter', x)",
|
||||
write={
|
||||
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')",
|
||||
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"DATE_TRUNC(x, 'year')",
|
||||
"DATE_TRUNC('year', x)",
|
||||
write={
|
||||
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
|
||||
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"DATE_TRUNC(x, 'millenium')",
|
||||
"DATE_TRUNC('millenium', x)",
|
||||
write={
|
||||
"mysql": UnsupportedError,
|
||||
"starrocks": UnsupportedError,
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"DATE_TRUNC('year', x)",
|
||||
read={
|
||||
"starrocks": "DATE_TRUNC('year', x)",
|
||||
},
|
||||
write={
|
||||
"starrocks": "DATE_TRUNC('year', x)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"DATE_TRUNC(x, year)",
|
||||
read={
|
||||
"bigquery": "DATE_TRUNC(x, year)",
|
||||
},
|
||||
write={
|
||||
"bigquery": "DATE_TRUNC(x, year)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
|
@ -564,6 +590,22 @@ class TestDialect(Validator):
|
|||
"spark": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"TIMESTAMP '2022-01-01'",
|
||||
write={
|
||||
"mysql": "CAST('2022-01-01' AS TIMESTAMP)",
|
||||
"starrocks": "CAST('2022-01-01' AS DATETIME)",
|
||||
"hive": "CAST('2022-01-01' AS TIMESTAMP)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"TIMESTAMP('2022-01-01')",
|
||||
write={
|
||||
"mysql": "TIMESTAMP('2022-01-01')",
|
||||
"starrocks": "TIMESTAMP('2022-01-01')",
|
||||
"hive": "TIMESTAMP('2022-01-01')",
|
||||
},
|
||||
)
|
||||
|
||||
for unit in ("DAY", "MONTH", "YEAR"):
|
||||
self.validate_all(
|
||||
|
@ -1002,7 +1044,10 @@ class TestDialect(Validator):
|
|||
)
|
||||
|
||||
def test_limit(self):
|
||||
self.validate_all("SELECT * FROM data LIMIT 10, 20", write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"})
|
||||
self.validate_all(
|
||||
"SELECT * FROM data LIMIT 10, 20",
|
||||
write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT x FROM y LIMIT 10",
|
||||
write={
|
||||
|
@ -1132,3 +1177,56 @@ class TestDialect(Validator):
|
|||
"sqlite": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
|
||||
},
|
||||
)
|
||||
|
||||
def test_nullsafe_eq(self):
|
||||
self.validate_all(
|
||||
"SELECT a IS NOT DISTINCT FROM b",
|
||||
read={
|
||||
"mysql": "SELECT a <=> b",
|
||||
"postgres": "SELECT a IS NOT DISTINCT FROM b",
|
||||
},
|
||||
write={
|
||||
"mysql": "SELECT a <=> b",
|
||||
"postgres": "SELECT a IS NOT DISTINCT FROM b",
|
||||
},
|
||||
)
|
||||
|
||||
def test_nullsafe_neq(self):
|
||||
self.validate_all(
|
||||
"SELECT a IS DISTINCT FROM b",
|
||||
read={
|
||||
"postgres": "SELECT a IS DISTINCT FROM b",
|
||||
},
|
||||
write={
|
||||
"mysql": "SELECT NOT a <=> b",
|
||||
"postgres": "SELECT a IS DISTINCT FROM b",
|
||||
},
|
||||
)
|
||||
|
||||
def test_hash_comments(self):
|
||||
self.validate_all(
|
||||
"SELECT 1 /* arbitrary content,,, until end-of-line */",
|
||||
read={
|
||||
"mysql": "SELECT 1 # arbitrary content,,, until end-of-line",
|
||||
"bigquery": "SELECT 1 # arbitrary content,,, until end-of-line",
|
||||
"clickhouse": "SELECT 1 #! arbitrary content,,, until end-of-line",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"""/* comment1 */
|
||||
SELECT
|
||||
x, -- comment2
|
||||
y -- comment3""",
|
||||
read={
|
||||
"mysql": """SELECT # comment1
|
||||
x, # comment2
|
||||
y # comment3""",
|
||||
"bigquery": """SELECT # comment1
|
||||
x, # comment2
|
||||
y # comment3""",
|
||||
"clickhouse": """SELECT # comment1
|
||||
x, # comment2
|
||||
y # comment3""",
|
||||
},
|
||||
pretty=True,
|
||||
)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from sqlglot import expressions as exp
|
||||
from tests.dialects.test_dialect import Validator
|
||||
|
||||
|
||||
|
@ -20,6 +21,52 @@ class TestMySQL(Validator):
|
|||
self.validate_identity("SELECT TRIM(TRAILING 'bla' FROM ' XXX ')")
|
||||
self.validate_identity("SELECT TRIM(BOTH 'bla' FROM ' XXX ')")
|
||||
self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')")
|
||||
self.validate_identity("@@GLOBAL.max_connections")
|
||||
|
||||
# SET Commands
|
||||
self.validate_identity("SET @var_name = expr")
|
||||
self.validate_identity("SET @name = 43")
|
||||
self.validate_identity("SET @total_tax = (SELECT SUM(tax) FROM taxable_transactions)")
|
||||
self.validate_identity("SET GLOBAL max_connections = 1000")
|
||||
self.validate_identity("SET @@GLOBAL.max_connections = 1000")
|
||||
self.validate_identity("SET SESSION sql_mode = 'TRADITIONAL'")
|
||||
self.validate_identity("SET LOCAL sql_mode = 'TRADITIONAL'")
|
||||
self.validate_identity("SET @@SESSION.sql_mode = 'TRADITIONAL'")
|
||||
self.validate_identity("SET @@LOCAL.sql_mode = 'TRADITIONAL'")
|
||||
self.validate_identity("SET @@sql_mode = 'TRADITIONAL'")
|
||||
self.validate_identity("SET sql_mode = 'TRADITIONAL'")
|
||||
self.validate_identity("SET PERSIST max_connections = 1000")
|
||||
self.validate_identity("SET @@PERSIST.max_connections = 1000")
|
||||
self.validate_identity("SET PERSIST_ONLY back_log = 100")
|
||||
self.validate_identity("SET @@PERSIST_ONLY.back_log = 100")
|
||||
self.validate_identity("SET @@SESSION.max_join_size = DEFAULT")
|
||||
self.validate_identity("SET @@SESSION.max_join_size = @@GLOBAL.max_join_size")
|
||||
self.validate_identity("SET @x = 1, SESSION sql_mode = ''")
|
||||
self.validate_identity(
|
||||
"SET GLOBAL sort_buffer_size = 1000000, SESSION sort_buffer_size = 1000000"
|
||||
)
|
||||
self.validate_identity(
|
||||
"SET @@GLOBAL.sort_buffer_size = 1000000, @@LOCAL.sort_buffer_size = 1000000"
|
||||
)
|
||||
self.validate_identity("SET GLOBAL max_connections = 1000, sort_buffer_size = 1000000")
|
||||
self.validate_identity("SET @@GLOBAL.sort_buffer_size = 50000, sort_buffer_size = 1000000")
|
||||
self.validate_identity("SET CHARACTER SET 'utf8'")
|
||||
self.validate_identity("SET CHARACTER SET utf8")
|
||||
self.validate_identity("SET CHARACTER SET DEFAULT")
|
||||
self.validate_identity("SET NAMES 'utf8'")
|
||||
self.validate_identity("SET NAMES DEFAULT")
|
||||
self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'")
|
||||
self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci")
|
||||
self.validate_identity("SET autocommit = ON")
|
||||
|
||||
def test_escape(self):
|
||||
self.validate_all(
|
||||
r"'a \' b '' '",
|
||||
write={
|
||||
"mysql": r"'a '' b '' '",
|
||||
"spark": r"'a \' b \' '",
|
||||
},
|
||||
)
|
||||
|
||||
def test_introducers(self):
|
||||
self.validate_all(
|
||||
|
@ -115,14 +162,6 @@ class TestMySQL(Validator):
|
|||
},
|
||||
)
|
||||
|
||||
def test_hash_comments(self):
|
||||
self.validate_all(
|
||||
"SELECT 1 # arbitrary content,,, until end-of-line",
|
||||
write={
|
||||
"mysql": "SELECT 1",
|
||||
},
|
||||
)
|
||||
|
||||
def test_mysql(self):
|
||||
self.validate_all(
|
||||
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
|
||||
|
@ -174,3 +213,242 @@ COMMENT='客户账户表'"""
|
|||
},
|
||||
pretty=True,
|
||||
)
|
||||
|
||||
def test_show_simple(self):
|
||||
for key, write_key in [
|
||||
("BINARY LOGS", "BINARY LOGS"),
|
||||
("MASTER LOGS", "BINARY LOGS"),
|
||||
("STORAGE ENGINES", "ENGINES"),
|
||||
("ENGINES", "ENGINES"),
|
||||
("EVENTS", "EVENTS"),
|
||||
("MASTER STATUS", "MASTER STATUS"),
|
||||
("PLUGINS", "PLUGINS"),
|
||||
("PRIVILEGES", "PRIVILEGES"),
|
||||
("PROFILES", "PROFILES"),
|
||||
("REPLICAS", "REPLICAS"),
|
||||
("SLAVE HOSTS", "REPLICAS"),
|
||||
]:
|
||||
show = self.validate_identity(f"SHOW {key}", f"SHOW {write_key}")
|
||||
self.assertIsInstance(show, exp.Show)
|
||||
self.assertEqual(show.name, write_key)
|
||||
|
||||
def test_show_events(self):
|
||||
for key in ["BINLOG", "RELAYLOG"]:
|
||||
show = self.validate_identity(f"SHOW {key} EVENTS")
|
||||
self.assertIsInstance(show, exp.Show)
|
||||
self.assertEqual(show.name, f"{key} EVENTS")
|
||||
|
||||
show = self.validate_identity(f"SHOW {key} EVENTS IN 'log' FROM 1 LIMIT 2, 3")
|
||||
self.assertEqual(show.text("log"), "log")
|
||||
self.assertEqual(show.text("position"), "1")
|
||||
self.assertEqual(show.text("limit"), "3")
|
||||
self.assertEqual(show.text("offset"), "2")
|
||||
|
||||
show = self.validate_identity(f"SHOW {key} EVENTS LIMIT 1")
|
||||
self.assertEqual(show.text("limit"), "1")
|
||||
self.assertIsNone(show.args.get("offset"))
|
||||
|
||||
def test_show_like_or_where(self):
|
||||
for key, write_key in [
|
||||
("CHARSET", "CHARACTER SET"),
|
||||
("CHARACTER SET", "CHARACTER SET"),
|
||||
("COLLATION", "COLLATION"),
|
||||
("DATABASES", "DATABASES"),
|
||||
("FUNCTION STATUS", "FUNCTION STATUS"),
|
||||
("PROCEDURE STATUS", "PROCEDURE STATUS"),
|
||||
("GLOBAL STATUS", "GLOBAL STATUS"),
|
||||
("SESSION STATUS", "STATUS"),
|
||||
("STATUS", "STATUS"),
|
||||
("GLOBAL VARIABLES", "GLOBAL VARIABLES"),
|
||||
("SESSION VARIABLES", "VARIABLES"),
|
||||
("VARIABLES", "VARIABLES"),
|
||||
]:
|
||||
expected_name = write_key.strip("GLOBAL").strip()
|
||||
template = "SHOW {}"
|
||||
show = self.validate_identity(template.format(key), template.format(write_key))
|
||||
self.assertIsInstance(show, exp.Show)
|
||||
self.assertEqual(show.name, expected_name)
|
||||
|
||||
template = "SHOW {} LIKE '%foo%'"
|
||||
show = self.validate_identity(template.format(key), template.format(write_key))
|
||||
self.assertIsInstance(show, exp.Show)
|
||||
self.assertIsInstance(show.args["like"], exp.Literal)
|
||||
self.assertEqual(show.text("like"), "%foo%")
|
||||
|
||||
template = "SHOW {} WHERE Column_name LIKE '%foo%'"
|
||||
show = self.validate_identity(template.format(key), template.format(write_key))
|
||||
self.assertIsInstance(show, exp.Show)
|
||||
self.assertIsInstance(show.args["where"], exp.Where)
|
||||
self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'")
|
||||
|
||||
def test_show_columns(self):
|
||||
show = self.validate_identity("SHOW COLUMNS FROM tbl_name")
|
||||
self.assertIsInstance(show, exp.Show)
|
||||
self.assertEqual(show.name, "COLUMNS")
|
||||
self.assertEqual(show.text("target"), "tbl_name")
|
||||
self.assertFalse(show.args["full"])
|
||||
|
||||
show = self.validate_identity("SHOW FULL COLUMNS FROM tbl_name FROM db_name LIKE '%foo%'")
|
||||
self.assertIsInstance(show, exp.Show)
|
||||
self.assertEqual(show.text("target"), "tbl_name")
|
||||
self.assertTrue(show.args["full"])
|
||||
self.assertEqual(show.text("db"), "db_name")
|
||||
self.assertIsInstance(show.args["like"], exp.Literal)
|
||||
self.assertEqual(show.text("like"), "%foo%")
|
||||
|
||||
def test_show_name(self):
|
||||
for key in [
|
||||
"CREATE DATABASE",
|
||||
"CREATE EVENT",
|
||||
"CREATE FUNCTION",
|
||||
"CREATE PROCEDURE",
|
||||
"CREATE TABLE",
|
||||
"CREATE TRIGGER",
|
||||
"CREATE VIEW",
|
||||
"FUNCTION CODE",
|
||||
"PROCEDURE CODE",
|
||||
]:
|
||||
show = self.validate_identity(f"SHOW {key} foo")
|
||||
self.assertIsInstance(show, exp.Show)
|
||||
self.assertEqual(show.name, key)
|
||||
self.assertEqual(show.text("target"), "foo")
|
||||
|
||||
def test_show_grants(self):
|
||||
show = self.validate_identity(f"SHOW GRANTS FOR foo")
|
||||
self.assertIsInstance(show, exp.Show)
|
||||
self.assertEqual(show.name, "GRANTS")
|
||||
self.assertEqual(show.text("target"), "foo")
|
||||
|
||||
def test_show_engine(self):
|
||||
show = self.validate_identity("SHOW ENGINE foo STATUS")
|
||||
self.assertIsInstance(show, exp.Show)
|
||||
self.assertEqual(show.name, "ENGINE")
|
||||
self.assertEqual(show.text("target"), "foo")
|
||||
self.assertFalse(show.args["mutex"])
|
||||
|
||||
show = self.validate_identity("SHOW ENGINE foo MUTEX")
|
||||
self.assertEqual(show.name, "ENGINE")
|
||||
self.assertEqual(show.text("target"), "foo")
|
||||
self.assertTrue(show.args["mutex"])
|
||||
|
||||
def test_show_errors(self):
|
||||
for key in ["ERRORS", "WARNINGS"]:
|
||||
show = self.validate_identity(f"SHOW {key}")
|
||||
self.assertIsInstance(show, exp.Show)
|
||||
self.assertEqual(show.name, key)
|
||||
|
||||
show = self.validate_identity(f"SHOW {key} LIMIT 2, 3")
|
||||
self.assertEqual(show.text("limit"), "3")
|
||||
self.assertEqual(show.text("offset"), "2")
|
||||
|
||||
def test_show_index(self):
|
||||
show = self.validate_identity("SHOW INDEX FROM foo")
|
||||
self.assertIsInstance(show, exp.Show)
|
||||
self.assertEqual(show.name, "INDEX")
|
||||
self.assertEqual(show.text("target"), "foo")
|
||||
|
||||
show = self.validate_identity("SHOW INDEX FROM foo FROM bar")
|
||||
self.assertEqual(show.text("db"), "bar")
|
||||
|
||||
def test_show_db_like_or_where_sql(self):
|
||||
for key in [
|
||||
"OPEN TABLES",
|
||||
"TABLE STATUS",
|
||||
"TRIGGERS",
|
||||
]:
|
||||
show = self.validate_identity(f"SHOW {key}")
|
||||
self.assertIsInstance(show, exp.Show)
|
||||
self.assertEqual(show.name, key)
|
||||
|
||||
show = self.validate_identity(f"SHOW {key} FROM db_name")
|
||||
self.assertEqual(show.name, key)
|
||||
self.assertEqual(show.text("db"), "db_name")
|
||||
|
||||
show = self.validate_identity(f"SHOW {key} LIKE '%foo%'")
|
||||
self.assertEqual(show.name, key)
|
||||
self.assertIsInstance(show.args["like"], exp.Literal)
|
||||
self.assertEqual(show.text("like"), "%foo%")
|
||||
|
||||
show = self.validate_identity(f"SHOW {key} WHERE Column_name LIKE '%foo%'")
|
||||
self.assertEqual(show.name, key)
|
||||
self.assertIsInstance(show.args["where"], exp.Where)
|
||||
self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'")
|
||||
|
||||
def test_show_processlist(self):
|
||||
show = self.validate_identity("SHOW PROCESSLIST")
|
||||
self.assertIsInstance(show, exp.Show)
|
||||
self.assertEqual(show.name, "PROCESSLIST")
|
||||
self.assertFalse(show.args["full"])
|
||||
|
||||
show = self.validate_identity("SHOW FULL PROCESSLIST")
|
||||
self.assertEqual(show.name, "PROCESSLIST")
|
||||
self.assertTrue(show.args["full"])
|
||||
|
||||
def test_show_profile(self):
|
||||
show = self.validate_identity("SHOW PROFILE")
|
||||
self.assertIsInstance(show, exp.Show)
|
||||
self.assertEqual(show.name, "PROFILE")
|
||||
|
||||
show = self.validate_identity("SHOW PROFILE BLOCK IO")
|
||||
self.assertEqual(show.args["types"][0].name, "BLOCK IO")
|
||||
|
||||
show = self.validate_identity(
|
||||
"SHOW PROFILE BLOCK IO, PAGE FAULTS FOR QUERY 1 OFFSET 2 LIMIT 3"
|
||||
)
|
||||
self.assertEqual(show.args["types"][0].name, "BLOCK IO")
|
||||
self.assertEqual(show.args["types"][1].name, "PAGE FAULTS")
|
||||
self.assertEqual(show.text("query"), "1")
|
||||
self.assertEqual(show.text("offset"), "2")
|
||||
self.assertEqual(show.text("limit"), "3")
|
||||
|
||||
def test_show_replica_status(self):
|
||||
show = self.validate_identity("SHOW REPLICA STATUS")
|
||||
self.assertIsInstance(show, exp.Show)
|
||||
self.assertEqual(show.name, "REPLICA STATUS")
|
||||
|
||||
show = self.validate_identity("SHOW SLAVE STATUS", "SHOW REPLICA STATUS")
|
||||
self.assertIsInstance(show, exp.Show)
|
||||
self.assertEqual(show.name, "REPLICA STATUS")
|
||||
|
||||
show = self.validate_identity("SHOW REPLICA STATUS FOR CHANNEL channel_name")
|
||||
self.assertEqual(show.text("channel"), "channel_name")
|
||||
|
||||
def test_show_tables(self):
|
||||
show = self.validate_identity("SHOW TABLES")
|
||||
self.assertIsInstance(show, exp.Show)
|
||||
self.assertEqual(show.name, "TABLES")
|
||||
|
||||
show = self.validate_identity("SHOW FULL TABLES FROM db_name LIKE '%foo%'")
|
||||
self.assertTrue(show.args["full"])
|
||||
self.assertEqual(show.text("db"), "db_name")
|
||||
self.assertIsInstance(show.args["like"], exp.Literal)
|
||||
self.assertEqual(show.text("like"), "%foo%")
|
||||
|
||||
def test_set_variable(self):
|
||||
cmd = self.parse_one("SET SESSION x = 1")
|
||||
item = cmd.expressions[0]
|
||||
self.assertEqual(item.text("kind"), "SESSION")
|
||||
self.assertIsInstance(item.this, exp.EQ)
|
||||
self.assertEqual(item.this.left.name, "x")
|
||||
self.assertEqual(item.this.right.name, "1")
|
||||
|
||||
cmd = self.parse_one("SET @@GLOBAL.x = @@GLOBAL.y")
|
||||
item = cmd.expressions[0]
|
||||
self.assertEqual(item.text("kind"), "")
|
||||
self.assertIsInstance(item.this, exp.EQ)
|
||||
self.assertIsInstance(item.this.left, exp.SessionParameter)
|
||||
self.assertIsInstance(item.this.right, exp.SessionParameter)
|
||||
|
||||
cmd = self.parse_one("SET NAMES 'charset_name' COLLATE 'collation_name'")
|
||||
item = cmd.expressions[0]
|
||||
self.assertEqual(item.text("kind"), "NAMES")
|
||||
self.assertEqual(item.name, "charset_name")
|
||||
self.assertEqual(item.text("collate"), "collation_name")
|
||||
|
||||
cmd = self.parse_one("SET CHARSET DEFAULT")
|
||||
item = cmd.expressions[0]
|
||||
self.assertEqual(item.text("kind"), "CHARACTER SET")
|
||||
self.assertEqual(item.this.name, "DEFAULT")
|
||||
|
||||
cmd = self.parse_one("SET x = 1, y = 2")
|
||||
self.assertEqual(len(cmd.expressions), 2)
|
||||
|
|
|
@ -8,7 +8,9 @@ class TestPostgres(Validator):
|
|||
def test_ddl(self):
|
||||
self.validate_all(
|
||||
"CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)",
|
||||
write={"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"},
|
||||
write={
|
||||
"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)",
|
||||
|
@ -59,15 +61,27 @@ class TestPostgres(Validator):
|
|||
|
||||
def test_postgres(self):
|
||||
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END")
|
||||
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END")
|
||||
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END")
|
||||
self.validate_identity('SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')')
|
||||
self.validate_identity(
|
||||
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END"
|
||||
)
|
||||
self.validate_identity(
|
||||
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END"
|
||||
)
|
||||
self.validate_identity(
|
||||
'SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')'
|
||||
)
|
||||
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')")
|
||||
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')")
|
||||
self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))")
|
||||
self.validate_identity(
|
||||
"SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')"
|
||||
)
|
||||
self.validate_identity(
|
||||
"SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))"
|
||||
)
|
||||
self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')")
|
||||
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
|
||||
self.validate_identity("SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')")
|
||||
self.validate_identity(
|
||||
"SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')"
|
||||
)
|
||||
self.validate_identity("COMMENT ON TABLE mytable IS 'this'")
|
||||
self.validate_identity("SELECT e'\\xDEADBEEF'")
|
||||
self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
|
||||
|
@ -75,7 +89,7 @@ class TestPostgres(Validator):
|
|||
self.validate_all(
|
||||
"CREATE TABLE x (a UUID, b BYTEA)",
|
||||
write={
|
||||
"duckdb": "CREATE TABLE x (a UUID, b BINARY)",
|
||||
"duckdb": "CREATE TABLE x (a UUID, b VARBINARY)",
|
||||
"presto": "CREATE TABLE x (a UUID, b VARBINARY)",
|
||||
"hive": "CREATE TABLE x (a UUID, b BINARY)",
|
||||
"spark": "CREATE TABLE x (a UUID, b BINARY)",
|
||||
|
@ -153,7 +167,9 @@ class TestPostgres(Validator):
|
|||
)
|
||||
self.validate_all(
|
||||
"SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss",
|
||||
read={"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"},
|
||||
read={
|
||||
"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL",
|
||||
|
@ -169,11 +185,15 @@ class TestPostgres(Validator):
|
|||
)
|
||||
self.validate_all(
|
||||
"SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL",
|
||||
read={"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL"},
|
||||
read={
|
||||
"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL"
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL",
|
||||
read={"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL"},
|
||||
read={
|
||||
"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL"
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"'[1,2,3]'::json->2",
|
||||
|
@ -184,7 +204,8 @@ class TestPostgres(Validator):
|
|||
write={"postgres": """CAST('{"a":1,"b":2}' AS JSON)->'b'"""},
|
||||
)
|
||||
self.validate_all(
|
||||
"""'{"x": {"y": 1}}'::json->'x'->'y'""", write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""}
|
||||
"""'{"x": {"y": 1}}'::json->'x'->'y'""",
|
||||
write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""},
|
||||
)
|
||||
self.validate_all(
|
||||
"""'{"x": {"y": 1}}'::json->'x'::json->'y'""",
|
||||
|
|
|
@ -61,4 +61,6 @@ class TestRedshift(Validator):
|
|||
"SELECT caldate + INTERVAL '1 second' AS dateplus FROM date WHERE caldate = '12-31-2008'"
|
||||
)
|
||||
self.validate_identity("CREATE TABLE datetable (start_date DATE, end_date DATE)")
|
||||
self.validate_identity("SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'")
|
||||
self.validate_identity(
|
||||
"SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'"
|
||||
)
|
||||
|
|
|
@ -336,7 +336,8 @@ class TestSnowflake(Validator):
|
|||
def test_table_literal(self):
|
||||
# All examples from https://docs.snowflake.com/en/sql-reference/literals-table.html
|
||||
self.validate_all(
|
||||
r"""SELECT * FROM TABLE('MYTABLE')""", write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""}
|
||||
r"""SELECT * FROM TABLE('MYTABLE')""",
|
||||
write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
|
@ -352,15 +353,123 @@ class TestSnowflake(Validator):
|
|||
write={"snowflake": r"""SELECT * FROM TABLE('MYDB. "MYSCHEMA"."MYTABLE"')"""},
|
||||
)
|
||||
|
||||
self.validate_all(r"""SELECT * FROM TABLE($MYVAR)""", write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""})
|
||||
|
||||
self.validate_all(r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""})
|
||||
self.validate_all(
|
||||
r"""SELECT * FROM TABLE($MYVAR)""",
|
||||
write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
r"""SELECT * FROM TABLE(:BINDING)""", write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""}
|
||||
r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""}
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
r"""SELECT * FROM TABLE(:BINDING)""",
|
||||
write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10""",
|
||||
write={"snowflake": r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10"""},
|
||||
)
|
||||
|
||||
def test_flatten(self):
|
||||
self.validate_all(
|
||||
"""
|
||||
select
|
||||
dag_report.acct_id,
|
||||
dag_report.report_date,
|
||||
dag_report.report_uuid,
|
||||
dag_report.airflow_name,
|
||||
dag_report.dag_id,
|
||||
f.value::varchar as operator
|
||||
from cs.telescope.dag_report,
|
||||
table(flatten(input=>split(operators, ','))) f
|
||||
""",
|
||||
write={
|
||||
"snowflake": """SELECT
|
||||
dag_report.acct_id,
|
||||
dag_report.report_date,
|
||||
dag_report.report_uuid,
|
||||
dag_report.airflow_name,
|
||||
dag_report.dag_id,
|
||||
CAST(f.value AS VARCHAR) AS operator
|
||||
FROM cs.telescope.dag_report, TABLE(FLATTEN(input => SPLIT(operators, ','))) AS f"""
|
||||
},
|
||||
pretty=True,
|
||||
)
|
||||
|
||||
# All examples from https://docs.snowflake.com/en/sql-reference/functions/flatten.html#syntax
|
||||
self.validate_all(
|
||||
"SELECT * FROM TABLE(FLATTEN(input => parse_json('[1, ,77]'))) f",
|
||||
write={
|
||||
"snowflake": "SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[1, ,77]'))) AS f"
|
||||
},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88]}'), outer => true)) f""",
|
||||
write={
|
||||
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88]}'), outer => TRUE)) AS f"""
|
||||
},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88]}'), path => 'b')) f""",
|
||||
write={
|
||||
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88]}'), path => 'b')) AS f"""
|
||||
},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('[]'))) f""",
|
||||
write={"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[]'))) AS f"""},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('[]'), outer => true)) f""",
|
||||
write={
|
||||
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[]'), outer => TRUE)) AS f"""
|
||||
},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'))) f""",
|
||||
write={
|
||||
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'))) AS f"""
|
||||
},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => true)) f""",
|
||||
write={
|
||||
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => TRUE)) AS f"""
|
||||
},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => true, mode => 'object')) f""",
|
||||
write={
|
||||
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => TRUE, mode => 'object')) AS f"""
|
||||
},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"""
|
||||
SELECT id as "ID",
|
||||
f.value AS "Contact",
|
||||
f1.value:type AS "Type",
|
||||
f1.value:content AS "Details"
|
||||
FROM persons p,
|
||||
lateral flatten(input => p.c, path => 'contact') f,
|
||||
lateral flatten(input => f.value:business) f1
|
||||
""",
|
||||
write={
|
||||
"snowflake": """SELECT
|
||||
id AS "ID",
|
||||
f.value AS "Contact",
|
||||
f1.value['type'] AS "Type",
|
||||
f1.value['content'] AS "Details"
|
||||
FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') f, LATERAL FLATTEN(input => f.value['business']) f1""",
|
||||
},
|
||||
pretty=True,
|
||||
)
|
||||
|
|
|
@ -284,4 +284,6 @@ TBLPROPERTIES (
|
|||
)
|
||||
|
||||
def test_iif(self):
|
||||
self.validate_all("SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"})
|
||||
self.validate_all(
|
||||
"SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"}
|
||||
)
|
||||
|
|
|
@ -6,3 +6,6 @@ class TestMySQL(Validator):
|
|||
|
||||
def test_identity(self):
|
||||
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")
|
||||
|
||||
def test_time(self):
|
||||
self.validate_identity("TIMESTAMP('2022-01-01')")
|
||||
|
|
|
@ -278,12 +278,19 @@ class TestTSQL(Validator):
|
|||
def test_add_date(self):
|
||||
self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')")
|
||||
self.validate_all(
|
||||
"SELECT DATEADD(year, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"}
|
||||
"SELECT DATEADD(year, 1, '2017/08/25')",
|
||||
write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT DATEADD(qq, 1, '2017/08/25')",
|
||||
write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"},
|
||||
)
|
||||
self.validate_all("SELECT DATEADD(qq, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"})
|
||||
self.validate_all(
|
||||
"SELECT DATEADD(wk, 1, '2017/08/25')",
|
||||
write={"spark": "SELECT DATE_ADD('2017/08/25', 7)", "databricks": "SELECT DATEADD(week, 1, '2017/08/25')"},
|
||||
write={
|
||||
"spark": "SELECT DATE_ADD('2017/08/25', 7)",
|
||||
"databricks": "SELECT DATEADD(week, 1, '2017/08/25')",
|
||||
},
|
||||
)
|
||||
|
||||
def test_date_diff(self):
|
||||
|
@ -370,13 +377,21 @@ class TestTSQL(Validator):
|
|||
"SELECT FORMAT(1000000.01,'###,###.###')",
|
||||
write={"spark": "SELECT FORMAT_NUMBER(1000000.01, '###,###.###')"},
|
||||
)
|
||||
self.validate_all("SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"})
|
||||
self.validate_all(
|
||||
"SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"}
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT FORMAT('01-01-1991', 'dd.mm.yyyy')",
|
||||
write={"spark": "SELECT DATE_FORMAT('01-01-1991', 'dd.mm.yyyy')"},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT FORMAT(date_col, 'dd.mm.yyyy')", write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"}
|
||||
"SELECT FORMAT(date_col, 'dd.mm.yyyy')",
|
||||
write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT FORMAT(date_col, 'm')",
|
||||
write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"}
|
||||
)
|
||||
self.validate_all("SELECT FORMAT(date_col, 'm')", write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"})
|
||||
self.validate_all("SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"})
|
||||
|
|
12
tests/fixtures/identity.sql
vendored
12
tests/fixtures/identity.sql
vendored
|
@ -523,6 +523,8 @@ DROP VIEW a.b
|
|||
DROP VIEW IF EXISTS a
|
||||
DROP VIEW IF EXISTS a.b
|
||||
SHOW TABLES
|
||||
USE db
|
||||
ROLLBACK
|
||||
EXPLAIN SELECT * FROM x
|
||||
INSERT INTO x SELECT * FROM y
|
||||
INSERT INTO x (SELECT * FROM y)
|
||||
|
@ -569,3 +571,13 @@ SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)
|
|||
SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3)
|
||||
SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)
|
||||
SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)
|
||||
SELECT CAST(x AS INT) /* comment */ FROM foo
|
||||
SELECT a /* x */, b /* x */
|
||||
SELECT * FROM foo /* x */, bla /* x */
|
||||
SELECT 1 /* comment */ + 1
|
||||
SELECT 1 /* c1 */ + 2 /* c2 */
|
||||
SELECT 1 /* c1 */ + 2 /* c2 */ + 3 /* c3 */
|
||||
SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */
|
||||
SELECT x FROM a.b.c /* x */, e.f.g /* x */
|
||||
SELECT FOO(x /* c */) /* FOO */, b /* b */
|
||||
SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */
|
||||
|
|
10
tests/fixtures/optimizer/qualify_columns.sql
vendored
10
tests/fixtures/optimizer/qualify_columns.sql
vendored
|
@ -104,6 +104,16 @@ SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_
|
|||
SELECT AGGREGATE(ARRAY(a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
|
||||
SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + x.a) AS sum_agg FROM x AS x;
|
||||
|
||||
# dialect: starrocks
|
||||
# execute: false
|
||||
SELECT DATE_TRUNC('week', a) AS a FROM x;
|
||||
SELECT DATE_TRUNC('week', x.a) AS a FROM x AS x;
|
||||
|
||||
# dialect: bigquery
|
||||
# execute: false
|
||||
SELECT DATE_TRUNC(a, MONTH) AS a FROM x;
|
||||
SELECT DATE_TRUNC(x.a, MONTH) AS a FROM x AS x;
|
||||
|
||||
--------------------------------------
|
||||
-- Derived tables
|
||||
--------------------------------------
|
||||
|
|
9
tests/fixtures/optimizer/simplify.sql
vendored
9
tests/fixtures/optimizer/simplify.sql
vendored
|
@ -79,6 +79,15 @@ NULL;
|
|||
NULL = NULL;
|
||||
NULL;
|
||||
|
||||
NULL <=> NULL;
|
||||
TRUE;
|
||||
|
||||
a IS NOT DISTINCT FROM a;
|
||||
TRUE;
|
||||
|
||||
NULL IS DISTINCT FROM NULL;
|
||||
FALSE;
|
||||
|
||||
NOT (NOT TRUE);
|
||||
TRUE;
|
||||
|
||||
|
|
28
tests/fixtures/pretty.sql
vendored
28
tests/fixtures/pretty.sql
vendored
|
@ -287,3 +287,31 @@ SELECT
|
|||
"fffffff"
|
||||
)
|
||||
);
|
||||
/*
|
||||
multi
|
||||
line
|
||||
comment
|
||||
*/
|
||||
SELECT * FROM foo;
|
||||
/*
|
||||
multi
|
||||
line
|
||||
comment
|
||||
*/
|
||||
SELECT
|
||||
*
|
||||
FROM foo;
|
||||
SELECT x FROM a.b.c /*x*/, e.f.g /*x*/;
|
||||
SELECT
|
||||
x
|
||||
FROM a.b.c /* x */, e.f.g /* x */;
|
||||
SELECT x FROM (SELECT * FROM bla /*x*/WHERE id = 1) /*x*/;
|
||||
SELECT
|
||||
x
|
||||
FROM (
|
||||
SELECT
|
||||
*
|
||||
FROM bla /* x */
|
||||
WHERE
|
||||
id = 1
|
||||
) /* x */;
|
||||
|
|
|
@ -100,15 +100,21 @@ class TestBuild(unittest.TestCase):
|
|||
"SELECT x FROM tbl LEFT OUTER JOIN tbl2",
|
||||
),
|
||||
(
|
||||
lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer"),
|
||||
lambda: select("x")
|
||||
.from_("tbl")
|
||||
.join(exp.Table(this="tbl2"), join_type="left outer"),
|
||||
"SELECT x FROM tbl LEFT OUTER JOIN tbl2",
|
||||
),
|
||||
(
|
||||
lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"),
|
||||
lambda: select("x")
|
||||
.from_("tbl")
|
||||
.join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"),
|
||||
"SELECT x FROM tbl LEFT OUTER JOIN tbl2 AS foo",
|
||||
),
|
||||
(
|
||||
lambda: select("x").from_("tbl").join(select("y").from_("tbl2"), join_type="left outer"),
|
||||
lambda: select("x")
|
||||
.from_("tbl")
|
||||
.join(select("y").from_("tbl2"), join_type="left outer"),
|
||||
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2)",
|
||||
),
|
||||
(
|
||||
|
@ -131,7 +137,9 @@ class TestBuild(unittest.TestCase):
|
|||
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2) AS aliased",
|
||||
),
|
||||
(
|
||||
lambda: select("x").from_("tbl").join(parse_one("left join x", into=exp.Join), on="a=b"),
|
||||
lambda: select("x")
|
||||
.from_("tbl")
|
||||
.join(parse_one("left join x", into=exp.Join), on="a=b"),
|
||||
"SELECT x FROM tbl LEFT JOIN x ON a = b",
|
||||
),
|
||||
(
|
||||
|
@ -139,7 +147,9 @@ class TestBuild(unittest.TestCase):
|
|||
"SELECT x FROM tbl LEFT JOIN x ON a = b",
|
||||
),
|
||||
(
|
||||
lambda: select("x").from_("tbl").join("select b from tbl2", on="a=b", join_type="left"),
|
||||
lambda: select("x")
|
||||
.from_("tbl")
|
||||
.join("select b from tbl2", on="a=b", join_type="left"),
|
||||
"SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) ON a = b",
|
||||
),
|
||||
(
|
||||
|
@ -162,7 +172,10 @@ class TestBuild(unittest.TestCase):
|
|||
(
|
||||
lambda: select("x", "y", "z")
|
||||
.from_("merged_df")
|
||||
.join("vte_diagnosis_df", using=[exp.to_identifier("patient_id"), exp.to_identifier("encounter_id")]),
|
||||
.join(
|
||||
"vte_diagnosis_df",
|
||||
using=[exp.to_identifier("patient_id"), exp.to_identifier("encounter_id")],
|
||||
),
|
||||
"SELECT x, y, z FROM merged_df JOIN vte_diagnosis_df USING (patient_id, encounter_id)",
|
||||
),
|
||||
(
|
||||
|
@ -222,7 +235,10 @@ class TestBuild(unittest.TestCase):
|
|||
"SELECT x, y, z, a FROM tbl ORDER BY x, y, z, a",
|
||||
),
|
||||
(
|
||||
lambda: select("x", "y", "z", "a").from_("tbl").cluster_by("x, y", "z").cluster_by("a"),
|
||||
lambda: select("x", "y", "z", "a")
|
||||
.from_("tbl")
|
||||
.cluster_by("x, y", "z")
|
||||
.cluster_by("a"),
|
||||
"SELECT x, y, z, a FROM tbl CLUSTER BY x, y, z, a",
|
||||
),
|
||||
(
|
||||
|
@ -239,7 +255,9 @@ class TestBuild(unittest.TestCase):
|
|||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
|
||||
),
|
||||
(
|
||||
lambda: select("x").from_("tbl").with_("tbl", as_="SELECT x FROM tbl2", recursive=True),
|
||||
lambda: select("x")
|
||||
.from_("tbl")
|
||||
.with_("tbl", as_="SELECT x FROM tbl2", recursive=True),
|
||||
"WITH RECURSIVE tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
|
||||
),
|
||||
(
|
||||
|
@ -247,7 +265,9 @@ class TestBuild(unittest.TestCase):
|
|||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
|
||||
),
|
||||
(
|
||||
lambda: select("x").from_("tbl").with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")),
|
||||
lambda: select("x")
|
||||
.from_("tbl")
|
||||
.with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")),
|
||||
"WITH tbl(x, y) AS (SELECT x, y FROM tbl2) SELECT x FROM tbl",
|
||||
),
|
||||
(
|
||||
|
@ -258,7 +278,10 @@ class TestBuild(unittest.TestCase):
|
|||
"WITH tbl AS (SELECT x FROM tbl2), tbl2 AS (SELECT x FROM tbl3) SELECT x FROM tbl",
|
||||
),
|
||||
(
|
||||
lambda: select("x").from_("tbl").with_("tbl", as_=select("x", "y").from_("tbl2")).select("y"),
|
||||
lambda: select("x")
|
||||
.from_("tbl")
|
||||
.with_("tbl", as_=select("x", "y").from_("tbl2"))
|
||||
.select("y"),
|
||||
"WITH tbl AS (SELECT x, y FROM tbl2) SELECT x, y FROM tbl",
|
||||
),
|
||||
(
|
||||
|
@ -266,35 +289,59 @@ class TestBuild(unittest.TestCase):
|
|||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
|
||||
),
|
||||
(
|
||||
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").group_by("x"),
|
||||
lambda: select("x")
|
||||
.with_("tbl", as_=select("x").from_("tbl2"))
|
||||
.from_("tbl")
|
||||
.group_by("x"),
|
||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl GROUP BY x",
|
||||
),
|
||||
(
|
||||
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").order_by("x"),
|
||||
lambda: select("x")
|
||||
.with_("tbl", as_=select("x").from_("tbl2"))
|
||||
.from_("tbl")
|
||||
.order_by("x"),
|
||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl ORDER BY x",
|
||||
),
|
||||
(
|
||||
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").limit(10),
|
||||
lambda: select("x")
|
||||
.with_("tbl", as_=select("x").from_("tbl2"))
|
||||
.from_("tbl")
|
||||
.limit(10),
|
||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl LIMIT 10",
|
||||
),
|
||||
(
|
||||
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").offset(10),
|
||||
lambda: select("x")
|
||||
.with_("tbl", as_=select("x").from_("tbl2"))
|
||||
.from_("tbl")
|
||||
.offset(10),
|
||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl OFFSET 10",
|
||||
),
|
||||
(
|
||||
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").join("tbl3"),
|
||||
lambda: select("x")
|
||||
.with_("tbl", as_=select("x").from_("tbl2"))
|
||||
.from_("tbl")
|
||||
.join("tbl3"),
|
||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl JOIN tbl3",
|
||||
),
|
||||
(
|
||||
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").distinct(),
|
||||
lambda: select("x")
|
||||
.with_("tbl", as_=select("x").from_("tbl2"))
|
||||
.from_("tbl")
|
||||
.distinct(),
|
||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT DISTINCT x FROM tbl",
|
||||
),
|
||||
(
|
||||
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").where("x > 10"),
|
||||
lambda: select("x")
|
||||
.with_("tbl", as_=select("x").from_("tbl2"))
|
||||
.from_("tbl")
|
||||
.where("x > 10"),
|
||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl WHERE x > 10",
|
||||
),
|
||||
(
|
||||
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").having("x > 20"),
|
||||
lambda: select("x")
|
||||
.with_("tbl", as_=select("x").from_("tbl2"))
|
||||
.from_("tbl")
|
||||
.having("x > 20"),
|
||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl HAVING x > 20",
|
||||
),
|
||||
(lambda: select("x").from_("tbl").subquery(), "(SELECT x FROM tbl)"),
|
||||
|
@ -354,7 +401,9 @@ class TestBuild(unittest.TestCase):
|
|||
"SELECT x FROM (SELECT x FROM tbl) AS foo WHERE x > 0",
|
||||
),
|
||||
(
|
||||
lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select("x"),
|
||||
lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select(
|
||||
"x"
|
||||
),
|
||||
"SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned",
|
||||
),
|
||||
(
|
||||
|
|
|
@ -33,7 +33,10 @@ class TestExecutor(unittest.TestCase):
|
|||
)
|
||||
|
||||
cls.cache = {}
|
||||
cls.sqls = [(sql, expected) for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")]
|
||||
cls.sqls = [
|
||||
(sql, expected)
|
||||
for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
|
@ -63,7 +66,9 @@ class TestExecutor(unittest.TestCase):
|
|||
def test_execute_tpch(self):
|
||||
def to_csv(expression):
|
||||
if isinstance(expression, exp.Table):
|
||||
return parse_one(f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}")
|
||||
return parse_one(
|
||||
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}"
|
||||
)
|
||||
return expression
|
||||
|
||||
for sql, _ in self.sqls[0:3]:
|
||||
|
|
|
@ -30,7 +30,9 @@ class TestExpressions(unittest.TestCase):
|
|||
self.assertEqual(parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)"))
|
||||
self.assertEqual(exp.Table(pivots=[]), exp.Table())
|
||||
self.assertNotEqual(exp.Table(pivots=[None]), exp.Table())
|
||||
self.assertEqual(exp.DataType.build("int"), exp.DataType(this=exp.DataType.Type.INT, nested=False))
|
||||
self.assertEqual(
|
||||
exp.DataType.build("int"), exp.DataType(this=exp.DataType.Type.INT, nested=False)
|
||||
)
|
||||
|
||||
def test_find(self):
|
||||
expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y")
|
||||
|
@ -89,7 +91,9 @@ class TestExpressions(unittest.TestCase):
|
|||
self.assertIsNone(column.find_ancestor(exp.Join))
|
||||
|
||||
def test_alias_or_name(self):
|
||||
expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
|
||||
expression = parse_one(
|
||||
"SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
|
||||
)
|
||||
self.assertEqual(
|
||||
[e.alias_or_name for e in expression.expressions],
|
||||
["a", "B", "e", "*", "zz", "z"],
|
||||
|
@ -166,7 +170,9 @@ class TestExpressions(unittest.TestCase):
|
|||
"SELECT * FROM foo WHERE ? > 100",
|
||||
)
|
||||
self.assertEqual(
|
||||
exp.replace_placeholders(parse_one("select * from :name WHERE ? > 100"), another_name="bla").sql(),
|
||||
exp.replace_placeholders(
|
||||
parse_one("select * from :name WHERE ? > 100"), another_name="bla"
|
||||
).sql(),
|
||||
"SELECT * FROM :name WHERE ? > 100",
|
||||
)
|
||||
self.assertEqual(
|
||||
|
@ -183,7 +189,9 @@ class TestExpressions(unittest.TestCase):
|
|||
)
|
||||
|
||||
def test_named_selects(self):
|
||||
expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
|
||||
expression = parse_one(
|
||||
"SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
|
||||
)
|
||||
self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"])
|
||||
|
||||
expression = parse_one(
|
||||
|
@ -367,7 +375,9 @@ class TestExpressions(unittest.TestCase):
|
|||
self.assertEqual(len(list(expression.walk())), 9)
|
||||
self.assertEqual(len(list(expression.walk(bfs=False))), 9)
|
||||
self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk()))
|
||||
self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False)))
|
||||
self.assertTrue(
|
||||
all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False))
|
||||
)
|
||||
|
||||
def test_functions(self):
|
||||
self.assertIsInstance(parse_one("ABS(a)"), exp.Abs)
|
||||
|
@ -512,14 +522,21 @@ class TestExpressions(unittest.TestCase):
|
|||
),
|
||||
exp.Properties(
|
||||
expressions=[
|
||||
exp.FileFormatProperty(this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet")),
|
||||
exp.FileFormatProperty(
|
||||
this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet")
|
||||
),
|
||||
exp.PartitionedByProperty(
|
||||
this=exp.Literal.string("PARTITIONED_BY"),
|
||||
value=exp.Tuple(expressions=[exp.to_identifier("a"), exp.to_identifier("b")]),
|
||||
value=exp.Tuple(
|
||||
expressions=[exp.to_identifier("a"), exp.to_identifier("b")]
|
||||
),
|
||||
),
|
||||
exp.AnonymousProperty(
|
||||
this=exp.Literal.string("custom"), value=exp.Literal.number(1)
|
||||
),
|
||||
exp.AnonymousProperty(this=exp.Literal.string("custom"), value=exp.Literal.number(1)),
|
||||
exp.TableFormatProperty(
|
||||
this=exp.Literal.string("TABLE_FORMAT"), value=exp.to_identifier("test_format")
|
||||
this=exp.Literal.string("TABLE_FORMAT"),
|
||||
value=exp.to_identifier("test_format"),
|
||||
),
|
||||
exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.NULL),
|
||||
exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.TRUE),
|
||||
|
@ -538,7 +555,10 @@ class TestExpressions(unittest.TestCase):
|
|||
((1, "2", None), "(1, '2', NULL)"),
|
||||
([1, "2", None], "ARRAY(1, '2', NULL)"),
|
||||
({"x": None}, "MAP('x', NULL)"),
|
||||
(datetime.datetime(2022, 10, 1, 1, 1, 1), "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')"),
|
||||
(
|
||||
datetime.datetime(2022, 10, 1, 1, 1, 1),
|
||||
"TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')",
|
||||
),
|
||||
(
|
||||
datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc),
|
||||
"TIME_STR_TO_TIME('2022-10-01 01:01:01.000000+0000')",
|
||||
|
@ -548,30 +568,48 @@ class TestExpressions(unittest.TestCase):
|
|||
with self.subTest(value):
|
||||
self.assertEqual(exp.convert(value).sql(), expected)
|
||||
|
||||
def test_annotation_alias(self):
|
||||
sql = "SELECT a, b AS B, c # comment, d AS D # another_comment FROM foo"
|
||||
def test_comment_alias(self):
|
||||
sql = """
|
||||
SELECT
|
||||
a,
|
||||
b AS B,
|
||||
c, /*comment*/
|
||||
d AS D, -- another comment
|
||||
CAST(x AS INT) -- final comment
|
||||
FROM foo
|
||||
"""
|
||||
expression = parse_one(sql)
|
||||
self.assertEqual(
|
||||
[e.alias_or_name for e in expression.expressions],
|
||||
["a", "B", "c", "D"],
|
||||
["a", "B", "c", "D", "x"],
|
||||
)
|
||||
self.assertEqual(expression.sql(), "SELECT a, b AS B, c, d AS D")
|
||||
self.assertEqual(expression.expressions[2].name, "comment")
|
||||
self.assertEqual(
|
||||
expression.sql(pretty=True, annotations=False),
|
||||
expression.sql(),
|
||||
"SELECT a, b AS B, c /* comment */, d AS D /* another comment */, CAST(x AS INT) /* final comment */ FROM foo",
|
||||
)
|
||||
self.assertEqual(
|
||||
expression.sql(comments=False),
|
||||
"SELECT a, b AS B, c, d AS D, CAST(x AS INT) FROM foo",
|
||||
)
|
||||
self.assertEqual(
|
||||
expression.sql(pretty=True, comments=False),
|
||||
"""SELECT
|
||||
a,
|
||||
b AS B,
|
||||
c,
|
||||
d AS D""",
|
||||
d AS D,
|
||||
CAST(x AS INT)
|
||||
FROM foo""",
|
||||
)
|
||||
self.assertEqual(
|
||||
expression.sql(pretty=True),
|
||||
"""SELECT
|
||||
a,
|
||||
b AS B,
|
||||
c # comment,
|
||||
d AS D # another_comment FROM foo""",
|
||||
c, -- comment
|
||||
d AS D, -- another comment
|
||||
CAST(x AS INT) -- final comment
|
||||
FROM foo""",
|
||||
)
|
||||
|
||||
def test_to_table(self):
|
||||
|
@ -605,5 +643,9 @@ class TestExpressions(unittest.TestCase):
|
|||
self.assertIsInstance(expression, exp.Union)
|
||||
self.assertEqual(expression.named_selects, ["cola", "colb"])
|
||||
self.assertEqual(
|
||||
expression.selects, [exp.Column(this=exp.to_identifier("cola")), exp.Column(this=exp.to_identifier("colb"))]
|
||||
expression.selects,
|
||||
[
|
||||
exp.Column(this=exp.to_identifier("cola")),
|
||||
exp.Column(this=exp.to_identifier("colb")),
|
||||
],
|
||||
)
|
||||
|
|
|
@ -67,7 +67,9 @@ class TestOptimizer(unittest.TestCase):
|
|||
}
|
||||
|
||||
def check_file(self, file, func, pretty=False, execute=False, **kwargs):
|
||||
for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1):
|
||||
for i, (meta, sql, expected) in enumerate(
|
||||
load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1
|
||||
):
|
||||
title = meta.get("title") or f"{i}, {sql}"
|
||||
dialect = meta.get("dialect")
|
||||
leave_tables_isolated = meta.get("leave_tables_isolated")
|
||||
|
@ -90,7 +92,9 @@ class TestOptimizer(unittest.TestCase):
|
|||
|
||||
if string_to_bool(should_execute):
|
||||
with self.subTest(f"(execute) {title}"):
|
||||
df1 = self.conn.execute(sqlglot.transpile(sql, read=dialect, write="duckdb")[0]).df()
|
||||
df1 = self.conn.execute(
|
||||
sqlglot.transpile(sql, read=dialect, write="duckdb")[0]
|
||||
).df()
|
||||
df2 = self.conn.execute(optimized.sql(pretty=pretty, dialect="duckdb")).df()
|
||||
assert_frame_equal(df1, df2)
|
||||
|
||||
|
@ -268,7 +272,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
|
||||
self.assertEqual(scopes[2].expression.sql(), "(VALUES (1, 'test')) AS tab(cola, colb)")
|
||||
self.assertEqual(
|
||||
scopes[3].expression.sql(), "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)"
|
||||
scopes[3].expression.sql(),
|
||||
"SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
|
||||
)
|
||||
self.assertEqual(scopes[4].expression.sql(), "SELECT y.c AS b FROM y")
|
||||
self.assertEqual(scopes[5].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
|
||||
|
@ -287,7 +292,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
|
||||
# Check that we can walk in scope from an arbitrary node
|
||||
self.assertEqual(
|
||||
{node.sql() for node, *_ in walk_in_scope(expression.find(exp.Where)) if isinstance(node, exp.Column)},
|
||||
{
|
||||
node.sql()
|
||||
for node, *_ in walk_in_scope(expression.find(exp.Where))
|
||||
if isinstance(node, exp.Column)
|
||||
},
|
||||
{"s.b"},
|
||||
)
|
||||
|
||||
|
@ -324,7 +333,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT)
|
||||
|
||||
def test_cache_annotation(self):
|
||||
expression = annotate_types(parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1"))
|
||||
expression = annotate_types(
|
||||
parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")
|
||||
)
|
||||
self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT)
|
||||
|
||||
def test_binary_annotation(self):
|
||||
|
@ -384,7 +395,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
"""
|
||||
|
||||
expression = annotate_types(parse_one(sql), schema=schema)
|
||||
self.assertEqual(expression.expressions[0].type, exp.DataType.Type.TEXT) # tbl.cola + tbl.colb + 'foo' AS col
|
||||
self.assertEqual(
|
||||
expression.expressions[0].type, exp.DataType.Type.TEXT
|
||||
) # tbl.cola + tbl.colb + 'foo' AS col
|
||||
|
||||
outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo'
|
||||
self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT)
|
||||
|
@ -396,7 +409,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT)
|
||||
|
||||
cte_select = expression.args["with"].expressions[0].this
|
||||
self.assertEqual(cte_select.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola + 'bla' AS cola
|
||||
self.assertEqual(
|
||||
cte_select.expressions[0].type, exp.DataType.Type.VARCHAR
|
||||
) # x.cola + 'bla' AS cola
|
||||
self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb
|
||||
|
||||
cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla'
|
||||
|
@ -405,7 +420,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR)
|
||||
|
||||
# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
|
||||
for d, t in zip(cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]):
|
||||
for d, t in zip(
|
||||
cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]
|
||||
):
|
||||
self.assertEqual(d.this.expressions[0].this.type, t)
|
||||
|
||||
def test_function_annotation(self):
|
||||
|
@ -421,6 +438,19 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
|
||||
self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb
|
||||
|
||||
sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x"
|
||||
|
||||
case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
|
||||
self.assertEqual(case_expr_alias.type, exp.DataType.Type.VARCHAR)
|
||||
|
||||
case_expr = case_expr_alias.this
|
||||
self.assertEqual(case_expr.type, exp.DataType.Type.VARCHAR)
|
||||
self.assertEqual(case_expr.args["default"].type, exp.DataType.Type.CHAR)
|
||||
|
||||
case_ifs_expr = case_expr.args["ifs"][0]
|
||||
self.assertEqual(case_ifs_expr.type, exp.DataType.Type.VARCHAR)
|
||||
self.assertEqual(case_ifs_expr.args["true"].type, exp.DataType.Type.VARCHAR)
|
||||
|
||||
def test_unknown_annotation(self):
|
||||
schema = {"x": {"cola": "VARCHAR"}}
|
||||
sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x"
|
||||
|
@ -431,8 +461,12 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
concat_expr = concat_expr_alias.this
|
||||
self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN)
|
||||
self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
|
||||
self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) # SOME_ANONYMOUS_FUNC(x.cola)
|
||||
self.assertEqual(concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola (arg)
|
||||
self.assertEqual(
|
||||
concat_expr.right.type, exp.DataType.Type.UNKNOWN
|
||||
) # SOME_ANONYMOUS_FUNC(x.cola)
|
||||
self.assertEqual(
|
||||
concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR
|
||||
) # x.cola (arg)
|
||||
|
||||
def test_null_annotation(self):
|
||||
expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this
|
||||
|
|
|
@ -23,8 +23,6 @@ class TestParser(unittest.TestCase):
|
|||
|
||||
def test_float(self):
|
||||
self.assertEqual(parse_one(".2"), parse_one("0.2"))
|
||||
self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)"))
|
||||
self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)"))
|
||||
|
||||
def test_table(self):
|
||||
tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)]
|
||||
|
@ -33,7 +31,9 @@ class TestParser(unittest.TestCase):
|
|||
def test_select(self):
|
||||
self.assertIsNotNone(parse_one("select 1 natural"))
|
||||
self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"])
|
||||
self.assertIsNotNone(parse_one("select * from x where a = (select 1) order by x.y").args["order"])
|
||||
self.assertIsNotNone(
|
||||
parse_one("select * from x where a = (select 1) order by x.y").args["order"]
|
||||
)
|
||||
self.assertEqual(len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1)
|
||||
self.assertEqual(
|
||||
parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(),
|
||||
|
@ -125,26 +125,70 @@ class TestParser(unittest.TestCase):
|
|||
def test_var(self):
|
||||
self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'")
|
||||
|
||||
def test_annotations(self):
|
||||
def test_comments(self):
|
||||
expression = parse_one(
|
||||
"""
|
||||
SELECT
|
||||
a #annotation1,
|
||||
b as B #annotation2:testing ,
|
||||
"test#annotation",c#annotation3, d #annotation4,
|
||||
e #,
|
||||
f # space
|
||||
--comment1
|
||||
SELECT /* this won't be used */
|
||||
a, --comment2
|
||||
b as B, --comment3:testing
|
||||
"test--annotation",
|
||||
c, --comment4 --foo
|
||||
e, --
|
||||
f -- space
|
||||
FROM foo
|
||||
"""
|
||||
)
|
||||
|
||||
assert expression.expressions[0].name == "annotation1"
|
||||
assert expression.expressions[1].name == "annotation2:testing"
|
||||
assert expression.expressions[2].name == "test#annotation"
|
||||
assert expression.expressions[3].name == "annotation3"
|
||||
assert expression.expressions[4].name == "annotation4"
|
||||
assert expression.expressions[5].name == ""
|
||||
assert expression.expressions[6].name == "space"
|
||||
self.assertEqual(expression.comment, "comment1")
|
||||
self.assertEqual(expression.expressions[0].comment, "comment2")
|
||||
self.assertEqual(expression.expressions[1].comment, "comment3:testing")
|
||||
self.assertEqual(expression.expressions[2].comment, None)
|
||||
self.assertEqual(expression.expressions[3].comment, "comment4 --foo")
|
||||
self.assertEqual(expression.expressions[4].comment, "")
|
||||
self.assertEqual(expression.expressions[5].comment, " space")
|
||||
|
||||
def test_type_literals(self):
|
||||
self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)"))
|
||||
self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)"))
|
||||
self.assertEqual(
|
||||
parse_one("TIMESTAMP '2022-01-01'").sql(), "CAST('2022-01-01' AS TIMESTAMP)"
|
||||
)
|
||||
self.assertEqual(
|
||||
parse_one("TIMESTAMP(1) '2022-01-01'").sql(), "CAST('2022-01-01' AS TIMESTAMP(1))"
|
||||
)
|
||||
self.assertEqual(
|
||||
parse_one("TIMESTAMP WITH TIME ZONE '2022-01-01'").sql(),
|
||||
"CAST('2022-01-01' AS TIMESTAMPTZ)",
|
||||
)
|
||||
self.assertEqual(
|
||||
parse_one("TIMESTAMP WITH LOCAL TIME ZONE '2022-01-01'").sql(),
|
||||
"CAST('2022-01-01' AS TIMESTAMPLTZ)",
|
||||
)
|
||||
self.assertEqual(
|
||||
parse_one("TIMESTAMP WITHOUT TIME ZONE '2022-01-01'").sql(),
|
||||
"CAST('2022-01-01' AS TIMESTAMP)",
|
||||
)
|
||||
self.assertEqual(
|
||||
parse_one("TIMESTAMP(1) WITH TIME ZONE '2022-01-01'").sql(),
|
||||
"CAST('2022-01-01' AS TIMESTAMPTZ(1))",
|
||||
)
|
||||
self.assertEqual(
|
||||
parse_one("TIMESTAMP(1) WITH LOCAL TIME ZONE '2022-01-01'").sql(),
|
||||
"CAST('2022-01-01' AS TIMESTAMPLTZ(1))",
|
||||
)
|
||||
self.assertEqual(
|
||||
parse_one("TIMESTAMP(1) WITHOUT TIME ZONE '2022-01-01'").sql(),
|
||||
"CAST('2022-01-01' AS TIMESTAMP(1))",
|
||||
)
|
||||
self.assertEqual(parse_one("TIMESTAMP(1) WITH TIME ZONE").sql(), "TIMESTAMPTZ(1)")
|
||||
self.assertEqual(parse_one("TIMESTAMP(1) WITH LOCAL TIME ZONE").sql(), "TIMESTAMPLTZ(1)")
|
||||
self.assertEqual(parse_one("TIMESTAMP(1) WITHOUT TIME ZONE").sql(), "TIMESTAMP(1)")
|
||||
self.assertEqual(parse_one("""JSON '{"x":"y"}'""").sql(), """CAST('{"x":"y"}' AS JSON)""")
|
||||
self.assertIsInstance(parse_one("TIMESTAMP(1)"), exp.Func)
|
||||
self.assertIsInstance(parse_one("TIMESTAMP('2022-01-01')"), exp.Func)
|
||||
self.assertIsInstance(parse_one("TIMESTAMP()"), exp.Func)
|
||||
self.assertIsInstance(parse_one("map.x"), exp.Column)
|
||||
|
||||
def test_pretty_config_override(self):
|
||||
self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x")
|
||||
|
|
|
@ -1,281 +1,141 @@
|
|||
import unittest
|
||||
|
||||
from sqlglot import table
|
||||
from sqlglot.dataframe.sql import types as df_types
|
||||
from sqlglot import exp, to_table
|
||||
from sqlglot.errors import SchemaError
|
||||
from sqlglot.schema import MappingSchema, ensure_schema
|
||||
|
||||
|
||||
class TestSchema(unittest.TestCase):
|
||||
def assert_column_names(self, schema, *table_results):
|
||||
for table, result in table_results:
|
||||
with self.subTest(f"{table} -> {result}"):
|
||||
self.assertEqual(schema.column_names(to_table(table)), result)
|
||||
|
||||
def assert_column_names_raises(self, schema, *tables):
|
||||
for table in tables:
|
||||
with self.subTest(table):
|
||||
with self.assertRaises(SchemaError):
|
||||
schema.column_names(to_table(table))
|
||||
|
||||
def test_schema(self):
|
||||
schema = ensure_schema(
|
||||
{
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
}
|
||||
}
|
||||
)
|
||||
self.assertEqual(
|
||||
schema.column_names(
|
||||
table(
|
||||
"x",
|
||||
)
|
||||
),
|
||||
["a"],
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db", catalog="c"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x2"))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
schema.add_table(table("y", db="db"), {"b": "string"})
|
||||
with self.assertRaises(ValueError):
|
||||
schema.add_table(table("y", db="db", catalog="c"), {"b": "string"})
|
||||
|
||||
schema.add_table(table("y"), {"b": "string"})
|
||||
schema_with_y = {
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
},
|
||||
"y": {
|
||||
"b": "string",
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema.schema, schema_with_y)
|
||||
|
||||
new_schema = schema.copy()
|
||||
new_schema.add_table(table("z"), {"c": "string"})
|
||||
self.assertEqual(schema.schema, schema_with_y)
|
||||
self.assertEqual(
|
||||
new_schema.schema,
|
||||
{
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
},
|
||||
"y": {
|
||||
"b": "string",
|
||||
},
|
||||
"z": {
|
||||
"c": "string",
|
||||
},
|
||||
},
|
||||
)
|
||||
schema.add_table(table("m"), {"d": "string"})
|
||||
schema.add_table(table("n"), {"e": "string"})
|
||||
schema_with_m_n = {
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
},
|
||||
"y": {
|
||||
"b": "string",
|
||||
},
|
||||
"m": {
|
||||
"d": "string",
|
||||
},
|
||||
"n": {
|
||||
"e": "string",
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema.schema, schema_with_m_n)
|
||||
new_schema = schema.copy()
|
||||
new_schema.add_table(table("o"), {"f": "string"})
|
||||
new_schema.add_table(table("p"), {"g": "string"})
|
||||
self.assertEqual(schema.schema, schema_with_m_n)
|
||||
self.assertEqual(
|
||||
new_schema.schema,
|
||||
{
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
},
|
||||
"y": {
|
||||
"b": "string",
|
||||
},
|
||||
"m": {
|
||||
"d": "string",
|
||||
},
|
||||
"n": {
|
||||
"e": "string",
|
||||
},
|
||||
"o": {
|
||||
"f": "string",
|
||||
},
|
||||
"p": {
|
||||
"g": "string",
|
||||
"b": "uint64",
|
||||
"c": "uint64",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
self.assert_column_names(
|
||||
schema,
|
||||
("x", ["a"]),
|
||||
("y", ["b", "c"]),
|
||||
("z.x", ["a"]),
|
||||
("z.x.y", ["b", "c"]),
|
||||
)
|
||||
|
||||
self.assert_column_names_raises(
|
||||
schema,
|
||||
"z",
|
||||
"z.z",
|
||||
"z.z.z",
|
||||
)
|
||||
|
||||
def test_schema_db(self):
|
||||
schema = ensure_schema(
|
||||
{
|
||||
"db": {
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
self.assertEqual(schema.column_names(table("x", db="db")), ["a"])
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db", catalog="c"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db2"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x2", db="db"))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
schema.add_table(table("y"), {"b": "string"})
|
||||
with self.assertRaises(ValueError):
|
||||
schema.add_table(table("y", db="db", catalog="c"), {"b": "string"})
|
||||
|
||||
schema.add_table(table("y", db="db"), {"b": "string"})
|
||||
self.assertEqual(
|
||||
schema.schema,
|
||||
{
|
||||
"db": {
|
||||
"d1": {
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
},
|
||||
"y": {
|
||||
"b": "string",
|
||||
"b": "uint64",
|
||||
},
|
||||
}
|
||||
},
|
||||
"d2": {
|
||||
"x": {
|
||||
"c": "uint64",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
self.assert_column_names(
|
||||
schema,
|
||||
("d1.x", ["a"]),
|
||||
("d2.x", ["c"]),
|
||||
("y", ["b"]),
|
||||
("d1.y", ["b"]),
|
||||
("z.d1.y", ["b"]),
|
||||
)
|
||||
|
||||
self.assert_column_names_raises(
|
||||
schema,
|
||||
"x",
|
||||
"z.x",
|
||||
"z.y",
|
||||
)
|
||||
|
||||
def test_schema_catalog(self):
|
||||
schema = ensure_schema(
|
||||
{
|
||||
"c": {
|
||||
"db": {
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"])
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db", catalog="c2"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db2"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x2", db="db"))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
schema.add_table(table("x"), {"b": "string"})
|
||||
with self.assertRaises(ValueError):
|
||||
schema.add_table(table("x", db="db"), {"b": "string"})
|
||||
|
||||
schema.add_table(table("y", db="db", catalog="c"), {"a": "string", "b": "int"})
|
||||
self.assertEqual(
|
||||
schema.schema,
|
||||
{
|
||||
"c": {
|
||||
"db": {
|
||||
"c1": {
|
||||
"d1": {
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
},
|
||||
"y": {
|
||||
"a": "string",
|
||||
"b": "int",
|
||||
"b": "uint64",
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
schema.add_table(table("z", db="db2", catalog="c"), {"c": "string", "d": "int"})
|
||||
self.assertEqual(
|
||||
schema.schema,
|
||||
{
|
||||
"c": {
|
||||
"db": {
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
},
|
||||
"y": {
|
||||
"a": "string",
|
||||
"b": "int",
|
||||
},
|
||||
},
|
||||
"db2": {
|
||||
"z": {
|
||||
"c": "string",
|
||||
"d": "int",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
schema.add_table(table("m", db="db2", catalog="c2"), {"e": "string", "f": "int"})
|
||||
self.assertEqual(
|
||||
schema.schema,
|
||||
{
|
||||
"c": {
|
||||
"db": {
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
"c": "uint64",
|
||||
},
|
||||
"y": {
|
||||
"a": "string",
|
||||
"b": "int",
|
||||
},
|
||||
},
|
||||
"db2": {
|
||||
"z": {
|
||||
"c": "string",
|
||||
"d": "int",
|
||||
}
|
||||
},
|
||||
},
|
||||
"c2": {
|
||||
"db2": {
|
||||
"m": {
|
||||
"e": "string",
|
||||
"f": "int",
|
||||
}
|
||||
}
|
||||
"d1": {
|
||||
"y": {
|
||||
"d": "uint64",
|
||||
},
|
||||
"z": {
|
||||
"e": "uint64",
|
||||
},
|
||||
},
|
||||
"d2": {
|
||||
"z": {
|
||||
"f": "uint64",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
schema = ensure_schema(
|
||||
{
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
}
|
||||
}
|
||||
)
|
||||
self.assertEqual(schema.column_names(table("x")), ["a"])
|
||||
|
||||
schema = MappingSchema()
|
||||
schema.add_table(table("x"), {"a": "string"})
|
||||
self.assertEqual(
|
||||
schema.schema,
|
||||
{
|
||||
"x": {
|
||||
"a": "string",
|
||||
}
|
||||
},
|
||||
self.assert_column_names(
|
||||
schema,
|
||||
("x", ["a"]),
|
||||
("d1.x", ["a"]),
|
||||
("c1.d1.x", ["a"]),
|
||||
("c1.d1.y", ["b"]),
|
||||
("c1.d1.z", ["c"]),
|
||||
("c2.d1.y", ["d"]),
|
||||
("c2.d1.z", ["e"]),
|
||||
("d2.z", ["f"]),
|
||||
("c2.d2.z", ["f"]),
|
||||
)
|
||||
schema.add_table(table("y"), df_types.StructType([df_types.StructField("b", df_types.StringType())]))
|
||||
self.assertEqual(
|
||||
schema.schema,
|
||||
{
|
||||
"x": {
|
||||
"a": "string",
|
||||
},
|
||||
"y": {
|
||||
"b": "string",
|
||||
},
|
||||
},
|
||||
|
||||
self.assert_column_names_raises(
|
||||
schema,
|
||||
"q",
|
||||
"d2.x",
|
||||
"y",
|
||||
"z",
|
||||
"d1.y",
|
||||
"d1.z",
|
||||
"a.b.c",
|
||||
)
|
||||
|
||||
def test_schema_add_table_with_and_without_mapping(self):
|
||||
|
@ -288,3 +148,34 @@ class TestSchema(unittest.TestCase):
|
|||
self.assertEqual(schema.column_names("test"), ["x", "y"])
|
||||
schema.add_table("test")
|
||||
self.assertEqual(schema.column_names("test"), ["x", "y"])
|
||||
|
||||
def test_schema_get_column_type(self):
|
||||
schema = MappingSchema({"a": {"b": "varchar"}})
|
||||
self.assertEqual(schema.get_column_type("a", "b"), exp.DataType.Type.VARCHAR)
|
||||
self.assertEqual(
|
||||
schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")),
|
||||
exp.DataType.Type.VARCHAR,
|
||||
)
|
||||
self.assertEqual(
|
||||
schema.get_column_type("a", exp.Column(this="b")), exp.DataType.Type.VARCHAR
|
||||
)
|
||||
self.assertEqual(
|
||||
schema.get_column_type(exp.Table(this="a"), "b"), exp.DataType.Type.VARCHAR
|
||||
)
|
||||
schema = MappingSchema({"a": {"b": {"c": "varchar"}}})
|
||||
self.assertEqual(
|
||||
schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")),
|
||||
exp.DataType.Type.VARCHAR,
|
||||
)
|
||||
self.assertEqual(
|
||||
schema.get_column_type(exp.Table(this="b", db="a"), "c"), exp.DataType.Type.VARCHAR
|
||||
)
|
||||
schema = MappingSchema({"a": {"b": {"c": {"d": "varchar"}}}})
|
||||
self.assertEqual(
|
||||
schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d")),
|
||||
exp.DataType.Type.VARCHAR,
|
||||
)
|
||||
self.assertEqual(
|
||||
schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d"),
|
||||
exp.DataType.Type.VARCHAR,
|
||||
)
|
||||
|
|
18
tests/test_tokens.py
Normal file
18
tests/test_tokens.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
import unittest
|
||||
|
||||
from sqlglot.tokens import Tokenizer
|
||||
|
||||
|
||||
class TestTokens(unittest.TestCase):
|
||||
def test_comment_attachment(self):
|
||||
tokenizer = Tokenizer()
|
||||
sql_comment = [
|
||||
("/*comment*/ foo", "comment"),
|
||||
("/*comment*/ foo --test", "comment"),
|
||||
("--comment\nfoo --test", "comment"),
|
||||
("foo --comment", "comment"),
|
||||
("foo", None),
|
||||
]
|
||||
|
||||
for sql, comment in sql_comment:
|
||||
self.assertEqual(tokenizer.tokenize(sql)[0].comment, comment)
|
|
@ -49,6 +49,12 @@ class TestTranspile(unittest.TestCase):
|
|||
leading_comma=True,
|
||||
pretty=True,
|
||||
)
|
||||
self.validate(
|
||||
"SELECT FOO, /*x*/\nBAR, /*y*/\nBAZ",
|
||||
"SELECT\n FOO -- x\n , BAR -- y\n , BAZ",
|
||||
leading_comma=True,
|
||||
pretty=True,
|
||||
)
|
||||
# without pretty, this should be a no-op
|
||||
self.validate(
|
||||
"SELECT FOO, BAR, BAZ",
|
||||
|
@ -63,24 +69,61 @@ class TestTranspile(unittest.TestCase):
|
|||
self.validate("SELECT 3>=3", "SELECT 3 >= 3")
|
||||
|
||||
def test_comments(self):
|
||||
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo")
|
||||
self.validate("SELECT 1 /* inline */ FROM foo -- comment", "SELECT 1 FROM foo")
|
||||
|
||||
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */")
|
||||
self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo")
|
||||
self.validate("SELECT --!5\nx FROM foo", "/* !5 */ SELECT x FROM foo")
|
||||
self.validate(
|
||||
"SELECT 1 /* inline */ FROM foo -- comment",
|
||||
"SELECT 1 /* inline */ FROM foo /* comment */",
|
||||
)
|
||||
self.validate(
|
||||
"SELECT FUN(x) /*x*/, [1,2,3] /*y*/", "SELECT FUN(x) /* x */, ARRAY(1, 2, 3) /* y */"
|
||||
)
|
||||
self.validate(
|
||||
"""
|
||||
SELECT 1 -- comment
|
||||
FROM foo -- comment
|
||||
""",
|
||||
"SELECT 1 FROM foo",
|
||||
"SELECT 1 /* comment */ FROM foo /* comment */",
|
||||
)
|
||||
|
||||
self.validate(
|
||||
"""
|
||||
SELECT 1 /* big comment
|
||||
like this */
|
||||
FROM foo -- comment
|
||||
""",
|
||||
"SELECT 1 FROM foo",
|
||||
"""SELECT 1 /* big comment
|
||||
like this */ FROM foo /* comment */""",
|
||||
)
|
||||
self.validate(
|
||||
"select x from foo -- x",
|
||||
"SELECT x FROM foo /* x */",
|
||||
)
|
||||
self.validate(
|
||||
"""
|
||||
/* multi
|
||||
line
|
||||
comment
|
||||
*/
|
||||
SELECT
|
||||
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
|
||||
CAST(x AS INT), # comment 3
|
||||
y -- comment 4
|
||||
FROM
|
||||
bar /* comment 5 */,
|
||||
tbl # comment 6
|
||||
""",
|
||||
"""/* multi
|
||||
line
|
||||
comment
|
||||
*/
|
||||
SELECT
|
||||
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
|
||||
CAST(x AS INT), -- comment 3
|
||||
y -- comment 4
|
||||
FROM bar /* comment 5 */, tbl /* comment 6 */""",
|
||||
read="mysql",
|
||||
pretty=True,
|
||||
)
|
||||
|
||||
def test_types(self):
|
||||
|
@ -146,6 +189,16 @@ class TestTranspile(unittest.TestCase):
|
|||
def test_ignore_nulls(self):
|
||||
self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)")
|
||||
|
||||
def test_with(self):
|
||||
self.validate(
|
||||
"WITH a AS (SELECT 1) WITH b AS (SELECT 2) SELECT *",
|
||||
"WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *",
|
||||
)
|
||||
self.validate(
|
||||
"WITH a AS (SELECT 1), WITH b AS (SELECT 2) SELECT *",
|
||||
"WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *",
|
||||
)
|
||||
|
||||
def test_time(self):
|
||||
self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)")
|
||||
self.validate("TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue