1
0
Fork 0

Merging upstream version 23.7.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:30:28 +01:00
parent ebba7c6a18
commit d26905e4af
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
187 changed files with 86502 additions and 71397 deletions

View file

@ -298,7 +298,9 @@ class TestOptimizer(unittest.TestCase):
self.check_file(
"qualify_columns", qualify_columns, execute=True, schema=self.schema, set_dialect=True
)
self.check_file("qualify_columns_ddl", qualify_columns, schema=self.schema)
self.check_file(
"qualify_columns_ddl", qualify_columns, schema=self.schema, set_dialect=True
)
def test_qualify_columns__with_invisible(self):
schema = MappingSchema(self.schema, {"x": {"a"}, "y": {"b"}, "z": {"b"}})
@ -340,6 +342,9 @@ class TestOptimizer(unittest.TestCase):
def test_simplify(self):
self.check_file("simplify", simplify, set_dialect=True)
expression = parse_one("SELECT a, c, b FROM table1 WHERE 1 = 1")
self.assertEqual(simplify(simplify(expression.find(exp.Where))).sql(), "WHERE TRUE")
expression = parse_one("TRUE AND TRUE AND TRUE")
self.assertEqual(exp.true(), optimizer.simplify.simplify(expression))
self.assertEqual(exp.true(), optimizer.simplify.simplify(expression.this))
@ -359,15 +364,18 @@ class TestOptimizer(unittest.TestCase):
self.assertEqual("CONCAT('a', x, 'bc')", simplified_safe_concat.sql())
anon_unquoted_str = parse_one("anonymous(x, y)")
self.assertEqual(optimizer.simplify.gen(anon_unquoted_str), "ANONYMOUS x,y")
self.assertEqual(optimizer.simplify.gen(anon_unquoted_str), "ANONYMOUS(x,y)")
query = parse_one("SELECT x FROM t")
self.assertEqual(optimizer.simplify.gen(query), optimizer.simplify.gen(query.copy()))
anon_unquoted_identifier = exp.Anonymous(
this=exp.to_identifier("anonymous"), expressions=[exp.column("x"), exp.column("y")]
)
self.assertEqual(optimizer.simplify.gen(anon_unquoted_identifier), "ANONYMOUS x,y")
self.assertEqual(optimizer.simplify.gen(anon_unquoted_identifier), "ANONYMOUS(x,y)")
anon_quoted = parse_one('"anonymous"(x, y)')
self.assertEqual(optimizer.simplify.gen(anon_quoted), '"anonymous" x,y')
self.assertEqual(optimizer.simplify.gen(anon_quoted), '"anonymous"(x,y)')
with self.assertRaises(ValueError) as e:
anon_invalid = exp.Anonymous(this=5)
@ -375,6 +383,28 @@ class TestOptimizer(unittest.TestCase):
self.assertIn("Anonymous.this expects a str or an Identifier, got 'int'.", str(e.exception))
sql = parse_one(
"""
WITH cte AS (select 1 union select 2), cte2 AS (
SELECT ROW() OVER (PARTITION BY y) FROM (
(select 1) limit 10
)
)
SELECT
*,
a + 1,
a div 1,
filter("B", (x, y) -> x + y)
FROM (z AS z CROSS JOIN z) AS f(a) LEFT JOIN a.b.c.d.e.f.g USING(n) ORDER BY 1
"""
)
self.assertEqual(
optimizer.simplify.gen(sql),
"""
SELECT :with,WITH :expressions,CTE :this,UNION :this,SELECT :expressions,1,:expression,SELECT :expressions,2,:distinct,True,:alias, AS cte,CTE :this,SELECT :expressions,WINDOW :this,ROW(),:partition_by,y,:over,OVER,:from,FROM ((SELECT :expressions,1):limit,LIMIT :expression,10),:alias, AS cte2,:expressions,STAR,a + 1,a DIV 1,FILTER("B",LAMBDA :this,x + y,:expressions,x,y),:from,FROM (z AS z:joins,JOIN :this,z,:kind,CROSS) AS f(a),:joins,JOIN :this,a.b.c.d.e.f.g,:side,LEFT,:using,n,:order,ORDER :expressions,ORDERED :this,1,:nulls_first,True
""".strip(),
)
def test_unnest_subqueries(self):
self.check_file(
"unnest_subqueries",
@ -475,6 +505,18 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
)
def test_scope(self):
ast = parse_one("SELECT IF(a IN UNNEST(b), 1, 0) AS c FROM t", dialect="bigquery")
self.assertEqual(build_scope(ast).columns, [exp.column("a"), exp.column("b")])
many_unions = parse_one(" UNION ALL ".join(["SELECT x FROM t"] * 10000))
scopes_using_traverse = list(build_scope(many_unions).traverse())
scopes_using_traverse_scope = traverse_scope(many_unions)
self.assertEqual(len(scopes_using_traverse), len(scopes_using_traverse_scope))
assert all(
x.expression is y.expression
for x, y in zip(scopes_using_traverse, scopes_using_traverse_scope)
)
sql = """
WITH q AS (
SELECT x.b FROM x
@ -522,7 +564,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(
{
node.sql()
for node, *_ in walk_in_scope(expression.find(exp.Where))
for node in walk_in_scope(expression.find(exp.Where))
if isinstance(node, exp.Column)
},
{"s.b"},
@ -667,6 +709,14 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expressions[0].type.this, exp.DataType.Type.BIGINT)
self.assertEqual(expressions[1].type.this, exp.DataType.Type.DOUBLE)
expressions = annotate_types(
parse_one("SELECT SUM(2 / 3), CAST(2 AS DECIMAL) / 3", dialect="mysql")
).expressions
self.assertEqual(expressions[0].type.this, exp.DataType.Type.DOUBLE)
self.assertEqual(expressions[0].this.type.this, exp.DataType.Type.DOUBLE)
self.assertEqual(expressions[1].type.this, exp.DataType.Type.DECIMAL)
def test_bracket_annotation(self):
expression = annotate_types(parse_one("SELECT A[:]")).expressions[0]
@ -1056,6 +1106,34 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expression.selects[1].type, exp.DataType.build("STRUCT<c int>"))
self.assertEqual(expression.selects[2].type, exp.DataType.build("int"))
self.assertEqual(
annotate_types(
optimizer.qualify.qualify(
parse_one(
"SELECT x FROM UNNEST(GENERATE_DATE_ARRAY('2021-01-01', current_date(), interval 1 day)) AS x"
)
)
)
.selects[0]
.type,
exp.DataType.build("date"),
)
def test_map_annotation(self):
# ToMap annotation
expression = annotate_types(parse_one("SELECT MAP {'x': 1}", read="duckdb"))
self.assertEqual(expression.selects[0].type, exp.DataType.build("MAP(VARCHAR, INT)"))
# Map annotation
expression = annotate_types(
parse_one("SELECT MAP(['key1', 'key2', 'key3'], [10, 20, 30])", read="duckdb")
)
self.assertEqual(expression.selects[0].type, exp.DataType.build("MAP(VARCHAR, INT)"))
# VarMap annotation
expression = annotate_types(parse_one("SELECT MAP('a', 'b')", read="spark"))
self.assertEqual(expression.selects[0].type, exp.DataType.build("MAP(VARCHAR, VARCHAR)"))
def test_recursive_cte(self):
query = parse_one(
"""