Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
d4fe7bdb16
commit
90988d8258
127 changed files with 73384 additions and 73067 deletions
|
@ -103,6 +103,10 @@ class TestOptimizer(unittest.TestCase):
|
|||
"d": "TEXT",
|
||||
"e": "TEXT",
|
||||
},
|
||||
"temporal": {
|
||||
"d": "DATE",
|
||||
"t": "DATETIME",
|
||||
},
|
||||
}
|
||||
|
||||
def check_file(self, file, func, pretty=False, execute=False, set_dialect=False, **kwargs):
|
||||
|
@ -179,6 +183,18 @@ class TestOptimizer(unittest.TestCase):
|
|||
)
|
||||
|
||||
def test_qualify_tables(self):
|
||||
self.assertEqual(
|
||||
optimizer.qualify_tables.qualify_tables(
|
||||
parse_one("select a from b"), catalog="catalog"
|
||||
).sql(),
|
||||
"SELECT a FROM b AS b",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
optimizer.qualify_tables.qualify_tables(parse_one("select a from b"), db='"DB"').sql(),
|
||||
'SELECT a FROM "DB".b AS b',
|
||||
)
|
||||
|
||||
self.check_file(
|
||||
"qualify_tables",
|
||||
optimizer.qualify_tables.qualify_tables,
|
||||
|
@ -282,6 +298,13 @@ class TestOptimizer(unittest.TestCase):
|
|||
|
||||
self.assertEqual(optimizer.normalize_identifiers.normalize_identifiers("a%").sql(), '"a%"')
|
||||
|
||||
def test_quote_identifiers(self):
|
||||
self.check_file(
|
||||
"quote_identifiers",
|
||||
optimizer.qualify_columns.quote_identifiers,
|
||||
set_dialect=True,
|
||||
)
|
||||
|
||||
def test_pushdown_projection(self):
|
||||
self.check_file("pushdown_projections", pushdown_projections, schema=self.schema)
|
||||
|
||||
|
@ -300,8 +323,8 @@ class TestOptimizer(unittest.TestCase):
|
|||
safe_concat = parse_one("CONCAT('a', x, 'b', 'c')")
|
||||
simplified_safe_concat = optimizer.simplify.simplify(safe_concat)
|
||||
|
||||
self.assertIs(type(simplified_concat), exp.Concat)
|
||||
self.assertIs(type(simplified_safe_concat), exp.SafeConcat)
|
||||
self.assertEqual(simplified_concat.args["safe"], False)
|
||||
self.assertEqual(simplified_safe_concat.args["safe"], True)
|
||||
|
||||
self.assertEqual("CONCAT('a', x, 'bc')", simplified_concat.sql(dialect="presto"))
|
||||
self.assertEqual("CONCAT('a', x, 'bc')", simplified_safe_concat.sql())
|
||||
|
@ -561,6 +584,19 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
self.assertEqual(expression.right.this.left.type.this, exp.DataType.Type.INT)
|
||||
self.assertEqual(expression.right.this.right.type.this, exp.DataType.Type.INT)
|
||||
|
||||
for numeric_type in ("BIGINT", "DOUBLE", "INT"):
|
||||
query = f"SELECT '1' + CAST(x AS {numeric_type})"
|
||||
expression = annotate_types(parse_one(query)).expressions[0]
|
||||
self.assertEqual(expression.type, exp.DataType.build(numeric_type))
|
||||
|
||||
def test_typeddiv_annotation(self):
|
||||
expressions = annotate_types(
|
||||
parse_one("SELECT 2 / 3, 2 / 3.0", dialect="presto")
|
||||
).expressions
|
||||
|
||||
self.assertEqual(expressions[0].type.this, exp.DataType.Type.BIGINT)
|
||||
self.assertEqual(expressions[1].type.this, exp.DataType.Type.DOUBLE)
|
||||
|
||||
def test_bracket_annotation(self):
|
||||
expression = annotate_types(parse_one("SELECT A[:]")).expressions[0]
|
||||
|
||||
|
@ -609,45 +645,60 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
"b": "DATETIME",
|
||||
}
|
||||
}
|
||||
for sql, expected_type, *expected_sql in [
|
||||
for sql, expected_type in [
|
||||
(
|
||||
"SELECT '2023-01-01' + INTERVAL '1' DAY",
|
||||
exp.DataType.Type.DATE,
|
||||
"SELECT CAST('2023-01-01' AS DATE) + INTERVAL '1' DAY",
|
||||
),
|
||||
(
|
||||
"SELECT '2023-01-01' + INTERVAL '1' HOUR",
|
||||
exp.DataType.Type.DATETIME,
|
||||
"SELECT CAST('2023-01-01' AS DATETIME) + INTERVAL '1' HOUR",
|
||||
),
|
||||
(
|
||||
"SELECT '2023-01-01 00:00:01' + INTERVAL '1' HOUR",
|
||||
exp.DataType.Type.DATETIME,
|
||||
"SELECT CAST('2023-01-01 00:00:01' AS DATETIME) + INTERVAL '1' HOUR",
|
||||
),
|
||||
("SELECT 'nonsense' + INTERVAL '1' DAY", exp.DataType.Type.UNKNOWN),
|
||||
("SELECT x.a + INTERVAL '1' DAY FROM x AS x", exp.DataType.Type.DATE),
|
||||
("SELECT x.a + INTERVAL '1' HOUR FROM x AS x", exp.DataType.Type.DATETIME),
|
||||
(
|
||||
"SELECT x.a + INTERVAL '1' HOUR FROM x AS x",
|
||||
exp.DataType.Type.DATETIME,
|
||||
),
|
||||
("SELECT x.b + INTERVAL '1' DAY FROM x AS x", exp.DataType.Type.DATETIME),
|
||||
("SELECT x.b + INTERVAL '1' HOUR FROM x AS x", exp.DataType.Type.DATETIME),
|
||||
(
|
||||
"SELECT DATE_ADD('2023-01-01', 1, 'DAY')",
|
||||
exp.DataType.Type.DATE,
|
||||
"SELECT DATE_ADD(CAST('2023-01-01' AS DATE), 1, 'DAY')",
|
||||
),
|
||||
(
|
||||
"SELECT DATE_ADD('2023-01-01 00:00:00', 1, 'DAY')",
|
||||
exp.DataType.Type.DATETIME,
|
||||
"SELECT DATE_ADD(CAST('2023-01-01 00:00:00' AS DATETIME), 1, 'DAY')",
|
||||
),
|
||||
("SELECT DATE_ADD(x.a, 1, 'DAY') FROM x AS x", exp.DataType.Type.DATE),
|
||||
("SELECT DATE_ADD(x.a, 1, 'HOUR') FROM x AS x", exp.DataType.Type.DATETIME),
|
||||
(
|
||||
"SELECT DATE_ADD(x.a, 1, 'HOUR') FROM x AS x",
|
||||
exp.DataType.Type.DATETIME,
|
||||
),
|
||||
("SELECT DATE_ADD(x.b, 1, 'DAY') FROM x AS x", exp.DataType.Type.DATETIME),
|
||||
("SELECT DATE_TRUNC('DAY', x.a) FROM x AS x", exp.DataType.Type.DATE),
|
||||
("SELECT DATE_TRUNC('DAY', x.b) FROM x AS x", exp.DataType.Type.DATETIME),
|
||||
(
|
||||
"SELECT DATE_TRUNC('SECOND', x.a) FROM x AS x",
|
||||
exp.DataType.Type.DATETIME,
|
||||
),
|
||||
(
|
||||
"SELECT DATE_TRUNC('DAY', '2023-01-01') FROM x AS x",
|
||||
exp.DataType.Type.DATE,
|
||||
),
|
||||
(
|
||||
"SELECT DATEDIFF('2023-01-01', '2023-01-02', DAY) FROM x AS x",
|
||||
exp.DataType.Type.INT,
|
||||
),
|
||||
]:
|
||||
with self.subTest(sql):
|
||||
expression = annotate_types(parse_one(sql), schema=schema)
|
||||
self.assertEqual(expected_type, expression.expressions[0].type.this)
|
||||
self.assertEqual(expected_sql[0] if expected_sql else sql, expression.sql())
|
||||
self.assertEqual(sql, expression.sql())
|
||||
|
||||
def test_lateral_annotation(self):
|
||||
expression = optimizer.optimize(
|
||||
|
@ -843,6 +894,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
("MAX", "cold"): exp.DataType.Type.DATE,
|
||||
("COUNT", "colb"): exp.DataType.Type.BIGINT,
|
||||
("STDDEV", "cola"): exp.DataType.Type.DOUBLE,
|
||||
("ABS", "cola"): exp.DataType.Type.SMALLINT,
|
||||
("ABS", "colb"): exp.DataType.Type.FLOAT,
|
||||
}
|
||||
|
||||
for (func, col), target_type in tests.items():
|
||||
|
@ -989,10 +1042,3 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
query = parse_one("select a.b:c from d", read="snowflake")
|
||||
qualified = optimizer.qualify.qualify(query)
|
||||
self.assertEqual(qualified.expressions[0].alias, "c")
|
||||
|
||||
def test_qualify_tables_no_schema(self):
|
||||
query = parse_one("select a from b")
|
||||
self.assertEqual(
|
||||
optimizer.qualify_tables.qualify_tables(query, catalog="catalog").sql(),
|
||||
"SELECT a FROM b AS b",
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue