138 lines
3.5 KiB
Python
138 lines
3.5 KiB
Python
|
import pytest
|
||
|
from sqlparse import parse
|
||
|
from pgcli.packages.parseutils.ctes import (
|
||
|
token_start_pos,
|
||
|
extract_ctes,
|
||
|
extract_column_names as _extract_column_names,
|
||
|
)
|
||
|
|
||
|
|
||
|
def extract_column_names(sql):
|
||
|
p = parse(sql)[0]
|
||
|
return _extract_column_names(p)
|
||
|
|
||
|
|
||
|
def test_token_str_pos():
|
||
|
sql = "SELECT * FROM xxx"
|
||
|
p = parse(sql)[0]
|
||
|
idx = p.token_index(p.tokens[-1])
|
||
|
assert token_start_pos(p.tokens, idx) == len("SELECT * FROM ")
|
||
|
|
||
|
sql = "SELECT * FROM \nxxx"
|
||
|
p = parse(sql)[0]
|
||
|
idx = p.token_index(p.tokens[-1])
|
||
|
assert token_start_pos(p.tokens, idx) == len("SELECT * FROM \n")
|
||
|
|
||
|
|
||
|
def test_single_column_name_extraction():
|
||
|
sql = "SELECT abc FROM xxx"
|
||
|
assert extract_column_names(sql) == ("abc",)
|
||
|
|
||
|
|
||
|
def test_aliased_single_column_name_extraction():
|
||
|
sql = "SELECT abc def FROM xxx"
|
||
|
assert extract_column_names(sql) == ("def",)
|
||
|
|
||
|
|
||
|
def test_aliased_expression_name_extraction():
|
||
|
sql = "SELECT 99 abc FROM xxx"
|
||
|
assert extract_column_names(sql) == ("abc",)
|
||
|
|
||
|
|
||
|
def test_multiple_column_name_extraction():
|
||
|
sql = "SELECT abc, def FROM xxx"
|
||
|
assert extract_column_names(sql) == ("abc", "def")
|
||
|
|
||
|
|
||
|
def test_missing_column_name_handled_gracefully():
|
||
|
sql = "SELECT abc, 99 FROM xxx"
|
||
|
assert extract_column_names(sql) == ("abc",)
|
||
|
|
||
|
sql = "SELECT abc, 99, def FROM xxx"
|
||
|
assert extract_column_names(sql) == ("abc", "def")
|
||
|
|
||
|
|
||
|
def test_aliased_multiple_column_name_extraction():
|
||
|
sql = "SELECT abc def, ghi jkl FROM xxx"
|
||
|
assert extract_column_names(sql) == ("def", "jkl")
|
||
|
|
||
|
|
||
|
def test_table_qualified_column_name_extraction():
|
||
|
sql = "SELECT abc.def, ghi.jkl FROM xxx"
|
||
|
assert extract_column_names(sql) == ("def", "jkl")
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"sql",
|
||
|
[
|
||
|
"INSERT INTO foo (x, y, z) VALUES (5, 6, 7) RETURNING x, y",
|
||
|
"DELETE FROM foo WHERE x > y RETURNING x, y",
|
||
|
"UPDATE foo SET x = 9 RETURNING x, y",
|
||
|
],
|
||
|
)
|
||
|
def test_extract_column_names_from_returning_clause(sql):
|
||
|
assert extract_column_names(sql) == ("x", "y")
|
||
|
|
||
|
|
||
|
def test_simple_cte_extraction():
|
||
|
sql = "WITH a AS (SELECT abc FROM xxx) SELECT * FROM a"
|
||
|
start_pos = len("WITH a AS ")
|
||
|
stop_pos = len("WITH a AS (SELECT abc FROM xxx)")
|
||
|
ctes, remainder = extract_ctes(sql)
|
||
|
|
||
|
assert tuple(ctes) == (("a", ("abc",), start_pos, stop_pos),)
|
||
|
assert remainder.strip() == "SELECT * FROM a"
|
||
|
|
||
|
|
||
|
def test_cte_extraction_around_comments():
|
||
|
sql = """--blah blah blah
|
||
|
WITH a AS (SELECT abc def FROM x)
|
||
|
SELECT * FROM a"""
|
||
|
start_pos = len(
|
||
|
"""--blah blah blah
|
||
|
WITH a AS """
|
||
|
)
|
||
|
stop_pos = len(
|
||
|
"""--blah blah blah
|
||
|
WITH a AS (SELECT abc def FROM x)"""
|
||
|
)
|
||
|
|
||
|
ctes, remainder = extract_ctes(sql)
|
||
|
assert tuple(ctes) == (("a", ("def",), start_pos, stop_pos),)
|
||
|
assert remainder.strip() == "SELECT * FROM a"
|
||
|
|
||
|
|
||
|
def test_multiple_cte_extraction():
|
||
|
sql = """WITH
|
||
|
x AS (SELECT abc, def FROM x),
|
||
|
y AS (SELECT ghi, jkl FROM y)
|
||
|
SELECT * FROM a, b"""
|
||
|
|
||
|
start1 = len(
|
||
|
"""WITH
|
||
|
x AS """
|
||
|
)
|
||
|
|
||
|
stop1 = len(
|
||
|
"""WITH
|
||
|
x AS (SELECT abc, def FROM x)"""
|
||
|
)
|
||
|
|
||
|
start2 = len(
|
||
|
"""WITH
|
||
|
x AS (SELECT abc, def FROM x),
|
||
|
y AS """
|
||
|
)
|
||
|
|
||
|
stop2 = len(
|
||
|
"""WITH
|
||
|
x AS (SELECT abc, def FROM x),
|
||
|
y AS (SELECT ghi, jkl FROM y)"""
|
||
|
)
|
||
|
|
||
|
ctes, remainder = extract_ctes(sql)
|
||
|
assert tuple(ctes) == (
|
||
|
("x", ("abc", "def"), start1, stop1),
|
||
|
("y", ("ghi", "jkl"), start2, stop2),
|
||
|
)
|