Adding upstream version 25.24.5.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
a0663ae805
commit
7af32ea9ec
80 changed files with 61531 additions and 59444 deletions
|
@ -54,6 +54,18 @@ def simplify(expression, **kwargs):
|
|||
return optimizer.simplify.simplify(expression, constant_propagation=True, **kwargs)
|
||||
|
||||
|
||||
def annotate_functions(expression, **kwargs):
|
||||
from sqlglot.dialects import Dialect
|
||||
|
||||
dialect = kwargs.get("dialect")
|
||||
schema = kwargs.get("schema")
|
||||
|
||||
annotators = Dialect.get_or_raise(dialect).ANNOTATORS
|
||||
annotated = annotate_types(expression, annotators=annotators, schema=schema)
|
||||
|
||||
return annotated.expressions[0]
|
||||
|
||||
|
||||
class TestOptimizer(unittest.TestCase):
|
||||
maxDiff = None
|
||||
|
||||
|
@ -787,6 +799,28 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
with self.subTest(title):
|
||||
self.assertEqual(result.type.sql(), exp.DataType.build(expected).sql())
|
||||
|
||||
def test_annotate_funcs(self):
|
||||
test_schema = {
|
||||
"tbl": {"bin_col": "BINARY", "str_col": "STRING", "bignum_col": "BIGNUMERIC"}
|
||||
}
|
||||
|
||||
for i, (meta, sql, expected) in enumerate(
|
||||
load_sql_fixture_pairs("optimizer/annotate_functions.sql"), start=1
|
||||
):
|
||||
title = meta.get("title") or f"{i}, {sql}"
|
||||
dialect = meta.get("dialect") or ""
|
||||
sql = f"SELECT {sql} FROM tbl"
|
||||
|
||||
for dialect in dialect.split(", "):
|
||||
result = parse_and_optimize(
|
||||
annotate_functions, sql, dialect, schema=test_schema, dialect=dialect
|
||||
)
|
||||
|
||||
with self.subTest(title):
|
||||
self.assertEqual(
|
||||
result.type.sql(dialect), exp.DataType.build(expected).sql(dialect)
|
||||
)
|
||||
|
||||
def test_cast_type_annotation(self):
|
||||
expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))"))
|
||||
self.assertEqual(expression.type.this, exp.DataType.Type.TIMESTAMPTZ)
|
||||
|
@ -1377,26 +1411,3 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
self.assertEqual(4, normalization_distance(gen_expr(2), max_=100))
|
||||
self.assertEqual(18, normalization_distance(gen_expr(3), max_=100))
|
||||
self.assertEqual(110, normalization_distance(gen_expr(10), max_=100))
|
||||
|
||||
def test_custom_annotators(self):
|
||||
# In Spark hierarchy, SUBSTRING result type is dependent on input expr type
|
||||
for dialect in ("spark2", "spark", "databricks"):
|
||||
for expr_type_pair in (
|
||||
("col", "STRING"),
|
||||
("col", "BINARY"),
|
||||
("'str_literal'", "STRING"),
|
||||
("CAST('str_literal' AS BINARY)", "BINARY"),
|
||||
):
|
||||
with self.subTest(
|
||||
f"Testing {dialect}'s SUBSTRING() result type for {expr_type_pair}"
|
||||
):
|
||||
expr, type = expr_type_pair
|
||||
ast = parse_one(f"SELECT substring({expr}, 2, 3) AS x FROM tbl", read=dialect)
|
||||
|
||||
subst_type = (
|
||||
optimizer.optimize(ast, schema={"tbl": {"col": type}}, dialect=dialect)
|
||||
.expressions[0]
|
||||
.type
|
||||
)
|
||||
|
||||
self.assertEqual(subst_type.sql(dialect), exp.DataType.build(type).sql(dialect))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue