1
0
Fork 0

Adding upstream version 1.11.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-05-22 10:43:26 +02:00
parent 40df376a7f
commit 02cacc5b45
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
37 changed files with 7317 additions and 0 deletions

26
.gitignore vendored Normal file
View file

@ -0,0 +1,26 @@
.idea
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof

22
.travis.yml Normal file
View file

@ -0,0 +1,22 @@
dist: bionic
language: go
go:
- 1.13.x
services:
- mysql
install:
- go get golang.org/x/tools/cmd/cover
- go get github.com/mattn/goveralls
- go get golang.org/x/lint/golint
before_script:
- mysql -e 'CREATE DATABASE pocketbase_dbx_test;';
script:
- test -z "`gofmt -l -d .`"
- go test -v -covermode=count -coverprofile=coverage.out
- $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci

17
LICENSE Normal file
View file

@ -0,0 +1,17 @@
The MIT License (MIT)
Copyright (c) 2016, Qiang Xue
Permission is hereby granted, free of charge, to any person obtaining a copy of this software
and associated documentation files (the "Software"), to deal in the Software without restriction,
including without limitation the rights to use, copy, modify, merge, publish, distribute,
sublicense, and/or sell copies of the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or
substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

748
README.md Normal file
View file

@ -0,0 +1,748 @@
dbx
[![Go Report Card](https://goreportcard.com/badge/github.com/pocketbase/dbx)](https://goreportcard.com/report/github.com/pocketbase/dbx)
[![GoDoc](https://godoc.org/github.com/pocketbase/dbx?status.svg)](https://pkg.go.dev/github.com/pocketbase/dbx)
================================================================================
> ⚠️ This is a maintained fork of [go-ozzo/ozzo-dbx](https://github.com/go-ozzo/ozzo-dbx) (see [#103](https://github.com/go-ozzo/ozzo-dbx/issues/103)).
>
> Currently, the changes are primarily related to better SQLite support and some other minor improvements, implementing [#99](https://github.com/go-ozzo/ozzo-dbx/pull/99), [#100](https://github.com/go-ozzo/ozzo-dbx/pull/100) and [#102](https://github.com/go-ozzo/ozzo-dbx/pull/102).
## Summary
- [Description](#description)
- [Requirements](#requirements)
- [Installation](#installation)
- [Supported Databases](#supported-databases)
- [Getting Started](#getting-started)
- [Connecting to Database](#connecting-to-database)
- [Executing Queries](#executing-queries)
- [Binding Parameters](#binding-parameters)
- [Building Queries](#building-queries)
- [Building SELECT Queries](#building-select-queries)
- [Building Query Conditions](#building-query-conditions)
- [Building Data Manipulation Queries](#building-data-manipulation-queries)
- [Building Schema Manipulation Queries](#building-schema-manipulation-queries)
- [CRUD Operations](#crud-operations)
- [Create](#create)
- [Read](#read)
- [Update](#update)
- [Delete](#delete)
- [Null Handling](#null-handling)
- [Quoting Table and Column Names](#quoting-table-and-column-names)
- [Using Transactions](#using-transactions)
- [Logging Executed SQL Statements](#logging-executed-sql-statements)
- [Supporting New Databases](#supporting-new-databases)
## Description
`dbx` is a Go package that enhances the standard `database/sql` package by providing powerful data retrieval methods
as well as DB-agnostic query building capabilities. `dbx` is not an ORM. It has the following features:
* Populating data into structs and NullString maps
* Named parameter binding
* DB-agnostic query building methods, including SELECT queries, data manipulation queries, and schema manipulation queries
* Inserting, updating, and deleting model structs
* Powerful query condition building
* Open architecture allowing addition of new database support or customization of existing support
* Logging executed SQL statements
* Supporting major relational databases
For an example on how this library is used in an application, please refer to [go-rest-api](https://github.com/qiangxue/go-rest-api) which is a starter kit for building RESTful APIs in Go.
## Requirements
Go 1.13 or above.
## Installation
Run the following command to install the package:
```
go get github.com/pocketbase/dbx
```
In addition, install the specific DB driver package for the kind of database to be used. Please refer to
[SQL database drivers](https://github.com/golang/go/wiki/SQLDrivers) for a complete list. For example, if you are
using MySQL, you may install the following package:
```sh
go get github.com/go-sql-driver/mysql
```
and import it in your main code like the following:
```go
import _ "github.com/go-sql-driver/mysql"
```
## Supported Databases
The following databases are fully supported out of box:
* SQLite
* MySQL
* PostgreSQL
* MS SQL Server (2012 or above)
* Oracle
For other databases, the query building feature may not work as expected. You can create a custom builder to
solve the problem. Please see the last section for more details.
## Getting Started
The following code snippet shows how you can use this package in order to access data from a MySQL database.
```go
package main
import (
"github.com/pocketbase/dbx"
_ "github.com/go-sql-driver/mysql"
)
func main() {
db, _ := dbx.Open("mysql", "user:pass@/example")
// create a new query
q := db.NewQuery("SELECT id, name FROM users LIMIT 10")
// fetch all rows into a struct array
var users []struct {
ID, Name string
}
q.All(&users)
// fetch a single row into a struct
var user struct {
ID, Name string
}
q.One(&user)
// fetch a single row into a string map
data := dbx.NullStringMap{}
q.One(data)
// fetch row by row
rows2, _ := q.Rows()
for rows2.Next() {
rows2.ScanStruct(&user)
// rows.ScanMap(data)
// rows.Scan(&id, &name)
}
}
```
And the following example shows how to use the query building capability of this package.
```go
package main
import (
"github.com/pocketbase/dbx"
_ "github.com/go-sql-driver/mysql"
)
func main() {
db, _ := dbx.Open("mysql", "user:pass@/example")
// build a SELECT query
// SELECT `id`, `name` FROM `users` WHERE `name` LIKE '%Charles%' ORDER BY `id`
q := db.Select("id", "name").
From("users").
Where(dbx.Like("name", "Charles")).
OrderBy("id")
// fetch all rows into a struct array
var users []struct {
ID, Name string
}
q.All(&users)
// build an INSERT query
// INSERT INTO `users` (`name`) VALUES ('James')
db.Insert("users", dbx.Params{
"name": "James",
}).Execute()
}
```
## Connecting to Database
To connect to a database, call `dbx.Open()` in the same way as you would do with the `Open()` method in `database/sql`.
```go
db, err := dbx.Open("mysql", "user:pass@hostname/db_name")
```
The method returns a `dbx.DB` instance which can be used to create and execute DB queries. Note that the method
does not really establish a connection until a query is made using the returned `dbx.DB` instance. It also
does not check the correctness of the data source name either. Call `dbx.MustOpen()` to make sure the data
source name is correct.
## Executing Queries
To execute a SQL statement, first create a `dbx.Query` instance by calling `DB.NewQuery()` with the SQL statement
to be executed. And then call `Query.Execute()` to execute the query if the query is not meant to retrieving data.
For example,
```go
q := db.NewQuery("UPDATE users SET status=1 WHERE id=100")
result, err := q.Execute()
```
If the SQL statement does retrieve data (e.g. a SELECT statement), one of the following methods should be called,
which will execute the query and populate the result into the specified variable(s).
* `Query.All()`: populate all rows of the result into a slice of structs or `NullString` maps.
* `Query.One()`: populate the first row of the result into a struct or a `NullString` map.
* `Query.Column()`: populate the first column of the result into a slice.
* `Query.Row()`: populate the first row of the result into a list of variables, one for each returning column.
* `Query.Rows()`: returns a `dbx.Rows` instance to allow retrieving data row by row.
For example,
```go
type User struct {
ID int
Name string
}
var (
users []User
user User
row dbx.NullStringMap
id int
name string
err error
)
q := db.NewQuery("SELECT id, name FROM users LIMIT 10")
// populate all rows into a User slice
err = q.All(&users)
fmt.Println(users[0].ID, users[0].Name)
// populate the first row into a User struct
err = q.One(&user)
fmt.Println(user.ID, user.Name)
// populate the first row into a NullString map
err = q.One(&row)
fmt.Println(row["id"], row["name"])
var ids []int
err = q.Column(&ids)
fmt.Println(ids)
// populate the first row into id and name
err = q.Row(&id, &name)
// populate data row by row
rows, _ := q.Rows()
for rows.Next() {
_ = rows.ScanMap(&row)
}
```
When populating a struct, the following rules are used to determine which columns should go into which struct fields:
* Only exported struct fields can be populated.
* A field receives data if its name is mapped to a column according to the field mapping function `Query.FieldMapper`.
The default field mapping function separates words in a field name by underscores and turns them into lower case.
For example, a field name `FirstName` will be mapped to the column name `first_name`, and `MyID` to `my_id`.
* If a field has a `db` tag, the tag value will be used as the corresponding column name. If the `db` tag is a dash `-`,
it means the field should NOT be populated.
* For anonymous fields that are of struct type, they will be expanded and their component fields will be populated
according to the rules described above.
* For named fields that are of struct type, they will also be expanded. But their component fields will be prefixed
with the struct names when being populated.
An exception to the above struct expansion is that when a struct type implements `sql.Scanner` or when it is `time.Time`.
In this case, the field will be populated as a whole by the DB driver. Also, if a field is a pointer to some type,
the field will be allocated memory and populated with the query result if it is not null.
The following example shows how fields are populated according to the rules above:
```go
type User struct {
id int
Type int `db:"-"`
MyName string `db:"name"`
Profile
Address Address `db:"addr"`
}
type Profile struct {
Age int
}
type Address struct {
City string
}
```
* `User.id`: not populated because the field is not exported;
* `User.Type`: not populated because the `db` tag is `-`;
* `User.MyName`: to be populated from the `name` column, according to the `db` tag;
* `Profile.Age`: to be populated from the `age` column, since `Profile` is an anonymous field;
* `Address.City`: to be populated from the `addr.city` column, since `Address` is a named field of struct type
and its fields will be prefixed with `addr.` according to the `db` tag.
Note that if a column in the result does not have a corresponding struct field, it will be ignored. Similarly,
if a struct field does not have a corresponding column in the result, it will not be populated.
## Binding Parameters
A SQL statement is usually parameterized with dynamic values. For example, you may want to select the user record
according to the user ID received from the client. Parameter binding should be used in this case, and it is almost
always preferred to prevent from SQL injection attacks. Unlike `database/sql` which does anonymous parameter binding,
`dbx` uses named parameter binding. *Anonymous parameter binding is not supported*, as it will mess up with named
parameters. For example,
```go
q := db.NewQuery("SELECT id, name FROM users WHERE id={:id}")
q.Bind(dbx.Params{"id": 100})
err := q.One(&user)
```
The above example will select the user record whose `id` is 100. The method `Query.Bind()` binds a set
of named parameters to a SQL statement which contains parameter placeholders in the format of `{:ParamName}`.
If a SQL statement needs to be executed multiple times with different parameter values, it may be prepared
to improve the performance. For example,
```go
q := db.NewQuery("SELECT id, name FROM users WHERE id={:id}")
q.Prepare()
defer q.Close()
q.Bind(dbx.Params{"id": 100})
err := q.One(&user)
q.Bind(dbx.Params{"id": 200})
err = q.One(&user)
// ...
```
## Cancelable Queries
Queries are cancelable when they are used with `context.Context`. In particular, by calling `Query.WithContext()` you
can associate a context with a query and use the context to cancel the query while it is running. For example,
```go
q := db.NewQuery("SELECT id, name FROM users")
err := q.WithContext(ctx).All(&users)
```
## Building Queries
Instead of writing plain SQLs, `dbx` allows you to build SQLs programmatically, which often leads to cleaner,
more secure, and DB-agnostic code. You can build three types of queries: the SELECT queries, the data manipulation
queries, and the schema manipulation queries.
### Building SELECT Queries
Building a SELECT query starts by calling `DB.Select()`. You can build different clauses of a SELECT query using
the corresponding query building methods. For example,
```go
db, _ := dbx.Open("mysql", "user:pass@/example")
err := db.Select("id", "name").
From("users").
Where(dbx.HashExp{"id": 100}).
One(&user)
```
The above code will generate and execute the following SQL statement:
```sql
SELECT `id`, `name` FROM `users` WHERE `id`={:p0}
```
Notice how the table and column names are properly quoted according to the currently using database type.
And parameter binding is used to populate the value of `p0` in the `WHERE` clause.
Every SQL keyword has a corresponding query building method. For example, `SELECT` corresponds to `Select()`,
`FROM` corresponds to `From()`, `WHERE` corresponds to `Where()`, and so on. You can chain these method calls
together, just like you would do when writing a plain SQL. Each of these methods returns the query instance
(of type `dbx.SelectQuery`) that is being built. Once you finish building a query, you may call methods such as
`One()`, `All()` to execute the query and populate data into variables. You may also explicitly call `Build()`
to build the query and turn it into a `dbx.Query` instance which may allow you to get the SQL statement and do
other interesting work.
### Building Query Conditions
`dbx` supports very flexible and powerful query condition building which can be used to build SQL clauses
such as `WHERE`, `HAVING`, etc. For example,
```go
// id=100
dbx.NewExp("id={:id}", dbx.Params{"id": 100})
// id=100 AND status=1
dbx.HashExp{"id": 100, "status": 1}
// status=1 OR age>30
dbx.Or(dbx.HashExp{"status": 1}, dbx.NewExp("age>30"))
// name LIKE '%admin%' AND name LIKE '%example%'
dbx.Like("name", "admin", "example")
```
When building a query condition expression, its parameter values will be populated using parameter binding, which
prevents SQL injection from happening. Also if an expression involves column names, they will be properly quoted.
The following condition building functions are available:
* `dbx.NewExp()`: creating a condition using the given expression string and binding parameters. For example,
`dbx.NewExp("id={:id}", dbx.Params{"id":100})` would create the expression `id=100`.
* `dbx.HashExp`: a map type that represents name-value pairs concatenated by `AND` operators. For example,
`dbx.HashExp{"id":100, "status":1}` would create `id=100 AND status=1`.
* `dbx.Not()`: creating a `NOT` expression by prepending `NOT` to the given expression.
* `dbx.And()`: creating an `AND` expression by concatenating the given expressions with the `AND` operators.
* `dbx.Or()`: creating an `OR` expression by concatenating the given expressions with the `OR` operators.
* `dbx.In()`: creating an `IN` expression for the specified column and the range of values.
For example, `dbx.In("age", 30, 40, 50)` would create the expression `age IN (30, 40, 50)`.
Note that if the value range is empty, it will generate an expression representing a false value.
* `dbx.NotIn()`: creating an `NOT IN` expression. This is very similar to `dbx.In()`.
* `dbx.Like()`: creating a `LIKE` expression for the specified column and the range of values. For example,
`dbx.Like("title", "golang", "framework")` would create the expression `title LIKE "%golang%" AND title LIKE "%framework%"`.
You can further customize a LIKE expression by calling `Escape()` and/or `Match()` functions of the resulting expression.
Note that if the value range is empty, it will generate an empty expression.
* `dbx.NotLike()`: creating a `NOT LIKE` expression. This is very similar to `dbx.Like()`.
* `dbx.OrLike()`: creating a `LIKE` expression but concatenating different `LIKE` sub-expressions using `OR` instead of `AND`.
* `dbx.OrNotLike()`: creating a `NOT LIKE` expression and concatenating different `NOT LIKE` sub-expressions using `OR` instead of `AND`.
* `dbx.Exists()`: creating an `EXISTS` expression by prepending `EXISTS` to the given expression.
* `dbx.NotExists()`: creating a `NOT EXISTS` expression by prepending `NOT EXISTS` to the given expression.
* `dbx.Between()`: creating a `BETWEEN` expression. For example, `dbx.Between("age", 30, 40)` would create the
expression `age BETWEEN 30 AND 40`.
* `dbx.NotBetween()`: creating a `NOT BETWEEN` expression. For example
You may also create other convenient functions to help building query conditions, as long as the functions return
an object implementing the `dbx.Expression` interface.
### Building Data Manipulation Queries
Data manipulation queries are those changing the data in the database, such as INSERT, UPDATE, DELETE statements.
Such queries can be built by calling the corresponding methods of `DB`. For example,
```go
db, _ := dbx.Open("mysql", "user:pass@/example")
// INSERT INTO `users` (`name`, `email`) VALUES ({:p0}, {:p1})
err := db.Insert("users", dbx.Params{
"name": "James",
"email": "james@example.com",
}).Execute()
// UPDATE `users` SET `status`={:p0} WHERE `id`={:p1}
err = db.Update("users", dbx.Params{"status": 1}, dbx.HashExp{"id": 100}).Execute()
// DELETE FROM `users` WHERE `status`={:p0}
err = db.Delete("users", dbx.HashExp{"status": 2}).Execute()
```
When building data manipulation queries, remember to call `Execute()` at the end to execute the queries.
### Building Schema Manipulation Queries
Schema manipulation queries are those changing the database schema, such as creating a new table, adding a new column.
These queries can be built by calling the corresponding methods of `DB`. For example,
```go
db, _ := dbx.Open("mysql", "user:pass@/example")
// CREATE TABLE `users` (`id` int primary key, `name` varchar(255))
q := db.CreateTable("users", map[string]string{
"id": "int primary key",
"name": "varchar(255)",
})
err := q.Execute()
```
## CRUD Operations
Although `dbx` is not an ORM, it does provide a very convenient way to do typical CRUD (Create, Read, Update, Delete)
operations without the need of writing plain SQL statements.
To use the CRUD feature, first define a struct type for a table. By default, a struct is associated with a table
whose name is the snake case version of the struct type name. For example, a struct named `MyCustomer`
corresponds to the table name `my_customer`. You may explicitly specify the table name for a struct by implementing
the `dbx.TableModel` interface. For example,
```go
type MyCustomer struct{}
func (c MyCustomer) TableName() string {
return "customer"
}
```
Note that the `TableName` method should be defined with a value receiver instead of a pointer receiver.
If the struct has a field named `ID` or `Id`, by default the field will be treated as the primary key field.
If you want to use a different field as the primary key, tag it with `db:"pk"`. You may tag multiple fields
for composite primary keys. Note that if you also want to explicitly specify the column name for a primary key field,
you should use the tag format `db:"pk,col_name"`.
You can give a common prefix or suffix to your table names by defining your own table name mapping via
`DB.TableMapFunc`. For example, the following code prefixes `tbl_` to all table names.
```go
db.TableMapper = func(a interface{}) string {
return "tbl_" + GetTableName(a)
}
```
### Create
To create (insert) a new row using a model, call the `ModelQuery.Insert()` method. For example,
```go
type Customer struct {
ID int
Name string
Email string
Status int
}
db, _ := dbx.Open("mysql", "user:pass@/example")
customer := Customer{
Name: "example",
Email: "test@example.com",
}
// INSERT INTO customer (name, email, status) VALUES ('example', 'test@example.com', 0)
err := db.Model(&customer).Insert()
```
This will insert a row using the values from *all* public fields (except the primary key field if it is empty) in the struct.
If a primary key field is zero (a integer zero or a nil pointer), it is assumed to be auto-incremental and
will be automatically filled with the last insertion ID after a successful insertion.
You can explicitly specify the fields that should be inserted by passing the list of the field names to the `Insert()` method.
You can also exclude certain fields from being inserted by calling `Exclude()` before calling `Insert()`. For example,
```go
db, _ := dbx.Open("mysql", "user:pass@/example")
// insert only Name and Email fields
err := db.Model(&customer).Insert("Name", "Email")
// insert all public fields except Status
err = db.Model(&customer).Exclude("Status").Insert()
// insert only Name
err = db.Model(&customer).Exclude("Status").Insert("Name", "Status")
```
### Read
To read a model by a given primary key value, call `SelectQuery.Model()`.
```go
db, _ := dbx.Open("mysql", "user:pass@/example")
var customer Customer
// SELECT * FROM customer WHERE id=100
err := db.Select().Model(100, &customer)
// SELECT name, email FROM customer WHERE status=1 AND id=100
err = db.Select("name", "email").Where(dbx.HashExp{"status": 1}).Model(100, &customer)
```
Note that `SelectQuery.Model()` does not support composite primary keys. You should use `SelectQuery.One()` in this case.
For example,
```go
db, _ := dbx.Open("mysql", "user:pass@/example")
var orderItem OrderItem
// SELECT * FROM order_item WHERE order_id=100 AND item_id=20
err := db.Select().Where(dbx.HashExp{"order_id": 100, "item_id": 20}).One(&orderItem)
```
In the above queries, we do not call `From()` to specify which table to select data from. This is because the select
query automatically sets the table according to the model struct being populated. If the struct implements `TableModel`,
the value returned by its `TableName()` method will be used as the table name. Otherwise, the snake case version
of the struct type name will be the table name.
You may also call `SelectQuery.All()` to read a list of model structs. Similarly, you do not need to call `From()`
if the table name can be inferred from the model structs.
### Update
To update a model, call the `ModelQuery.Update()` method. Like `Insert()`, by default, the `Update()` method will
update *all* public fields except primary key fields of the model. You can explicitly specify which fields can
be updated and which cannot in the same way as described for the `Insert()` method. For example,
```go
db, _ := dbx.Open("mysql", "user:pass@/example")
// update all public fields of customer
err := db.Model(&customer).Update()
// update only Status
err = db.Model(&customer).Update("Status")
// update all public fields except Status
err = db.Model(&customer).Exclude("Status").Update()
```
Note that the `Update()` method assumes that the primary keys are immutable. It uses the primary key value of the model
to look for the row that should be updated. An error will be returned if a model does not have a primary key.
### Delete
To delete a model, call the `ModelQuery.Delete()` method. The method deletes the row using the primary key value
specified by the model. If the model does not have a primary key, an error will be returned. For example,
```go
db, _ := dbx.Open("mysql", "user:pass@/example")
err := db.Model(&customer).Delete()
```
### Null Handling
To represent a nullable database value, you can use a pointer type. If the pointer is nil, it means the corresponding
database value is null.
Another option to represent a database null is to use `sql.NullXyz` types. For example, if a string column is nullable,
you may use `sql.NullString`. The `NullString.Valid` field indicates whether the value is a null or not, and
`NullString.String` returns the string value when it is not null. Because `sql.NulLXyz` types do not handle JSON
marshalling, you may use the [null package](https://github.com/guregu/null), instead.
Below is an example of handling nulls:
```go
type Customer struct {
ID int
Email string
FirstName *string // use pointer to represent null
LastName sql.NullString // use sql.NullString to represent null
}
```
## Quoting Table and Column Names
Databases vary in quoting table and column names. To allow writing DB-agnostic SQLs, `dbx` introduces a special
syntax in quoting table and column names. A word enclosed within `{{` and `}}` is treated as a table name and will
be quoted according to the particular DB driver. Similarly, a word enclosed within `[[` and `]]` is treated as a
column name and will be quoted accordingly as well. For example, when working with a MySQL database, the following
query will be properly quoted:
```go
// SELECT * FROM `users` WHERE `status`=1
q := db.NewQuery("SELECT * FROM {{users}} WHERE [[status]]=1")
```
Note that if a table or column name contains a prefix, it will still be properly quoted. For example, `{{public.users}}`
will be quoted as `"public"."users"` for PostgreSQL.
## Using Transactions
You can use all aforementioned query execution and building methods with transaction. For example,
```go
db, _ := dbx.Open("mysql", "user:pass@/example")
tx, _ := db.Begin()
_, err1 := tx.Insert("users", dbx.Params{
"name": "user1",
}).Execute()
_, err2 := tx.Insert("users", dbx.Params{
"name": "user2",
}).Execute()
if err1 == nil && err2 == nil {
tx.Commit()
} else {
tx.Rollback()
}
```
You may use `DB.Transactional()` to simplify your transactional code without explicitly committing or rolling back
transactions. The method will start a transaction and automatically roll back the transaction if the callback
returns an error. Otherwise it will
automatically commit the transaction.
```go
db, _ := dbx.Open("mysql", "user:pass@/example")
err := db.Transactional(func(tx *dbx.Tx) error {
var err error
_, err = tx.Insert("users", dbx.Params{
"name": "user1",
}).Execute()
if err != nil {
return err
}
_, err = tx.Insert("users", dbx.Params{
"name": "user2",
}).Execute()
return err
})
fmt.Println(err)
```
## Logging Executed SQL Statements
You can log and instrument DB queries by installing loggers with a DB connection. There are three kinds of loggers you
can install:
* `DB.LogFunc`: this is called each time when a SQL statement is queried or executed. The function signature is the
same as that of `fmt.Printf`, which makes it very easy to use.
* `DB.QueryLogFunc`: this is called each time when querying with a SQL statement.
* `DB.ExecLogFunc`: this is called when executing a SQL statement.
The following example shows how you can make use of these loggers.
```go
package main
import (
"context"
"database/sql"
"log"
"time"
"github.com/pocketbase/dbx"
)
func main() {
db, _ := dbx.Open("mysql", "user:pass@/example")
// simple logging
db.LogFunc = log.Printf
// or you can use the following more flexible logging
db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
log.Printf("[%.2fms] Query SQL: %v", float64(t.Milliseconds()), sql)
}
db.ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
log.Printf("[%.2fms] Execute SQL: %v", float64(t.Milliseconds()), sql)
}
// ...
}
```
## Supporting New Databases
While `dbx` provides out-of-box query building support for most major relational databases, its open architecture
allows you to add support for new databases. The effort of adding support for a new database involves:
* Create a struct that implements the `QueryBuilder` interface. You may use `BaseQueryBuilder` directly or extend it
via composition.
* Create a struct that implements the `Builder` interface. You may extend `BaseBuilder` via composition.
* Write an `init()` function to register the new builder in `dbx.BuilderFuncMap`.

402
builder.go Normal file
View file

@ -0,0 +1,402 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"errors"
"fmt"
"sort"
"strings"
)
// Builder supports building SQL statements in a DB-agnostic way.
// Builder mainly provides two sets of query building methods: those building SELECT statements
// and those manipulating DB data or schema (e.g. INSERT statements, CREATE TABLE statements).
type Builder interface {
// NewQuery creates a new Query object with the given SQL statement.
// The SQL statement may contain parameter placeholders which can be bound with actual parameter
// values before the statement is executed.
NewQuery(string) *Query
// Select returns a new SelectQuery object that can be used to build a SELECT statement.
// The parameters to this method should be the list column names to be selected.
// A column name may have an optional alias name. For example, Select("id", "my_name AS name").
Select(...string) *SelectQuery
// ModelQuery returns a new ModelQuery object that can be used to perform model insertion, update, and deletion.
// The parameter to this method should be a pointer to the model struct that needs to be inserted, updated, or deleted.
Model(interface{}) *ModelQuery
// GeneratePlaceholder generates an anonymous parameter placeholder with the given parameter ID.
GeneratePlaceholder(int) string
// Quote quotes a string so that it can be embedded in a SQL statement as a string value.
Quote(string) string
// QuoteSimpleTableName quotes a simple table name.
// A simple table name does not contain any schema prefix.
QuoteSimpleTableName(string) string
// QuoteSimpleColumnName quotes a simple column name.
// A simple column name does not contain any table prefix.
QuoteSimpleColumnName(string) string
// QueryBuilder returns the query builder supporting the current DB.
QueryBuilder() QueryBuilder
// Insert creates a Query that represents an INSERT SQL statement.
// The keys of cols are the column names, while the values of cols are the corresponding column
// values to be inserted.
Insert(table string, cols Params) *Query
// Upsert creates a Query that represents an UPSERT SQL statement.
// Upsert inserts a row into the table if the primary key or unique index is not found.
// Otherwise it will update the row with the new values.
// The keys of cols are the column names, while the values of cols are the corresponding column
// values to be inserted.
Upsert(table string, cols Params, constraints ...string) *Query
// Update creates a Query that represents an UPDATE SQL statement.
// The keys of cols are the column names, while the values of cols are the corresponding new column
// values. If the "where" expression is nil, the UPDATE SQL statement will have no WHERE clause
// (be careful in this case as the SQL statement will update ALL rows in the table).
Update(table string, cols Params, where Expression) *Query
// Delete creates a Query that represents a DELETE SQL statement.
// If the "where" expression is nil, the DELETE SQL statement will have no WHERE clause
// (be careful in this case as the SQL statement will delete ALL rows in the table).
Delete(table string, where Expression) *Query
// CreateTable creates a Query that represents a CREATE TABLE SQL statement.
// The keys of cols are the column names, while the values of cols are the corresponding column types.
// The optional "options" parameters will be appended to the generated SQL statement.
CreateTable(table string, cols map[string]string, options ...string) *Query
// RenameTable creates a Query that can be used to rename a table.
RenameTable(oldName, newName string) *Query
// DropTable creates a Query that can be used to drop a table.
DropTable(table string) *Query
// TruncateTable creates a Query that can be used to truncate a table.
TruncateTable(table string) *Query
// AddColumn creates a Query that can be used to add a column to a table.
AddColumn(table, col, typ string) *Query
// DropColumn creates a Query that can be used to drop a column from a table.
DropColumn(table, col string) *Query
// RenameColumn creates a Query that can be used to rename a column in a table.
RenameColumn(table, oldName, newName string) *Query
// AlterColumn creates a Query that can be used to change the definition of a table column.
AlterColumn(table, col, typ string) *Query
// AddPrimaryKey creates a Query that can be used to specify primary key(s) for a table.
// The "name" parameter specifies the name of the primary key constraint.
AddPrimaryKey(table, name string, cols ...string) *Query
// DropPrimaryKey creates a Query that can be used to remove the named primary key constraint from a table.
DropPrimaryKey(table, name string) *Query
// AddForeignKey creates a Query that can be used to add a foreign key constraint to a table.
// The length of cols and refCols must be the same as they refer to the primary and referential columns.
// The optional "options" parameters will be appended to the SQL statement. They can be used to
// specify options such as "ON DELETE CASCADE".
AddForeignKey(table, name string, cols, refCols []string, refTable string, options ...string) *Query
// DropForeignKey creates a Query that can be used to remove the named foreign key constraint from a table.
DropForeignKey(table, name string) *Query
// CreateIndex creates a Query that can be used to create an index for a table.
CreateIndex(table, name string, cols ...string) *Query
// CreateUniqueIndex creates a Query that can be used to create a unique index for a table.
CreateUniqueIndex(table, name string, cols ...string) *Query
// DropIndex creates a Query that can be used to remove the named index from a table.
DropIndex(table, name string) *Query
}
// BaseBuilder provides a basic implementation of the Builder interface.
type BaseBuilder struct {
db *DB
executor Executor
}
// NewBaseBuilder creates a new BaseBuilder instance.
func NewBaseBuilder(db *DB, executor Executor) *BaseBuilder {
return &BaseBuilder{db, executor}
}
// DB returns the DB instance that this builder is associated with.
func (b *BaseBuilder) DB() *DB {
return b.db
}
// Executor returns the executor object (a DB instance or a transaction) for executing SQL statements.
func (b *BaseBuilder) Executor() Executor {
return b.executor
}
// NewQuery creates a new Query object with the given SQL statement.
// The SQL statement may contain parameter placeholders which can be bound with actual parameter
// values before the statement is executed.
func (b *BaseBuilder) NewQuery(sql string) *Query {
return NewQuery(b.db, b.executor, sql)
}
// GeneratePlaceholder generates an anonymous parameter placeholder with the given parameter ID.
func (b *BaseBuilder) GeneratePlaceholder(int) string {
return "?"
}
// Quote quotes a string so that it can be embedded in a SQL statement as a string value.
func (b *BaseBuilder) Quote(s string) string {
return "'" + strings.Replace(s, "'", "''", -1) + "'"
}
// QuoteSimpleTableName quotes a simple table name.
// A simple table name does not contain any schema prefix.
func (b *BaseBuilder) QuoteSimpleTableName(s string) string {
if strings.Contains(s, `"`) {
return s
}
return `"` + s + `"`
}
// QuoteSimpleColumnName quotes a simple column name.
// A simple column name does not contain any table prefix.
func (b *BaseBuilder) QuoteSimpleColumnName(s string) string {
if strings.Contains(s, `"`) || s == "*" {
return s
}
return `"` + s + `"`
}
// Insert creates a Query that represents an INSERT SQL statement.
// The keys of cols are the column names, while the values of cols are the corresponding column
// values to be inserted.
func (b *BaseBuilder) Insert(table string, cols Params) *Query {
names := make([]string, 0, len(cols))
for name := range cols {
names = append(names, name)
}
sort.Strings(names)
params := Params{}
columns := make([]string, 0, len(names))
values := make([]string, 0, len(names))
for _, name := range names {
columns = append(columns, b.db.QuoteColumnName(name))
value := cols[name]
if e, ok := value.(Expression); ok {
values = append(values, e.Build(b.db, params))
} else {
values = append(values, fmt.Sprintf("{:p%v}", len(params)))
params[fmt.Sprintf("p%v", len(params))] = value
}
}
var sql string
if len(names) == 0 {
sql = fmt.Sprintf("INSERT INTO %v DEFAULT VALUES", b.db.QuoteTableName(table))
} else {
sql = fmt.Sprintf("INSERT INTO %v (%v) VALUES (%v)",
b.db.QuoteTableName(table),
strings.Join(columns, ", "),
strings.Join(values, ", "),
)
}
return b.NewQuery(sql).Bind(params)
}
// Upsert creates a Query that represents an UPSERT SQL statement.
// Upsert inserts a row into the table if the primary key or unique index is not found.
// Otherwise it will update the row with the new values.
// The keys of cols are the column names, while the values of cols are the corresponding column
// values to be inserted.
func (b *BaseBuilder) Upsert(table string, cols Params, constraints ...string) *Query {
q := b.NewQuery("")
q.LastError = errors.New("Upsert is not supported")
return q
}
// Update creates a Query that represents an UPDATE SQL statement.
// The keys of cols are the column names, while the values of cols are the corresponding new column
// values. If the "where" expression is nil, the UPDATE SQL statement will have no WHERE clause
// (be careful in this case as the SQL statement will update ALL rows in the table).
func (b *BaseBuilder) Update(table string, cols Params, where Expression) *Query {
names := make([]string, 0, len(cols))
for name := range cols {
names = append(names, name)
}
sort.Strings(names)
params := Params{}
lines := make([]string, 0, len(names))
for _, name := range names {
value := cols[name]
name = b.db.QuoteColumnName(name)
if e, ok := value.(Expression); ok {
lines = append(lines, name+"="+e.Build(b.db, params))
} else {
lines = append(lines, fmt.Sprintf("%v={:p%v}", name, len(params)))
params[fmt.Sprintf("p%v", len(params))] = value
}
}
sql := fmt.Sprintf("UPDATE %v SET %v", b.db.QuoteTableName(table), strings.Join(lines, ", "))
if where != nil {
w := where.Build(b.db, params)
if w != "" {
sql += " WHERE " + w
}
}
return b.NewQuery(sql).Bind(params)
}
// Delete creates a Query that represents a DELETE SQL statement.
// If the "where" expression is nil, the DELETE SQL statement will have no WHERE clause
// (be careful in this case as the SQL statement will delete ALL rows in the table).
func (b *BaseBuilder) Delete(table string, where Expression) *Query {
sql := "DELETE FROM " + b.db.QuoteTableName(table)
params := Params{}
if where != nil {
w := where.Build(b.db, params)
if w != "" {
sql += " WHERE " + w
}
}
return b.NewQuery(sql).Bind(params)
}
// CreateTable creates a Query that represents a CREATE TABLE SQL statement.
// The keys of cols are the column names, while the values of cols are the corresponding column types.
// The optional "options" parameters will be appended to the generated SQL statement.
func (b *BaseBuilder) CreateTable(table string, cols map[string]string, options ...string) *Query {
names := []string{}
for name := range cols {
names = append(names, name)
}
sort.Strings(names)
columns := []string{}
for _, name := range names {
columns = append(columns, b.db.QuoteColumnName(name)+" "+cols[name])
}
sql := fmt.Sprintf("CREATE TABLE %v (%v)", b.db.QuoteTableName(table), strings.Join(columns, ", "))
for _, opt := range options {
sql += " " + opt
}
return b.NewQuery(sql)
}
// RenameTable creates a Query that can be used to rename a table.
func (b *BaseBuilder) RenameTable(oldName, newName string) *Query {
sql := fmt.Sprintf("RENAME TABLE %v TO %v", b.db.QuoteTableName(oldName), b.db.QuoteTableName(newName))
return b.NewQuery(sql)
}
// DropTable creates a Query that can be used to drop a table.
func (b *BaseBuilder) DropTable(table string) *Query {
sql := "DROP TABLE " + b.db.QuoteTableName(table)
return b.NewQuery(sql)
}
// TruncateTable creates a Query that can be used to truncate a table.
func (b *BaseBuilder) TruncateTable(table string) *Query {
sql := "TRUNCATE TABLE " + b.db.QuoteTableName(table)
return b.NewQuery(sql)
}
// AddColumn creates a Query that can be used to add a column to a table.
func (b *BaseBuilder) AddColumn(table, col, typ string) *Query {
sql := fmt.Sprintf("ALTER TABLE %v ADD %v %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(col), typ)
return b.NewQuery(sql)
}
// DropColumn creates a Query that can be used to drop a column from a table.
func (b *BaseBuilder) DropColumn(table, col string) *Query {
sql := fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(col))
return b.NewQuery(sql)
}
// RenameColumn creates a Query that can be used to rename a column in a table.
func (b *BaseBuilder) RenameColumn(table, oldName, newName string) *Query {
sql := fmt.Sprintf("ALTER TABLE %v RENAME COLUMN %v TO %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(oldName), b.db.QuoteColumnName(newName))
return b.NewQuery(sql)
}
// AlterColumn creates a Query that can be used to change the definition of a table column.
func (b *BaseBuilder) AlterColumn(table, col, typ string) *Query {
col = b.db.QuoteColumnName(col)
sql := fmt.Sprintf("ALTER TABLE %v CHANGE %v %v %v", b.db.QuoteTableName(table), col, col, typ)
return b.NewQuery(sql)
}
// AddPrimaryKey creates a Query that can be used to specify primary key(s) for a table.
// The "name" parameter specifies the name of the primary key constraint.
func (b *BaseBuilder) AddPrimaryKey(table, name string, cols ...string) *Query {
sql := fmt.Sprintf("ALTER TABLE %v ADD CONSTRAINT %v PRIMARY KEY (%v)",
b.db.QuoteTableName(table),
b.db.QuoteColumnName(name),
b.quoteColumns(cols))
return b.NewQuery(sql)
}
// DropPrimaryKey creates a Query that can be used to remove the named primary key constraint from a table.
func (b *BaseBuilder) DropPrimaryKey(table, name string) *Query {
sql := fmt.Sprintf("ALTER TABLE %v DROP CONSTRAINT %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(name))
return b.NewQuery(sql)
}
// AddForeignKey creates a Query that can be used to add a foreign key constraint to a table.
// The length of cols and refCols must be the same as they refer to the primary and referential columns.
// The optional "options" parameters will be appended to the SQL statement. They can be used to
// specify options such as "ON DELETE CASCADE".
func (b *BaseBuilder) AddForeignKey(table, name string, cols, refCols []string, refTable string, options ...string) *Query {
sql := fmt.Sprintf("ALTER TABLE %v ADD CONSTRAINT %v FOREIGN KEY (%v) REFERENCES %v (%v)",
b.db.QuoteTableName(table),
b.db.QuoteColumnName(name),
b.quoteColumns(cols),
b.db.QuoteTableName(refTable),
b.quoteColumns(refCols))
for _, opt := range options {
sql += " " + opt
}
return b.NewQuery(sql)
}
// DropForeignKey creates a Query that can be used to remove the named foreign key constraint from a table.
func (b *BaseBuilder) DropForeignKey(table, name string) *Query {
sql := fmt.Sprintf("ALTER TABLE %v DROP CONSTRAINT %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(name))
return b.NewQuery(sql)
}
// CreateIndex creates a Query that can be used to create an index for a table.
func (b *BaseBuilder) CreateIndex(table, name string, cols ...string) *Query {
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v)",
b.db.QuoteColumnName(name),
b.db.QuoteTableName(table),
b.quoteColumns(cols))
return b.NewQuery(sql)
}
// CreateUniqueIndex creates a Query that can be used to create a unique index for a table.
func (b *BaseBuilder) CreateUniqueIndex(table, name string, cols ...string) *Query {
sql := fmt.Sprintf("CREATE UNIQUE INDEX %v ON %v (%v)",
b.db.QuoteColumnName(name),
b.db.QuoteTableName(table),
b.quoteColumns(cols))
return b.NewQuery(sql)
}
// DropIndex creates a Query that can be used to remove the named index from a table.
func (b *BaseBuilder) DropIndex(table, name string) *Query {
sql := fmt.Sprintf("DROP INDEX %v ON %v", b.db.QuoteColumnName(name), b.db.QuoteTableName(table))
return b.NewQuery(sql)
}
// quoteColumns quotes a list of columns and concatenates them with commas.
func (b *BaseBuilder) quoteColumns(cols []string) string {
s := ""
for i, col := range cols {
if i == 0 {
s = b.db.QuoteColumnName(col)
} else {
s += ", " + b.db.QuoteColumnName(col)
}
}
return s
}

115
builder_mssql.go Normal file
View file

@ -0,0 +1,115 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"fmt"
"strings"
)
// MssqlBuilder is the builder for SQL Server databases.
type MssqlBuilder struct {
*BaseBuilder
qb *MssqlQueryBuilder
}
var _ Builder = &MssqlBuilder{}
// MssqlQueryBuilder is the query builder for SQL Server databases.
type MssqlQueryBuilder struct {
*BaseQueryBuilder
}
// NewMssqlBuilder creates a new MssqlBuilder instance.
func NewMssqlBuilder(db *DB, executor Executor) Builder {
return &MssqlBuilder{
NewBaseBuilder(db, executor),
&MssqlQueryBuilder{NewBaseQueryBuilder(db)},
}
}
// QueryBuilder returns the query builder supporting the current DB.
func (b *MssqlBuilder) QueryBuilder() QueryBuilder {
return b.qb
}
// Select returns a new SelectQuery object that can be used to build a SELECT statement.
// The parameters to this method should be the list column names to be selected.
// A column name may have an optional alias name. For example, Select("id", "my_name AS name").
func (b *MssqlBuilder) Select(cols ...string) *SelectQuery {
return NewSelectQuery(b, b.db).Select(cols...)
}
// Model returns a new ModelQuery object that can be used to perform model-based DB operations.
// The model passed to this method should be a pointer to a model struct.
func (b *MssqlBuilder) Model(model interface{}) *ModelQuery {
return NewModelQuery(model, b.db.FieldMapper, b.db, b)
}
// QuoteSimpleTableName quotes a simple table name.
// A simple table name does not contain any schema prefix.
func (b *MssqlBuilder) QuoteSimpleTableName(s string) string {
if strings.Contains(s, `[`) {
return s
}
return `[` + s + `]`
}
// QuoteSimpleColumnName quotes a simple column name.
// A simple column name does not contain any table prefix.
func (b *MssqlBuilder) QuoteSimpleColumnName(s string) string {
if strings.Contains(s, `[`) || s == "*" {
return s
}
return `[` + s + `]`
}
// RenameTable creates a Query that can be used to rename a table.
func (b *MssqlBuilder) RenameTable(oldName, newName string) *Query {
sql := fmt.Sprintf("sp_name '%v', '%v'", oldName, newName)
return b.NewQuery(sql)
}
// RenameColumn creates a Query that can be used to rename a column in a table.
func (b *MssqlBuilder) RenameColumn(table, oldName, newName string) *Query {
sql := fmt.Sprintf("sp_name '%v.%v', '%v', 'COLUMN'", table, oldName, newName)
return b.NewQuery(sql)
}
// AlterColumn creates a Query that can be used to change the definition of a table column.
func (b *MssqlBuilder) AlterColumn(table, col, typ string) *Query {
col = b.db.QuoteColumnName(col)
sql := fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", b.db.QuoteTableName(table), col, typ)
return b.NewQuery(sql)
}
// BuildOrderByAndLimit generates the ORDER BY and LIMIT clauses.
func (q *MssqlQueryBuilder) BuildOrderByAndLimit(sql string, cols []string, limit int64, offset int64) string {
orderBy := q.BuildOrderBy(cols)
if limit < 0 && offset < 0 {
if orderBy == "" {
return sql
}
return sql + "\n" + orderBy
}
// only SQL SERVER 2012 or newer are supported by this method
if orderBy == "" {
// ORDER BY clause is required when FETCH and OFFSET are in the SQL
orderBy = "ORDER BY (SELECT NULL)"
}
sql += "\n" + orderBy
// http://technet.microsoft.com/en-us/library/gg699618.aspx
if offset < 0 {
offset = 0
}
sql += "\n" + fmt.Sprintf("OFFSET %v ROWS", offset)
if limit >= 0 {
sql += "\n" + fmt.Sprintf("FETCH NEXT %v ROWS ONLY", limit)
}
return sql
}

73
builder_mssql_test.go Normal file
View file

@ -0,0 +1,73 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestMssqlBuilder_QuoteSimpleTableName(t *testing.T) {
b := getMssqlBuilder()
assert.Equal(t, b.QuoteSimpleTableName(`abc`), "[abc]", "t1")
assert.Equal(t, b.QuoteSimpleTableName("[abc]"), "[abc]", "t2")
assert.Equal(t, b.QuoteSimpleTableName(`{{abc}}`), "[{{abc}}]", "t3")
assert.Equal(t, b.QuoteSimpleTableName(`a.bc`), "[a.bc]", "t4")
}
func TestMssqlBuilder_QuoteSimpleColumnName(t *testing.T) {
b := getMssqlBuilder()
assert.Equal(t, b.QuoteSimpleColumnName(`abc`), "[abc]", "t1")
assert.Equal(t, b.QuoteSimpleColumnName("[abc]"), "[abc]", "t2")
assert.Equal(t, b.QuoteSimpleColumnName(`{{abc}}`), "[{{abc}}]", "t3")
assert.Equal(t, b.QuoteSimpleColumnName(`a.bc`), "[a.bc]", "t4")
assert.Equal(t, b.QuoteSimpleColumnName(`*`), `*`, "t5")
}
func TestMssqlBuilder_RenameTable(t *testing.T) {
b := getMssqlBuilder()
q := b.RenameTable("users", "user")
assert.Equal(t, q.SQL(), `sp_name 'users', 'user'`, "t1")
}
func TestMssqlBuilder_RenameColumn(t *testing.T) {
b := getMssqlBuilder()
q := b.RenameColumn("users", "name", "username")
assert.Equal(t, q.SQL(), `sp_name 'users.name', 'username', 'COLUMN'`, "t1")
}
func TestMssqlBuilder_AlterColumn(t *testing.T) {
b := getMssqlBuilder()
q := b.AlterColumn("users", "name", "int")
assert.Equal(t, q.SQL(), `ALTER TABLE [users] ALTER COLUMN [name] int`, "t1")
}
func TestMssqlQueryBuilder_BuildOrderByAndLimit(t *testing.T) {
qb := getMssqlBuilder().QueryBuilder()
sql := qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, 10, 2)
expected := "SELECT *\nORDER BY [name]\nOFFSET 2 ROWS\nFETCH NEXT 10 ROWS ONLY"
assert.Equal(t, sql, expected, "t1")
sql = qb.BuildOrderByAndLimit("SELECT *", nil, -1, -1)
expected = "SELECT *"
assert.Equal(t, sql, expected, "t2")
sql = qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, -1, -1)
expected = "SELECT *\nORDER BY [name]"
assert.Equal(t, sql, expected, "t3")
sql = qb.BuildOrderByAndLimit("SELECT *", nil, 10, -1)
expected = "SELECT *\nORDER BY (SELECT NULL)\nOFFSET 0 ROWS\nFETCH NEXT 10 ROWS ONLY"
assert.Equal(t, sql, expected, "t4")
}
func getMssqlBuilder() Builder {
db := getDB()
b := NewMssqlBuilder(db, db.sqlDB)
db.Builder = b
return b
}

133
builder_mysql.go Normal file
View file

@ -0,0 +1,133 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"fmt"
"regexp"
"sort"
"strings"
)
// MysqlBuilder is the builder for MySQL databases.
type MysqlBuilder struct {
*BaseBuilder
qb *BaseQueryBuilder
}
var _ Builder = &MysqlBuilder{}
// NewMysqlBuilder creates a new MysqlBuilder instance.
func NewMysqlBuilder(db *DB, executor Executor) Builder {
return &MysqlBuilder{
NewBaseBuilder(db, executor),
NewBaseQueryBuilder(db),
}
}
// QueryBuilder returns the query builder supporting the current DB.
func (b *MysqlBuilder) QueryBuilder() QueryBuilder {
return b.qb
}
// Select returns a new SelectQuery object that can be used to build a SELECT statement.
// The parameters to this method should be the list column names to be selected.
// A column name may have an optional alias name. For example, Select("id", "my_name AS name").
func (b *MysqlBuilder) Select(cols ...string) *SelectQuery {
return NewSelectQuery(b, b.db).Select(cols...)
}
// Model returns a new ModelQuery object that can be used to perform model-based DB operations.
// The model passed to this method should be a pointer to a model struct.
func (b *MysqlBuilder) Model(model interface{}) *ModelQuery {
return NewModelQuery(model, b.db.FieldMapper, b.db, b)
}
// QuoteSimpleTableName quotes a simple table name.
// A simple table name does not contain any schema prefix.
func (b *MysqlBuilder) QuoteSimpleTableName(s string) string {
if strings.ContainsAny(s, "`") {
return s
}
return "`" + s + "`"
}
// QuoteSimpleColumnName quotes a simple column name.
// A simple column name does not contain any table prefix.
func (b *MysqlBuilder) QuoteSimpleColumnName(s string) string {
if strings.Contains(s, "`") || s == "*" {
return s
}
return "`" + s + "`"
}
// Upsert creates a Query that represents an UPSERT SQL statement.
// Upsert inserts a row into the table if the primary key or unique index is not found.
// Otherwise it will update the row with the new values.
// The keys of cols are the column names, while the values of cols are the corresponding column
// values to be inserted.
func (b *MysqlBuilder) Upsert(table string, cols Params, constraints ...string) *Query {
q := b.Insert(table, cols)
names := []string{}
for name := range cols {
names = append(names, name)
}
sort.Strings(names)
lines := []string{}
for _, name := range names {
value := cols[name]
name = b.db.QuoteColumnName(name)
if e, ok := value.(Expression); ok {
lines = append(lines, name+"="+e.Build(b.db, q.params))
} else {
lines = append(lines, fmt.Sprintf("%v={:p%v}", name, len(q.params)))
q.params[fmt.Sprintf("p%v", len(q.params))] = value
}
}
q.sql += " ON DUPLICATE KEY UPDATE " + strings.Join(lines, ", ")
return q
}
var mysqlColumnRegexp = regexp.MustCompile("(?m)^\\s*[`\"](.*?)[`\"]\\s+(.*?),?$")
// RenameColumn creates a Query that can be used to rename a column in a table.
func (b *MysqlBuilder) RenameColumn(table, oldName, newName string) *Query {
qt := b.db.QuoteTableName(table)
sql := fmt.Sprintf("ALTER TABLE %v CHANGE %v %v", qt, b.db.QuoteColumnName(oldName), b.db.QuoteColumnName(newName))
var info struct {
SQL string `db:"Create Table"`
}
if err := b.db.NewQuery("SHOW CREATE TABLE " + qt).One(&info); err != nil {
return b.db.NewQuery(sql)
}
if matches := mysqlColumnRegexp.FindAllStringSubmatch(info.SQL, -1); matches != nil {
for _, match := range matches {
if match[1] == oldName {
sql += " " + match[2]
break
}
}
}
return b.db.NewQuery(sql)
}
// DropPrimaryKey creates a Query that can be used to remove the named primary key constraint from a table.
func (b *MysqlBuilder) DropPrimaryKey(table, name string) *Query {
sql := fmt.Sprintf("ALTER TABLE %v DROP PRIMARY KEY", b.db.QuoteTableName(table))
return b.db.NewQuery(sql)
}
// DropForeignKey creates a Query that can be used to remove the named foreign key constraint from a table.
func (b *MysqlBuilder) DropForeignKey(table, name string) *Query {
sql := fmt.Sprintf("ALTER TABLE %v DROP FOREIGN KEY %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(name))
return b.db.NewQuery(sql)
}

69
builder_mysql_test.go Normal file
View file

@ -0,0 +1,69 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestMysqlBuilder_QuoteSimpleTableName(t *testing.T) {
b := getMysqlBuilder()
assert.Equal(t, b.QuoteSimpleTableName(`abc`), "`abc`", "t1")
assert.Equal(t, b.QuoteSimpleTableName("`abc`"), "`abc`", "t2")
assert.Equal(t, b.QuoteSimpleTableName(`{{abc}}`), "`{{abc}}`", "t3")
assert.Equal(t, b.QuoteSimpleTableName(`a.bc`), "`a.bc`", "t4")
}
func TestMysqlBuilder_QuoteSimpleColumnName(t *testing.T) {
b := getMysqlBuilder()
assert.Equal(t, b.QuoteSimpleColumnName(`abc`), "`abc`", "t1")
assert.Equal(t, b.QuoteSimpleColumnName("`abc`"), "`abc`", "t2")
assert.Equal(t, b.QuoteSimpleColumnName(`{{abc}}`), "`{{abc}}`", "t3")
assert.Equal(t, b.QuoteSimpleColumnName(`a.bc`), "`a.bc`", "t4")
assert.Equal(t, b.QuoteSimpleColumnName(`*`), `*`, "t5")
}
func TestMysqlBuilder_Upsert(t *testing.T) {
getPreparedDB()
b := getMysqlBuilder()
q := b.Upsert("users", Params{
"name": "James",
"age": 30,
})
assert.Equal(t, q.SQL(), "INSERT INTO `users` (`age`, `name`) VALUES ({:p0}, {:p1}) ON DUPLICATE KEY UPDATE `age`={:p2}, `name`={:p3}", "t1")
assert.Equal(t, q.Params()["p0"], 30, "t2")
assert.Equal(t, q.Params()["p1"], "James", "t3")
assert.Equal(t, q.Params()["p2"], 30, "t2")
assert.Equal(t, q.Params()["p3"], "James", "t3")
}
func TestMysqlBuilder_RenameColumn(t *testing.T) {
b := getMysqlBuilder()
q := b.RenameColumn("users", "name", "username")
assert.Equal(t, q.SQL(), "ALTER TABLE `users` CHANGE `name` `username`")
q = b.RenameColumn("customer", "email", "e")
assert.Equal(t, q.SQL(), "ALTER TABLE `customer` CHANGE `email` `e` varchar(128) NOT NULL")
}
func TestMysqlBuilder_DropPrimaryKey(t *testing.T) {
b := getMysqlBuilder()
q := b.DropPrimaryKey("users", "pk")
assert.Equal(t, q.SQL(), "ALTER TABLE `users` DROP PRIMARY KEY", "t1")
}
func TestMysqlBuilder_DropForeignKey(t *testing.T) {
b := getMysqlBuilder()
q := b.DropForeignKey("users", "fk")
assert.Equal(t, q.SQL(), "ALTER TABLE `users` DROP FOREIGN KEY `fk`", "t1")
}
func getMysqlBuilder() Builder {
db := getDB()
b := NewMysqlBuilder(db, db.sqlDB)
db.Builder = b
return b
}

98
builder_oci.go Normal file
View file

@ -0,0 +1,98 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"fmt"
)
// OciBuilder is the builder for Oracle databases.
type OciBuilder struct {
*BaseBuilder
qb *OciQueryBuilder
}
var _ Builder = &OciBuilder{}
// OciQueryBuilder is the query builder for Oracle databases.
type OciQueryBuilder struct {
*BaseQueryBuilder
}
// NewOciBuilder creates a new OciBuilder instance.
func NewOciBuilder(db *DB, executor Executor) Builder {
return &OciBuilder{
NewBaseBuilder(db, executor),
&OciQueryBuilder{NewBaseQueryBuilder(db)},
}
}
// Select returns a new SelectQuery object that can be used to build a SELECT statement.
// The parameters to this method should be the list column names to be selected.
// A column name may have an optional alias name. For example, Select("id", "my_name AS name").
func (b *OciBuilder) Select(cols ...string) *SelectQuery {
return NewSelectQuery(b, b.db).Select(cols...)
}
// Model returns a new ModelQuery object that can be used to perform model-based DB operations.
// The model passed to this method should be a pointer to a model struct.
func (b *OciBuilder) Model(model interface{}) *ModelQuery {
return NewModelQuery(model, b.db.FieldMapper, b.db, b)
}
// GeneratePlaceholder generates an anonymous parameter placeholder with the given parameter ID.
func (b *OciBuilder) GeneratePlaceholder(i int) string {
return fmt.Sprintf(":p%v", i)
}
// QueryBuilder returns the query builder supporting the current DB.
func (b *OciBuilder) QueryBuilder() QueryBuilder {
return b.qb
}
// DropIndex creates a Query that can be used to remove the named index from a table.
func (b *OciBuilder) DropIndex(table, name string) *Query {
sql := fmt.Sprintf("DROP INDEX %v", b.db.QuoteColumnName(name))
return b.NewQuery(sql)
}
// RenameTable creates a Query that can be used to rename a table.
func (b *OciBuilder) RenameTable(oldName, newName string) *Query {
sql := fmt.Sprintf("ALTER TABLE %v RENAME TO %v", b.db.QuoteTableName(oldName), b.db.QuoteTableName(newName))
return b.NewQuery(sql)
}
// AlterColumn creates a Query that can be used to change the definition of a table column.
func (b *OciBuilder) AlterColumn(table, col, typ string) *Query {
col = b.db.QuoteColumnName(col)
sql := fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", b.db.QuoteTableName(table), col, typ)
return b.NewQuery(sql)
}
// BuildOrderByAndLimit generates the ORDER BY and LIMIT clauses.
func (q *OciQueryBuilder) BuildOrderByAndLimit(sql string, cols []string, limit int64, offset int64) string {
if orderBy := q.BuildOrderBy(cols); orderBy != "" {
sql += "\n" + orderBy
}
c := ""
if offset > 0 {
c = fmt.Sprintf("rowNumId > %v", offset)
}
if limit >= 0 {
if c != "" {
c += " AND "
}
c += fmt.Sprintf("rowNum <= %v", limit)
}
if c == "" {
return sql
}
return `WITH USER_SQL AS (` + sql + `),
PAGINATION AS (SELECT USER_SQL.*, rownum as rowNumId FROM USER_SQL)
SELECT * FROM PAGINATION WHERE ` + c
}

56
builder_oci_test.go Normal file
View file

@ -0,0 +1,56 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestOciBuilder_DropIndex(t *testing.T) {
b := getOciBuilder()
q := b.DropIndex("users", "idx")
assert.Equal(t, q.SQL(), `DROP INDEX "idx"`, "t1")
}
func TestOciBuilder_RenameTable(t *testing.T) {
b := getOciBuilder()
q := b.RenameTable("users", "user")
assert.Equal(t, q.SQL(), `ALTER TABLE "users" RENAME TO "user"`, "t1")
}
func TestOciBuilder_AlterColumn(t *testing.T) {
b := getOciBuilder()
q := b.AlterColumn("users", "name", "int")
assert.Equal(t, q.SQL(), `ALTER TABLE "users" MODIFY "name" int`, "t1")
}
func TestOciQueryBuilder_BuildOrderByAndLimit(t *testing.T) {
qb := getOciBuilder().QueryBuilder()
sql := qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, 10, 2)
expected := "WITH USER_SQL AS (SELECT *\nORDER BY \"name\"),\n\tPAGINATION AS (SELECT USER_SQL.*, rownum as rowNumId FROM USER_SQL)\nSELECT * FROM PAGINATION WHERE rowNumId > 2 AND rowNum <= 10"
assert.Equal(t, sql, expected, "t1")
sql = qb.BuildOrderByAndLimit("SELECT *", nil, -1, -1)
expected = "SELECT *"
assert.Equal(t, sql, expected, "t2")
sql = qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, -1, -1)
expected = "SELECT *\nORDER BY \"name\""
assert.Equal(t, sql, expected, "t3")
sql = qb.BuildOrderByAndLimit("SELECT *", nil, 10, -1)
expected = "WITH USER_SQL AS (SELECT *),\n\tPAGINATION AS (SELECT USER_SQL.*, rownum as rowNumId FROM USER_SQL)\nSELECT * FROM PAGINATION WHERE rowNum <= 10"
assert.Equal(t, sql, expected, "t4")
}
func getOciBuilder() Builder {
db := getDB()
b := NewOciBuilder(db, db.sqlDB)
db.Builder = b
return b
}

105
builder_pgsql.go Normal file
View file

@ -0,0 +1,105 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"fmt"
"sort"
"strings"
)
// PgsqlBuilder is the builder for PostgreSQL databases.
type PgsqlBuilder struct {
*BaseBuilder
qb *BaseQueryBuilder
}
var _ Builder = &PgsqlBuilder{}
// NewPgsqlBuilder creates a new PgsqlBuilder instance.
func NewPgsqlBuilder(db *DB, executor Executor) Builder {
return &PgsqlBuilder{
NewBaseBuilder(db, executor),
NewBaseQueryBuilder(db),
}
}
// Select returns a new SelectQuery object that can be used to build a SELECT statement.
// The parameters to this method should be the list column names to be selected.
// A column name may have an optional alias name. For example, Select("id", "my_name AS name").
func (b *PgsqlBuilder) Select(cols ...string) *SelectQuery {
return NewSelectQuery(b, b.db).Select(cols...)
}
// Model returns a new ModelQuery object that can be used to perform model-based DB operations.
// The model passed to this method should be a pointer to a model struct.
func (b *PgsqlBuilder) Model(model interface{}) *ModelQuery {
return NewModelQuery(model, b.db.FieldMapper, b.db, b)
}
// GeneratePlaceholder generates an anonymous parameter placeholder with the given parameter ID.
func (b *PgsqlBuilder) GeneratePlaceholder(i int) string {
return fmt.Sprintf("$%v", i)
}
// QueryBuilder returns the query builder supporting the current DB.
func (b *PgsqlBuilder) QueryBuilder() QueryBuilder {
return b.qb
}
// Upsert creates a Query that represents an UPSERT SQL statement.
// Upsert inserts a row into the table if the primary key or unique index is not found.
// Otherwise it will update the row with the new values.
// The keys of cols are the column names, while the values of cols are the corresponding column
// values to be inserted.
func (b *PgsqlBuilder) Upsert(table string, cols Params, constraints ...string) *Query {
q := b.Insert(table, cols)
names := []string{}
for name := range cols {
names = append(names, name)
}
sort.Strings(names)
lines := []string{}
for _, name := range names {
value := cols[name]
name = b.db.QuoteColumnName(name)
if e, ok := value.(Expression); ok {
lines = append(lines, name+"="+e.Build(b.db, q.params))
} else {
lines = append(lines, fmt.Sprintf("%v={:p%v}", name, len(q.params)))
q.params[fmt.Sprintf("p%v", len(q.params))] = value
}
}
if len(constraints) > 0 {
c := b.quoteColumns(constraints)
q.sql += " ON CONFLICT (" + c + ") DO UPDATE SET " + strings.Join(lines, ", ")
} else {
q.sql += " ON CONFLICT DO UPDATE SET " + strings.Join(lines, ", ")
}
return b.NewQuery(q.sql).Bind(q.params)
}
// DropIndex creates a Query that can be used to remove the named index from a table.
func (b *PgsqlBuilder) DropIndex(table, name string) *Query {
sql := fmt.Sprintf("DROP INDEX %v", b.db.QuoteColumnName(name))
return b.NewQuery(sql)
}
// RenameTable creates a Query that can be used to rename a table.
func (b *PgsqlBuilder) RenameTable(oldName, newName string) *Query {
sql := fmt.Sprintf("ALTER TABLE %v RENAME TO %v", b.db.QuoteTableName(oldName), b.db.QuoteTableName(newName))
return b.NewQuery(sql)
}
// AlterColumn creates a Query that can be used to change the definition of a table column.
func (b *PgsqlBuilder) AlterColumn(table, col, typ string) *Query {
col = b.db.QuoteColumnName(col)
sql := fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", b.db.QuoteTableName(table), col, typ)
return b.NewQuery(sql)
}

49
builder_pgsql_test.go Normal file
View file

@ -0,0 +1,49 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestPgsqlBuilder_Upsert(t *testing.T) {
b := getPgsqlBuilder()
q := b.Upsert("users", Params{
"name": "James",
"age": 30,
}, "id")
assert.Equal(t, q.sql, `INSERT INTO "users" ("age", "name") VALUES ({:p0}, {:p1}) ON CONFLICT ("id") DO UPDATE SET "age"={:p2}, "name"={:p3}`, "t1")
assert.Equal(t, q.rawSQL, `INSERT INTO "users" ("age", "name") VALUES ($1, $2) ON CONFLICT ("id") DO UPDATE SET "age"=$3, "name"=$4`, "t2")
assert.Equal(t, q.Params()["p0"], 30, "t3")
assert.Equal(t, q.Params()["p1"], "James", "t4")
assert.Equal(t, q.Params()["p2"], 30, "t5")
assert.Equal(t, q.Params()["p3"], "James", "t6")
}
func TestPgsqlBuilder_DropIndex(t *testing.T) {
b := getPgsqlBuilder()
q := b.DropIndex("users", "idx")
assert.Equal(t, q.SQL(), `DROP INDEX "idx"`, "t1")
}
func TestPgsqlBuilder_RenameTable(t *testing.T) {
b := getPgsqlBuilder()
q := b.RenameTable("users", "user")
assert.Equal(t, q.SQL(), `ALTER TABLE "users" RENAME TO "user"`, "t1")
}
func TestPgsqlBuilder_AlterColumn(t *testing.T) {
b := getPgsqlBuilder()
q := b.AlterColumn("users", "name", "int")
assert.Equal(t, q.SQL(), `ALTER TABLE "users" ALTER COLUMN "name" TYPE int`, "t1")
}
func getPgsqlBuilder() Builder {
db := getDB()
b := NewPgsqlBuilder(db, db.sqlDB)
db.Builder = b
return b
}

120
builder_sqlite.go Normal file
View file

@ -0,0 +1,120 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"errors"
"fmt"
"strings"
)
// SqliteBuilder is the builder for SQLite databases.
type SqliteBuilder struct {
*BaseBuilder
qb *BaseQueryBuilder
}
var _ Builder = &SqliteBuilder{}
// NewSqliteBuilder creates a new SqliteBuilder instance.
func NewSqliteBuilder(db *DB, executor Executor) Builder {
return &SqliteBuilder{
NewBaseBuilder(db, executor),
NewBaseQueryBuilder(db),
}
}
// QueryBuilder returns the query builder supporting the current DB.
func (b *SqliteBuilder) QueryBuilder() QueryBuilder {
return b.qb
}
// Select returns a new SelectQuery object that can be used to build a SELECT statement.
// The parameters to this method should be the list column names to be selected.
// A column name may have an optional alias name. For example, Select("id", "my_name AS name").
func (b *SqliteBuilder) Select(cols ...string) *SelectQuery {
return NewSelectQuery(b, b.db).Select(cols...)
}
// Model returns a new ModelQuery object that can be used to perform model-based DB operations.
// The model passed to this method should be a pointer to a model struct.
func (b *SqliteBuilder) Model(model interface{}) *ModelQuery {
return NewModelQuery(model, b.db.FieldMapper, b.db, b)
}
// QuoteSimpleTableName quotes a simple table name.
// A simple table name does not contain any schema prefix.
func (b *SqliteBuilder) QuoteSimpleTableName(s string) string {
if strings.ContainsAny(s, "`") {
return s
}
return "`" + s + "`"
}
// QuoteSimpleColumnName quotes a simple column name.
// A simple column name does not contain any table prefix.
func (b *SqliteBuilder) QuoteSimpleColumnName(s string) string {
if strings.Contains(s, "`") || s == "*" {
return s
}
return "`" + s + "`"
}
// DropIndex creates a Query that can be used to remove the named index from a table.
func (b *SqliteBuilder) DropIndex(table, name string) *Query {
sql := fmt.Sprintf("DROP INDEX %v", b.db.QuoteColumnName(name))
return b.NewQuery(sql)
}
// TruncateTable creates a Query that can be used to truncate a table.
func (b *SqliteBuilder) TruncateTable(table string) *Query {
sql := "DELETE FROM " + b.db.QuoteTableName(table)
return b.NewQuery(sql)
}
// RenameTable creates a Query that can be used to rename a table.
func (b *SqliteBuilder) RenameTable(oldName, newName string) *Query {
sql := fmt.Sprintf("ALTER TABLE %v RENAME TO %v", b.db.QuoteTableName(oldName), b.db.QuoteTableName(newName))
return b.NewQuery(sql)
}
// AlterColumn creates a Query that can be used to change the definition of a table column.
func (b *SqliteBuilder) AlterColumn(table, col, typ string) *Query {
q := b.NewQuery("")
q.LastError = errors.New("SQLite does not support altering column")
return q
}
// AddPrimaryKey creates a Query that can be used to specify primary key(s) for a table.
// The "name" parameter specifies the name of the primary key constraint.
func (b *SqliteBuilder) AddPrimaryKey(table, name string, cols ...string) *Query {
q := b.NewQuery("")
q.LastError = errors.New("SQLite does not support adding primary key")
return q
}
// DropPrimaryKey creates a Query that can be used to remove the named primary key constraint from a table.
func (b *SqliteBuilder) DropPrimaryKey(table, name string) *Query {
q := b.NewQuery("")
q.LastError = errors.New("SQLite does not support dropping primary key")
return q
}
// AddForeignKey creates a Query that can be used to add a foreign key constraint to a table.
// The length of cols and refCols must be the same as they refer to the primary and referential columns.
// The optional "options" parameters will be appended to the SQL statement. They can be used to
// specify options such as "ON DELETE CASCADE".
func (b *SqliteBuilder) AddForeignKey(table, name string, cols, refCols []string, refTable string, options ...string) *Query {
q := b.NewQuery("")
q.LastError = errors.New("SQLite does not support adding foreign keys")
return q
}
// DropForeignKey creates a Query that can be used to remove the named foreign key constraint from a table.
func (b *SqliteBuilder) DropForeignKey(table, name string) *Query {
q := b.NewQuery("")
q.LastError = errors.New("SQLite does not support dropping foreign keys")
return q
}

83
builder_sqlite_test.go Normal file
View file

@ -0,0 +1,83 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSqliteBuilder_QuoteSimpleTableName(t *testing.T) {
b := getSqliteBuilder()
assert.Equal(t, b.QuoteSimpleTableName(`abc`), "`abc`", "t1")
assert.Equal(t, b.QuoteSimpleTableName("`abc`"), "`abc`", "t2")
assert.Equal(t, b.QuoteSimpleTableName(`{{abc}}`), "`{{abc}}`", "t3")
assert.Equal(t, b.QuoteSimpleTableName(`a.bc`), "`a.bc`", "t4")
}
func TestSqliteBuilder_QuoteSimpleColumnName(t *testing.T) {
b := getSqliteBuilder()
assert.Equal(t, b.QuoteSimpleColumnName(`abc`), "`abc`", "t1")
assert.Equal(t, b.QuoteSimpleColumnName("`abc`"), "`abc`", "t2")
assert.Equal(t, b.QuoteSimpleColumnName(`{{abc}}`), "`{{abc}}`", "t3")
assert.Equal(t, b.QuoteSimpleColumnName(`a.bc`), "`a.bc`", "t4")
assert.Equal(t, b.QuoteSimpleColumnName(`*`), `*`, "t5")
}
func TestSqliteBuilder_DropIndex(t *testing.T) {
b := getSqliteBuilder()
q := b.DropIndex("users", "idx")
assert.Equal(t, q.SQL(), "DROP INDEX `idx`", "t1")
}
func TestSqliteBuilder_TruncateTable(t *testing.T) {
b := getSqliteBuilder()
q := b.TruncateTable("users")
assert.Equal(t, q.SQL(), "DELETE FROM `users`", "t1")
}
func TestSqliteBuilder_RenameTable(t *testing.T) {
b := getSqliteBuilder()
q := b.RenameTable("usersOld", "usersNew")
assert.Equal(t, q.SQL(), "ALTER TABLE `usersOld` RENAME TO `usersNew`", "t1")
}
func TestSqliteBuilder_AlterColumn(t *testing.T) {
b := getSqliteBuilder()
q := b.AlterColumn("users", "name", "int")
assert.NotEqual(t, q.LastError, nil, "t1")
}
func TestSqliteBuilder_AddPrimaryKey(t *testing.T) {
b := getSqliteBuilder()
q := b.AddPrimaryKey("users", "pk", "id1", "id2")
assert.NotEqual(t, q.LastError, nil, "t1")
}
func TestSqliteBuilder_DropPrimaryKey(t *testing.T) {
b := getSqliteBuilder()
q := b.DropPrimaryKey("users", "pk")
assert.NotEqual(t, q.LastError, nil, "t1")
}
func TestSqliteBuilder_AddForeignKey(t *testing.T) {
b := getSqliteBuilder()
q := b.AddForeignKey("users", "fk", []string{"p1", "p2"}, []string{"f1", "f2"}, "profile", "opt")
assert.NotEqual(t, q.LastError, nil, "t1")
}
func TestSqliteBuilder_DropForeignKey(t *testing.T) {
b := getSqliteBuilder()
q := b.DropForeignKey("users", "fk")
assert.NotEqual(t, q.LastError, nil, "t1")
}
func getSqliteBuilder() Builder {
db := getDB()
b := NewSqliteBuilder(db, db.sqlDB)
db.Builder = b
return b
}

39
builder_standard.go Normal file
View file

@ -0,0 +1,39 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
// StandardBuilder is the builder that is used by DB for an unknown driver.
type StandardBuilder struct {
*BaseBuilder
qb *BaseQueryBuilder
}
var _ Builder = &StandardBuilder{}
// NewStandardBuilder creates a new StandardBuilder instance.
func NewStandardBuilder(db *DB, executor Executor) Builder {
return &StandardBuilder{
NewBaseBuilder(db, executor),
NewBaseQueryBuilder(db),
}
}
// QueryBuilder returns the query builder supporting the current DB.
func (b *StandardBuilder) QueryBuilder() QueryBuilder {
return b.qb
}
// Select returns a new SelectQuery object that can be used to build a SELECT statement.
// The parameters to this method should be the list column names to be selected.
// A column name may have an optional alias name. For example, Select("id", "my_name AS name").
func (b *StandardBuilder) Select(cols ...string) *SelectQuery {
return NewSelectQuery(b, b.db).Select(cols...)
}
// Model returns a new ModelQuery object that can be used to perform model-based DB operations.
// The model passed to this method should be a pointer to a model struct.
func (b *StandardBuilder) Model(model interface{}) *ModelQuery {
return NewModelQuery(model, b.db.FieldMapper, b.db, b)
}

183
builder_standard_test.go Normal file
View file

@ -0,0 +1,183 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestStandardBuilder_Quote(t *testing.T) {
b := getStandardBuilder()
assert.Equal(t, b.Quote(`abc`), `'abc'`, "t1")
assert.Equal(t, b.Quote(`I'm`), `'I''m'`, "t2")
assert.Equal(t, b.Quote(``), `''`, "t3")
}
func TestStandardBuilder_QuoteSimpleTableName(t *testing.T) {
b := getStandardBuilder()
assert.Equal(t, b.QuoteSimpleTableName(`abc`), `"abc"`, "t1")
assert.Equal(t, b.QuoteSimpleTableName(`"abc"`), `"abc"`, "t2")
assert.Equal(t, b.QuoteSimpleTableName(`{{abc}}`), `"{{abc}}"`, "t3")
assert.Equal(t, b.QuoteSimpleTableName(`a.bc`), `"a.bc"`, "t4")
}
func TestStandardBuilder_QuoteSimpleColumnName(t *testing.T) {
b := getStandardBuilder()
assert.Equal(t, b.QuoteSimpleColumnName(`abc`), `"abc"`, "t1")
assert.Equal(t, b.QuoteSimpleColumnName(`"abc"`), `"abc"`, "t2")
assert.Equal(t, b.QuoteSimpleColumnName(`{{abc}}`), `"{{abc}}"`, "t3")
assert.Equal(t, b.QuoteSimpleColumnName(`a.bc`), `"a.bc"`, "t4")
assert.Equal(t, b.QuoteSimpleColumnName(`*`), `*`, "t5")
}
func TestStandardBuilder_Insert(t *testing.T) {
b := getStandardBuilder()
q := b.Insert("users", Params{
"name": "James",
"age": 30,
})
assert.Equal(t, q.SQL(), `INSERT INTO "users" ("age", "name") VALUES ({:p0}, {:p1})`, "t1")
assert.Equal(t, q.Params()["p0"], 30, "t2")
assert.Equal(t, q.Params()["p1"], "James", "t3")
q = b.Insert("users", Params{})
assert.Equal(t, q.SQL(), `INSERT INTO "users" DEFAULT VALUES`, "t2")
}
func TestStandardBuilder_Upsert(t *testing.T) {
b := getStandardBuilder()
q := b.Upsert("users", Params{
"name": "James",
"age": 30,
})
assert.NotEqual(t, q.LastError, nil, "t1")
}
func TestStandardBuilder_Update(t *testing.T) {
b := getStandardBuilder()
q := b.Update("users", Params{
"name": "James",
"age": 30,
}, NewExp("id=10"))
assert.Equal(t, q.SQL(), `UPDATE "users" SET "age"={:p0}, "name"={:p1} WHERE id=10`, "t1")
assert.Equal(t, q.Params()["p0"], 30, "t2")
assert.Equal(t, q.Params()["p1"], "James", "t3")
q = b.Update("users", Params{
"name": "James",
"age": 30,
}, nil)
assert.Equal(t, q.SQL(), `UPDATE "users" SET "age"={:p0}, "name"={:p1}`, "t2")
}
func TestStandardBuilder_Delete(t *testing.T) {
b := getStandardBuilder()
q := b.Delete("users", NewExp("id=10"))
assert.Equal(t, q.SQL(), `DELETE FROM "users" WHERE id=10`, "t1")
q = b.Delete("users", nil)
assert.Equal(t, q.SQL(), `DELETE FROM "users"`, "t2")
}
func TestStandardBuilder_CreateTable(t *testing.T) {
b := getStandardBuilder()
q := b.CreateTable("users", map[string]string{
"id": "int primary key",
"name": "varchar(255)",
}, "ON DELETE CASCADE")
assert.Equal(t, q.SQL(), "CREATE TABLE \"users\" (\"id\" int primary key, \"name\" varchar(255)) ON DELETE CASCADE", "t1")
}
func TestStandardBuilder_RenameTable(t *testing.T) {
b := getStandardBuilder()
q := b.RenameTable("users", "user")
assert.Equal(t, q.SQL(), `RENAME TABLE "users" TO "user"`, "t1")
}
func TestStandardBuilder_DropTable(t *testing.T) {
b := getStandardBuilder()
q := b.DropTable("users")
assert.Equal(t, q.SQL(), `DROP TABLE "users"`, "t1")
}
func TestStandardBuilder_TruncateTable(t *testing.T) {
b := getStandardBuilder()
q := b.TruncateTable("users")
assert.Equal(t, q.SQL(), `TRUNCATE TABLE "users"`, "t1")
}
func TestStandardBuilder_AddColumn(t *testing.T) {
b := getStandardBuilder()
q := b.AddColumn("users", "age", "int")
assert.Equal(t, q.SQL(), `ALTER TABLE "users" ADD "age" int`, "t1")
}
func TestStandardBuilder_DropColumn(t *testing.T) {
b := getStandardBuilder()
q := b.DropColumn("users", "age")
assert.Equal(t, q.SQL(), `ALTER TABLE "users" DROP COLUMN "age"`, "t1")
}
func TestStandardBuilder_RenameColumn(t *testing.T) {
b := getStandardBuilder()
q := b.RenameColumn("users", "name", "username")
assert.Equal(t, q.SQL(), `ALTER TABLE "users" RENAME COLUMN "name" TO "username"`, "t1")
}
func TestStandardBuilder_AlterColumn(t *testing.T) {
b := getStandardBuilder()
q := b.AlterColumn("users", "name", "int")
assert.Equal(t, q.SQL(), `ALTER TABLE "users" CHANGE "name" "name" int`, "t1")
}
func TestStandardBuilder_AddPrimaryKey(t *testing.T) {
b := getStandardBuilder()
q := b.AddPrimaryKey("users", "pk", "id1", "id2")
assert.Equal(t, q.SQL(), `ALTER TABLE "users" ADD CONSTRAINT "pk" PRIMARY KEY ("id1", "id2")`, "t1")
}
func TestStandardBuilder_DropPrimaryKey(t *testing.T) {
b := getStandardBuilder()
q := b.DropPrimaryKey("users", "pk")
assert.Equal(t, q.SQL(), `ALTER TABLE "users" DROP CONSTRAINT "pk"`, "t1")
}
func TestStandardBuilder_AddForeignKey(t *testing.T) {
b := getStandardBuilder()
q := b.AddForeignKey("users", "fk", []string{"p1", "p2"}, []string{"f1", "f2"}, "profile", "opt")
assert.Equal(t, q.SQL(), `ALTER TABLE "users" ADD CONSTRAINT "fk" FOREIGN KEY ("p1", "p2") REFERENCES "profile" ("f1", "f2") opt`, "t1")
}
func TestStandardBuilder_DropForeignKey(t *testing.T) {
b := getStandardBuilder()
q := b.DropForeignKey("users", "fk")
assert.Equal(t, q.SQL(), `ALTER TABLE "users" DROP CONSTRAINT "fk"`, "t1")
}
func TestStandardBuilder_CreateIndex(t *testing.T) {
b := getStandardBuilder()
q := b.CreateIndex("users", "idx", "id1", "id2")
assert.Equal(t, q.SQL(), `CREATE INDEX "idx" ON "users" ("id1", "id2")`, "t1")
}
func TestStandardBuilder_CreateUniqueIndex(t *testing.T) {
b := getStandardBuilder()
q := b.CreateUniqueIndex("users", "idx", "id1", "id2")
assert.Equal(t, q.SQL(), `CREATE UNIQUE INDEX "idx" ON "users" ("id1", "id2")`, "t1")
}
func TestStandardBuilder_DropIndex(t *testing.T) {
b := getStandardBuilder()
q := b.DropIndex("users", "idx")
assert.Equal(t, q.SQL(), `DROP INDEX "idx" ON "users"`, "t1")
}
func getStandardBuilder() Builder {
db := getDB()
b := NewStandardBuilder(db, db.sqlDB)
db.Builder = b
return b
}

338
db.go Normal file
View file

@ -0,0 +1,338 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
// Package dbx provides a set of DB-agnostic and easy-to-use query building methods for relational databases.
package dbx
import (
"bytes"
"context"
"database/sql"
"regexp"
"strings"
"time"
)
type (
// LogFunc logs a message for each SQL statement being executed.
// This method takes one or multiple parameters. If a single parameter
// is provided, it will be treated as the log message. If multiple parameters
// are provided, they will be passed to fmt.Sprintf() to generate the log message.
LogFunc func(format string, a ...interface{})
// PerfFunc is called when a query finishes execution.
// The query execution time is passed to this function so that the DB performance
// can be profiled. The "ns" parameter gives the number of nanoseconds that the
// SQL statement takes to execute, while the "execute" parameter indicates whether
// the SQL statement is executed or queried (usually SELECT statements).
PerfFunc func(ns int64, sql string, execute bool)
// QueryLogFunc is called each time when performing a SQL query.
// The "t" parameter gives the time that the SQL statement takes to execute,
// while rows and err are the result of the query.
QueryLogFunc func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error)
// ExecLogFunc is called each time when a SQL statement is executed.
// The "t" parameter gives the time that the SQL statement takes to execute,
// while result and err refer to the result of the execution.
ExecLogFunc func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error)
// BuilderFunc creates a Builder instance using the given DB instance and Executor.
BuilderFunc func(*DB, Executor) Builder
// DB enhances sql.DB by providing a set of DB-agnostic query building methods.
// DB allows easier query building and population of data into Go variables.
DB struct {
Builder
// FieldMapper maps struct fields to DB columns. Defaults to DefaultFieldMapFunc.
FieldMapper FieldMapFunc
// TableMapper maps structs to table names. Defaults to GetTableName.
TableMapper TableMapFunc
// LogFunc logs the SQL statements being executed. Defaults to nil, meaning no logging.
LogFunc LogFunc
// PerfFunc logs the SQL execution time. Defaults to nil, meaning no performance profiling.
// Deprecated: Please use QueryLogFunc and ExecLogFunc instead.
PerfFunc PerfFunc
// QueryLogFunc is called each time when performing a SQL query that returns data.
QueryLogFunc QueryLogFunc
// ExecLogFunc is called each time when a SQL statement is executed.
ExecLogFunc ExecLogFunc
sqlDB *sql.DB
driverName string
ctx context.Context
}
// Errors represents a list of errors.
Errors []error
)
// BuilderFuncMap lists supported BuilderFunc according to DB driver names.
// You may modify this variable to add the builder support for a new DB driver.
// If a DB driver is not listed here, the StandardBuilder will be used.
var BuilderFuncMap = map[string]BuilderFunc{
"sqlite": NewSqliteBuilder,
"sqlite3": NewSqliteBuilder,
"mysql": NewMysqlBuilder,
"postgres": NewPgsqlBuilder,
"pgx": NewPgsqlBuilder,
"mssql": NewMssqlBuilder,
"oci8": NewOciBuilder,
}
// NewFromDB encapsulates an existing database connection.
func NewFromDB(sqlDB *sql.DB, driverName string) *DB {
db := &DB{
driverName: driverName,
sqlDB: sqlDB,
FieldMapper: DefaultFieldMapFunc,
TableMapper: GetTableName,
}
db.Builder = db.newBuilder(db.sqlDB)
return db
}
// Open opens a database specified by a driver name and data source name (DSN).
// Note that Open does not check if DSN is specified correctly. It doesn't try to establish a DB connection either.
// Please refer to sql.Open() for more information.
func Open(driverName, dsn string) (*DB, error) {
sqlDB, err := sql.Open(driverName, dsn)
if err != nil {
return nil, err
}
return NewFromDB(sqlDB, driverName), nil
}
// MustOpen opens a database and establishes a connection to it.
// Please refer to sql.Open() and sql.Ping() for more information.
func MustOpen(driverName, dsn string) (*DB, error) {
db, err := Open(driverName, dsn)
if err != nil {
return nil, err
}
if err := db.sqlDB.Ping(); err != nil {
db.Close()
return nil, err
}
return db, nil
}
// Clone makes a shallow copy of DB.
func (db *DB) Clone() *DB {
db2 := &DB{
driverName: db.driverName,
sqlDB: db.sqlDB,
FieldMapper: db.FieldMapper,
TableMapper: db.TableMapper,
PerfFunc: db.PerfFunc,
LogFunc: db.LogFunc,
QueryLogFunc: db.QueryLogFunc,
ExecLogFunc: db.ExecLogFunc,
}
db2.Builder = db2.newBuilder(db.sqlDB)
return db2
}
// WithContext returns a new instance of DB associated with the given context.
func (db *DB) WithContext(ctx context.Context) *DB {
db2 := db.Clone()
db2.ctx = ctx
return db2
}
// Context returns the context associated with the DB instance.
// It returns nil if no context is associated.
func (db *DB) Context() context.Context {
return db.ctx
}
// DB returns the sql.DB instance encapsulated by dbx.DB.
func (db *DB) DB() *sql.DB {
return db.sqlDB
}
// Close closes the database, releasing any open resources.
// It is rare to Close a DB, as the DB handle is meant to be
// long-lived and shared between many goroutines.
func (db *DB) Close() error {
return db.sqlDB.Close()
}
// Begin starts a transaction.
func (db *DB) Begin() (*Tx, error) {
var tx *sql.Tx
var err error
if db.ctx != nil {
tx, err = db.sqlDB.BeginTx(db.ctx, nil)
} else {
tx, err = db.sqlDB.Begin()
}
if err != nil {
return nil, err
}
return &Tx{db.newBuilder(tx), tx}, nil
}
// BeginTx starts a transaction with the given context and transaction options.
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
tx, err := db.sqlDB.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &Tx{db.newBuilder(tx), tx}, nil
}
// Wrap encapsulates an existing transaction.
func (db *DB) Wrap(sqlTx *sql.Tx) *Tx {
return &Tx{db.newBuilder(sqlTx), sqlTx}
}
// Transactional starts a transaction and executes the given function.
// If the function returns an error, the transaction will be rolled back.
// Otherwise, the transaction will be committed.
func (db *DB) Transactional(f func(*Tx) error) (err error) {
tx, err := db.Begin()
if err != nil {
return err
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
} else if err != nil {
if err2 := tx.Rollback(); err2 != nil {
if err2 == sql.ErrTxDone {
return
}
err = Errors{err, err2}
}
} else {
if err = tx.Commit(); err == sql.ErrTxDone {
err = nil
}
}
}()
err = f(tx)
return err
}
// TransactionalContext starts a transaction and executes the given function with the given context and transaction options.
// If the function returns an error, the transaction will be rolled back.
// Otherwise, the transaction will be committed.
func (db *DB) TransactionalContext(ctx context.Context, opts *sql.TxOptions, f func(*Tx) error) (err error) {
tx, err := db.BeginTx(ctx, opts)
if err != nil {
return err
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
} else if err != nil {
if err2 := tx.Rollback(); err2 != nil {
if err2 == sql.ErrTxDone {
return
}
err = Errors{err, err2}
}
} else {
if err = tx.Commit(); err == sql.ErrTxDone {
err = nil
}
}
}()
err = f(tx)
return err
}
// DriverName returns the name of the DB driver.
func (db *DB) DriverName() string {
return db.driverName
}
// QuoteTableName quotes the given table name appropriately.
// If the table name contains DB schema prefix, it will be handled accordingly.
// This method will do nothing if the table name is already quoted or if it contains parenthesis.
func (db *DB) QuoteTableName(s string) string {
if strings.Contains(s, "(") || strings.Contains(s, "{{") {
return s
}
if !strings.Contains(s, ".") {
return db.QuoteSimpleTableName(s)
}
parts := strings.Split(s, ".")
for i, part := range parts {
parts[i] = db.QuoteSimpleTableName(part)
}
return strings.Join(parts, ".")
}
// QuoteColumnName quotes the given column name appropriately.
// If the table name contains table name prefix, it will be handled accordingly.
// This method will do nothing if the column name is already quoted or if it contains parenthesis.
func (db *DB) QuoteColumnName(s string) string {
if strings.Contains(s, "(") || strings.Contains(s, "{{") || strings.Contains(s, "[[") {
return s
}
prefix := ""
if pos := strings.LastIndex(s, "."); pos != -1 {
prefix = db.QuoteTableName(s[:pos]) + "."
s = s[pos+1:]
}
return prefix + db.QuoteSimpleColumnName(s)
}
var (
plRegex = regexp.MustCompile(`\{:\w+\}`)
quoteRegex = regexp.MustCompile(`(\{\{[\w\-\. ]+\}\}|\[\[[\w\-\. ]+\]\])`)
)
// processSQL replaces the named param placeholders in the given SQL with anonymous ones.
// It also quotes table names and column names found in the SQL if these names are enclosed
// within double square/curly brackets. The method will return the updated SQL and the list of parameter names.
func (db *DB) processSQL(s string) (string, []string) {
var placeholders []string
count := 0
s = plRegex.ReplaceAllStringFunc(s, func(m string) string {
count++
placeholders = append(placeholders, m[2:len(m)-1])
return db.GeneratePlaceholder(count)
})
s = quoteRegex.ReplaceAllStringFunc(s, func(m string) string {
if m[0] == '{' {
return db.QuoteTableName(m[2 : len(m)-2])
}
return db.QuoteColumnName(m[2 : len(m)-2])
})
return s, placeholders
}
// newBuilder creates a query builder based on the current driver name.
func (db *DB) newBuilder(executor Executor) Builder {
builderFunc, ok := BuilderFuncMap[db.driverName]
if !ok {
builderFunc = NewStandardBuilder
}
return builderFunc(db, executor)
}
// Error returns the error string of Errors.
func (errs Errors) Error() string {
var b bytes.Buffer
for i, e := range errs {
if i > 0 {
b.WriteRune('\n')
}
b.WriteString(e.Error())
}
return b.String()
}

380
db_test.go Normal file
View file

@ -0,0 +1,380 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"context"
"database/sql"
"errors"
"io/ioutil"
"strings"
"testing"
// @todo change to sqlite
_ "github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"
)
const (
TestDSN = "travis:@/pocketbase_dbx_test?parseTime=true"
FixtureFile = "testdata/mysql.sql"
)
func TestDB_NewFromDB(t *testing.T) {
sqlDB, err := sql.Open("mysql", TestDSN)
if assert.Nil(t, err) {
db := NewFromDB(sqlDB, "mysql")
assert.NotNil(t, db.sqlDB)
assert.NotNil(t, db.FieldMapper)
}
}
func TestDB_Open(t *testing.T) {
db, err := Open("mysql", TestDSN)
assert.Nil(t, err)
if assert.NotNil(t, db) {
assert.NotNil(t, db.sqlDB)
assert.NotNil(t, db.FieldMapper)
db2 := db.Clone()
assert.NotEqual(t, db, db2)
assert.Equal(t, db.driverName, db2.driverName)
ctx := context.Background()
db3 := db.WithContext(ctx)
assert.Equal(t, ctx, db3.ctx)
assert.Equal(t, ctx, db3.Context())
assert.NotEqual(t, db, db3)
}
_, err = Open("xyz", TestDSN)
assert.NotNil(t, err)
}
func TestDB_MustOpen(t *testing.T) {
_, err := MustOpen("mysql", TestDSN)
assert.Nil(t, err)
_, err = MustOpen("mysql", "unknown:x@/test")
assert.NotNil(t, err)
}
func TestDB_Close(t *testing.T) {
db := getDB()
assert.Nil(t, db.Close())
}
func TestDB_DriverName(t *testing.T) {
db := getDB()
assert.Equal(t, "mysql", db.DriverName())
}
func TestDB_QuoteTableName(t *testing.T) {
tests := []struct {
input, output string
}{
{"users", "`users`"},
{"`users`", "`users`"},
{"(select)", "(select)"},
{"{{users}}", "{{users}}"},
{"public.db1.users", "`public`.`db1`.`users`"},
}
db := getDB()
for _, test := range tests {
result := db.QuoteTableName(test.input)
assert.Equal(t, test.output, result, test.input)
}
}
func TestDB_QuoteColumnName(t *testing.T) {
tests := []struct {
input, output string
}{
{"*", "*"},
{"users.*", "`users`.*"},
{"name", "`name`"},
{"`name`", "`name`"},
{"(select)", "(select)"},
{"{{name}}", "{{name}}"},
{"[[name]]", "[[name]]"},
{"public.db1.users", "`public`.`db1`.`users`"},
}
db := getDB()
for _, test := range tests {
result := db.QuoteColumnName(test.input)
assert.Equal(t, test.output, result, test.input)
}
}
func TestDB_ProcessSQL(t *testing.T) {
tests := []struct {
tag string
sql string // original SQL
mysql string // expected MySQL version
postgres string // expected PostgreSQL version
oci8 string // expected OCI version
params []string // expected params
}{
{
"normal case",
`INSERT INTO employee (id, name, age) VALUES ({:id}, {:name}, {:age})`,
`INSERT INTO employee (id, name, age) VALUES (?, ?, ?)`,
`INSERT INTO employee (id, name, age) VALUES ($1, $2, $3)`,
`INSERT INTO employee (id, name, age) VALUES (:p1, :p2, :p3)`,
[]string{"id", "name", "age"},
},
{
"the same placeholder is used twice",
`SELECT * FROM employee WHERE first_name LIKE {:keyword} OR last_name LIKE {:keyword}`,
`SELECT * FROM employee WHERE first_name LIKE ? OR last_name LIKE ?`,
`SELECT * FROM employee WHERE first_name LIKE $1 OR last_name LIKE $2`,
`SELECT * FROM employee WHERE first_name LIKE :p1 OR last_name LIKE :p2`,
[]string{"keyword", "keyword"},
},
{
"non-matching placeholder",
`SELECT * FROM employee WHERE first_name LIKE "{:key?word}" OR last_name LIKE {:keyword}`,
`SELECT * FROM employee WHERE first_name LIKE "{:key?word}" OR last_name LIKE ?`,
`SELECT * FROM employee WHERE first_name LIKE "{:key?word}" OR last_name LIKE $1`,
`SELECT * FROM employee WHERE first_name LIKE "{:key?word}" OR last_name LIKE :p1`,
[]string{"keyword"},
},
{
"quote table/column names",
`SELECT * FROM {{public.user}} WHERE [[user.id]]=1`,
"SELECT * FROM `public`.`user` WHERE `user`.`id`=1",
"SELECT * FROM \"public\".\"user\" WHERE \"user\".\"id\"=1",
"SELECT * FROM \"public\".\"user\" WHERE \"user\".\"id\"=1",
nil,
},
}
mysqlDB := getDB()
mysqlDB.Builder = NewMysqlBuilder(nil, nil)
pgsqlDB := getDB()
pgsqlDB.Builder = NewPgsqlBuilder(nil, nil)
ociDB := getDB()
ociDB.Builder = NewOciBuilder(nil, nil)
for _, test := range tests {
s1, names := mysqlDB.processSQL(test.sql)
assert.Equal(t, test.mysql, s1, test.tag)
s2, _ := pgsqlDB.processSQL(test.sql)
assert.Equal(t, test.postgres, s2, test.tag)
s3, _ := ociDB.processSQL(test.sql)
assert.Equal(t, test.oci8, s3, test.tag)
assert.Equal(t, test.params, names, test.tag)
}
}
func TestDB_Begin(t *testing.T) {
tests := []struct {
makeTx func(db *DB) *Tx
desc string
}{
{
makeTx: func(db *DB) *Tx {
tx, _ := db.Begin()
return tx
},
desc: "Begin",
},
{
makeTx: func(db *DB) *Tx {
sqlTx, _ := db.DB().Begin()
return db.Wrap(sqlTx)
},
desc: "Wrap",
},
{
makeTx: func(db *DB) *Tx {
tx, _ := db.BeginTx(context.Background(), nil)
return tx
},
desc: "BeginTx",
},
}
db := getPreparedDB()
var (
lastID int
name string
tx *Tx
)
db.NewQuery("SELECT MAX(id) FROM item").Row(&lastID)
for _, test := range tests {
t.Log(test.desc)
tx = test.makeTx(db)
_, err1 := tx.Insert("item", Params{
"name": "name1",
}).Execute()
_, err2 := tx.Insert("item", Params{
"name": "name2",
}).Execute()
if err1 == nil && err2 == nil {
tx.Commit()
} else {
t.Errorf("Unexpected TX rollback: %v, %v", err1, err2)
tx.Rollback()
}
q := db.NewQuery("SELECT name FROM item WHERE id={:id}")
q.Bind(Params{"id": lastID + 1}).Row(&name)
assert.Equal(t, "name1", name)
q.Bind(Params{"id": lastID + 2}).Row(&name)
assert.Equal(t, "name2", name)
tx = test.makeTx(db)
_, err3 := tx.NewQuery("DELETE FROM item WHERE id=7").Execute()
_, err4 := tx.NewQuery("DELETE FROM items WHERE id=7").Execute()
if err3 == nil && err4 == nil {
t.Error("Unexpected TX commit")
tx.Commit()
} else {
tx.Rollback()
}
}
}
func TestDB_Transactional(t *testing.T) {
db := getPreparedDB()
var (
lastID int
name string
)
db.NewQuery("SELECT MAX(id) FROM item").Row(&lastID)
err := db.Transactional(func(tx *Tx) error {
_, err := tx.Insert("item", Params{
"name": "name1",
}).Execute()
if err != nil {
return err
}
_, err = tx.Insert("item", Params{
"name": "name2",
}).Execute()
if err != nil {
return err
}
return nil
})
if assert.Nil(t, err) {
q := db.NewQuery("SELECT name FROM item WHERE id={:id}")
q.Bind(Params{"id": lastID + 1}).Row(&name)
assert.Equal(t, "name1", name)
q.Bind(Params{"id": lastID + 2}).Row(&name)
assert.Equal(t, "name2", name)
}
err = db.Transactional(func(tx *Tx) error {
_, err := tx.NewQuery("DELETE FROM item WHERE id=2").Execute()
if err != nil {
return err
}
_, err = tx.NewQuery("DELETE FROM items WHERE id=2").Execute()
if err != nil {
return err
}
return nil
})
if assert.NotNil(t, err) {
db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name)
assert.Equal(t, "Go in Action", name)
}
// Rollback called within Transactional and return error
err = db.Transactional(func(tx *Tx) error {
_, err := tx.NewQuery("DELETE FROM item WHERE id=2").Execute()
if err != nil {
return err
}
_, err = tx.NewQuery("DELETE FROM items WHERE id=2").Execute()
if err != nil {
tx.Rollback()
return err
}
return nil
})
if assert.NotNil(t, err) {
db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name)
assert.Equal(t, "Go in Action", name)
}
// Rollback called within Transactional without returning error
err = db.Transactional(func(tx *Tx) error {
_, err := tx.NewQuery("DELETE FROM item WHERE id=2").Execute()
if err != nil {
return err
}
_, err = tx.NewQuery("DELETE FROM items WHERE id=2").Execute()
if err != nil {
tx.Rollback()
return nil
}
return nil
})
if assert.Nil(t, err) {
db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name)
assert.Equal(t, "Go in Action", name)
}
}
func TestErrors_Error(t *testing.T) {
errs := Errors{}
assert.Equal(t, "", errs.Error())
errs = Errors{errors.New("a")}
assert.Equal(t, "a", errs.Error())
errs = Errors{errors.New("a"), errors.New("b")}
assert.Equal(t, "a\nb", errs.Error())
}
func getDB() *DB {
db, err := Open("mysql", TestDSN)
if err != nil {
panic(err)
}
return db
}
func getPreparedDB() *DB {
db := getDB()
s, err := ioutil.ReadFile(FixtureFile)
if err != nil {
panic(err)
}
lines := strings.Split(string(s), ";")
for _, line := range lines {
if strings.TrimSpace(line) == "" {
continue
}
if _, err := db.NewQuery(line).Execute(); err != nil {
panic(err)
}
}
return db
}
// Naming according to issue 49 ( https://github.com/pocketbase/dbx/issues/49 )
type ArtistDAO struct {
nickname string
}
func (ArtistDAO) TableName() string {
return "artists"
}
func Test_TableNameWithPrefix(t *testing.T) {
db := NewFromDB(nil, "mysql")
db.TableMapper = func(a interface{}) string {
return "tbl_" + GetTableName(a)
}
assert.Equal(t, "tbl_artists", db.TableMapper(ArtistDAO{}))
}

281
example_test.go Normal file
View file

@ -0,0 +1,281 @@
package dbx_test
import (
"fmt"
"github.com/pocketbase/dbx"
)
// This example shows how to populate DB data in different ways.
func Example_dbQueries() {
db, _ := dbx.Open("mysql", "user:pass@/example")
// create a new query
q := db.NewQuery("SELECT id, name FROM users LIMIT 10")
// fetch all rows into a struct array
var users []struct {
ID, Name string
}
q.All(&users)
// fetch a single row into a struct
var user struct {
ID, Name string
}
q.One(&user)
// fetch a single row into a string map
data := dbx.NullStringMap{}
q.One(data)
// fetch row by row
rows2, _ := q.Rows()
for rows2.Next() {
rows2.ScanStruct(&user)
// rows.ScanMap(data)
// rows.Scan(&id, &name)
}
}
// This example shows how to use query builder to build DB queries.
func Example_queryBuilder() {
db, _ := dbx.Open("mysql", "user:pass@/example")
// build a SELECT query
// SELECT `id`, `name` FROM `users` WHERE `name` LIKE '%Charles%' ORDER BY `id`
q := db.Select("id", "name").
From("users").
Where(dbx.Like("name", "Charles")).
OrderBy("id")
// fetch all rows into a struct array
var users []struct {
ID, Name string
}
q.All(&users)
// build an INSERT query
// INSERT INTO `users` (name) VALUES ('James')
db.Insert("users", dbx.Params{
"name": "James",
}).Execute()
}
// This example shows how to use query builder in transactions.
func Example_transactions() {
db, _ := dbx.Open("mysql", "user:pass@/example")
db.Transactional(func(tx *dbx.Tx) error {
_, err := tx.Insert("user", dbx.Params{
"name": "user1",
}).Execute()
if err != nil {
return err
}
_, err = tx.Insert("user", dbx.Params{
"name": "user2",
}).Execute()
return err
})
}
type Customer struct {
ID string
Name string
}
// This example shows how to do CRUD operations.
func Example_crudOperations() {
db, _ := dbx.Open("mysql", "user:pass@/example")
var customer Customer
// read a customer: SELECT * FROM customer WHERE id=100
db.Select().Model(100, &customer)
// create a customer: INSERT INTO customer (name) VALUES ('test')
db.Model(&customer).Insert()
// update a customer: UPDATE customer SET name='test' WHERE id=100
db.Model(&customer).Update()
// delete a customer: DELETE FROM customer WHERE id=100
db.Model(&customer).Delete()
}
func ExampleSchemaBuilder() {
db, _ := dbx.Open("mysql", "user:pass@/example")
db.Insert("users", dbx.Params{
"name": "James",
"age": 30,
}).Execute()
}
func ExampleRows_ScanMap() {
db, _ := dbx.Open("mysql", "user:pass@/example")
user := dbx.NullStringMap{}
sql := "SELECT id, name FROM users LIMIT 10"
rows, _ := db.NewQuery(sql).Rows()
for rows.Next() {
rows.ScanMap(user)
// ...
}
}
func ExampleRows_ScanStruct() {
db, _ := dbx.Open("mysql", "user:pass@/example")
var user struct {
ID, Name string
}
sql := "SELECT id, name FROM users LIMIT 10"
rows, _ := db.NewQuery(sql).Rows()
for rows.Next() {
rows.ScanStruct(&user)
// ...
}
}
func ExampleQuery_All() {
db, _ := dbx.Open("mysql", "user:pass@/example")
sql := "SELECT id, name FROM users LIMIT 10"
// fetches data into a slice of struct
var users []struct {
ID, Name string
}
db.NewQuery(sql).All(&users)
// fetches data into a slice of NullStringMap
var users2 []dbx.NullStringMap
db.NewQuery(sql).All(&users2)
for _, user := range users2 {
fmt.Println(user["id"].String, user["name"].String)
}
}
func ExampleQuery_One() {
db, _ := dbx.Open("mysql", "user:pass@/example")
sql := "SELECT id, name FROM users LIMIT 10"
// fetches data into a struct
var user struct {
ID, Name string
}
db.NewQuery(sql).One(&user)
// fetches data into a NullStringMap
var user2 dbx.NullStringMap
db.NewQuery(sql).All(user2)
fmt.Println(user2["id"].String, user2["name"].String)
}
func ExampleQuery_Row() {
db, _ := dbx.Open("mysql", "user:pass@/example")
sql := "SELECT id, name FROM users LIMIT 10"
// fetches data into a struct
var (
id int
name string
)
db.NewQuery(sql).Row(&id, &name)
}
func ExampleQuery_Rows() {
var user struct {
ID, Name string
}
db, _ := dbx.Open("mysql", "user:pass@/example")
sql := "SELECT id, name FROM users LIMIT 10"
rows, _ := db.NewQuery(sql).Rows()
for rows.Next() {
rows.ScanStruct(&user)
// ...
}
}
func ExampleQuery_Bind() {
var user struct {
ID, Name string
}
db, _ := dbx.Open("mysql", "user:pass@/example")
sql := "SELECT id, name FROM users WHERE age>{:age} AND status={:status}"
q := db.NewQuery(sql)
q.Bind(dbx.Params{"age": 30, "status": 1}).One(&user)
}
func ExampleQuery_Prepare() {
var users1, users2, users3 []struct {
ID, Name string
}
db, _ := dbx.Open("mysql", "user:pass@/example")
sql := "SELECT id, name FROM users WHERE age>{:age} AND status={:status}"
q := db.NewQuery(sql).Prepare()
defer q.Close()
q.Bind(dbx.Params{"age": 30, "status": 1}).All(&users1)
q.Bind(dbx.Params{"age": 20, "status": 1}).All(&users2)
q.Bind(dbx.Params{"age": 10, "status": 1}).All(&users3)
}
func ExampleDB() {
db, _ := dbx.Open("mysql", "user:pass@/example")
// queries data through a plain SQL
var users []struct {
ID, Name string
}
db.NewQuery("SELECT id, name FROM users WHERE age=30").All(&users)
// queries data using query builder
db.Select("id", "name").From("users").Where(dbx.HashExp{"age": 30}).All(&users)
// executes a plain SQL
db.NewQuery("INSERT INTO users (name) SET ({:name})").Bind(dbx.Params{"name": "James"}).Execute()
// executes a SQL using query builder
db.Insert("users", dbx.Params{"name": "James"}).Execute()
}
func ExampleDB_Open() {
db, err := dbx.Open("mysql", "user:pass@/example")
if err != nil {
panic(err)
}
var users []dbx.NullStringMap
if err := db.NewQuery("SELECT * FROM users LIMIT 10").All(&users); err != nil {
panic(err)
}
}
func ExampleDB_Begin() {
db, _ := dbx.Open("mysql", "user:pass@/example")
tx, _ := db.Begin()
_, err1 := tx.Insert("user", dbx.Params{
"name": "user1",
}).Execute()
_, err2 := tx.Insert("user", dbx.Params{
"name": "user2",
}).Execute()
if err1 == nil && err2 == nil {
tx.Commit()
} else {
tx.Rollback()
}
}

421
expression.go Normal file
View file

@ -0,0 +1,421 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"fmt"
"sort"
"strings"
)
// Expression represents a DB expression that can be embedded in a SQL statement.
type Expression interface {
// Build converts an expression into a SQL fragment.
// If the expression contains binding parameters, they will be added to the given Params.
Build(*DB, Params) string
}
// HashExp represents a hash expression.
//
// A hash expression is a map whose keys are DB column names which need to be filtered according
// to the corresponding values. For example, HashExp{"level": 2, "dept": 10} will generate
// the SQL: "level"=2 AND "dept"=10.
//
// HashExp also handles nil values and slice values. For example, HashExp{"level": []interface{}{1, 2}, "dept": nil}
// will generate: "level" IN (1, 2) AND "dept" IS NULL.
type HashExp map[string]interface{}
// NewExp generates an expression with the specified SQL fragment and the optional binding parameters.
func NewExp(e string, params ...Params) Expression {
if len(params) > 0 {
return &Exp{e, params[0]}
}
return &Exp{e, nil}
}
// Not generates a NOT expression which prefixes "NOT" to the specified expression.
func Not(e Expression) Expression {
return &NotExp{e}
}
// And generates an AND expression which concatenates the given expressions with "AND".
func And(exps ...Expression) Expression {
return &AndOrExp{exps, "AND"}
}
// Or generates an OR expression which concatenates the given expressions with "OR".
func Or(exps ...Expression) Expression {
return &AndOrExp{exps, "OR"}
}
// In generates an IN expression for the specified column and the list of allowed values.
// If values is empty, a SQL "0=1" will be generated which represents a false expression.
func In(col string, values ...interface{}) Expression {
return &InExp{col, values, false}
}
// NotIn generates an NOT IN expression for the specified column and the list of disallowed values.
// If values is empty, an empty string will be returned indicating a true expression.
func NotIn(col string, values ...interface{}) Expression {
return &InExp{col, values, true}
}
// DefaultLikeEscape specifies the default special character escaping for LIKE expressions
// The strings at 2i positions are the special characters to be escaped while those at 2i+1 positions
// are the corresponding escaped versions.
var DefaultLikeEscape = []string{"\\", "\\\\", "%", "\\%", "_", "\\_"}
// Like generates a LIKE expression for the specified column and the possible strings that the column should be like.
// If multiple values are present, the column should be like *all* of them. For example, Like("name", "key", "word")
// will generate a SQL expression: "name" LIKE "%key%" AND "name" LIKE "%word%".
//
// By default, each value will be surrounded by "%" to enable partial matching. If a value contains special characters
// such as "%", "\", "_", they will also be properly escaped.
//
// You may call Escape() and/or Match() to change the default behavior. For example, Like("name", "key").Match(false, true)
// generates "name" LIKE "key%".
func Like(col string, values ...string) *LikeExp {
return &LikeExp{
left: true,
right: true,
col: col,
values: values,
escape: DefaultLikeEscape,
Like: "LIKE",
}
}
// NotLike generates a NOT LIKE expression.
// For example, NotLike("name", "key", "word") will generate a SQL expression:
// "name" NOT LIKE "%key%" AND "name" NOT LIKE "%word%". Please see Like() for more details.
func NotLike(col string, values ...string) *LikeExp {
return &LikeExp{
left: true,
right: true,
col: col,
values: values,
escape: DefaultLikeEscape,
Like: "NOT LIKE",
}
}
// OrLike generates an OR LIKE expression.
// This is similar to Like() except that the column should be like one of the possible values.
// For example, OrLike("name", "key", "word") will generate a SQL expression:
// "name" LIKE "%key%" OR "name" LIKE "%word%". Please see Like() for more details.
func OrLike(col string, values ...string) *LikeExp {
return &LikeExp{
or: true,
left: true,
right: true,
col: col,
values: values,
escape: DefaultLikeEscape,
Like: "LIKE",
}
}
// OrNotLike generates an OR NOT LIKE expression.
// For example, OrNotLike("name", "key", "word") will generate a SQL expression:
// "name" NOT LIKE "%key%" OR "name" NOT LIKE "%word%". Please see Like() for more details.
func OrNotLike(col string, values ...string) *LikeExp {
return &LikeExp{
or: true,
left: true,
right: true,
col: col,
values: values,
escape: DefaultLikeEscape,
Like: "NOT LIKE",
}
}
// Exists generates an EXISTS expression by prefixing "EXISTS" to the given expression.
func Exists(exp Expression) Expression {
return &ExistsExp{exp, false}
}
// NotExists generates an EXISTS expression by prefixing "NOT EXISTS" to the given expression.
func NotExists(exp Expression) Expression {
return &ExistsExp{exp, true}
}
// Between generates a BETWEEN expression.
// For example, Between("age", 10, 30) generates: "age" BETWEEN 10 AND 30
func Between(col string, from, to interface{}) Expression {
return &BetweenExp{col, from, to, false}
}
// NotBetween generates a NOT BETWEEN expression.
// For example, NotBetween("age", 10, 30) generates: "age" NOT BETWEEN 10 AND 30
func NotBetween(col string, from, to interface{}) Expression {
return &BetweenExp{col, from, to, true}
}
// Exp represents an expression with a SQL fragment and a list of optional binding parameters.
type Exp struct {
e string
params Params
}
// Build converts an expression into a SQL fragment.
func (e *Exp) Build(db *DB, params Params) string {
if len(e.params) == 0 {
return e.e
}
for k, v := range e.params {
params[k] = v
}
return e.e
}
// Build converts an expression into a SQL fragment.
func (e HashExp) Build(db *DB, params Params) string {
if len(e) == 0 {
return ""
}
// ensure the hash exp generates the same SQL for different runs
names := []string{}
for name := range e {
names = append(names, name)
}
sort.Strings(names)
var parts []string
for _, name := range names {
value := e[name]
switch value.(type) {
case nil:
name = db.QuoteColumnName(name)
parts = append(parts, name+" IS NULL")
case Expression:
if sql := value.(Expression).Build(db, params); sql != "" {
parts = append(parts, "("+sql+")")
}
case []interface{}:
in := In(name, value.([]interface{})...)
if sql := in.Build(db, params); sql != "" {
parts = append(parts, sql)
}
default:
pn := fmt.Sprintf("p%v", len(params))
name = db.QuoteColumnName(name)
parts = append(parts, name+"={:"+pn+"}")
params[pn] = value
}
}
if len(parts) == 1 {
return parts[0]
}
return strings.Join(parts, " AND ")
}
// NotExp represents an expression that should prefix "NOT" to a specified expression.
type NotExp struct {
e Expression
}
// Build converts an expression into a SQL fragment.
func (e *NotExp) Build(db *DB, params Params) string {
if sql := e.e.Build(db, params); sql != "" {
return "NOT (" + sql + ")"
}
return ""
}
// AndOrExp represents an expression that concatenates multiple expressions using either "AND" or "OR".
type AndOrExp struct {
exps []Expression
op string
}
// Build converts an expression into a SQL fragment.
func (e *AndOrExp) Build(db *DB, params Params) string {
if len(e.exps) == 0 {
return ""
}
var parts []string
for _, a := range e.exps {
if a == nil {
continue
}
if sql := a.Build(db, params); sql != "" {
parts = append(parts, sql)
}
}
if len(parts) == 1 {
return parts[0]
}
return "(" + strings.Join(parts, ") "+e.op+" (") + ")"
}
// InExp represents an "IN" or "NOT IN" expression.
type InExp struct {
col string
values []interface{}
not bool
}
// Build converts an expression into a SQL fragment.
func (e *InExp) Build(db *DB, params Params) string {
if len(e.values) == 0 {
if e.not {
return ""
}
return "0=1"
}
var values []string
for _, value := range e.values {
switch value.(type) {
case nil:
values = append(values, "NULL")
case Expression:
sql := value.(Expression).Build(db, params)
values = append(values, sql)
default:
name := fmt.Sprintf("p%v", len(params))
params[name] = value
values = append(values, "{:"+name+"}")
}
}
col := db.QuoteColumnName(e.col)
if len(values) == 1 {
if e.not {
return col + "<>" + values[0]
}
return col + "=" + values[0]
}
in := "IN"
if e.not {
in = "NOT IN"
}
return fmt.Sprintf("%v %v (%v)", col, in, strings.Join(values, ", "))
}
// LikeExp represents a variant of LIKE expressions.
type LikeExp struct {
or bool
left, right bool
col string
values []string
escape []string
// Like stores the LIKE operator. It can be "LIKE", "NOT LIKE".
// It may also be customized as something like "ILIKE".
Like string
}
// Escape specifies how a LIKE expression should be escaped.
// Each string at position 2i represents a special character and the string at position 2i+1 is
// the corresponding escaped version.
func (e *LikeExp) Escape(chars ...string) *LikeExp {
e.escape = chars
return e
}
// Match specifies whether to do wildcard matching on the left and/or right of given strings.
func (e *LikeExp) Match(left, right bool) *LikeExp {
e.left, e.right = left, right
return e
}
// Build converts an expression into a SQL fragment.
func (e *LikeExp) Build(db *DB, params Params) string {
if len(e.values) == 0 {
return ""
}
if len(e.escape)%2 != 0 {
panic("LikeExp.Escape must be a slice of even number of strings")
}
var parts []string
col := db.QuoteColumnName(e.col)
for _, value := range e.values {
name := fmt.Sprintf("p%v", len(params))
for i := 0; i < len(e.escape); i += 2 {
value = strings.Replace(value, e.escape[i], e.escape[i+1], -1)
}
if e.left {
value = "%" + value
}
if e.right {
value += "%"
}
params[name] = value
parts = append(parts, fmt.Sprintf("%v %v {:%v}", col, e.Like, name))
}
if e.or {
return strings.Join(parts, " OR ")
}
return strings.Join(parts, " AND ")
}
// ExistsExp represents an EXISTS or NOT EXISTS expression.
type ExistsExp struct {
exp Expression
not bool
}
// Build converts an expression into a SQL fragment.
func (e *ExistsExp) Build(db *DB, params Params) string {
sql := e.exp.Build(db, params)
if sql == "" {
if e.not {
return ""
}
return "0=1"
}
if e.not {
return "NOT EXISTS (" + sql + ")"
}
return "EXISTS (" + sql + ")"
}
// BetweenExp represents a BETWEEN or a NOT BETWEEN expression.
type BetweenExp struct {
col string
from, to interface{}
not bool
}
// Build converts an expression into a SQL fragment.
func (e *BetweenExp) Build(db *DB, params Params) string {
between := "BETWEEN"
if e.not {
between = "NOT BETWEEN"
}
name1 := fmt.Sprintf("p%v", len(params))
name2 := fmt.Sprintf("p%v", len(params)+1)
params[name1] = e.from
params[name2] = e.to
col := db.QuoteColumnName(e.col)
return fmt.Sprintf("%v %v {:%v} AND {:%v}", col, between, name1, name2)
}
// Enclose surrounds the provided nonempty expression with parenthesis "()".
func Enclose(exp Expression) Expression {
return &EncloseExp{exp}
}
// EncloseExp represents a parenthesis enclosed expression.
type EncloseExp struct {
exp Expression
}
// Build converts an expression into a SQL fragment.
func (e *EncloseExp) Build(db *DB, params Params) string {
str := e.exp.Build(db, params)
if str == "" {
return ""
}
return "(" + str + ")"
}

196
expression_test.go Normal file
View file

@ -0,0 +1,196 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestExp(t *testing.T) {
params := Params{"k2": "v2"}
e1 := NewExp("s1").(*Exp)
assert.Equal(t, e1.Build(nil, params), "s1", "e1.Build()")
assert.Equal(t, len(params), 1, `len(params)@1`)
e2 := NewExp("s2", Params{"k1": "v1"}).(*Exp)
assert.Equal(t, e2.Build(nil, params), "s2", "e2.Build()")
assert.Equal(t, len(params), 2, `len(params)@2`)
}
func TestHashExp(t *testing.T) {
e1 := HashExp{}
assert.Equal(t, e1.Build(nil, nil), "", `e1.Build()`)
e2 := HashExp{
"k1": nil,
"k2": NewExp("s1", Params{"ka": "va"}),
"k3": 1.1,
"k4": "abc",
"k5": []interface{}{1, 2},
}
db := getDB()
params := Params{"k0": "v0"}
expected := "`k1` IS NULL AND (s1) AND `k3`={:p2} AND `k4`={:p3} AND `k5` IN ({:p4}, {:p5})"
assert.Equal(t, e2.Build(db, params), expected, `e2.Build()`)
assert.Equal(t, len(params), 6, `len(params)`)
assert.Equal(t, params["p5"].(int), 2, `params["p5"]`)
}
func TestNotExp(t *testing.T) {
e1 := Not(NewExp("s1"))
assert.Equal(t, e1.Build(nil, nil), "NOT (s1)", `e1.Build()`)
e2 := Not(NewExp(""))
assert.Equal(t, e2.Build(nil, nil), "", `e2.Build()`)
}
func TestAndOrExp(t *testing.T) {
e1 := And(NewExp("s1", Params{"k1": "v1"}), NewExp(""), NewExp("s2", Params{"k2": "v2"}))
params := Params{}
assert.Equal(t, e1.Build(nil, params), "(s1) AND (s2)", `e1.Build()`)
assert.Equal(t, len(params), 2, `len(params)`)
e2 := Or(NewExp("s1"), NewExp("s2"))
assert.Equal(t, e2.Build(nil, nil), "(s1) OR (s2)", `e2.Build()`)
e3 := And()
assert.Equal(t, e3.Build(nil, nil), "", `e3.Build()`)
e4 := And(NewExp("s1"))
assert.Equal(t, e4.Build(nil, nil), "s1", `e4.Build()`)
e5 := And(NewExp("s1"), nil)
assert.Equal(t, e5.Build(nil, nil), "s1", `e5.Build()`)
}
func TestInExp(t *testing.T) {
db := getDB()
e1 := In("age", 1, 2, 3)
params := Params{}
assert.Equal(t, e1.Build(db, params), "`age` IN ({:p0}, {:p1}, {:p2})", `e1.Build()`)
assert.Equal(t, len(params), 3, `len(params)@1`)
e2 := In("age", 1)
params = Params{}
assert.Equal(t, e2.Build(db, params), "`age`={:p0}", `e2.Build()`)
assert.Equal(t, len(params), 1, `len(params)@2`)
e3 := NotIn("age", 1, 2, 3)
params = Params{}
assert.Equal(t, e3.Build(db, params), "`age` NOT IN ({:p0}, {:p1}, {:p2})", `e3.Build()`)
assert.Equal(t, len(params), 3, `len(params)@3`)
e4 := NotIn("age", 1)
params = Params{}
assert.Equal(t, e4.Build(db, params), "`age`<>{:p0}", `e4.Build()`)
assert.Equal(t, len(params), 1, `len(params)@4`)
e5 := In("age")
assert.Equal(t, e5.Build(db, nil), "0=1", `e5.Build()`)
e6 := NotIn("age")
assert.Equal(t, e6.Build(db, nil), "", `e6.Build()`)
}
func TestLikeExp(t *testing.T) {
db := getDB()
e1 := Like("name", "a", "b", "c")
params := Params{}
assert.Equal(t, e1.Build(db, params), "`name` LIKE {:p0} AND `name` LIKE {:p1} AND `name` LIKE {:p2}", `e1.Build()`)
assert.Equal(t, len(params), 3, `len(params)@1`)
e2 := Like("name", "a")
params = Params{}
assert.Equal(t, e2.Build(db, params), "`name` LIKE {:p0}", `e2.Build()`)
assert.Equal(t, len(params), 1, `len(params)@2`)
e3 := Like("name")
assert.Equal(t, e3.Build(db, nil), "", `e3.Build()`)
e4 := NotLike("name", "a", "b", "c")
params = Params{}
assert.Equal(t, e4.Build(db, params), "`name` NOT LIKE {:p0} AND `name` NOT LIKE {:p1} AND `name` NOT LIKE {:p2}", `e4.Build()`)
assert.Equal(t, len(params), 3, `len(params)@4`)
e5 := OrLike("name", "a", "b", "c")
params = Params{}
assert.Equal(t, e5.Build(db, params), "`name` LIKE {:p0} OR `name` LIKE {:p1} OR `name` LIKE {:p2}", `e5.Build()`)
assert.Equal(t, len(params), 3, `len(params)@5`)
e6 := OrNotLike("name", "a", "b", "c")
params = Params{}
assert.Equal(t, e6.Build(db, params), "`name` NOT LIKE {:p0} OR `name` NOT LIKE {:p1} OR `name` NOT LIKE {:p2}", `e6.Build()`)
assert.Equal(t, len(params), 3, `len(params)@6`)
e7 := Like("name", "a_\\%")
params = Params{}
e7.Build(db, params)
assert.Equal(t, params["p0"], "%a\\_\\\\\\%%", `params["p0"]@1`)
e8 := Like("name", "a").Match(false, true)
params = Params{}
e8.Build(db, params)
assert.Equal(t, params["p0"], "a%", `params["p0"]@2`)
e9 := Like("name", "a").Match(true, false)
params = Params{}
e9.Build(db, params)
assert.Equal(t, params["p0"], "%a", `params["p0"]@3`)
e10 := Like("name", "a").Match(false, false)
params = Params{}
e10.Build(db, params)
assert.Equal(t, params["p0"], "a", `params["p0"]@4`)
e11 := Like("name", "%a").Match(false, false).Escape()
params = Params{}
e11.Build(db, params)
assert.Equal(t, params["p0"], "%a", `params["p0"]@5`)
}
func TestBetweenExp(t *testing.T) {
db := getDB()
e1 := Between("age", 30, 40)
params := Params{}
assert.Equal(t, e1.Build(db, params), "`age` BETWEEN {:p0} AND {:p1}", `e1.Build()`)
assert.Equal(t, len(params), 2, `len(params)@1`)
e2 := NotBetween("age", 30, 40)
params = Params{}
assert.Equal(t, e2.Build(db, params), "`age` NOT BETWEEN {:p0} AND {:p1}", `e2.Build()`)
assert.Equal(t, len(params), 2, `len(params)@2`)
}
func TestExistsExp(t *testing.T) {
e1 := Exists(NewExp("s1"))
assert.Equal(t, e1.Build(nil, nil), "EXISTS (s1)", `e1.Build()`)
e2 := NotExists(NewExp("s1"))
assert.Equal(t, e2.Build(nil, nil), "NOT EXISTS (s1)", `e2.Build()`)
e3 := Exists(NewExp(""))
assert.Equal(t, e3.Build(nil, nil), "0=1", `e3.Build()`)
e4 := NotExists(NewExp(""))
assert.Equal(t, e4.Build(nil, nil), "", `e4.Build()`)
}
func TestEncloseExp(t *testing.T) {
e1 := Enclose(NewExp(""))
assert.Equal(t, e1.Build(nil, nil), "", `e1.Build()`)
e2 := Enclose(NewExp("s1"))
assert.Equal(t, e2.Build(nil, nil), "(s1)", `e2.Build()`)
e3 := Enclose(NewExp("(s1)"))
assert.Equal(t, e3.Build(nil, nil), "((s1))", `e3.Build()`)
}

9
go.mod Normal file
View file

@ -0,0 +1,9 @@
module github.com/pocketbase/dbx
go 1.13
require (
github.com/go-sql-driver/mysql v1.4.1
github.com/stretchr/testify v1.4.0
google.golang.org/appengine v1.6.5 // indirect
)

22
go.sum Normal file
View file

@ -0,0 +1,22 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA=
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

174
model_query.go Normal file
View file

@ -0,0 +1,174 @@
package dbx
import (
"context"
"errors"
"fmt"
"reflect"
)
type (
// TableModel is the interface that should be implemented by models which have unconventional table names.
TableModel interface {
TableName() string
}
// ModelQuery represents a query associated with a struct model.
ModelQuery struct {
db *DB
ctx context.Context
builder Builder
model *structValue
exclude []string
lastError error
}
)
var (
MissingPKError = errors.New("missing primary key declaration")
CompositePKError = errors.New("composite primary key is not supported")
)
func NewModelQuery(model interface{}, fieldMapFunc FieldMapFunc, db *DB, builder Builder) *ModelQuery {
q := &ModelQuery{
db: db,
ctx: db.ctx,
builder: builder,
model: newStructValue(model, fieldMapFunc, db.TableMapper),
}
if q.model == nil {
q.lastError = VarTypeError("must be a pointer to a struct representing the model")
}
return q
}
// Context returns the context associated with the query.
func (q *ModelQuery) Context() context.Context {
return q.ctx
}
// WithContext associates a context with the query.
func (q *ModelQuery) WithContext(ctx context.Context) *ModelQuery {
q.ctx = ctx
return q
}
// Exclude excludes the specified struct fields from being inserted/updated into the DB table.
func (q *ModelQuery) Exclude(attrs ...string) *ModelQuery {
q.exclude = attrs
return q
}
// Insert inserts a row in the table using the struct model associated with this query.
//
// By default, it inserts *all* public fields into the table, including those nil or empty ones.
// You may pass a list of the fields to this method to indicate that only those fields should be inserted.
// You may also call Exclude to exclude some fields from being inserted.
//
// If a model has an empty primary key, it is considered auto-incremental and the corresponding struct
// field will be filled with the generated primary key value after a successful insertion.
func (q *ModelQuery) Insert(attrs ...string) error {
if q.lastError != nil {
return q.lastError
}
cols := q.model.columns(attrs, q.exclude)
pkName := ""
for name, value := range q.model.pk() {
if isAutoInc(value) {
delete(cols, name)
pkName = name
break
}
}
if pkName == "" {
_, err := q.builder.Insert(q.model.tableName, Params(cols)).WithContext(q.ctx).Execute()
return err
}
// handle auto-incremental PK
query := q.builder.Insert(q.model.tableName, Params(cols)).WithContext(q.ctx)
pkValue, err := insertAndReturnPK(q.db, query, pkName)
if err != nil {
return err
}
pkField := indirect(q.model.dbNameMap[pkName].getField(q.model.value))
switch pkField.Kind() {
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
pkField.SetUint(uint64(pkValue))
default:
pkField.SetInt(pkValue)
}
return nil
}
func insertAndReturnPK(db *DB, query *Query, pkName string) (int64, error) {
if db.DriverName() != "postgres" && db.DriverName() != "pgx" {
result, err := query.Execute()
if err != nil {
return 0, err
}
return result.LastInsertId()
}
// specially handle postgres (lib/pq) as it doesn't support LastInsertId
returning := fmt.Sprintf(" RETURNING %s", db.QuoteColumnName(pkName))
query.sql += returning
query.rawSQL += returning
var pkValue int64
err := query.Row(&pkValue)
return pkValue, err
}
func isAutoInc(value interface{}) bool {
v := reflect.ValueOf(value)
switch v.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Ptr:
return v.IsNil() || isAutoInc(v.Elem())
case reflect.Invalid:
return true
}
return false
}
// Update updates a row in the table using the struct model associated with this query.
// The row being updated has the same primary key as specified by the model.
//
// By default, it updates *all* public fields in the table, including those nil or empty ones.
// You may pass a list of the fields to this method to indicate that only those fields should be updated.
// You may also call Exclude to exclude some fields from being updated.
func (q *ModelQuery) Update(attrs ...string) error {
if q.lastError != nil {
return q.lastError
}
pk := q.model.pk()
if len(pk) == 0 {
return MissingPKError
}
cols := q.model.columns(attrs, q.exclude)
for name := range pk {
delete(cols, name)
}
_, err := q.builder.Update(q.model.tableName, Params(cols), HashExp(pk)).WithContext(q.ctx).Execute()
return err
}
// Delete deletes a row in the table using the primary key specified by the struct model associated with this query.
func (q *ModelQuery) Delete() error {
if q.lastError != nil {
return q.lastError
}
pk := q.model.pk()
if len(pk) == 0 {
return MissingPKError
}
_, err := q.builder.Delete(q.model.tableName, HashExp(pk)).WithContext(q.ctx).Execute()
return err
}

238
model_query_test.go Normal file
View file

@ -0,0 +1,238 @@
package dbx
import (
"database/sql"
"testing"
"github.com/stretchr/testify/assert"
)
type Item struct {
ID2 int
Name string
}
func TestModelQuery_Insert(t *testing.T) {
db := getPreparedDB()
defer db.Close()
name := "test"
email := "test@example.com"
{
// inserting normally
customer := Customer{
Name: name,
Email: email,
}
err := db.Model(&customer).Insert()
if assert.Nil(t, err) {
assert.Equal(t, 4, customer.ID)
var c Customer
db.Select().From("customer").Where(HashExp{"ID": 4}).One(&c)
assert.Equal(t, name, c.Name)
assert.Equal(t, email, c.Email)
assert.Equal(t, 0, c.Status)
assert.False(t, c.Address.Valid)
}
}
{
// inserting with pointer-typed fields
customer := CustomerPtr{
Name: name,
Email: &email,
}
err := db.Model(&customer).Insert()
if assert.Nil(t, err) && assert.NotNil(t, customer.ID) {
assert.Equal(t, 5, *customer.ID)
var c CustomerPtr
db.Select().From("customer").Where(HashExp{"ID": 4}).One(&c)
assert.Equal(t, name, c.Name)
if assert.NotNil(t, c.Email) {
assert.Equal(t, email, *c.Email)
}
if assert.NotNil(t, c.Status) {
assert.Equal(t, 0, *c.Status)
}
assert.Nil(t, c.Address)
}
}
{
// inserting with null-typed fields
customer := CustomerNull{
Name: name,
Email: sql.NullString{email, true},
}
err := db.Model(&customer).Insert()
if assert.Nil(t, err) {
// potential todo: need to check if the field implements sql.Scanner
// assert.Equal(t, int64(6), customer.ID.Int64)
var c CustomerNull
db.Select().From("customer").Where(HashExp{"ID": 4}).One(&c)
assert.Equal(t, name, c.Name)
assert.Equal(t, email, c.Email.String)
if assert.NotNil(t, c.Status) {
assert.Equal(t, int64(0), c.Status.Int64)
}
assert.False(t, c.Address.Valid)
}
}
{
// inserting with embedded structures
customer := CustomerEmbedded{
Id: 100,
Email: &email,
InnerCustomer: InnerCustomer{
Name: &name,
Status: sql.NullInt64{1, true},
},
}
err := db.Model(&customer).Insert()
if assert.Nil(t, err) {
assert.Equal(t, 100, customer.Id)
var c CustomerEmbedded
db.Select().From("customer").Where(HashExp{"ID": 100}).One(&c)
assert.Equal(t, name, *c.Name)
assert.Equal(t, email, *c.Email)
if assert.NotNil(t, c.Status) {
assert.Equal(t, int64(1), c.Status.Int64)
}
assert.False(t, c.Address.Valid)
}
}
{
// inserting with include/exclude fields
customer := Customer{
Name: name,
Email: email,
Status: 1,
}
err := db.Model(&customer).Exclude("Name").Insert("Name", "Email")
if assert.Nil(t, err) {
assert.Equal(t, 101, customer.ID)
var c Customer
db.Select().From("customer").Where(HashExp{"ID": 101}).One(&c)
assert.Equal(t, "", c.Name)
assert.Equal(t, email, c.Email)
assert.Equal(t, 0, c.Status)
assert.False(t, c.Address.Valid)
}
}
var a int
assert.NotNil(t, db.Model(&a).Insert())
}
func TestModelQuery_Update(t *testing.T) {
db := getPreparedDB()
defer db.Close()
id := 2
name := "test"
email := "test@example.com"
{
// updating normally
customer := Customer{
ID: id,
Name: name,
Email: email,
}
err := db.Model(&customer).Update()
if assert.Nil(t, err) {
var c Customer
db.Select().From("customer").Where(HashExp{"ID": id}).One(&c)
assert.Equal(t, name, c.Name)
assert.Equal(t, email, c.Email)
assert.Equal(t, 0, c.Status)
}
}
{
// updating without primary keys
item2 := Item{
Name: name,
}
err := db.Model(&item2).Update()
assert.Equal(t, MissingPKError, err)
}
{
// updating all fields
customer := CustomerPtr{
ID: &id,
Name: name,
Email: &email,
}
err := db.Model(&customer).Update()
if assert.Nil(t, err) {
assert.Equal(t, id, *customer.ID)
var c CustomerPtr
db.Select().From("customer").Where(HashExp{"ID": id}).One(&c)
assert.Equal(t, name, c.Name)
if assert.NotNil(t, c.Email) {
assert.Equal(t, email, *c.Email)
}
assert.Nil(t, c.Status)
}
}
{
// updating selected fields only
id = 3
customer := CustomerPtr{
ID: &id,
Name: name,
Email: &email,
}
err := db.Model(&customer).Update("Name", "Email")
if assert.Nil(t, err) {
assert.Equal(t, id, *customer.ID)
var c CustomerPtr
db.Select().From("customer").Where(HashExp{"ID": id}).One(&c)
assert.Equal(t, name, c.Name)
if assert.NotNil(t, c.Email) {
assert.Equal(t, email, *c.Email)
}
if assert.NotNil(t, c.Status) {
assert.Equal(t, 2, *c.Status)
}
}
}
{
// updating non-struct
var a int
assert.NotNil(t, db.Model(&a).Update())
}
}
func TestModelQuery_Delete(t *testing.T) {
db := getPreparedDB()
defer db.Close()
customer := Customer{
ID: 2,
}
err := db.Model(&customer).Delete()
if assert.Nil(t, err) {
var m Customer
err := db.Select().From("customer").Where(HashExp{"ID": 2}).One(&m)
assert.NotNil(t, err)
}
{
// deleting without primary keys
item2 := Item{
Name: "",
}
err := db.Model(&item2).Delete()
assert.Equal(t, MissingPKError, err)
}
var a int
assert.NotNil(t, db.Model(&a).Delete())
}

384
query.go Normal file
View file

@ -0,0 +1,384 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/hex"
"errors"
"fmt"
"strings"
"time"
)
// ExecHookFunc executes before op allowing custom handling like auto fail/retry.
type ExecHookFunc func(q *Query, op func() error) error
// OneHookFunc executes right before the query populate the row result from One() call (aka. op).
type OneHookFunc func(q *Query, a interface{}, op func(b interface{}) error) error
// AllHookFunc executes right before the query populate the row result from All() call (aka. op).
type AllHookFunc func(q *Query, sliceA interface{}, op func(sliceB interface{}) error) error
// Params represents a list of parameter values to be bound to a SQL statement.
// The map keys are the parameter names while the map values are the corresponding parameter values.
type Params map[string]interface{}
// Executor prepares, executes, or queries a SQL statement.
type Executor interface {
// Exec executes a SQL statement
Exec(query string, args ...interface{}) (sql.Result, error)
// ExecContext executes a SQL statement with the given context
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
// Query queries a SQL statement
Query(query string, args ...interface{}) (*sql.Rows, error)
// QueryContext queries a SQL statement with the given context
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
// Prepare creates a prepared statement
Prepare(query string) (*sql.Stmt, error)
}
// Query represents a SQL statement to be executed.
type Query struct {
executor Executor
sql, rawSQL string
placeholders []string
params Params
stmt *sql.Stmt
ctx context.Context
// hooks
execHook ExecHookFunc
oneHook OneHookFunc
allHook AllHookFunc
// FieldMapper maps struct field names to DB column names.
FieldMapper FieldMapFunc
// LastError contains the last error (if any) of the query.
// LastError is cleared by Execute(), Row(), Rows(), One(), and All().
LastError error
// LogFunc is used to log the SQL statement being executed.
LogFunc LogFunc
// PerfFunc is used to log the SQL execution time. It is ignored if nil.
// Deprecated: Please use QueryLogFunc and ExecLogFunc instead.
PerfFunc PerfFunc
// QueryLogFunc is called each time when performing a SQL query that returns data.
QueryLogFunc QueryLogFunc
// ExecLogFunc is called each time when a SQL statement is executed.
ExecLogFunc ExecLogFunc
}
// NewQuery creates a new Query with the given SQL statement.
func NewQuery(db *DB, executor Executor, sql string) *Query {
rawSQL, placeholders := db.processSQL(sql)
return &Query{
executor: executor,
sql: sql,
rawSQL: rawSQL,
placeholders: placeholders,
params: Params{},
ctx: db.ctx,
FieldMapper: db.FieldMapper,
LogFunc: db.LogFunc,
PerfFunc: db.PerfFunc,
QueryLogFunc: db.QueryLogFunc,
ExecLogFunc: db.ExecLogFunc,
}
}
// SQL returns the original SQL used to create the query.
// The actual SQL (RawSQL) being executed is obtained by replacing the named
// parameter placeholders with anonymous ones.
func (q *Query) SQL() string {
return q.sql
}
// Context returns the context associated with the query.
func (q *Query) Context() context.Context {
return q.ctx
}
// WithContext associates a context with the query.
func (q *Query) WithContext(ctx context.Context) *Query {
q.ctx = ctx
return q
}
// WithExecHook associates the provided exec hook function with the query.
//
// It is called for every Query resolver (Execute(), One(), All(), Row(), Column()),
// allowing you to implement auto fail/retry or any other additional handling.
func (q *Query) WithExecHook(fn ExecHookFunc) *Query {
q.execHook = fn
return q
}
// WithOneHook associates the provided hook function with the query,
// called on q.One(), allowing you to implement custom struct scan based
// on the One() argument and/or result.
func (q *Query) WithOneHook(fn OneHookFunc) *Query {
q.oneHook = fn
return q
}
// WithOneHook associates the provided hook function with the query,
// called on q.All(), allowing you to implement custom slice scan based
// on the All() argument and/or result.
func (q *Query) WithAllHook(fn AllHookFunc) *Query {
q.allHook = fn
return q
}
// logSQL returns the SQL statement with parameters being replaced with the actual values.
// The result is only for logging purpose and should not be used to execute.
func (q *Query) logSQL() string {
s := q.sql
for k, v := range q.params {
if valuer, ok := v.(driver.Valuer); ok && valuer != nil {
v, _ = valuer.Value()
}
var sv string
if str, ok := v.(string); ok {
sv = "'" + strings.Replace(str, "'", "''", -1) + "'"
} else if bs, ok := v.([]byte); ok {
sv = "0x" + hex.EncodeToString(bs)
} else {
sv = fmt.Sprintf("%v", v)
}
s = strings.Replace(s, "{:"+k+"}", sv, -1)
}
return s
}
// Params returns the parameters to be bound to the SQL statement represented by this query.
func (q *Query) Params() Params {
return q.params
}
// Prepare creates a prepared statement for later queries or executions.
// Close() should be called after finishing all queries.
func (q *Query) Prepare() *Query {
stmt, err := q.executor.Prepare(q.rawSQL)
if err != nil {
q.LastError = err
return q
}
q.stmt = stmt
return q
}
// Close closes the underlying prepared statement.
// Close does nothing if the query has not been prepared before.
func (q *Query) Close() error {
if q.stmt == nil {
return nil
}
err := q.stmt.Close()
q.stmt = nil
return err
}
// Bind sets the parameters that should be bound to the SQL statement.
// The parameter placeholders in the SQL statement are in the format of "{:ParamName}".
func (q *Query) Bind(params Params) *Query {
if len(q.params) == 0 {
q.params = params
} else {
for k, v := range params {
q.params[k] = v
}
}
return q
}
// Execute executes the SQL statement without retrieving data.
func (q *Query) Execute() (sql.Result, error) {
var result sql.Result
execErr := q.execWrap(func() error {
var err error
result, err = q.execute()
return err
})
return result, execErr
}
func (q *Query) execute() (result sql.Result, err error) {
err = q.LastError
q.LastError = nil
if err != nil {
return
}
var params []interface{}
params, err = replacePlaceholders(q.placeholders, q.params)
if err != nil {
return
}
start := time.Now()
if q.ctx == nil {
if q.stmt == nil {
result, err = q.executor.Exec(q.rawSQL, params...)
} else {
result, err = q.stmt.Exec(params...)
}
} else {
if q.stmt == nil {
result, err = q.executor.ExecContext(q.ctx, q.rawSQL, params...)
} else {
result, err = q.stmt.ExecContext(q.ctx, params...)
}
}
if q.ExecLogFunc != nil {
q.ExecLogFunc(q.ctx, time.Now().Sub(start), q.logSQL(), result, err)
}
if q.LogFunc != nil {
q.LogFunc("[%.2fms] Execute SQL: %v", float64(time.Now().Sub(start).Milliseconds()), q.logSQL())
}
if q.PerfFunc != nil {
q.PerfFunc(time.Now().Sub(start).Nanoseconds(), q.logSQL(), true)
}
return
}
// One executes the SQL statement and populates the first row of the result into a struct or NullStringMap.
// Refer to Rows.ScanStruct() and Rows.ScanMap() for more details on how to specify
// the variable to be populated.
// Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned.
func (q *Query) One(a interface{}) error {
return q.execWrap(func() error {
rows, err := q.Rows()
if err != nil {
return err
}
if q.oneHook != nil {
return q.oneHook(q, a, rows.one)
}
return rows.one(a)
})
}
// All executes the SQL statement and populates all the resulting rows into a slice of struct or NullStringMap.
// The slice must be given as a pointer. Each slice element must be either a struct or a NullStringMap.
// Refer to Rows.ScanStruct() and Rows.ScanMap() for more details on how each slice element can be.
// If the query returns no row, the slice will be an empty slice (not nil).
func (q *Query) All(slice interface{}) error {
return q.execWrap(func() error {
rows, err := q.Rows()
if err != nil {
return err
}
if q.allHook != nil {
return q.allHook(q, slice, rows.all)
}
return rows.all(slice)
})
}
// Row executes the SQL statement and populates the first row of the result into a list of variables.
// Note that the number of the variables should match to that of the columns in the query result.
// Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned.
func (q *Query) Row(a ...interface{}) error {
return q.execWrap(func() error {
rows, err := q.Rows()
if err != nil {
return err
}
return rows.row(a...)
})
}
// Column executes the SQL statement and populates the first column of the result into a slice.
// Note that the parameter must be a pointer to a slice.
func (q *Query) Column(a interface{}) error {
return q.execWrap(func() error {
rows, err := q.Rows()
if err != nil {
return err
}
return rows.column(a)
})
}
// Rows executes the SQL statement and returns a Rows object to allow retrieving data row by row.
func (q *Query) Rows() (rows *Rows, err error) {
err = q.LastError
q.LastError = nil
if err != nil {
return
}
var params []interface{}
params, err = replacePlaceholders(q.placeholders, q.params)
if err != nil {
return
}
start := time.Now()
var rr *sql.Rows
if q.ctx == nil {
if q.stmt == nil {
rr, err = q.executor.Query(q.rawSQL, params...)
} else {
rr, err = q.stmt.Query(params...)
}
} else {
if q.stmt == nil {
rr, err = q.executor.QueryContext(q.ctx, q.rawSQL, params...)
} else {
rr, err = q.stmt.QueryContext(q.ctx, params...)
}
}
rows = &Rows{rr, q.FieldMapper}
if q.QueryLogFunc != nil {
q.QueryLogFunc(q.ctx, time.Now().Sub(start), q.logSQL(), rr, err)
}
if q.LogFunc != nil {
q.LogFunc("[%.2fms] Query SQL: %v", float64(time.Now().Sub(start).Milliseconds()), q.logSQL())
}
if q.PerfFunc != nil {
q.PerfFunc(time.Now().Sub(start).Nanoseconds(), q.logSQL(), false)
}
return
}
func (q *Query) execWrap(op func() error) error {
if q.execHook != nil {
return q.execHook(q, op)
}
return op()
}
// replacePlaceholders converts a list of named parameters into a list of anonymous parameters.
func replacePlaceholders(placeholders []string, params Params) ([]interface{}, error) {
if len(placeholders) == 0 {
return nil, nil
}
var result []interface{}
for _, name := range placeholders {
if value, ok := params[name]; ok {
result = append(result, value)
} else {
return nil, errors.New("Named parameter not found: " + name)
}
}
return result, nil
}

244
query_builder.go Normal file
View file

@ -0,0 +1,244 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"bytes"
"fmt"
"regexp"
"strings"
)
// QueryBuilder builds different clauses for a SELECT SQL statement.
type QueryBuilder interface {
// BuildSelect generates a SELECT clause from the given selected column names.
BuildSelect(cols []string, distinct bool, option string) string
// BuildFrom generates a FROM clause from the given tables.
BuildFrom(tables []string) string
// BuildGroupBy generates a GROUP BY clause from the given group-by columns.
BuildGroupBy(cols []string) string
// BuildJoin generates a JOIN clause from the given join information.
BuildJoin([]JoinInfo, Params) string
// BuildWhere generates a WHERE clause from the given expression.
BuildWhere(Expression, Params) string
// BuildHaving generates a HAVING clause from the given expression.
BuildHaving(Expression, Params) string
// BuildOrderByAndLimit generates the ORDER BY and LIMIT clauses.
BuildOrderByAndLimit(string, []string, int64, int64) string
// BuildUnion generates a UNION clause from the given union information.
BuildUnion([]UnionInfo, Params) string
}
// BaseQueryBuilder provides a basic implementation of QueryBuilder.
type BaseQueryBuilder struct {
db *DB
}
var _ QueryBuilder = &BaseQueryBuilder{}
// NewBaseQueryBuilder creates a new BaseQueryBuilder instance.
func NewBaseQueryBuilder(db *DB) *BaseQueryBuilder {
return &BaseQueryBuilder{db}
}
// DB returns the DB instance associated with the query builder.
func (q *BaseQueryBuilder) DB() *DB {
return q.db
}
// the regexp for columns and tables.
var selectRegex = regexp.MustCompile(`(?i:\s+as\s+|\s+)([\w\-_\.]+)$`)
// BuildSelect generates a SELECT clause from the given selected column names.
func (q *BaseQueryBuilder) BuildSelect(cols []string, distinct bool, option string) string {
var s bytes.Buffer
s.WriteString("SELECT ")
if distinct {
s.WriteString("DISTINCT ")
}
if option != "" {
s.WriteString(option)
s.WriteString(" ")
}
if len(cols) == 0 {
s.WriteString("*")
return s.String()
}
for i, col := range cols {
if i > 0 {
s.WriteString(", ")
}
matches := selectRegex.FindStringSubmatch(col)
if len(matches) == 0 {
s.WriteString(q.db.QuoteColumnName(col))
} else {
col := col[:len(col)-len(matches[0])]
alias := matches[1]
s.WriteString(q.db.QuoteColumnName(col) + " AS " + q.db.QuoteSimpleColumnName(alias))
}
}
return s.String()
}
// BuildFrom generates a FROM clause from the given tables.
func (q *BaseQueryBuilder) BuildFrom(tables []string) string {
if len(tables) == 0 {
return ""
}
s := ""
for _, table := range tables {
table = q.quoteTableNameAndAlias(table)
if s == "" {
s = table
} else {
s += ", " + table
}
}
return "FROM " + s
}
// BuildJoin generates a JOIN clause from the given join information.
func (q *BaseQueryBuilder) BuildJoin(joins []JoinInfo, params Params) string {
if len(joins) == 0 {
return ""
}
parts := []string{}
for _, join := range joins {
sql := join.Join + " " + q.quoteTableNameAndAlias(join.Table)
on := ""
if join.On != nil {
on = join.On.Build(q.db, params)
}
if on != "" {
sql += " ON " + on
}
parts = append(parts, sql)
}
return strings.Join(parts, " ")
}
// BuildWhere generates a WHERE clause from the given expression.
func (q *BaseQueryBuilder) BuildWhere(e Expression, params Params) string {
if e != nil {
if c := e.Build(q.db, params); c != "" {
return "WHERE " + c
}
}
return ""
}
// BuildHaving generates a HAVING clause from the given expression.
func (q *BaseQueryBuilder) BuildHaving(e Expression, params Params) string {
if e != nil {
if c := e.Build(q.db, params); c != "" {
return "HAVING " + c
}
}
return ""
}
// BuildGroupBy generates a GROUP BY clause from the given group-by columns.
func (q *BaseQueryBuilder) BuildGroupBy(cols []string) string {
if len(cols) == 0 {
return ""
}
s := ""
for i, col := range cols {
if i == 0 {
s = q.db.QuoteColumnName(col)
} else {
s += ", " + q.db.QuoteColumnName(col)
}
}
return "GROUP BY " + s
}
// BuildOrderByAndLimit generates the ORDER BY and LIMIT clauses.
func (q *BaseQueryBuilder) BuildOrderByAndLimit(sql string, cols []string, limit int64, offset int64) string {
if orderBy := q.BuildOrderBy(cols); orderBy != "" {
sql += " " + orderBy
}
if limit := q.BuildLimit(limit, offset); limit != "" {
return sql + " " + limit
}
return sql
}
// BuildUnion generates a UNION clause from the given union information.
func (q *BaseQueryBuilder) BuildUnion(unions []UnionInfo, params Params) string {
if len(unions) == 0 {
return ""
}
sql := ""
for i, union := range unions {
if i > 0 {
sql += " "
}
for k, v := range union.Query.params {
params[k] = v
}
u := "UNION"
if union.All {
u = "UNION ALL"
}
sql += fmt.Sprintf("%v (%v)", u, union.Query.sql)
}
return sql
}
var orderRegex = regexp.MustCompile(`\s+((?i)ASC|DESC)$`)
// BuildOrderBy generates the ORDER BY clause.
func (q *BaseQueryBuilder) BuildOrderBy(cols []string) string {
if len(cols) == 0 {
return ""
}
s := ""
for i, col := range cols {
if i > 0 {
s += ", "
}
matches := orderRegex.FindStringSubmatch(col)
if len(matches) == 0 {
s += q.db.QuoteColumnName(col)
} else {
col := col[:len(col)-len(matches[0])]
dir := matches[1]
s += q.db.QuoteColumnName(col) + " " + dir
}
}
return "ORDER BY " + s
}
// BuildLimit generates the LIMIT clause.
func (q *BaseQueryBuilder) BuildLimit(limit int64, offset int64) string {
if limit < 0 && offset > 0 {
// most DBMS requires LIMIT when OFFSET is present
limit = 9223372036854775807 // 2^63 - 1
}
sql := ""
if limit >= 0 {
sql = fmt.Sprintf("LIMIT %v", limit)
}
if offset <= 0 {
return sql
}
if sql != "" {
sql += " "
}
return sql + fmt.Sprintf("OFFSET %v", offset)
}
func (q *BaseQueryBuilder) quoteTableNameAndAlias(table string) string {
matches := selectRegex.FindStringSubmatch(table)
if len(matches) == 0 {
return q.db.QuoteTableName(table)
}
table = table[:len(table)-len(matches[0])]
return q.db.QuoteTableName(table) + " " + q.db.QuoteSimpleTableName(matches[1])
}

228
query_builder_test.go Normal file
View file

@ -0,0 +1,228 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestQB_BuildSelect(t *testing.T) {
tests := []struct {
tag string
cols []string
distinct bool
option string
expected string
}{
{"empty", []string{}, false, "", "SELECT *"},
{"empty distinct", []string{}, true, "CALC_ROWS", "SELECT DISTINCT CALC_ROWS *"},
{"multi-columns", []string{"name", "DOB1"}, false, "", "SELECT `name`, `DOB1`"},
{"aliased columns", []string{"name As Name", "users.last_name", "u.first1 first"}, false, "", "SELECT `name` AS `Name`, `users`.`last_name`, `u`.`first1` AS `first`"},
}
db := getDB()
qb := db.QueryBuilder()
for _, test := range tests {
s := qb.BuildSelect(test.cols, test.distinct, test.option)
assert.Equal(t, test.expected, s, test.tag)
}
assert.Equal(t, qb.(*BaseQueryBuilder).DB(), db)
}
func TestQB_BuildFrom(t *testing.T) {
tests := []struct {
tag string
tables []string
expected string
}{
{"empty", []string{}, ""},
{"single table", []string{"users"}, "FROM `users`"},
{"multiple tables", []string{"users", "posts"}, "FROM `users`, `posts`"},
{"table alias", []string{"users u", "posts as p"}, "FROM `users` `u`, `posts` `p`"},
{"table prefix and alias", []string{"pub.users p.u", "posts AS p1"}, "FROM `pub`.`users` `p.u`, `posts` `p1`"},
}
qb := getDB().QueryBuilder()
for _, test := range tests {
s := qb.BuildFrom(test.tables)
assert.Equal(t, test.expected, s, test.tag)
}
}
func TestQB_BuildGroupBy(t *testing.T) {
tests := []struct {
tag string
cols []string
expected string
}{
{"empty", []string{}, ""},
{"single column", []string{"name"}, "GROUP BY `name`"},
{"multiple columns", []string{"name", "age"}, "GROUP BY `name`, `age`"},
}
qb := getDB().QueryBuilder()
for _, test := range tests {
s := qb.BuildGroupBy(test.cols)
assert.Equal(t, test.expected, s, test.tag)
}
}
func TestQB_BuildWhere(t *testing.T) {
tests := []struct {
exp Expression
expected string
count int
tag string
}{
{HashExp{"age": 30, "dept": "marketing"}, "WHERE `age`={:p0} AND `dept`={:p1}", 2, "t1"},
{nil, "", 0, "t2"},
{NewExp(""), "", 0, "t3"},
}
qb := getDB().QueryBuilder()
for _, test := range tests {
params := Params{}
s := qb.BuildWhere(test.exp, params)
assert.Equal(t, test.expected, s, test.tag)
assert.Equal(t, test.count, len(params), test.tag)
}
}
func TestQB_BuildHaving(t *testing.T) {
tests := []struct {
exp Expression
expected string
count int
tag string
}{
{HashExp{"age": 30, "dept": "marketing"}, "HAVING `age`={:p0} AND `dept`={:p1}", 2, "t1"},
{nil, "", 0, "t2"},
{NewExp(""), "", 0, "t3"},
}
qb := getDB().QueryBuilder()
for _, test := range tests {
params := Params{}
s := qb.BuildHaving(test.exp, params)
assert.Equal(t, test.expected, s, test.tag)
assert.Equal(t, test.count, len(params), test.tag)
}
}
func TestQB_BuildOrderBy(t *testing.T) {
tests := []struct {
tag string
cols []string
expected string
}{
{"empty", []string{}, ""},
{"single column", []string{"name"}, "ORDER BY `name`"},
{"multiple columns", []string{"name ASC", "age DESC", "id desc"}, "ORDER BY `name` ASC, `age` DESC, `id` desc"},
}
qb := getDB().QueryBuilder().(*BaseQueryBuilder)
for _, test := range tests {
s := qb.BuildOrderBy(test.cols)
assert.Equal(t, test.expected, s, test.tag)
}
}
func TestQB_BuildLimit(t *testing.T) {
tests := []struct {
tag string
limit, offset int64
expected string
}{
{"t1", 10, -1, "LIMIT 10"},
{"t2", 10, 0, "LIMIT 10"},
{"t3", 10, 2, "LIMIT 10 OFFSET 2"},
{"t4", 0, 2, "LIMIT 0 OFFSET 2"},
{"t5", -1, 2, "LIMIT 9223372036854775807 OFFSET 2"},
{"t6", -1, 0, ""},
}
qb := getDB().QueryBuilder().(*BaseQueryBuilder)
for _, test := range tests {
s := qb.BuildLimit(test.limit, test.offset)
assert.Equal(t, test.expected, s, test.tag)
}
}
func TestQB_BuildOrderByAndLimit(t *testing.T) {
qb := getDB().QueryBuilder()
sql := qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, 10, 2)
expected := "SELECT * ORDER BY `name` LIMIT 10 OFFSET 2"
assert.Equal(t, sql, expected, "t1")
sql = qb.BuildOrderByAndLimit("SELECT *", nil, -1, -1)
expected = "SELECT *"
assert.Equal(t, sql, expected, "t2")
sql = qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, -1, -1)
expected = "SELECT * ORDER BY `name`"
assert.Equal(t, sql, expected, "t3")
sql = qb.BuildOrderByAndLimit("SELECT *", nil, 10, -1)
expected = "SELECT * LIMIT 10"
assert.Equal(t, sql, expected, "t4")
}
func TestQB_BuildJoin(t *testing.T) {
qb := getDB().QueryBuilder()
params := Params{}
ji := JoinInfo{"LEFT JOIN", "users u", NewExp("id=u.id", Params{"id": 1})}
sql := qb.BuildJoin([]JoinInfo{ji}, params)
expected := "LEFT JOIN `users` `u` ON id=u.id"
assert.Equal(t, sql, expected, "BuildJoin@1")
assert.Equal(t, len(params), 1, "len(params)@1")
params = Params{}
ji = JoinInfo{"INNER JOIN", "users", nil}
sql = qb.BuildJoin([]JoinInfo{ji}, params)
expected = "INNER JOIN `users`"
assert.Equal(t, sql, expected, "BuildJoin@2")
assert.Equal(t, len(params), 0, "len(params)@2")
sql = qb.BuildJoin([]JoinInfo{}, nil)
expected = ""
assert.Equal(t, sql, expected, "BuildJoin@3")
ji = JoinInfo{"INNER JOIN", "users", nil}
ji2 := JoinInfo{"LEFT JOIN", "posts", nil}
sql = qb.BuildJoin([]JoinInfo{ji, ji2}, nil)
expected = "INNER JOIN `users` LEFT JOIN `posts`"
assert.Equal(t, sql, expected, "BuildJoin@3")
}
func TestQB_BuildUnion(t *testing.T) {
db := getDB()
qb := db.QueryBuilder()
params := Params{}
ui := UnionInfo{false, db.NewQuery("SELECT names").Bind(Params{"id": 1})}
sql := qb.BuildUnion([]UnionInfo{ui}, params)
expected := "UNION (SELECT names)"
assert.Equal(t, sql, expected, "BuildUnion@1")
assert.Equal(t, len(params), 1, "len(params)@1")
params = Params{}
ui = UnionInfo{true, db.NewQuery("SELECT names")}
sql = qb.BuildUnion([]UnionInfo{ui}, params)
expected = "UNION ALL (SELECT names)"
assert.Equal(t, sql, expected, "BuildUnion@2")
assert.Equal(t, len(params), 0, "len(params)@2")
sql = qb.BuildUnion([]UnionInfo{}, nil)
expected = ""
assert.Equal(t, sql, expected, "BuildUnion@3")
ui = UnionInfo{true, db.NewQuery("SELECT names")}
ui2 := UnionInfo{false, db.NewQuery("SELECT ages")}
sql = qb.BuildUnion([]UnionInfo{ui, ui2}, nil)
expected = "UNION ALL (SELECT names) UNION (SELECT ages)"
assert.Equal(t, sql, expected, "BuildUnion@4")
}

600
query_test.go Normal file
View file

@ -0,0 +1,600 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
ss "database/sql"
"encoding/json"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
type City struct {
ID int
Name string
}
func TestNewQuery(t *testing.T) {
db := getDB()
sql := "SELECT * FROM users WHERE id={:id}"
q := NewQuery(db, db.sqlDB, sql)
assert.Equal(t, q.SQL(), sql, "q.SQL()")
assert.Equal(t, q.rawSQL, "SELECT * FROM users WHERE id=?", "q.RawSQL()")
assert.Equal(t, len(q.Params()), 0, "len(q.Params())@1")
q.Bind(Params{"id": 1})
assert.Equal(t, len(q.Params()), 1, "len(q.Params())@2")
}
func TestQuery_Execute(t *testing.T) {
db := getPreparedDB()
defer db.Close()
result, err := db.NewQuery("INSERT INTO item (name) VALUES ('test')").Execute()
if assert.Nil(t, err) {
rows, _ := result.RowsAffected()
assert.Equal(t, rows, int64(1), "Result.RowsAffected()")
lastID, _ := result.LastInsertId()
assert.Equal(t, lastID, int64(6), "Result.LastInsertId()")
}
}
type Customer struct {
scanned bool
ID int
Email string
Status int
Name string
Address ss.NullString
}
func (m Customer) TableName() string {
return "customer"
}
func (m *Customer) PostScan() error {
m.scanned = true
return nil
}
type CustomerPtr struct {
ID *int `db:"pk"`
Email *string
Status *int
Name string
Address *string
}
func (m CustomerPtr) TableName() string {
return "customer"
}
type CustomerNull struct {
ID ss.NullInt64 `db:"pk,id"`
Email ss.NullString
Status *ss.NullInt64
Name string
Address ss.NullString
}
func (m CustomerNull) TableName() string {
return "customer"
}
type CustomerEmbedded struct {
Id int
Email *string
InnerCustomer
}
func (m CustomerEmbedded) TableName() string {
return "customer"
}
type CustomerEmbedded2 struct {
ID int
Email *string
Inner InnerCustomer
}
type InnerCustomer struct {
Status ss.NullInt64
Name *string
Address ss.NullString
}
func TestQuery_Rows(t *testing.T) {
db := getPreparedDB()
defer db.Close()
var (
sql string
err error
)
// Query.All()
var customers []Customer
sql = `SELECT * FROM customer ORDER BY id`
err = db.NewQuery(sql).All(&customers)
if assert.Nil(t, err) {
assert.Equal(t, len(customers), 3, "len(customers)")
assert.Equal(t, customers[2].ID, 3, "customers[2].ID")
assert.Equal(t, customers[2].Email, `user3@example.com`, "customers[2].Email")
assert.Equal(t, customers[2].Status, 2, "customers[2].Status")
assert.Equal(t, customers[0].scanned, true, "customers[0].scanned")
assert.Equal(t, customers[1].scanned, true, "customers[1].scanned")
assert.Equal(t, customers[2].scanned, true, "customers[2].scanned")
}
// Query.All() with slice of pointers
var customersPtrSlice []*Customer
sql = `SELECT * FROM customer ORDER BY id`
err = db.NewQuery(sql).All(&customersPtrSlice)
if assert.Nil(t, err) {
assert.Equal(t, len(customersPtrSlice), 3, "len(customersPtrSlice)")
assert.Equal(t, customersPtrSlice[2].ID, 3, "customersPtrSlice[2].ID")
assert.Equal(t, customersPtrSlice[2].Email, `user3@example.com`, "customersPtrSlice[2].Email")
assert.Equal(t, customersPtrSlice[2].Status, 2, "customersPtrSlice[2].Status")
assert.Equal(t, customersPtrSlice[0].scanned, true, "customersPtrSlice[0].scanned")
assert.Equal(t, customersPtrSlice[1].scanned, true, "customersPtrSlice[1].scanned")
assert.Equal(t, customersPtrSlice[2].scanned, true, "customersPtrSlice[2].scanned")
}
var customers2 []NullStringMap
err = db.NewQuery(sql).All(&customers2)
if assert.Nil(t, err) {
assert.Equal(t, len(customers2), 3, "len(customers2)")
assert.Equal(t, customers2[1]["id"].String, "2", "customers2[1][id]")
assert.Equal(t, customers2[1]["email"].String, `user2@example.com`, "customers2[1][email]")
assert.Equal(t, customers2[1]["status"].String, "1", "customers2[1][status]")
}
err = db.NewQuery(sql).All(customers)
assert.NotNil(t, err)
var customers3 []string
err = db.NewQuery(sql).All(&customers3)
assert.NotNil(t, err)
var customers4 string
err = db.NewQuery(sql).All(&customers4)
assert.NotNil(t, err)
var customers5 []Customer
err = db.NewQuery(`SELECT * FROM customer WHERE id=999`).All(&customers5)
if assert.Nil(t, err) {
assert.NotNil(t, customers5)
assert.Zero(t, len(customers5))
}
// One
var customer Customer
sql = `SELECT * FROM customer WHERE id={:id}`
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customer)
if assert.Nil(t, err) {
assert.Equal(t, customer.ID, 2, "customer.ID")
assert.Equal(t, customer.Email, `user2@example.com`, "customer.Email")
assert.Equal(t, customer.Status, 1, "customer.Status")
}
var customerPtr2 CustomerPtr
sql = `SELECT id, email, address FROM customer WHERE id=2`
rows2, err := db.sqlDB.Query(sql)
defer rows2.Close()
assert.Nil(t, err)
rows2.Next()
err = rows2.Scan(&customerPtr2.ID, &customerPtr2.Email, &customerPtr2.Address)
if assert.Nil(t, err) {
assert.Equal(t, *customerPtr2.ID, 2, "customer.ID")
assert.Equal(t, *customerPtr2.Email, `user2@example.com`)
assert.Nil(t, customerPtr2.Address)
}
// struct fields are pointers
var customerPtr CustomerPtr
sql = `SELECT * FROM customer WHERE id={:id}`
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customerPtr)
if assert.Nil(t, err) {
assert.Equal(t, *customerPtr.ID, 2, "customer.ID")
assert.Equal(t, *customerPtr.Email, `user2@example.com`, "customer.Email")
assert.Equal(t, *customerPtr.Status, 1, "customer.Status")
}
// struct fields are null types
var customerNull CustomerNull
sql = `SELECT * FROM customer WHERE id={:id}`
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customerNull)
if assert.Nil(t, err) {
assert.Equal(t, customerNull.ID.Int64, int64(2), "customer.ID")
assert.Equal(t, customerNull.Email.String, `user2@example.com`, "customer.Email")
assert.Equal(t, customerNull.Status.Int64, int64(1), "customer.Status")
}
// embedded with anonymous struct
var customerEmbedded CustomerEmbedded
sql = `SELECT * FROM customer WHERE id={:id}`
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customerEmbedded)
if assert.Nil(t, err) {
assert.Equal(t, customerEmbedded.Id, 2, "customer.ID")
assert.Equal(t, *customerEmbedded.Email, `user2@example.com`, "customer.Email")
assert.Equal(t, customerEmbedded.Status.Int64, int64(1), "customer.Status")
}
// embedded with named struct
var customerEmbedded2 CustomerEmbedded2
sql = `SELECT id, email, status as "inner.status" FROM customer WHERE id={:id}`
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customerEmbedded2)
if assert.Nil(t, err) {
assert.Equal(t, customerEmbedded2.ID, 2, "customer.ID")
assert.Equal(t, *customerEmbedded2.Email, `user2@example.com`, "customer.Email")
assert.Equal(t, customerEmbedded2.Inner.Status.Int64, int64(1), "customer.Status")
}
customer2 := NullStringMap{}
sql = `SELECT * FROM customer WHERE id={:id}`
err = db.NewQuery(sql).Bind(Params{"id": 1}).One(customer2)
if assert.Nil(t, err) {
assert.Equal(t, customer2["id"].String, "1", "customer2[id]")
assert.Equal(t, customer2["email"].String, `user1@example.com`, "customer2[email]")
assert.Equal(t, customer2["status"].String, "1", "customer2[status]")
}
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(customer)
assert.NotNil(t, err)
var customer3 NullStringMap
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(customer3)
assert.NotNil(t, err)
err = db.NewQuery(sql).Bind(Params{"id": 1}).One(&customer3)
if assert.Nil(t, err) {
assert.Equal(t, customer3["id"].String, "1", "customer3[id]")
}
// Rows
sql = `SELECT * FROM customer ORDER BY id DESC`
rows, err := db.NewQuery(sql).Rows()
if assert.Nil(t, err) {
s := ""
for rows.Next() {
rows.ScanStruct(&customer)
s += customer.Email + ","
}
assert.Equal(t, s, "user3@example.com,user2@example.com,user1@example.com,", "Rows().Next()")
}
// FieldMapper
var a struct {
MyID string `db:"id"`
name string
}
sql = `SELECT * FROM customer WHERE id=2`
err = db.NewQuery(sql).One(&a)
if assert.Nil(t, err) {
assert.Equal(t, a.MyID, "2", "a.MyID")
// unexported field is not populated
assert.Equal(t, a.name, "", "a.name")
}
// prepared statement
sql = `SELECT * FROM customer WHERE id={:id}`
q := db.NewQuery(sql).Prepare()
q.Bind(Params{"id": 1}).One(&customer)
assert.Equal(t, customer.ID, 1, "prepared@1")
err = q.Bind(Params{"id": 20}).One(&customer)
assert.Equal(t, err, ss.ErrNoRows, "prepared@2")
q.Bind(Params{"id": 3}).One(&customer)
assert.Equal(t, customer.ID, 3, "prepared@3")
sql = `SELECT name FROM customer WHERE id={:id}`
var name string
q = db.NewQuery(sql).Prepare()
q.Bind(Params{"id": 1}).Row(&name)
assert.Equal(t, name, "user1", "prepared2@1")
err = q.Bind(Params{"id": 20}).Row(&name)
assert.Equal(t, err, ss.ErrNoRows, "prepared2@2")
q.Bind(Params{"id": 3}).Row(&name)
assert.Equal(t, name, "user3", "prepared2@3")
// Query.LastError
sql = `SELECT * FROM a`
q = db.NewQuery(sql).Prepare()
customer.ID = 100
err = q.Bind(Params{"id": 1}).One(&customer)
assert.NotEqual(t, err, nil, "LastError@0")
assert.Equal(t, customer.ID, 100, "LastError@1")
assert.Equal(t, q.LastError, nil, "LastError@2")
// Query.Column
sql = `SELECT name, id FROM customer ORDER BY id`
var names []string
err = db.NewQuery(sql).Column(&names)
if assert.Nil(t, err) && assert.Equal(t, 3, len(names)) {
assert.Equal(t, "user1", names[0])
assert.Equal(t, "user2", names[1])
assert.Equal(t, "user3", names[2])
}
err = db.NewQuery(sql).Column(names)
assert.NotNil(t, err)
}
func TestQuery_logSQL(t *testing.T) {
db := getDB()
q := db.NewQuery("SELECT * FROM users WHERE type={:type} AND id={:id} AND bytes={:bytes}").Bind(Params{
"id": 1,
"type": "a",
"bytes": []byte("test"),
})
expected := "SELECT * FROM users WHERE type='a' AND id=1 AND bytes=0x74657374"
assert.Equal(t, q.logSQL(), expected, "logSQL()")
}
func TestReplacePlaceholders(t *testing.T) {
tests := []struct {
ID string
Placeholders []string
Params Params
ExpectedParams string
HasError bool
}{
{"t1", nil, nil, "null", false},
{"t2", []string{"id", "name"}, Params{"id": 1, "name": "xyz"}, `[1,"xyz"]`, false},
{"t3", []string{"id", "name"}, Params{"id": 1}, `null`, true},
{"t4", []string{"id", "name"}, Params{"id": 1, "name": "xyz", "age": 30}, `[1,"xyz"]`, false},
}
for _, test := range tests {
params, err := replacePlaceholders(test.Placeholders, test.Params)
result, _ := json.Marshal(params)
assert.Equal(t, string(result), test.ExpectedParams, "params@"+test.ID)
assert.Equal(t, err != nil, test.HasError, "error@"+test.ID)
}
}
func TestIssue6(t *testing.T) {
db := getPreparedDB()
q := db.Select("*").From("customer").Where(HashExp{"id": 1})
var customer Customer
assert.Equal(t, q.One(&customer), nil)
assert.Equal(t, 1, customer.ID)
}
type User struct {
ID int64
Email string
Created time.Time
Updated *time.Time
}
func TestIssue13(t *testing.T) {
db := getPreparedDB()
var user User
err := db.Select().From("user").Where(HashExp{"id": 1}).One(&user)
if assert.Nil(t, err) {
assert.NotZero(t, user.Created)
assert.Nil(t, user.Updated)
}
now := time.Now()
user2 := User{
Email: "now@example.com",
Created: now,
}
err = db.Model(&user2).Insert()
if assert.Nil(t, err) {
assert.NotZero(t, user2.ID)
}
user3 := User{
Email: "now@example.com",
Created: now,
Updated: &now,
}
err = db.Model(&user3).Insert()
if assert.Nil(t, err) {
assert.NotZero(t, user2.ID)
}
}
func TestQueryWithExecHook(t *testing.T) {
db := getPreparedDB()
defer db.Close()
// error return
{
err := db.NewQuery("select * from user").
WithExecHook(func(q *Query, op func() error) error {
return errors.New("test")
}).
Row()
assert.Error(t, err)
}
// Row()
{
calls := 0
err := db.NewQuery("select * from user").
WithExecHook(func(q *Query, op func() error) error {
calls++
return nil
}).
Row()
assert.Nil(t, err)
assert.Equal(t, 1, calls, "Row()")
}
// One()
{
calls := 0
err := db.NewQuery("select * from user").
WithExecHook(func(q *Query, op func() error) error {
calls++
return nil
}).
One(nil)
assert.Nil(t, err)
assert.Equal(t, 1, calls, "One()")
}
// All()
{
calls := 0
err := db.NewQuery("select * from user").
WithExecHook(func(q *Query, op func() error) error {
calls++
return nil
}).
All(nil)
assert.Nil(t, err)
assert.Equal(t, 1, calls, "All()")
}
// Column()
{
calls := 0
err := db.NewQuery("select * from user").
WithExecHook(func(q *Query, op func() error) error {
calls++
return nil
}).
Column(nil)
assert.Nil(t, err)
assert.Equal(t, 1, calls, "Column()")
}
// Execute()
{
calls := 0
_, err := db.NewQuery("select * from user").
WithExecHook(func(q *Query, op func() error) error {
calls++
return nil
}).
Execute()
assert.Nil(t, err)
assert.Equal(t, 1, calls, "Execute()")
}
// op call
{
calls := 0
var id int
err := db.NewQuery("select id from user where id = 2").
WithExecHook(func(q *Query, op func() error) error {
calls++
return op()
}).
Row(&id)
assert.Nil(t, err)
assert.Equal(t, 1, calls, "op hook calls")
assert.Equal(t, 2, id, "id mismatch")
}
}
func TestQueryWithOneHook(t *testing.T) {
db := getPreparedDB()
defer db.Close()
// error return
{
err := db.NewQuery("select * from user").
WithOneHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
return errors.New("test")
}).
One(nil)
assert.Error(t, err)
}
// hooks call order
{
hookCalls := []string{}
err := db.NewQuery("select * from user").
WithExecHook(func(q *Query, op func() error) error {
hookCalls = append(hookCalls, "exec")
return op()
}).
WithOneHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
hookCalls = append(hookCalls, "one")
return nil
}).
One(nil)
assert.Nil(t, err)
assert.Equal(t, hookCalls, []string{"exec", "one"})
}
// op call
{
calls := 0
other := User{}
err := db.NewQuery("select id from user where id = 2").
WithOneHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
calls++
return op(&other)
}).
One(nil)
assert.Nil(t, err)
assert.Equal(t, 1, calls, "hook calls")
assert.Equal(t, int64(2), other.ID, "replaced scan struct")
}
}
func TestQueryWithAllHook(t *testing.T) {
db := getPreparedDB()
defer db.Close()
// error return
{
err := db.NewQuery("select * from user").
WithAllHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
return errors.New("test")
}).
All(nil)
assert.Error(t, err)
}
// hooks call order
{
hookCalls := []string{}
err := db.NewQuery("select * from user").
WithExecHook(func(q *Query, op func() error) error {
hookCalls = append(hookCalls, "exec")
return op()
}).
WithAllHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
hookCalls = append(hookCalls, "all")
return nil
}).
All(nil)
assert.Nil(t, err)
assert.Equal(t, hookCalls, []string{"exec", "all"})
}
// op call
{
calls := 0
other := []User{}
err := db.NewQuery("select id from user order by id asc").
WithAllHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
calls++
return op(&other)
}).
All(nil)
assert.Nil(t, err)
assert.Equal(t, 1, calls, "hook calls")
assert.Equal(t, 2, len(other), "users length")
assert.Equal(t, int64(1), other[0].ID, "user 1 id check")
assert.Equal(t, int64(2), other[1].ID, "user 2 id check")
}
}

