600 lines
16 KiB
Go
600 lines
16 KiB
Go
// Copyright 2016 Qiang Xue. All rights reserved.
|
|
// Use of this source code is governed by a MIT-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package dbx
|
|
|
|
import (
|
|
ss "database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
type City struct {
|
|
ID int
|
|
Name string
|
|
}
|
|
|
|
func TestNewQuery(t *testing.T) {
|
|
db := getDB()
|
|
sql := "SELECT * FROM users WHERE id={:id}"
|
|
q := NewQuery(db, db.sqlDB, sql)
|
|
assert.Equal(t, q.SQL(), sql, "q.SQL()")
|
|
assert.Equal(t, q.rawSQL, "SELECT * FROM users WHERE id=?", "q.RawSQL()")
|
|
|
|
assert.Equal(t, len(q.Params()), 0, "len(q.Params())@1")
|
|
q.Bind(Params{"id": 1})
|
|
assert.Equal(t, len(q.Params()), 1, "len(q.Params())@2")
|
|
}
|
|
|
|
func TestQuery_Execute(t *testing.T) {
|
|
db := getPreparedDB()
|
|
defer db.Close()
|
|
|
|
result, err := db.NewQuery("INSERT INTO item (name) VALUES ('test')").Execute()
|
|
if assert.Nil(t, err) {
|
|
rows, _ := result.RowsAffected()
|
|
assert.Equal(t, rows, int64(1), "Result.RowsAffected()")
|
|
lastID, _ := result.LastInsertId()
|
|
assert.Equal(t, lastID, int64(6), "Result.LastInsertId()")
|
|
}
|
|
}
|
|
|
|
type Customer struct {
|
|
scanned bool
|
|
|
|
ID int
|
|
Email string
|
|
Status int
|
|
Name string
|
|
Address ss.NullString
|
|
}
|
|
|
|
func (m Customer) TableName() string {
|
|
return "customer"
|
|
}
|
|
|
|
func (m *Customer) PostScan() error {
|
|
m.scanned = true
|
|
return nil
|
|
}
|
|
|
|
type CustomerPtr struct {
|
|
ID *int `db:"pk"`
|
|
Email *string
|
|
Status *int
|
|
Name string
|
|
Address *string
|
|
}
|
|
|
|
func (m CustomerPtr) TableName() string {
|
|
return "customer"
|
|
}
|
|
|
|
type CustomerNull struct {
|
|
ID ss.NullInt64 `db:"pk,id"`
|
|
Email ss.NullString
|
|
Status *ss.NullInt64
|
|
Name string
|
|
Address ss.NullString
|
|
}
|
|
|
|
func (m CustomerNull) TableName() string {
|
|
return "customer"
|
|
}
|
|
|
|
type CustomerEmbedded struct {
|
|
Id int
|
|
Email *string
|
|
InnerCustomer
|
|
}
|
|
|
|
func (m CustomerEmbedded) TableName() string {
|
|
return "customer"
|
|
}
|
|
|
|
type CustomerEmbedded2 struct {
|
|
ID int
|
|
Email *string
|
|
Inner InnerCustomer
|
|
}
|
|
|
|
type InnerCustomer struct {
|
|
Status ss.NullInt64
|
|
Name *string
|
|
Address ss.NullString
|
|
}
|
|
|
|
func TestQuery_Rows(t *testing.T) {
|
|
db := getPreparedDB()
|
|
defer db.Close()
|
|
|
|
var (
|
|
sql string
|
|
err error
|
|
)
|
|
|
|
// Query.All()
|
|
var customers []Customer
|
|
sql = `SELECT * FROM customer ORDER BY id`
|
|
err = db.NewQuery(sql).All(&customers)
|
|
if assert.Nil(t, err) {
|
|
assert.Equal(t, len(customers), 3, "len(customers)")
|
|
assert.Equal(t, customers[2].ID, 3, "customers[2].ID")
|
|
assert.Equal(t, customers[2].Email, `user3@example.com`, "customers[2].Email")
|
|
assert.Equal(t, customers[2].Status, 2, "customers[2].Status")
|
|
assert.Equal(t, customers[0].scanned, true, "customers[0].scanned")
|
|
assert.Equal(t, customers[1].scanned, true, "customers[1].scanned")
|
|
assert.Equal(t, customers[2].scanned, true, "customers[2].scanned")
|
|
}
|
|
|
|
// Query.All() with slice of pointers
|
|
var customersPtrSlice []*Customer
|
|
sql = `SELECT * FROM customer ORDER BY id`
|
|
err = db.NewQuery(sql).All(&customersPtrSlice)
|
|
if assert.Nil(t, err) {
|
|
assert.Equal(t, len(customersPtrSlice), 3, "len(customersPtrSlice)")
|
|
assert.Equal(t, customersPtrSlice[2].ID, 3, "customersPtrSlice[2].ID")
|
|
assert.Equal(t, customersPtrSlice[2].Email, `user3@example.com`, "customersPtrSlice[2].Email")
|
|
assert.Equal(t, customersPtrSlice[2].Status, 2, "customersPtrSlice[2].Status")
|
|
assert.Equal(t, customersPtrSlice[0].scanned, true, "customersPtrSlice[0].scanned")
|
|
assert.Equal(t, customersPtrSlice[1].scanned, true, "customersPtrSlice[1].scanned")
|
|
assert.Equal(t, customersPtrSlice[2].scanned, true, "customersPtrSlice[2].scanned")
|
|
}
|
|
|
|
var customers2 []NullStringMap
|
|
err = db.NewQuery(sql).All(&customers2)
|
|
if assert.Nil(t, err) {
|
|
assert.Equal(t, len(customers2), 3, "len(customers2)")
|
|
assert.Equal(t, customers2[1]["id"].String, "2", "customers2[1][id]")
|
|
assert.Equal(t, customers2[1]["email"].String, `user2@example.com`, "customers2[1][email]")
|
|
assert.Equal(t, customers2[1]["status"].String, "1", "customers2[1][status]")
|
|
}
|
|
err = db.NewQuery(sql).All(customers)
|
|
assert.NotNil(t, err)
|
|
|
|
var customers3 []string
|
|
err = db.NewQuery(sql).All(&customers3)
|
|
assert.NotNil(t, err)
|
|
|
|
var customers4 string
|
|
err = db.NewQuery(sql).All(&customers4)
|
|
assert.NotNil(t, err)
|
|
|
|
var customers5 []Customer
|
|
err = db.NewQuery(`SELECT * FROM customer WHERE id=999`).All(&customers5)
|
|
if assert.Nil(t, err) {
|
|
assert.NotNil(t, customers5)
|
|
assert.Zero(t, len(customers5))
|
|
}
|
|
|
|
// One
|
|
var customer Customer
|
|
sql = `SELECT * FROM customer WHERE id={:id}`
|
|
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customer)
|
|
if assert.Nil(t, err) {
|
|
assert.Equal(t, customer.ID, 2, "customer.ID")
|
|
assert.Equal(t, customer.Email, `user2@example.com`, "customer.Email")
|
|
assert.Equal(t, customer.Status, 1, "customer.Status")
|
|
}
|
|
|
|
var customerPtr2 CustomerPtr
|
|
sql = `SELECT id, email, address FROM customer WHERE id=2`
|
|
rows2, err := db.sqlDB.Query(sql)
|
|
defer rows2.Close()
|
|
assert.Nil(t, err)
|
|
rows2.Next()
|
|
err = rows2.Scan(&customerPtr2.ID, &customerPtr2.Email, &customerPtr2.Address)
|
|
if assert.Nil(t, err) {
|
|
assert.Equal(t, *customerPtr2.ID, 2, "customer.ID")
|
|
assert.Equal(t, *customerPtr2.Email, `user2@example.com`)
|
|
assert.Nil(t, customerPtr2.Address)
|
|
}
|
|
|
|
// struct fields are pointers
|
|
var customerPtr CustomerPtr
|
|
sql = `SELECT * FROM customer WHERE id={:id}`
|
|
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customerPtr)
|
|
if assert.Nil(t, err) {
|
|
assert.Equal(t, *customerPtr.ID, 2, "customer.ID")
|
|
assert.Equal(t, *customerPtr.Email, `user2@example.com`, "customer.Email")
|
|
assert.Equal(t, *customerPtr.Status, 1, "customer.Status")
|
|
}
|
|
|
|
// struct fields are null types
|
|
var customerNull CustomerNull
|
|
sql = `SELECT * FROM customer WHERE id={:id}`
|
|
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customerNull)
|
|
if assert.Nil(t, err) {
|
|
assert.Equal(t, customerNull.ID.Int64, int64(2), "customer.ID")
|
|
assert.Equal(t, customerNull.Email.String, `user2@example.com`, "customer.Email")
|
|
assert.Equal(t, customerNull.Status.Int64, int64(1), "customer.Status")
|
|
}
|
|
|
|
// embedded with anonymous struct
|
|
var customerEmbedded CustomerEmbedded
|
|
sql = `SELECT * FROM customer WHERE id={:id}`
|
|
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customerEmbedded)
|
|
if assert.Nil(t, err) {
|
|
assert.Equal(t, customerEmbedded.Id, 2, "customer.ID")
|
|
assert.Equal(t, *customerEmbedded.Email, `user2@example.com`, "customer.Email")
|
|
assert.Equal(t, customerEmbedded.Status.Int64, int64(1), "customer.Status")
|
|
}
|
|
|
|
// embedded with named struct
|
|
var customerEmbedded2 CustomerEmbedded2
|
|
sql = `SELECT id, email, status as "inner.status" FROM customer WHERE id={:id}`
|
|
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customerEmbedded2)
|
|
if assert.Nil(t, err) {
|
|
assert.Equal(t, customerEmbedded2.ID, 2, "customer.ID")
|
|
assert.Equal(t, *customerEmbedded2.Email, `user2@example.com`, "customer.Email")
|
|
assert.Equal(t, customerEmbedded2.Inner.Status.Int64, int64(1), "customer.Status")
|
|
}
|
|
|
|
customer2 := NullStringMap{}
|
|
sql = `SELECT * FROM customer WHERE id={:id}`
|
|
err = db.NewQuery(sql).Bind(Params{"id": 1}).One(customer2)
|
|
if assert.Nil(t, err) {
|
|
assert.Equal(t, customer2["id"].String, "1", "customer2[id]")
|
|
assert.Equal(t, customer2["email"].String, `user1@example.com`, "customer2[email]")
|
|
assert.Equal(t, customer2["status"].String, "1", "customer2[status]")
|
|
}
|
|
|
|
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(customer)
|
|
assert.NotNil(t, err)
|
|
|
|
var customer3 NullStringMap
|
|
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(customer3)
|
|
assert.NotNil(t, err)
|
|
|
|
err = db.NewQuery(sql).Bind(Params{"id": 1}).One(&customer3)
|
|
if assert.Nil(t, err) {
|
|
assert.Equal(t, customer3["id"].String, "1", "customer3[id]")
|
|
}
|
|
|
|
// Rows
|
|
sql = `SELECT * FROM customer ORDER BY id DESC`
|
|
rows, err := db.NewQuery(sql).Rows()
|
|
if assert.Nil(t, err) {
|
|
s := ""
|
|
for rows.Next() {
|
|
rows.ScanStruct(&customer)
|
|
s += customer.Email + ","
|
|
}
|
|
assert.Equal(t, s, "user3@example.com,user2@example.com,user1@example.com,", "Rows().Next()")
|
|
}
|
|
|
|
// FieldMapper
|
|
var a struct {
|
|
MyID string `db:"id"`
|
|
name string
|
|
}
|
|
sql = `SELECT * FROM customer WHERE id=2`
|
|
err = db.NewQuery(sql).One(&a)
|
|
if assert.Nil(t, err) {
|
|
assert.Equal(t, a.MyID, "2", "a.MyID")
|
|
// unexported field is not populated
|
|
assert.Equal(t, a.name, "", "a.name")
|
|
}
|
|
|
|
// prepared statement
|
|
sql = `SELECT * FROM customer WHERE id={:id}`
|
|
q := db.NewQuery(sql).Prepare()
|
|
q.Bind(Params{"id": 1}).One(&customer)
|
|
assert.Equal(t, customer.ID, 1, "prepared@1")
|
|
err = q.Bind(Params{"id": 20}).One(&customer)
|
|
assert.Equal(t, err, ss.ErrNoRows, "prepared@2")
|
|
q.Bind(Params{"id": 3}).One(&customer)
|
|
assert.Equal(t, customer.ID, 3, "prepared@3")
|
|
|
|
sql = `SELECT name FROM customer WHERE id={:id}`
|
|
var name string
|
|
q = db.NewQuery(sql).Prepare()
|
|
q.Bind(Params{"id": 1}).Row(&name)
|
|
assert.Equal(t, name, "user1", "prepared2@1")
|
|
err = q.Bind(Params{"id": 20}).Row(&name)
|
|
assert.Equal(t, err, ss.ErrNoRows, "prepared2@2")
|
|
q.Bind(Params{"id": 3}).Row(&name)
|
|
assert.Equal(t, name, "user3", "prepared2@3")
|
|
|
|
// Query.LastError
|
|
sql = `SELECT * FROM a`
|
|
q = db.NewQuery(sql).Prepare()
|
|
customer.ID = 100
|
|
err = q.Bind(Params{"id": 1}).One(&customer)
|
|
assert.NotEqual(t, err, nil, "LastError@0")
|
|
assert.Equal(t, customer.ID, 100, "LastError@1")
|
|
assert.Equal(t, q.LastError, nil, "LastError@2")
|
|
|
|
// Query.Column
|
|
sql = `SELECT name, id FROM customer ORDER BY id`
|
|
var names []string
|
|
err = db.NewQuery(sql).Column(&names)
|
|
if assert.Nil(t, err) && assert.Equal(t, 3, len(names)) {
|
|
assert.Equal(t, "user1", names[0])
|
|
assert.Equal(t, "user2", names[1])
|
|
assert.Equal(t, "user3", names[2])
|
|
}
|
|
err = db.NewQuery(sql).Column(names)
|
|
assert.NotNil(t, err)
|
|
}
|
|
|
|
func TestQuery_logSQL(t *testing.T) {
|
|
db := getDB()
|
|
q := db.NewQuery("SELECT * FROM users WHERE type={:type} AND id={:id} AND bytes={:bytes}").Bind(Params{
|
|
"id": 1,
|
|
"type": "a",
|
|
"bytes": []byte("test"),
|
|
})
|
|
expected := "SELECT * FROM users WHERE type='a' AND id=1 AND bytes=0x74657374"
|
|
assert.Equal(t, q.logSQL(), expected, "logSQL()")
|
|
}
|
|
|
|
func TestReplacePlaceholders(t *testing.T) {
|
|
tests := []struct {
|
|
ID string
|
|
Placeholders []string
|
|
Params Params
|
|
ExpectedParams string
|
|
HasError bool
|
|
}{
|
|
{"t1", nil, nil, "null", false},
|
|
{"t2", []string{"id", "name"}, Params{"id": 1, "name": "xyz"}, `[1,"xyz"]`, false},
|
|
{"t3", []string{"id", "name"}, Params{"id": 1}, `null`, true},
|
|
{"t4", []string{"id", "name"}, Params{"id": 1, "name": "xyz", "age": 30}, `[1,"xyz"]`, false},
|
|
}
|
|
for _, test := range tests {
|
|
params, err := replacePlaceholders(test.Placeholders, test.Params)
|
|
result, _ := json.Marshal(params)
|
|
assert.Equal(t, string(result), test.ExpectedParams, "params@"+test.ID)
|
|
assert.Equal(t, err != nil, test.HasError, "error@"+test.ID)
|
|
}
|
|
}
|
|
|
|
func TestIssue6(t *testing.T) {
|
|
db := getPreparedDB()
|
|
q := db.Select("*").From("customer").Where(HashExp{"id": 1})
|
|
var customer Customer
|
|
assert.Equal(t, q.One(&customer), nil)
|
|
assert.Equal(t, 1, customer.ID)
|
|
}
|
|
|
|
type User struct {
|
|
ID int64
|
|
Email string
|
|
Created time.Time
|
|
Updated *time.Time
|
|
}
|
|
|
|
func TestIssue13(t *testing.T) {
|
|
db := getPreparedDB()
|
|
var user User
|
|
err := db.Select().From("user").Where(HashExp{"id": 1}).One(&user)
|
|
if assert.Nil(t, err) {
|
|
assert.NotZero(t, user.Created)
|
|
assert.Nil(t, user.Updated)
|
|
}
|
|
|
|
now := time.Now()
|
|
|
|
user2 := User{
|
|
Email: "now@example.com",
|
|
Created: now,
|
|
}
|
|
err = db.Model(&user2).Insert()
|
|
if assert.Nil(t, err) {
|
|
assert.NotZero(t, user2.ID)
|
|
}
|
|
|
|
user3 := User{
|
|
Email: "now@example.com",
|
|
Created: now,
|
|
Updated: &now,
|
|
}
|
|
err = db.Model(&user3).Insert()
|
|
if assert.Nil(t, err) {
|
|
assert.NotZero(t, user2.ID)
|
|
}
|
|
}
|
|
|
|
func TestQueryWithExecHook(t *testing.T) {
|
|
db := getPreparedDB()
|
|
defer db.Close()
|
|
|
|
// error return
|
|
{
|
|
err := db.NewQuery("select * from user").
|
|
WithExecHook(func(q *Query, op func() error) error {
|
|
return errors.New("test")
|
|
}).
|
|
Row()
|
|
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
// Row()
|
|
{
|
|
calls := 0
|
|
err := db.NewQuery("select * from user").
|
|
WithExecHook(func(q *Query, op func() error) error {
|
|
calls++
|
|
return nil
|
|
}).
|
|
Row()
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, 1, calls, "Row()")
|
|
}
|
|
|
|
// One()
|
|
{
|
|
calls := 0
|
|
err := db.NewQuery("select * from user").
|
|
WithExecHook(func(q *Query, op func() error) error {
|
|
calls++
|
|
return nil
|
|
}).
|
|
One(nil)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, 1, calls, "One()")
|
|
}
|
|
|
|
// All()
|
|
{
|
|
calls := 0
|
|
err := db.NewQuery("select * from user").
|
|
WithExecHook(func(q *Query, op func() error) error {
|
|
calls++
|
|
return nil
|
|
}).
|
|
All(nil)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, 1, calls, "All()")
|
|
}
|
|
|
|
// Column()
|
|
{
|
|
calls := 0
|
|
err := db.NewQuery("select * from user").
|
|
WithExecHook(func(q *Query, op func() error) error {
|
|
calls++
|
|
return nil
|
|
}).
|
|
Column(nil)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, 1, calls, "Column()")
|
|
}
|
|
|
|
// Execute()
|
|
{
|
|
calls := 0
|
|
_, err := db.NewQuery("select * from user").
|
|
WithExecHook(func(q *Query, op func() error) error {
|
|
calls++
|
|
return nil
|
|
}).
|
|
Execute()
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, 1, calls, "Execute()")
|
|
}
|
|
|
|
// op call
|
|
{
|
|
calls := 0
|
|
var id int
|
|
err := db.NewQuery("select id from user where id = 2").
|
|
WithExecHook(func(q *Query, op func() error) error {
|
|
calls++
|
|
return op()
|
|
}).
|
|
Row(&id)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, 1, calls, "op hook calls")
|
|
assert.Equal(t, 2, id, "id mismatch")
|
|
}
|
|
}
|
|
|
|
func TestQueryWithOneHook(t *testing.T) {
|
|
db := getPreparedDB()
|
|
defer db.Close()
|
|
|
|
// error return
|
|
{
|
|
err := db.NewQuery("select * from user").
|
|
WithOneHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
|
|
return errors.New("test")
|
|
}).
|
|
One(nil)
|
|
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
// hooks call order
|
|
{
|
|
hookCalls := []string{}
|
|
err := db.NewQuery("select * from user").
|
|
WithExecHook(func(q *Query, op func() error) error {
|
|
hookCalls = append(hookCalls, "exec")
|
|
return op()
|
|
}).
|
|
WithOneHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
|
|
hookCalls = append(hookCalls, "one")
|
|
return nil
|
|
}).
|
|
One(nil)
|
|
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, hookCalls, []string{"exec", "one"})
|
|
}
|
|
|
|
// op call
|
|
{
|
|
calls := 0
|
|
other := User{}
|
|
err := db.NewQuery("select id from user where id = 2").
|
|
WithOneHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
|
|
calls++
|
|
return op(&other)
|
|
}).
|
|
One(nil)
|
|
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, 1, calls, "hook calls")
|
|
assert.Equal(t, int64(2), other.ID, "replaced scan struct")
|
|
}
|
|
}
|
|
|
|
func TestQueryWithAllHook(t *testing.T) {
|
|
db := getPreparedDB()
|
|
defer db.Close()
|
|
|
|
// error return
|
|
{
|
|
err := db.NewQuery("select * from user").
|
|
WithAllHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
|
|
return errors.New("test")
|
|
}).
|
|
All(nil)
|
|
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
// hooks call order
|
|
{
|
|
hookCalls := []string{}
|
|
err := db.NewQuery("select * from user").
|
|
WithExecHook(func(q *Query, op func() error) error {
|
|
hookCalls = append(hookCalls, "exec")
|
|
return op()
|
|
}).
|
|
WithAllHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
|
|
hookCalls = append(hookCalls, "all")
|
|
return nil
|
|
}).
|
|
All(nil)
|
|
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, hookCalls, []string{"exec", "all"})
|
|
}
|
|
|
|
// op call
|
|
{
|
|
calls := 0
|
|
other := []User{}
|
|
err := db.NewQuery("select id from user order by id asc").
|
|
WithAllHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
|
|
calls++
|
|
return op(&other)
|
|
}).
|
|
All(nil)
|
|
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, 1, calls, "hook calls")
|
|
assert.Equal(t, 2, len(other), "users length")
|
|
assert.Equal(t, int64(1), other[0].ID, "user 1 id check")
|
|
assert.Equal(t, int64(2), other[1].ID, "user 2 id check")
|
|
}
|
|
}
|