1
0
Fork 0
golang-github-pocketbase-dbx/db_test.go
Daniel Baumann 02cacc5b45
Adding upstream version 1.11.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-05-22 10:43:26 +02:00

380 lines
9 KiB
Go

// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"context"
"database/sql"
"errors"
"io/ioutil"
"strings"
"testing"
// @todo change to sqlite
_ "github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"
)
const (
TestDSN = "travis:@/pocketbase_dbx_test?parseTime=true"
FixtureFile = "testdata/mysql.sql"
)
func TestDB_NewFromDB(t *testing.T) {
sqlDB, err := sql.Open("mysql", TestDSN)
if assert.Nil(t, err) {
db := NewFromDB(sqlDB, "mysql")
assert.NotNil(t, db.sqlDB)
assert.NotNil(t, db.FieldMapper)
}
}
func TestDB_Open(t *testing.T) {
db, err := Open("mysql", TestDSN)
assert.Nil(t, err)
if assert.NotNil(t, db) {
assert.NotNil(t, db.sqlDB)
assert.NotNil(t, db.FieldMapper)
db2 := db.Clone()
assert.NotEqual(t, db, db2)
assert.Equal(t, db.driverName, db2.driverName)
ctx := context.Background()
db3 := db.WithContext(ctx)
assert.Equal(t, ctx, db3.ctx)
assert.Equal(t, ctx, db3.Context())
assert.NotEqual(t, db, db3)
}
_, err = Open("xyz", TestDSN)
assert.NotNil(t, err)
}
func TestDB_MustOpen(t *testing.T) {
_, err := MustOpen("mysql", TestDSN)
assert.Nil(t, err)
_, err = MustOpen("mysql", "unknown:x@/test")
assert.NotNil(t, err)
}
func TestDB_Close(t *testing.T) {
db := getDB()
assert.Nil(t, db.Close())
}
func TestDB_DriverName(t *testing.T) {
db := getDB()
assert.Equal(t, "mysql", db.DriverName())
}
func TestDB_QuoteTableName(t *testing.T) {
tests := []struct {
input, output string
}{
{"users", "`users`"},
{"`users`", "`users`"},
{"(select)", "(select)"},
{"{{users}}", "{{users}}"},
{"public.db1.users", "`public`.`db1`.`users`"},
}
db := getDB()
for _, test := range tests {
result := db.QuoteTableName(test.input)
assert.Equal(t, test.output, result, test.input)
}
}
func TestDB_QuoteColumnName(t *testing.T) {
tests := []struct {
input, output string
}{
{"*", "*"},
{"users.*", "`users`.*"},
{"name", "`name`"},
{"`name`", "`name`"},
{"(select)", "(select)"},
{"{{name}}", "{{name}}"},
{"[[name]]", "[[name]]"},
{"public.db1.users", "`public`.`db1`.`users`"},
}
db := getDB()
for _, test := range tests {
result := db.QuoteColumnName(test.input)
assert.Equal(t, test.output, result, test.input)
}
}
func TestDB_ProcessSQL(t *testing.T) {
tests := []struct {
tag string
sql string // original SQL
mysql string // expected MySQL version
postgres string // expected PostgreSQL version
oci8 string // expected OCI version
params []string // expected params
}{
{
"normal case",
`INSERT INTO employee (id, name, age) VALUES ({:id}, {:name}, {:age})`,
`INSERT INTO employee (id, name, age) VALUES (?, ?, ?)`,
`INSERT INTO employee (id, name, age) VALUES ($1, $2, $3)`,
`INSERT INTO employee (id, name, age) VALUES (:p1, :p2, :p3)`,
[]string{"id", "name", "age"},
},
{
"the same placeholder is used twice",
`SELECT * FROM employee WHERE first_name LIKE {:keyword} OR last_name LIKE {:keyword}`,
`SELECT * FROM employee WHERE first_name LIKE ? OR last_name LIKE ?`,
`SELECT * FROM employee WHERE first_name LIKE $1 OR last_name LIKE $2`,
`SELECT * FROM employee WHERE first_name LIKE :p1 OR last_name LIKE :p2`,
[]string{"keyword", "keyword"},
},
{
"non-matching placeholder",
`SELECT * FROM employee WHERE first_name LIKE "{:key?word}" OR last_name LIKE {:keyword}`,
`SELECT * FROM employee WHERE first_name LIKE "{:key?word}" OR last_name LIKE ?`,
`SELECT * FROM employee WHERE first_name LIKE "{:key?word}" OR last_name LIKE $1`,
`SELECT * FROM employee WHERE first_name LIKE "{:key?word}" OR last_name LIKE :p1`,
[]string{"keyword"},
},
{
"quote table/column names",
`SELECT * FROM {{public.user}} WHERE [[user.id]]=1`,
"SELECT * FROM `public`.`user` WHERE `user`.`id`=1",
"SELECT * FROM \"public\".\"user\" WHERE \"user\".\"id\"=1",
"SELECT * FROM \"public\".\"user\" WHERE \"user\".\"id\"=1",
nil,
},
}
mysqlDB := getDB()
mysqlDB.Builder = NewMysqlBuilder(nil, nil)
pgsqlDB := getDB()
pgsqlDB.Builder = NewPgsqlBuilder(nil, nil)
ociDB := getDB()
ociDB.Builder = NewOciBuilder(nil, nil)
for _, test := range tests {
s1, names := mysqlDB.processSQL(test.sql)
assert.Equal(t, test.mysql, s1, test.tag)
s2, _ := pgsqlDB.processSQL(test.sql)
assert.Equal(t, test.postgres, s2, test.tag)
s3, _ := ociDB.processSQL(test.sql)
assert.Equal(t, test.oci8, s3, test.tag)
assert.Equal(t, test.params, names, test.tag)
}
}
func TestDB_Begin(t *testing.T) {
tests := []struct {
makeTx func(db *DB) *Tx
desc string
}{
{
makeTx: func(db *DB) *Tx {
tx, _ := db.Begin()
return tx
},
desc: "Begin",
},
{
makeTx: func(db *DB) *Tx {
sqlTx, _ := db.DB().Begin()
return db.Wrap(sqlTx)
},
desc: "Wrap",
},
{
makeTx: func(db *DB) *Tx {
tx, _ := db.BeginTx(context.Background(), nil)
return tx
},
desc: "BeginTx",
},
}
db := getPreparedDB()
var (
lastID int
name string
tx *Tx
)
db.NewQuery("SELECT MAX(id) FROM item").Row(&lastID)
for _, test := range tests {
t.Log(test.desc)
tx = test.makeTx(db)
_, err1 := tx.Insert("item", Params{
"name": "name1",
}).Execute()
_, err2 := tx.Insert("item", Params{
"name": "name2",
}).Execute()
if err1 == nil && err2 == nil {
tx.Commit()
} else {
t.Errorf("Unexpected TX rollback: %v, %v", err1, err2)
tx.Rollback()
}
q := db.NewQuery("SELECT name FROM item WHERE id={:id}")
q.Bind(Params{"id": lastID + 1}).Row(&name)
assert.Equal(t, "name1", name)
q.Bind(Params{"id": lastID + 2}).Row(&name)
assert.Equal(t, "name2", name)
tx = test.makeTx(db)
_, err3 := tx.NewQuery("DELETE FROM item WHERE id=7").Execute()
_, err4 := tx.NewQuery("DELETE FROM items WHERE id=7").Execute()
if err3 == nil && err4 == nil {
t.Error("Unexpected TX commit")
tx.Commit()
} else {
tx.Rollback()
}
}
}
func TestDB_Transactional(t *testing.T) {
db := getPreparedDB()
var (
lastID int
name string
)
db.NewQuery("SELECT MAX(id) FROM item").Row(&lastID)
err := db.Transactional(func(tx *Tx) error {
_, err := tx.Insert("item", Params{
"name": "name1",
}).Execute()
if err != nil {
return err
}
_, err = tx.Insert("item", Params{
"name": "name2",
}).Execute()
if err != nil {
return err
}
return nil
})
if assert.Nil(t, err) {
q := db.NewQuery("SELECT name FROM item WHERE id={:id}")
q.Bind(Params{"id": lastID + 1}).Row(&name)
assert.Equal(t, "name1", name)
q.Bind(Params{"id": lastID + 2}).Row(&name)
assert.Equal(t, "name2", name)
}
err = db.Transactional(func(tx *Tx) error {
_, err := tx.NewQuery("DELETE FROM item WHERE id=2").Execute()
if err != nil {
return err
}
_, err = tx.NewQuery("DELETE FROM items WHERE id=2").Execute()
if err != nil {
return err
}
return nil
})
if assert.NotNil(t, err) {
db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name)
assert.Equal(t, "Go in Action", name)
}
// Rollback called within Transactional and return error
err = db.Transactional(func(tx *Tx) error {
_, err := tx.NewQuery("DELETE FROM item WHERE id=2").Execute()
if err != nil {
return err
}
_, err = tx.NewQuery("DELETE FROM items WHERE id=2").Execute()
if err != nil {
tx.Rollback()
return err
}
return nil
})
if assert.NotNil(t, err) {
db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name)
assert.Equal(t, "Go in Action", name)
}
// Rollback called within Transactional without returning error
err = db.Transactional(func(tx *Tx) error {
_, err := tx.NewQuery("DELETE FROM item WHERE id=2").Execute()
if err != nil {
return err
}
_, err = tx.NewQuery("DELETE FROM items WHERE id=2").Execute()
if err != nil {
tx.Rollback()
return nil
}
return nil
})
if assert.Nil(t, err) {
db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name)
assert.Equal(t, "Go in Action", name)
}
}
func TestErrors_Error(t *testing.T) {
errs := Errors{}
assert.Equal(t, "", errs.Error())
errs = Errors{errors.New("a")}
assert.Equal(t, "a", errs.Error())
errs = Errors{errors.New("a"), errors.New("b")}
assert.Equal(t, "a\nb", errs.Error())
}
func getDB() *DB {
db, err := Open("mysql", TestDSN)
if err != nil {
panic(err)
}
return db
}
func getPreparedDB() *DB {
db := getDB()
s, err := ioutil.ReadFile(FixtureFile)
if err != nil {
panic(err)
}
lines := strings.Split(string(s), ";")
for _, line := range lines {
if strings.TrimSpace(line) == "" {
continue
}
if _, err := db.NewQuery(line).Execute(); err != nil {
panic(err)
}
}
return db
}
// Naming according to issue 49 ( https://github.com/pocketbase/dbx/issues/49 )
type ArtistDAO struct {
nickname string
}
func (ArtistDAO) TableName() string {
return "artists"
}
func Test_TableNameWithPrefix(t *testing.T) {
db := NewFromDB(nil, "mysql")
db.TableMapper = func(a interface{}) string {
return "tbl_" + GetTableName(a)
}
assert.Equal(t, "tbl_artists", db.TableMapper(ArtistDAO{}))
}