301
rows.go Normal file
View file

@ -0,0 +1,301 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"database/sql"
"reflect"
)
// VarTypeError indicates a variable type error when trying to populating a variable with DB result.
type VarTypeError string
// Error returns the error message.
func (s VarTypeError) Error() string {
return "Invalid variable type: " + string(s)
}
// NullStringMap is a map of sql.NullString that can be used to hold DB query result.
// The map keys correspond to the DB column names, while the map values are their corresponding column values.
type NullStringMap map[string]sql.NullString
// Rows enhances sql.Rows by providing additional data query methods.
// Rows can be obtained by calling Query.Rows(). It is mainly used to populate data row by row.
type Rows struct {
*sql.Rows
fieldMapFunc FieldMapFunc
}
// ScanMap populates the current row of data into a NullStringMap.
// Note that the NullStringMap must not be nil, or it will panic.
// The NullStringMap will be populated using column names as keys and their values as
// the corresponding element values.
func (r *Rows) ScanMap(a NullStringMap) error {
cols, _ := r.Columns()
var refs []interface{}
for i := 0; i < len(cols); i++ {
var t sql.NullString
refs = append(refs, &t)
}
if err := r.Scan(refs...); err != nil {
return err
}
for i, col := range cols {
a[col] = *refs[i].(*sql.NullString)
}
return nil
}
// ScanStruct populates the current row of data into a struct.
// The struct must be given as a pointer.
//
// ScanStruct associates struct fields with DB table columns through a field mapping function.
// It populates a struct field with the data of its associated column.
// Note that only exported struct fields will be populated.
//
// By default, DefaultFieldMapFunc() is used to map struct fields to table columns.
// This function separates each word in a field name with a underscore and turns every letter into lower case.
// For example, "LastName" is mapped to "last_name", "MyID" is mapped to "my_id", and so on.
// To change the default behavior, set DB.FieldMapper with your custom mapping function.
// You may also set Query.FieldMapper to change the behavior for particular queries.
func (r *Rows) ScanStruct(a interface{}) error {
rv := reflect.ValueOf(a)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return VarTypeError("must be a pointer")
}
rv = indirect(rv)
if rv.Kind() != reflect.Struct {
return VarTypeError("must be a pointer to a struct")
}
si := getStructInfo(rv.Type(), r.fieldMapFunc)
cols, _ := r.Columns()
refs := make([]interface{}, len(cols))
for i, col := range cols {
if fi, ok := si.dbNameMap[col]; ok {
refs[i] = fi.getField(rv).Addr().Interface()
} else {
refs[i] = &sql.NullString{}
}
}
if err := r.Scan(refs...); err != nil {
return err
}
// check for PostScanner
if rv.CanAddr() {
addr := rv.Addr()
if addr.CanInterface() {
if ps, ok := addr.Interface().(PostScanner); ok {
if err := ps.PostScan(); err != nil {
return err
}
}
}
}
return nil
}
// all populates all rows of query result into a slice of struct or NullStringMap.
// Note that the slice must be given as a pointer.
func (r *Rows) all(slice interface{}) error {
defer r.Close()
v := reflect.ValueOf(slice)
if v.Kind() != reflect.Ptr || v.IsNil() {
return VarTypeError("must be a pointer")
}
v = indirect(v)
if v.Kind() != reflect.Slice {
return VarTypeError("must be a slice of struct or NullStringMap")
}
if v.IsNil() {
// create an empty slice
v.Set(reflect.MakeSlice(v.Type(), 0, 0))
}
et := v.Type().Elem()
if et.Kind() == reflect.Map {
for r.Next() {
ev, ok := reflect.MakeMap(et).Interface().(NullStringMap)
if !ok {
return VarTypeError("must be a slice of struct or NullStringMap")
}
if err := r.ScanMap(ev); err != nil {
return err
}
v.Set(reflect.Append(v, reflect.ValueOf(ev)))
}
return r.Close()
}
var isSliceOfPointers bool
if et.Kind() == reflect.Ptr {
isSliceOfPointers = true
et = et.Elem()
}
if et.Kind() != reflect.Struct {
return VarTypeError("must be a slice of struct or NullStringMap")
}
etPtr := reflect.PtrTo(et)
implementsPostScanner := etPtr.Implements(postScannerType)
si := getStructInfo(et, r.fieldMapFunc)
cols, _ := r.Columns()
for r.Next() {
ev := reflect.New(et).Elem()
refs := make([]interface{}, len(cols))
for i, col := range cols {
if fi, ok := si.dbNameMap[col]; ok {
refs[i] = fi.getField(ev).Addr().Interface()
} else {
refs[i] = &sql.NullString{}
}
}
if err := r.Scan(refs...); err != nil {
return err
}
if isSliceOfPointers {
ev = ev.Addr()
}
// check for PostScanner
if implementsPostScanner {
evAddr := ev
if ev.CanAddr() {
evAddr = ev.Addr()
}
if evAddr.CanInterface() {
if ps, ok := evAddr.Interface().(PostScanner); ok {
if err := ps.PostScan(); err != nil {
return err
}
}
}
}
v.Set(reflect.Append(v, ev))
}
return r.Close()
}
// column populates the given slice with the first column of the query result.
// Note that the slice must be given as a pointer.
func (r *Rows) column(slice interface{}) error {
defer r.Close()
v := reflect.ValueOf(slice)
if v.Kind() != reflect.Ptr || v.IsNil() {
return VarTypeError("must be a pointer to a slice")
}
v = indirect(v)
if v.Kind() != reflect.Slice {
return VarTypeError("must be a pointer to a slice")
}
et := v.Type().Elem()
cols, _ := r.Columns()
for r.Next() {
ev := reflect.New(et)
refs := make([]interface{}, len(cols))
for i := range cols {
if i == 0 {
refs[i] = ev.Interface()
} else {
refs[i] = &sql.NullString{}
}
}
if err := r.Scan(refs...); err != nil {
return err
}
v.Set(reflect.Append(v, ev.Elem()))
}
return r.Close()
}
// one populates a single row of query result into a struct or a NullStringMap.
// Note that if a struct is given, it should be a pointer.
func (r *Rows) one(a interface{}) error {
defer r.Close()
if !r.Next() {
if err := r.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
var err error
rt := reflect.TypeOf(a)
if rt.Kind() == reflect.Ptr && rt.Elem().Kind() == reflect.Map {
// pointer to map
v := indirect(reflect.ValueOf(a))
if v.IsNil() {
v.Set(reflect.MakeMap(v.Type()))
}
a = v.Interface()
rt = reflect.TypeOf(a)
}
if rt.Kind() == reflect.Map {
v, ok := a.(NullStringMap)
if !ok {
return VarTypeError("must be a NullStringMap")
}
if v == nil {
return VarTypeError("NullStringMap is nil")
}
err = r.ScanMap(v)
} else {
err = r.ScanStruct(a)
}
if err != nil {
return err
}
return r.Close()
}
// row populates a single row of query result into a list of variables.
func (r *Rows) row(a ...interface{}) error {
defer r.Close()
for _, dp := range a {
if _, ok := dp.(*sql.RawBytes); ok {
return VarTypeError("RawBytes isn't allowed on Row()")
}
}
if !r.Next() {
if err := r.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
if err := r.Scan(a...); err != nil {
return err
}
return r.Close()
}

445
select.go Normal file
View file

@ -0,0 +1,445 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"context"
"fmt"
"reflect"
)
// BuildHookFunc defines a callback function that is executed on Query creation.
type BuildHookFunc func(q *Query)
// SelectQuery represents a DB-agnostic SELECT query.
// It can be built into a DB-specific query by calling the Build() method.
type SelectQuery struct {
// FieldMapper maps struct field names to DB column names.
FieldMapper FieldMapFunc
// TableMapper maps structs to DB table names.
TableMapper TableMapFunc
builder Builder
ctx context.Context
buildHook BuildHookFunc
preFragment string
postFragment string
selects []string
distinct bool
selectOption string
from []string
where Expression
join []JoinInfo
orderBy []string
groupBy []string
having Expression
union []UnionInfo
limit int64
offset int64
params Params
}
// JoinInfo contains the specification for a JOIN clause.
type JoinInfo struct {
Join string
Table string
On Expression
}
// UnionInfo contains the specification for a UNION clause.
type UnionInfo struct {
All bool
Query *Query
}
// NewSelectQuery creates a new SelectQuery instance.
func NewSelectQuery(builder Builder, db *DB) *SelectQuery {
return &SelectQuery{
builder: builder,
selects: []string{},
from: []string{},
join: []JoinInfo{},
orderBy: []string{},
groupBy: []string{},
union: []UnionInfo{},
limit: -1,
params: Params{},
ctx: db.ctx,
FieldMapper: db.FieldMapper,
TableMapper: db.TableMapper,
}
}
// WithBuildHook runs the provided hook function with the query created on Build().
func (q *SelectQuery) WithBuildHook(fn BuildHookFunc) *SelectQuery {
q.buildHook = fn
return q
}
// Context returns the context associated with the query.
func (q *SelectQuery) Context() context.Context {
return q.ctx
}
// WithContext associates a context with the query.
func (q *SelectQuery) WithContext(ctx context.Context) *SelectQuery {
q.ctx = ctx
return q
}
// PreFragment sets SQL fragment that should be prepended before the select query (e.g. WITH clause).
func (s *SelectQuery) PreFragment(fragment string) *SelectQuery {
s.preFragment = fragment
return s
}
// PostFragment sets SQL fragment that should be appended at the end of the select query.
func (s *SelectQuery) PostFragment(fragment string) *SelectQuery {
s.postFragment = fragment
return s
}
// Select specifies the columns to be selected.
// Column names will be automatically quoted.
func (s *SelectQuery) Select(cols ...string) *SelectQuery {
s.selects = cols
return s
}
// AndSelect adds additional columns to be selected.
// Column names will be automatically quoted.
func (s *SelectQuery) AndSelect(cols ...string) *SelectQuery {
s.selects = append(s.selects, cols...)
return s
}
// Distinct specifies whether to select columns distinctively.
// By default, distinct is false.
func (s *SelectQuery) Distinct(v bool) *SelectQuery {
s.distinct = v
return s
}
// SelectOption specifies additional option that should be append to "SELECT".
func (s *SelectQuery) SelectOption(option string) *SelectQuery {
s.selectOption = option
return s
}
// From specifies which tables to select from.
// Table names will be automatically quoted.
func (s *SelectQuery) From(tables ...string) *SelectQuery {
s.from = tables
return s
}
// Where specifies the WHERE condition.
func (s *SelectQuery) Where(e Expression) *SelectQuery {
s.where = e
return s
}
// AndWhere concatenates a new WHERE condition with the existing one (if any) using "AND".
func (s *SelectQuery) AndWhere(e Expression) *SelectQuery {
s.where = And(s.where, e)
return s
}
// OrWhere concatenates a new WHERE condition with the existing one (if any) using "OR".
func (s *SelectQuery) OrWhere(e Expression) *SelectQuery {
s.where = Or(s.where, e)
return s
}
// Join specifies a JOIN clause.
// The "typ" parameter specifies the JOIN type (e.g. "INNER JOIN", "LEFT JOIN").
func (s *SelectQuery) Join(typ string, table string, on Expression) *SelectQuery {
s.join = append(s.join, JoinInfo{typ, table, on})
return s
}
// InnerJoin specifies an INNER JOIN clause.
// This is a shortcut method for Join.
func (s *SelectQuery) InnerJoin(table string, on Expression) *SelectQuery {
return s.Join("INNER JOIN", table, on)
}
// LeftJoin specifies a LEFT JOIN clause.
// This is a shortcut method for Join.
func (s *SelectQuery) LeftJoin(table string, on Expression) *SelectQuery {
return s.Join("LEFT JOIN", table, on)
}
// RightJoin specifies a RIGHT JOIN clause.
// This is a shortcut method for Join.
func (s *SelectQuery) RightJoin(table string, on Expression) *SelectQuery {
return s.Join("RIGHT JOIN", table, on)
}
// OrderBy specifies the ORDER BY clause.
// Column names will be properly quoted. A column name can contain "ASC" or "DESC" to indicate its ordering direction.
func (s *SelectQuery) OrderBy(cols ...string) *SelectQuery {
s.orderBy = cols
return s
}
// AndOrderBy appends additional columns to the existing ORDER BY clause.
// Column names will be properly quoted. A column name can contain "ASC" or "DESC" to indicate its ordering direction.
func (s *SelectQuery) AndOrderBy(cols ...string) *SelectQuery {
s.orderBy = append(s.orderBy, cols...)
return s
}
// GroupBy specifies the GROUP BY clause.
// Column names will be properly quoted.
func (s *SelectQuery) GroupBy(cols ...string) *SelectQuery {
s.groupBy = cols
return s
}
// AndGroupBy appends additional columns to the existing GROUP BY clause.
// Column names will be properly quoted.
func (s *SelectQuery) AndGroupBy(cols ...string) *SelectQuery {
s.groupBy = append(s.groupBy, cols...)
return s
}
// Having specifies the HAVING clause.
func (s *SelectQuery) Having(e Expression) *SelectQuery {
s.having = e
return s
}
// AndHaving concatenates a new HAVING condition with the existing one (if any) using "AND".
func (s *SelectQuery) AndHaving(e Expression) *SelectQuery {
s.having = And(s.having, e)
return s
}
// OrHaving concatenates a new HAVING condition with the existing one (if any) using "OR".
func (s *SelectQuery) OrHaving(e Expression) *SelectQuery {
s.having = Or(s.having, e)
return s
}
// Union specifies a UNION clause.
func (s *SelectQuery) Union(q *Query) *SelectQuery {
s.union = append(s.union, UnionInfo{false, q})
return s
}
// UnionAll specifies a UNION ALL clause.
func (s *SelectQuery) UnionAll(q *Query) *SelectQuery {
s.union = append(s.union, UnionInfo{true, q})
return s
}
// Limit specifies the LIMIT clause.
// A negative limit means no limit.
func (s *SelectQuery) Limit(limit int64) *SelectQuery {
s.limit = limit
return s
}
// Offset specifies the OFFSET clause.
// A negative offset means no offset.
func (s *SelectQuery) Offset(offset int64) *SelectQuery {
s.offset = offset
return s
}
// Bind specifies the parameter values to be bound to the query.
func (s *SelectQuery) Bind(params Params) *SelectQuery {
s.params = params
return s
}
// AndBind appends additional parameters to be bound to the query.
func (s *SelectQuery) AndBind(params Params) *SelectQuery {
if len(s.params) == 0 {
s.params = params
} else {
for k, v := range params {
s.params[k] = v
}
}
return s
}
// Build builds the SELECT query and returns an executable Query object.
func (s *SelectQuery) Build() *Query {
params := Params{}
for k, v := range s.params {
params[k] = v
}
qb := s.builder.QueryBuilder()
clauses := []string{
s.preFragment,
qb.BuildSelect(s.selects, s.distinct, s.selectOption),
qb.BuildFrom(s.from),
qb.BuildJoin(s.join, params),
qb.BuildWhere(s.where, params),
qb.BuildGroupBy(s.groupBy),
qb.BuildHaving(s.having, params),
}
sql := ""
for _, clause := range clauses {
if clause != "" {
if sql == "" {
sql = clause
} else {
sql += " " + clause
}
}
}
sql = qb.BuildOrderByAndLimit(sql, s.orderBy, s.limit, s.offset)
if s.postFragment != "" {
sql += " " + s.postFragment
}
if union := qb.BuildUnion(s.union, params); union != "" {
sql = fmt.Sprintf("(%v) %v", sql, union)
}
query := s.builder.NewQuery(sql).Bind(params).WithContext(s.ctx)
if s.buildHook != nil {
s.buildHook(query)
}
return query
}
// One executes the SELECT query and populates the first row of the result into the specified variable.
//
// If the query does not specify a "from" clause, the method will try to infer the name of the table
// to be selected from by calling getTableName() which will return either the variable type name
// or the TableName() method if the variable implements the TableModel interface.
//
// Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned.
func (s *SelectQuery) One(a interface{}) error {
if len(s.from) == 0 {
if tableName := s.TableMapper(a); tableName != "" {
s.from = []string{tableName}
}
}
return s.Build().One(a)
}
// Model selects the row with the specified primary key and populates the model with the row data.
//
// The model variable should be a pointer to a struct. If the query does not specify a "from" clause,
// it will use the model struct to determine which table to select data from. It will also use the model
// to infer the name of the primary key column. Only simple primary key is supported. For composite primary keys,
// please use Where() to specify the filtering condition.
func (s *SelectQuery) Model(pk, model interface{}) error {
t := reflect.TypeOf(model)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return VarTypeError("must be a pointer to a struct")
}
si := getStructInfo(t, s.FieldMapper)
if len(si.pkNames) == 1 {
return s.AndWhere(HashExp{si.nameMap[si.pkNames[0]].dbName: pk}).One(model)
}
if len(si.pkNames) == 0 {
return MissingPKError
}
return CompositePKError
}
// All executes the SELECT query and populates all rows of the result into a slice.
//
// Note that the slice must be passed in as a pointer.
//
// If the query does not specify a "from" clause, the method will try to infer the name of the table
// to be selected from by calling getTableName() which will return either the type name of the slice elements
// or the TableName() method if the slice element implements the TableModel interface.
func (s *SelectQuery) All(slice interface{}) error {
if len(s.from) == 0 {
if tableName := s.TableMapper(slice); tableName != "" {
s.from = []string{tableName}
}
}
return s.Build().All(slice)
}
// Rows builds and executes the SELECT query and returns a Rows object for data retrieval purpose.
// This is a shortcut to SelectQuery.Build().Rows()
func (s *SelectQuery) Rows() (*Rows, error) {
return s.Build().Rows()
}
// Row builds and executes the SELECT query and populates the first row of the result into the specified variables.
// This is a shortcut to SelectQuery.Build().Row()
func (s *SelectQuery) Row(a ...interface{}) error {
return s.Build().Row(a...)
}
// Column builds and executes the SELECT statement and populates the first column of the result into a slice.
// Note that the parameter must be a pointer to a slice.
// This is a shortcut to SelectQuery.Build().Column()
func (s *SelectQuery) Column(a interface{}) error {
return s.Build().Column(a)
}
// QueryInfo represents a debug/info struct with exported SelectQuery fields.
type QueryInfo struct {
PreFragment string
PostFragment string
Builder Builder
Selects []string
Distinct bool
SelectOption string
From []string
Where Expression
Join []JoinInfo
OrderBy []string
GroupBy []string
Having Expression
Union []UnionInfo
Limit int64
Offset int64
Params Params
Context context.Context
BuildHook BuildHookFunc
}
// Info exports common SelectQuery fields allowing to inspect the
// current select query options.
func (s *SelectQuery) Info() *QueryInfo {
return &QueryInfo{
Builder: s.builder,
PreFragment: s.preFragment,
PostFragment: s.postFragment,
Selects: s.selects,
Distinct: s.distinct,
SelectOption: s.selectOption,
From: s.from,
Where: s.where,
Join: s.join,
OrderBy: s.orderBy,
GroupBy: s.groupBy,
Having: s.having,
Union: s.union,
Limit: s.limit,
Offset: s.offset,
Params: s.params,
Context: s.ctx,
BuildHook: s.buildHook,
}
}

