93 lines
4.1 KiB
Python
93 lines
4.1 KiB
Python
from sqlglot import ParseError, transpile
|
|
from tests.dialects.test_dialect import Validator
|
|
|
|
|
|
class TestPostgres(Validator):
|
|
dialect = "postgres"
|
|
|
|
def test_ddl(self):
|
|
self.validate_all(
|
|
"CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)",
|
|
write={
|
|
"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"
|
|
},
|
|
)
|
|
self.validate_all(
|
|
"CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)",
|
|
write={
|
|
"postgres": "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)"
|
|
},
|
|
)
|
|
self.validate_all(
|
|
"CREATE TABLE products (product_no INT, name TEXT, price DECIMAL, UNIQUE (product_no, name))",
|
|
write={
|
|
"postgres": "CREATE TABLE products (product_no INT, name TEXT, price DECIMAL, UNIQUE (product_no, name))"
|
|
},
|
|
)
|
|
self.validate_all(
|
|
"CREATE TABLE products ("
|
|
"product_no INT UNIQUE,"
|
|
" name TEXT,"
|
|
" price DECIMAL CHECK (price > 0),"
|
|
" discounted_price DECIMAL CONSTRAINT positive_discount CHECK (discounted_price > 0),"
|
|
" CHECK (product_no > 1),"
|
|
" CONSTRAINT valid_discount CHECK (price > discounted_price))",
|
|
write={
|
|
"postgres": "CREATE TABLE products ("
|
|
"product_no INT UNIQUE,"
|
|
" name TEXT,"
|
|
" price DECIMAL CHECK (price > 0),"
|
|
" discounted_price DECIMAL CONSTRAINT positive_discount CHECK (discounted_price > 0),"
|
|
" CHECK (product_no > 1),"
|
|
" CONSTRAINT valid_discount CHECK (price > discounted_price))"
|
|
},
|
|
)
|
|
|
|
with self.assertRaises(ParseError):
|
|
transpile(
|
|
"CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres"
|
|
)
|
|
with self.assertRaises(ParseError):
|
|
transpile(
|
|
"CREATE TABLE products (price DECIMAL, CHECK price > 1)",
|
|
read="postgres",
|
|
)
|
|
|
|
def test_postgres(self):
|
|
self.validate_all(
|
|
"CREATE TABLE x (a INT SERIAL)",
|
|
read={"sqlite": "CREATE TABLE x (a INTEGER AUTOINCREMENT)"},
|
|
write={"sqlite": "CREATE TABLE x (a INTEGER AUTOINCREMENT)"},
|
|
)
|
|
self.validate_all(
|
|
"CREATE TABLE x (a UUID, b BYTEA)",
|
|
write={
|
|
"presto": "CREATE TABLE x (a UUID, b VARBINARY)",
|
|
"hive": "CREATE TABLE x (a UUID, b BINARY)",
|
|
"spark": "CREATE TABLE x (a UUID, b BINARY)",
|
|
},
|
|
)
|
|
self.validate_all(
|
|
"SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS 1 PRECEDING)",
|
|
write={
|
|
"postgres": "SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS BETWEEN 1 PRECEDING AND CURRENT ROW)",
|
|
},
|
|
)
|
|
self.validate_all(
|
|
"SELECT * FROM x FETCH 1 ROW",
|
|
write={
|
|
"postgres": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY",
|
|
"presto": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY",
|
|
"hive": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY",
|
|
"spark": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY",
|
|
},
|
|
)
|
|
self.validate_all(
|
|
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
|
|
write={
|
|
"postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname",
|
|
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname",
|
|
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
|
|
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
|
|
},
|
|
)
|