1
0
Fork 0
sqlglot/tests/dialects/test_postgres.py

165 lines
7.9 KiB
Python
Raw Normal View History

from sqlglot import ParseError, transpile
from tests.dialects.test_dialect import Validator
class TestPostgres(Validator):
dialect = "postgres"
def test_ddl(self):
self.validate_all(
"CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)",
write={"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"},
)
self.validate_all(
"CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)",
write={
"postgres": "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)"
},
)
self.validate_all(
"CREATE TABLE products (product_no INT, name TEXT, price DECIMAL, UNIQUE (product_no, name))",
write={
"postgres": "CREATE TABLE products (product_no INT, name TEXT, price DECIMAL, UNIQUE (product_no, name))"
},
)
self.validate_all(
"CREATE TABLE products ("
"product_no INT UNIQUE,"
" name TEXT,"
" price DECIMAL CHECK (price > 0),"
" discounted_price DECIMAL CONSTRAINT positive_discount CHECK (discounted_price > 0),"
" CHECK (product_no > 1),"
" CONSTRAINT valid_discount CHECK (price > discounted_price))",
write={
"postgres": "CREATE TABLE products ("
"product_no INT UNIQUE,"
" name TEXT,"
" price DECIMAL CHECK (price > 0),"
" discounted_price DECIMAL CONSTRAINT positive_discount CHECK (discounted_price > 0),"
" CHECK (product_no > 1),"
" CONSTRAINT valid_discount CHECK (price > discounted_price))"
},
)
self.validate_all(
"CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)",
write={"postgres": "CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)"},
)
self.validate_all(
"CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)",
write={"postgres": "CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)"},
)
with self.assertRaises(ParseError):
transpile("CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres")
with self.assertRaises(ParseError):
transpile(
"CREATE TABLE products (price DECIMAL, CHECK price > 1)",
read="postgres",
)
def test_postgres(self):
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END")
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END")
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END")
self.validate_identity('SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')')
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')")
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')")
self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))")
self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')")
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
self.validate_all(
"CREATE TABLE x (a UUID, b BYTEA)",
write={
"presto": "CREATE TABLE x (a UUID, b VARBINARY)",
"hive": "CREATE TABLE x (a UUID, b BINARY)",
"spark": "CREATE TABLE x (a UUID, b BINARY)",
},
)
self.validate_all(
"SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS 1 PRECEDING)",
write={
"postgres": "SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS BETWEEN 1 PRECEDING AND CURRENT ROW)",
},
)
self.validate_all(
"SELECT * FROM x FETCH 1 ROW",
write={
"postgres": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY",
"presto": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY",
"hive": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY",
"spark": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY",
},
)
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
write={
"postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname",
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname",
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
},
)
self.validate_all(
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END",
write={
"hive": "SELECT CASE WHEN SUBSTRING('abcdefg', 1, 2) IN ('ab') THEN 1 ELSE 0 END",
"spark": "SELECT CASE WHEN SUBSTRING('abcdefg', 1, 2) IN ('ab') THEN 1 ELSE 0 END",
},
)
self.validate_all(
"SELECT * FROM x WHERE SUBSTRING(col1 FROM 3 + LENGTH(col1) - 10 FOR 10) IN (col2)",
write={
"hive": "SELECT * FROM x WHERE SUBSTRING(col1, 3 + LENGTH(col1) - 10, 10) IN (col2)",
"spark": "SELECT * FROM x WHERE SUBSTRING(col1, 3 + LENGTH(col1) - 10, 10) IN (col2)",
},
)
self.validate_all(
"SELECT SUBSTRING(CAST(2022 AS CHAR(4)) || LPAD(CAST(3 AS CHAR(2)), 2, '0') FROM 3 FOR 4)",
read={
"postgres": "SELECT SUBSTRING(2022::CHAR(4) || LPAD(3::CHAR(2), 2, '0') FROM 3 FOR 4)",
},
)
self.validate_all(
"SELECT TRIM(BOTH ' XXX ')",
write={
"mysql": "SELECT TRIM(' XXX ')",
"postgres": "SELECT TRIM(' XXX ')",
"hive": "SELECT TRIM(' XXX ')",
},
)
self.validate_all(
"TRIM(LEADING FROM ' XXX ')",
write={
"mysql": "LTRIM(' XXX ')",
"postgres": "LTRIM(' XXX ')",
"hive": "LTRIM(' XXX ')",
"presto": "LTRIM(' XXX ')",
},
)
self.validate_all(
"TRIM(TRAILING FROM ' XXX ')",
write={
"mysql": "RTRIM(' XXX ')",
"postgres": "RTRIM(' XXX ')",
"hive": "RTRIM(' XXX ')",
"presto": "RTRIM(' XXX ')",
},
)
self.validate_all(
"SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss",
read={"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"},
)
self.validate_all(
"SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL",
read={
"postgres": "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL",
},
)
self.validate_all(
"SELECT p1.id, p2.id, v1, v2 FROM polygons AS p1, polygons AS p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id <> p2.id",
read={
"postgres": "SELECT p1.id, p2.id, v1, v2 FROM polygons p1, polygons p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id != p2.id",
},
)