179
select_test.go Normal file
View file

@ -0,0 +1,179 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"testing"
"database/sql"
"github.com/stretchr/testify/assert"
)
func TestSelectQuery(t *testing.T) {
db := getDB()
// minimal select query
q := db.Select().From("users").Build()
expected := "SELECT * FROM `users`"
assert.Equal(t, q.SQL(), expected, "t1")
assert.Equal(t, len(q.Params()), 0, "t2")
// a full select query
q = db.Select("id", "name").
PreFragment("pre").
PostFragment("post").
AndSelect("age").
Distinct(true).
SelectOption("CALC").
From("users").
Where(NewExp("age>30")).
AndWhere(NewExp("status=1")).
OrWhere(NewExp("type=2")).
InnerJoin("profile", NewExp("user.id=profile.id")).
LeftJoin("team", nil).
RightJoin("dept", nil).
OrderBy("age DESC", "type").
AndOrderBy("id").
GroupBy("id").
AndGroupBy("age").
Having(NewExp("id>10")).
AndHaving(NewExp("id<20")).
OrHaving(NewExp("type=3")).
Limit(10).
Offset(20).
Bind(Params{"id": 1}).
AndBind(Params{"age": 30}).
Build()
expected = "pre SELECT DISTINCT CALC `id`, `name`, `age` FROM `users` INNER JOIN `profile` ON user.id=profile.id LEFT JOIN `team` RIGHT JOIN `dept` WHERE ((age>30) AND (status=1)) OR (type=2) GROUP BY `id`, `age` HAVING ((id>10) AND (id<20)) OR (type=3) ORDER BY `age` DESC, `type`, `id` LIMIT 10 OFFSET 20 post"
assert.Equal(t, q.SQL(), expected, "t3")
assert.Equal(t, len(q.Params()), 2, "t4")
q3 := db.Select().AndBind(Params{"id": 1}).Build()
assert.Equal(t, len(q3.Params()), 1)
// union
q1 := db.Select().From("users").PreFragment("pre_q1").Build()
q2 := db.Select().From("posts").PostFragment("post_q2").Build()
q = db.Select().From("profiles").Union(q1).UnionAll(q2).PreFragment("pre").PostFragment("post").Build()
expected = "(pre SELECT * FROM `profiles` post) UNION (pre_q1 SELECT * FROM `users`) UNION ALL (SELECT * FROM `posts` post_q2)"
assert.Equal(t, q.SQL(), expected, "t5")
}
func TestSelectQuery_Data(t *testing.T) {
db := getPreparedDB()
defer db.Close()
q := db.Select("id", "email").From("customer").OrderBy("id")
var customer Customer
q.One(&customer)
assert.Equal(t, customer.Email, "user1@example.com", "customer.Email")
var customers []Customer
q.All(&customers)
assert.Equal(t, len(customers), 3, "len(customers)")
rows, _ := q.Rows()
customer.Email = ""
rows.one(&customer)
assert.Equal(t, customer.Email, "user1@example.com", "customer.Email")
var id, email string
q.Row(&id, &email)
assert.Equal(t, id, "1", "id")
assert.Equal(t, email, "user1@example.com", "email")
var emails []string
err := db.Select("email").From("customer").Column(&emails)
if assert.Nil(t, err) {
assert.Equal(t, 3, len(emails))
}
var e int
err = db.Select().From("customer").One(&e)
assert.NotNil(t, err)
err = db.Select().From("customer").All(&e)
assert.NotNil(t, err)
}
func TestSelectQuery_Model(t *testing.T) {
db := getPreparedDB()
defer db.Close()
{
// One without specifying FROM
var customer CustomerPtr
err := db.Select().OrderBy("id").One(&customer)
if assert.Nil(t, err) {
assert.Equal(t, "user1@example.com", *customer.Email)
}
}
{
// All without specifying FROM
var customers []CustomerPtr
err := db.Select().OrderBy("id").All(&customers)
if assert.Nil(t, err) {
assert.Equal(t, 3, len(customers))
}
}
{
// Model without specifying FROM
var customer CustomerPtr
err := db.Select().Model(2, &customer)
if assert.Nil(t, err) {
assert.Equal(t, "user2@example.com", *customer.Email)
}
}
{
// Model with WHERE
var customer CustomerPtr
err := db.Select().Where(HashExp{"id": 1}).Model(2, &customer)
assert.Equal(t, sql.ErrNoRows, err)
err = db.Select().Where(HashExp{"id": 2}).Model(2, &customer)
assert.Nil(t, err)
}
{
// errors
var i int
err := db.Select().Model(1, &i)
assert.Equal(t, VarTypeError("must be a pointer to a struct"), err)
var a struct {
Name string
}
err = db.Select().Model(1, &a)
assert.Equal(t, MissingPKError, err)
var b struct {
ID1 string `db:"pk"`
ID2 string `db:"pk"`
}
err = db.Select().Model(1, &b)
assert.Equal(t, CompositePKError, err)
}
}
func TestSelectWithBuildHook(t *testing.T) {
db := getPreparedDB()
defer db.Close()
var buildSQL string
db.Select("id").
From("user").
WithBuildHook(func(q *Query) {
buildSQL = q.SQL()
}).
Build()
assert.Equal(t, "SELECT `id` FROM `user`", buildSQL)
}

281
struct.go Normal file
View file

@ -0,0 +1,281 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"database/sql"
"reflect"
"regexp"
"strings"
"sync"
)
type (
// FieldMapFunc converts a struct field name into a DB column name.
FieldMapFunc func(string) string
// TableMapFunc converts a sample struct into a DB table name.
TableMapFunc func(a interface{}) string
structInfo struct {
nameMap map[string]*fieldInfo // mapping from struct field names to field infos
dbNameMap map[string]*fieldInfo // mapping from db column names to field infos
pkNames []string // struct field names representing PKs
}
structValue struct {
*structInfo
value reflect.Value // the struct value
tableName string // the db table name for the struct
}
fieldInfo struct {
name string // field name
dbName string // db column name
path []int // index path to the struct field reflection
}
structInfoMapKey struct {
t reflect.Type
m reflect.Value
}
)
var (
// DbTag is the name of the struct tag used to specify the column name for the associated struct field
DbTag = "db"
fieldRegex = regexp.MustCompile(`([^A-Z_])([A-Z])`)
scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
postScannerType = reflect.TypeOf((*PostScanner)(nil)).Elem()
structInfoMap = make(map[structInfoMapKey]*structInfo)
muStructInfoMap sync.Mutex
)
// PostScanner is an optional interface used by ScanStruct.
type PostScanner interface {
// PostScan executes right after the struct has been populated
// with the DB values, allowing you to further normalize or validate
// the loaded data.
PostScan() error
}
// DefaultFieldMapFunc maps a field name to a DB column name.
// The mapping rule set by this method is that words in a field name will be separated by underscores
// and the name will be turned into lower case. For example, "FirstName" maps to "first_name", and "MyID" becomes "my_id".
// See DB.FieldMapper for more details.
func DefaultFieldMapFunc(f string) string {
return strings.ToLower(fieldRegex.ReplaceAllString(f, "${1}_$2"))
}
func getStructInfo(a reflect.Type, mapper FieldMapFunc) *structInfo {
muStructInfoMap.Lock()
defer muStructInfoMap.Unlock()
key := structInfoMapKey{a, reflect.ValueOf(mapper)}
if si, ok := structInfoMap[key]; ok {
return si
}
si := &structInfo{
nameMap: map[string]*fieldInfo{},
dbNameMap: map[string]*fieldInfo{},
}
si.build(a, make([]int, 0), "", "", mapper)
structInfoMap[key] = si
return si
}
func newStructValue(model interface{}, fieldMapFunc FieldMapFunc, tableMapFunc TableMapFunc) *structValue {
value := reflect.ValueOf(model)
if value.Kind() != reflect.Ptr || value.Elem().Kind() != reflect.Struct || value.IsNil() {
return nil
}
return &structValue{
structInfo: getStructInfo(reflect.TypeOf(model).Elem(), fieldMapFunc),
value: value.Elem(),
tableName: tableMapFunc(model),
}
}
// pk returns the primary key values indexed by the corresponding primary key column names.
func (s *structValue) pk() map[string]interface{} {
if len(s.pkNames) == 0 {
return nil
}
return s.columns(s.pkNames, nil)
}
// columns returns the struct field values indexed by their corresponding DB column names.
func (s *structValue) columns(include, exclude []string) map[string]interface{} {
v := make(map[string]interface{}, len(s.nameMap))
if len(include) == 0 {
for _, fi := range s.nameMap {
v[fi.dbName] = fi.getValue(s.value)
}
} else {
for _, attr := range include {
if fi, ok := s.nameMap[attr]; ok {
v[fi.dbName] = fi.getValue(s.value)
}
}
}
if len(exclude) > 0 {
for _, name := range exclude {
if fi, ok := s.nameMap[name]; ok {
delete(v, fi.dbName)
}
}
}
return v
}
// getValue returns the field value for the given struct value.
func (fi *fieldInfo) getValue(a reflect.Value) interface{} {
for _, i := range fi.path {
a = a.Field(i)
if a.Kind() == reflect.Ptr {
if a.IsNil() {
return nil
}
a = a.Elem()
}
}
return a.Interface()
}
// getField returns the reflection value of the field for the given struct value.
func (fi *fieldInfo) getField(a reflect.Value) reflect.Value {
i := 0
for ; i < len(fi.path)-1; i++ {
a = indirect(a.Field(fi.path[i]))
}
return a.Field(fi.path[i])
}
func (si *structInfo) build(a reflect.Type, path []int, namePrefix, dbNamePrefix string, mapper FieldMapFunc) {
n := a.NumField()
for i := 0; i < n; i++ {
field := a.Field(i)
tag := field.Tag.Get(DbTag)
// only handle anonymous or exported fields
if !field.Anonymous && field.PkgPath != "" || tag == "-" {
continue
}
path2 := make([]int, len(path), len(path)+1)
copy(path2, path)
path2 = append(path2, i)
ft := field.Type
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
name := field.Name
dbName, isPK := parseTag(tag)
if dbName == "" && !field.Anonymous {
if mapper != nil {
dbName = mapper(field.Name)
} else {
dbName = field.Name
}
}
if field.Anonymous {
name = ""
}
if isNestedStruct(ft) {
// dive into non-scanner struct
si.build(ft, path2, concat(namePrefix, name), concat(dbNamePrefix, dbName), mapper)
} else if dbName != "" {
// non-anonymous scanner or struct field
fi := &fieldInfo{
name: concat(namePrefix, name),
dbName: concat(dbNamePrefix, dbName),
path: path2,
}
// a field in an anonymous struct may be shadowed
if _, ok := si.nameMap[fi.name]; !ok || len(path2) < len(si.nameMap[fi.name].path) {
si.nameMap[fi.name] = fi
si.dbNameMap[fi.dbName] = fi
if isPK {
si.pkNames = append(si.pkNames, fi.name)
}
}
}
}
if len(si.pkNames) == 0 {
if _, ok := si.nameMap["ID"]; ok {
si.pkNames = append(si.pkNames, "ID")
} else if _, ok := si.nameMap["Id"]; ok {
si.pkNames = append(si.pkNames, "Id")
}
}
}
func isNestedStruct(t reflect.Type) bool {
if t.PkgPath() == "time" && t.Name() == "Time" {
return false
}
return t.Kind() == reflect.Struct && !reflect.PtrTo(t).Implements(scannerType)
}
func parseTag(tag string) (string, bool) {
if tag == "pk" {
return "", true
}
if strings.HasPrefix(tag, "pk,") {
return tag[3:], true
}
return tag, false
}
func concat(s1, s2 string) string {
if s1 == "" {
return s2
} else if s2 == "" {
return s1
} else {
return s1 + "." + s2
}
}
// indirect dereferences pointers and returns the actual value it points to.
// If a pointer is nil, it will be initialized with a new value.
func indirect(v reflect.Value) reflect.Value {
for v.Kind() == reflect.Ptr {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
return v
}
// GetTableName implements the default way of determining the table name corresponding to the given model struct
// or slice of structs. To get the actual table name for a model, you should use DB.TableMapFunc() instead.
// Do not call this method in a model's TableName() method because it will cause infinite loop.
func GetTableName(a interface{}) string {
if tm, ok := a.(TableModel); ok {
v := reflect.ValueOf(a)
if v.Kind() == reflect.Ptr && v.IsNil() {
a = reflect.New(v.Type().Elem()).Interface()
return a.(TableModel).TableName()
}
return tm.TableName()
}
t := reflect.TypeOf(a)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() == reflect.Slice {
return GetTableName(reflect.Zero(t.Elem()).Interface())
}
return DefaultFieldMapFunc(t.Name())
}

154
struct_test.go Normal file
View file

@ -0,0 +1,154 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import (
"database/sql"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
func TestDefaultFieldMapFunc(t *testing.T) {
tests := []struct {
input, output string
}{
{"Name", "name"},
{"FirstName", "first_name"},
{"Name0", "name0"},
{"ID", "id"},
{"UserID", "user_id"},
{"User0ID", "user0_id"},
{"MyURL", "my_url"},
{"URLPath", "urlpath"},
{"MyURLPath", "my_urlpath"},
{"First_Name", "first_name"},
{"first_name", "first_name"},
{"_FirstName", "_first_name"},
{"_First_Name", "_first_name"},
}
for _, test := range tests {
r := DefaultFieldMapFunc(test.input)
assert.Equal(t, test.output, r, test.input)
}
}
func Test_concat(t *testing.T) {
assert.Equal(t, "a.b", concat("a", "b"))
assert.Equal(t, "a", concat("a", ""))
assert.Equal(t, "b", concat("", "b"))
}
func Test_parseTag(t *testing.T) {
name, pk := parseTag("abc")
assert.Equal(t, "abc", name)
assert.False(t, pk)
name, pk = parseTag("pk,abc")
assert.Equal(t, "abc", name)
assert.True(t, pk)
name, pk = parseTag("pk")
assert.Equal(t, "", name)
assert.True(t, pk)
}
func Test_indirect(t *testing.T) {
var a int
assert.Equal(t, reflect.ValueOf(a).Kind(), indirect(reflect.ValueOf(a)).Kind())
var b *int
bi := indirect(reflect.ValueOf(&b))
assert.Equal(t, reflect.ValueOf(a).Kind(), bi.Kind())
if assert.NotNil(t, b) {
assert.Equal(t, 0, *b)
}
}
func Test_structValue_columns(t *testing.T) {
customer := Customer{
ID: 1,
Name: "abc",
Status: 2,
Email: "abc@example.com",
}
sv := newStructValue(&customer, DefaultFieldMapFunc, GetTableName)
cols := sv.columns(nil, nil)
assert.Equal(t, map[string]interface{}{"id": 1, "name": "abc", "status": 2, "email": "abc@example.com", "address": sql.NullString{}}, cols)
cols = sv.columns([]string{"ID", "name"}, nil)
assert.Equal(t, map[string]interface{}{"id": 1}, cols)
cols = sv.columns([]string{"ID", "Name"}, []string{"ID"})
assert.Equal(t, map[string]interface{}{"name": "abc"}, cols)
cols = sv.columns(nil, []string{"ID", "Address"})
assert.Equal(t, map[string]interface{}{"name": "abc", "status": 2, "email": "abc@example.com"}, cols)
sv = newStructValue(&customer, nil, GetTableName)
cols = sv.columns([]string{"ID", "Name"}, []string{"ID"})
assert.Equal(t, map[string]interface{}{"Name": "abc"}, cols)
}
func TestIssue37(t *testing.T) {
customer := Customer{
ID: 1,
Name: "abc",
Status: 2,
Email: "abc@example.com",
}
ev := struct {
Customer
Status string
}{customer, "20"}
sv := newStructValue(&ev, nil, GetTableName)
cols := sv.columns([]string{"ID", "Status"}, nil)
assert.Equal(t, map[string]interface{}{"ID": 1, "Status": "20"}, cols)
ev2 := struct {
Status string
Customer
}{"20", customer}
sv = newStructValue(&ev2, nil, GetTableName)
cols = sv.columns([]string{"ID", "Status"}, nil)
assert.Equal(t, map[string]interface{}{"ID": 1, "Status": "20"}, cols)
}
type MyCustomer struct{}
func TestGetTableName(t *testing.T) {
var c1 Customer
assert.Equal(t, "customer", GetTableName(c1))
var c2 *Customer
assert.Equal(t, "customer", GetTableName(c2))
var c3 MyCustomer
assert.Equal(t, "my_customer", GetTableName(c3))
var c4 []Customer
assert.Equal(t, "customer", GetTableName(c4))
var c5 *[]Customer
assert.Equal(t, "customer", GetTableName(c5))
var c6 []MyCustomer
assert.Equal(t, "my_customer", GetTableName(c6))
var c7 []CustomerPtr
assert.Equal(t, "customer", GetTableName(c7))
var c8 **int
assert.Equal(t, "", GetTableName(c8))
}
type FA struct {
A1 string
A2 int
}
type FB struct {
B1 string
}

81
testdata/mysql.sql vendored Normal file
View file

@ -0,0 +1,81 @@
/**
* This is the database schema for testing MySQL support of ozzo-dbx.
* The following database setup is required in order to run the test:
* - host: 127.0.0.1
* - user: travis
* - pass: <none>
* - database: pocketbase_dbx_test
*/
DROP TABLE IF EXISTS `order_item` CASCADE;
DROP TABLE IF EXISTS `item` CASCADE;
DROP TABLE IF EXISTS `order` CASCADE;
DROP TABLE IF EXISTS `customer` CASCADE;
DROP TABLE IF EXISTS `user` CASCADE;
CREATE TABLE `customer` (
`id` int(11) NOT NULL AUTO_INCREMENT,
`email` varchar(128) NOT NULL,
`name` varchar(128),
`address` text,
`status` int (11) DEFAULT 0,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
CREATE TABLE `user` (
`id` int(11) NOT NULL AUTO_INCREMENT,
`email` varchar(128) NOT NULL,
`created` date,
`updated` date,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
CREATE TABLE `item` (
`id` int(11) NOT NULL AUTO_INCREMENT,
`name` varchar(128) NOT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
CREATE TABLE `order` (
`id` int(11) NOT NULL AUTO_INCREMENT,
`customer_id` int(11) NOT NULL,
`created_at` int(11) NOT NULL,
`total` decimal(10,0) NOT NULL,
PRIMARY KEY (`id`),
CONSTRAINT `FK_order_customer_id` FOREIGN KEY (`customer_id`) REFERENCES `customer` (`id`) ON DELETE CASCADE
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
CREATE TABLE `order_item` (
`order_id` int(11) NOT NULL,
`item_id` int(11) NOT NULL,
`quantity` int(11) NOT NULL,
`subtotal` decimal(10,0) NOT NULL,
PRIMARY KEY (`order_id`,`item_id`),
KEY `FK_order_item_item_id` (`item_id`),
CONSTRAINT `FK_order_item_order_id` FOREIGN KEY (`order_id`) REFERENCES `order` (`id`) ON DELETE CASCADE,
CONSTRAINT `FK_order_item_item_id` FOREIGN KEY (`item_id`) REFERENCES `item` (`id`) ON DELETE CASCADE
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
INSERT INTO `customer` (email, name, address, status) VALUES ('user1@example.com', 'user1', 'address1', 1);
INSERT INTO `customer` (email, name, address, status) VALUES ('user2@example.com', 'user2', NULL, 1);
INSERT INTO `customer` (email, name, address, status) VALUES ('user3@example.com', 'user3', 'address3', 2);
INSERT INTO `user` (email, created) VALUES ('user1@example.com', '2015-01-02');
INSERT INTO `user` (email, created) VALUES ('user2@example.com', now());
INSERT INTO `item` (name) VALUES ('The Go Programming Language');
INSERT INTO `item` (name) VALUES ('Go in Action');
INSERT INTO `item` (name) VALUES ('Go Programming Blueprints');
INSERT INTO `item` (name) VALUES ('Building Microservices');
INSERT INTO `item` (name) VALUES ('Go Web Programming');
INSERT INTO `order` (customer_id, created_at, total) VALUES (1, 1325282384, 110.0);
INSERT INTO `order` (customer_id, created_at, total) VALUES (2, 1325334482, 33.0);
INSERT INTO `order` (customer_id, created_at, total) VALUES (2, 1325502201, 40.0);
INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (1, 1, 1, 30.0);
INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (1, 2, 2, 40.0);
INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (2, 4, 1, 10.0);
INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (2, 5, 1, 15.0);
INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (2, 3, 1, 8.0);
INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (3, 2, 1, 40.0);

23
tx.go Normal file
View file

@ -0,0 +1,23 @@
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbx
import "database/sql"
// Tx enhances sql.Tx with additional querying methods.
type Tx struct {
Builder
tx *sql.Tx
}
// Commit commits the transaction.
func (t *Tx) Commit() error {
return t.tx.Commit()
}
// Rollback aborts the transaction.
func (t *Tx) Rollback() error {
return t.tx.Rollback()
}