Adding upstream version 1.11.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
40df376a7f
commit
02cacc5b45
37 changed files with 7317 additions and 0 deletions
26
.gitignore
vendored
Normal file
26
.gitignore
vendored
Normal file
|
@ -0,0 +1,26 @@
|
|||
.idea
|
||||
|
||||
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
||||
*.o
|
||||
*.a
|
||||
*.so
|
||||
|
||||
# Folders
|
||||
_obj
|
||||
_test
|
||||
|
||||
# Architecture specific extensions/prefixes
|
||||
*.[568vq]
|
||||
[568vq].out
|
||||
|
||||
*.cgo1.go
|
||||
*.cgo2.c
|
||||
_cgo_defun.c
|
||||
_cgo_gotypes.go
|
||||
_cgo_export.*
|
||||
|
||||
_testmain.go
|
||||
|
||||
*.exe
|
||||
*.test
|
||||
*.prof
|
22
.travis.yml
Normal file
22
.travis.yml
Normal file
|
@ -0,0 +1,22 @@
|
|||
dist: bionic
|
||||
|
||||
language: go
|
||||
|
||||
go:
|
||||
- 1.13.x
|
||||
|
||||
services:
|
||||
- mysql
|
||||
|
||||
install:
|
||||
- go get golang.org/x/tools/cmd/cover
|
||||
- go get github.com/mattn/goveralls
|
||||
- go get golang.org/x/lint/golint
|
||||
|
||||
before_script:
|
||||
- mysql -e 'CREATE DATABASE pocketbase_dbx_test;';
|
||||
|
||||
script:
|
||||
- test -z "`gofmt -l -d .`"
|
||||
- go test -v -covermode=count -coverprofile=coverage.out
|
||||
- $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci
|
17
LICENSE
Normal file
17
LICENSE
Normal file
|
@ -0,0 +1,17 @@
|
|||
The MIT License (MIT)
|
||||
Copyright (c) 2016, Qiang Xue
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software
|
||||
and associated documentation files (the "Software"), to deal in the Software without restriction,
|
||||
including without limitation the rights to use, copy, modify, merge, publish, distribute,
|
||||
sublicense, and/or sell copies of the Software, and to permit persons to whom the Software
|
||||
is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or
|
||||
substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
|
||||
BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
748
README.md
Normal file
748
README.md
Normal file
|
@ -0,0 +1,748 @@
|
|||
dbx
|
||||
[](https://goreportcard.com/report/github.com/pocketbase/dbx)
|
||||
[](https://pkg.go.dev/github.com/pocketbase/dbx)
|
||||
================================================================================
|
||||
|
||||
> ⚠️ This is a maintained fork of [go-ozzo/ozzo-dbx](https://github.com/go-ozzo/ozzo-dbx) (see [#103](https://github.com/go-ozzo/ozzo-dbx/issues/103)).
|
||||
>
|
||||
> Currently, the changes are primarily related to better SQLite support and some other minor improvements, implementing [#99](https://github.com/go-ozzo/ozzo-dbx/pull/99), [#100](https://github.com/go-ozzo/ozzo-dbx/pull/100) and [#102](https://github.com/go-ozzo/ozzo-dbx/pull/102).
|
||||
|
||||
|
||||
## Summary
|
||||
|
||||
- [Description](#description)
|
||||
- [Requirements](#requirements)
|
||||
- [Installation](#installation)
|
||||
- [Supported Databases](#supported-databases)
|
||||
- [Getting Started](#getting-started)
|
||||
- [Connecting to Database](#connecting-to-database)
|
||||
- [Executing Queries](#executing-queries)
|
||||
- [Binding Parameters](#binding-parameters)
|
||||
- [Building Queries](#building-queries)
|
||||
- [Building SELECT Queries](#building-select-queries)
|
||||
- [Building Query Conditions](#building-query-conditions)
|
||||
- [Building Data Manipulation Queries](#building-data-manipulation-queries)
|
||||
- [Building Schema Manipulation Queries](#building-schema-manipulation-queries)
|
||||
- [CRUD Operations](#crud-operations)
|
||||
- [Create](#create)
|
||||
- [Read](#read)
|
||||
- [Update](#update)
|
||||
- [Delete](#delete)
|
||||
- [Null Handling](#null-handling)
|
||||
- [Quoting Table and Column Names](#quoting-table-and-column-names)
|
||||
- [Using Transactions](#using-transactions)
|
||||
- [Logging Executed SQL Statements](#logging-executed-sql-statements)
|
||||
- [Supporting New Databases](#supporting-new-databases)
|
||||
|
||||
|
||||
## Description
|
||||
|
||||
`dbx` is a Go package that enhances the standard `database/sql` package by providing powerful data retrieval methods
|
||||
as well as DB-agnostic query building capabilities. `dbx` is not an ORM. It has the following features:
|
||||
|
||||
* Populating data into structs and NullString maps
|
||||
* Named parameter binding
|
||||
* DB-agnostic query building methods, including SELECT queries, data manipulation queries, and schema manipulation queries
|
||||
* Inserting, updating, and deleting model structs
|
||||
* Powerful query condition building
|
||||
* Open architecture allowing addition of new database support or customization of existing support
|
||||
* Logging executed SQL statements
|
||||
* Supporting major relational databases
|
||||
|
||||
For an example on how this library is used in an application, please refer to [go-rest-api](https://github.com/qiangxue/go-rest-api) which is a starter kit for building RESTful APIs in Go.
|
||||
|
||||
## Requirements
|
||||
|
||||
Go 1.13 or above.
|
||||
|
||||
## Installation
|
||||
|
||||
Run the following command to install the package:
|
||||
|
||||
```
|
||||
go get github.com/pocketbase/dbx
|
||||
```
|
||||
|
||||
In addition, install the specific DB driver package for the kind of database to be used. Please refer to
|
||||
[SQL database drivers](https://github.com/golang/go/wiki/SQLDrivers) for a complete list. For example, if you are
|
||||
using MySQL, you may install the following package:
|
||||
|
||||
```sh
|
||||
go get github.com/go-sql-driver/mysql
|
||||
```
|
||||
|
||||
and import it in your main code like the following:
|
||||
|
||||
```go
|
||||
import _ "github.com/go-sql-driver/mysql"
|
||||
```
|
||||
|
||||
## Supported Databases
|
||||
|
||||
The following databases are fully supported out of box:
|
||||
|
||||
* SQLite
|
||||
* MySQL
|
||||
* PostgreSQL
|
||||
* MS SQL Server (2012 or above)
|
||||
* Oracle
|
||||
|
||||
For other databases, the query building feature may not work as expected. You can create a custom builder to
|
||||
solve the problem. Please see the last section for more details.
|
||||
|
||||
## Getting Started
|
||||
|
||||
The following code snippet shows how you can use this package in order to access data from a MySQL database.
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/pocketbase/dbx"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
// create a new query
|
||||
q := db.NewQuery("SELECT id, name FROM users LIMIT 10")
|
||||
|
||||
// fetch all rows into a struct array
|
||||
var users []struct {
|
||||
ID, Name string
|
||||
}
|
||||
q.All(&users)
|
||||
|
||||
// fetch a single row into a struct
|
||||
var user struct {
|
||||
ID, Name string
|
||||
}
|
||||
q.One(&user)
|
||||
|
||||
// fetch a single row into a string map
|
||||
data := dbx.NullStringMap{}
|
||||
q.One(data)
|
||||
|
||||
// fetch row by row
|
||||
rows2, _ := q.Rows()
|
||||
for rows2.Next() {
|
||||
rows2.ScanStruct(&user)
|
||||
// rows.ScanMap(data)
|
||||
// rows.Scan(&id, &name)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
And the following example shows how to use the query building capability of this package.
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/pocketbase/dbx"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
// build a SELECT query
|
||||
// SELECT `id`, `name` FROM `users` WHERE `name` LIKE '%Charles%' ORDER BY `id`
|
||||
q := db.Select("id", "name").
|
||||
From("users").
|
||||
Where(dbx.Like("name", "Charles")).
|
||||
OrderBy("id")
|
||||
|
||||
// fetch all rows into a struct array
|
||||
var users []struct {
|
||||
ID, Name string
|
||||
}
|
||||
q.All(&users)
|
||||
|
||||
// build an INSERT query
|
||||
// INSERT INTO `users` (`name`) VALUES ('James')
|
||||
db.Insert("users", dbx.Params{
|
||||
"name": "James",
|
||||
}).Execute()
|
||||
}
|
||||
```
|
||||
|
||||
## Connecting to Database
|
||||
|
||||
To connect to a database, call `dbx.Open()` in the same way as you would do with the `Open()` method in `database/sql`.
|
||||
|
||||
```go
|
||||
db, err := dbx.Open("mysql", "user:pass@hostname/db_name")
|
||||
```
|
||||
|
||||
The method returns a `dbx.DB` instance which can be used to create and execute DB queries. Note that the method
|
||||
does not really establish a connection until a query is made using the returned `dbx.DB` instance. It also
|
||||
does not check the correctness of the data source name either. Call `dbx.MustOpen()` to make sure the data
|
||||
source name is correct.
|
||||
|
||||
## Executing Queries
|
||||
|
||||
To execute a SQL statement, first create a `dbx.Query` instance by calling `DB.NewQuery()` with the SQL statement
|
||||
to be executed. And then call `Query.Execute()` to execute the query if the query is not meant to retrieving data.
|
||||
For example,
|
||||
|
||||
```go
|
||||
q := db.NewQuery("UPDATE users SET status=1 WHERE id=100")
|
||||
result, err := q.Execute()
|
||||
```
|
||||
|
||||
If the SQL statement does retrieve data (e.g. a SELECT statement), one of the following methods should be called,
|
||||
which will execute the query and populate the result into the specified variable(s).
|
||||
|
||||
* `Query.All()`: populate all rows of the result into a slice of structs or `NullString` maps.
|
||||
* `Query.One()`: populate the first row of the result into a struct or a `NullString` map.
|
||||
* `Query.Column()`: populate the first column of the result into a slice.
|
||||
* `Query.Row()`: populate the first row of the result into a list of variables, one for each returning column.
|
||||
* `Query.Rows()`: returns a `dbx.Rows` instance to allow retrieving data row by row.
|
||||
|
||||
For example,
|
||||
|
||||
```go
|
||||
type User struct {
|
||||
ID int
|
||||
Name string
|
||||
}
|
||||
|
||||
var (
|
||||
users []User
|
||||
user User
|
||||
|
||||
row dbx.NullStringMap
|
||||
|
||||
id int
|
||||
name string
|
||||
|
||||
err error
|
||||
)
|
||||
|
||||
q := db.NewQuery("SELECT id, name FROM users LIMIT 10")
|
||||
|
||||
// populate all rows into a User slice
|
||||
err = q.All(&users)
|
||||
fmt.Println(users[0].ID, users[0].Name)
|
||||
|
||||
// populate the first row into a User struct
|
||||
err = q.One(&user)
|
||||
fmt.Println(user.ID, user.Name)
|
||||
|
||||
// populate the first row into a NullString map
|
||||
err = q.One(&row)
|
||||
fmt.Println(row["id"], row["name"])
|
||||
|
||||
var ids []int
|
||||
err = q.Column(&ids)
|
||||
fmt.Println(ids)
|
||||
|
||||
// populate the first row into id and name
|
||||
err = q.Row(&id, &name)
|
||||
|
||||
// populate data row by row
|
||||
rows, _ := q.Rows()
|
||||
for rows.Next() {
|
||||
_ = rows.ScanMap(&row)
|
||||
}
|
||||
```
|
||||
|
||||
When populating a struct, the following rules are used to determine which columns should go into which struct fields:
|
||||
|
||||
* Only exported struct fields can be populated.
|
||||
* A field receives data if its name is mapped to a column according to the field mapping function `Query.FieldMapper`.
|
||||
The default field mapping function separates words in a field name by underscores and turns them into lower case.
|
||||
For example, a field name `FirstName` will be mapped to the column name `first_name`, and `MyID` to `my_id`.
|
||||
* If a field has a `db` tag, the tag value will be used as the corresponding column name. If the `db` tag is a dash `-`,
|
||||
it means the field should NOT be populated.
|
||||
* For anonymous fields that are of struct type, they will be expanded and their component fields will be populated
|
||||
according to the rules described above.
|
||||
* For named fields that are of struct type, they will also be expanded. But their component fields will be prefixed
|
||||
with the struct names when being populated.
|
||||
|
||||
An exception to the above struct expansion is that when a struct type implements `sql.Scanner` or when it is `time.Time`.
|
||||
In this case, the field will be populated as a whole by the DB driver. Also, if a field is a pointer to some type,
|
||||
the field will be allocated memory and populated with the query result if it is not null.
|
||||
|
||||
The following example shows how fields are populated according to the rules above:
|
||||
|
||||
```go
|
||||
type User struct {
|
||||
id int
|
||||
Type int `db:"-"`
|
||||
MyName string `db:"name"`
|
||||
Profile
|
||||
Address Address `db:"addr"`
|
||||
}
|
||||
|
||||
type Profile struct {
|
||||
Age int
|
||||
}
|
||||
|
||||
type Address struct {
|
||||
City string
|
||||
}
|
||||
```
|
||||
|
||||
* `User.id`: not populated because the field is not exported;
|
||||
* `User.Type`: not populated because the `db` tag is `-`;
|
||||
* `User.MyName`: to be populated from the `name` column, according to the `db` tag;
|
||||
* `Profile.Age`: to be populated from the `age` column, since `Profile` is an anonymous field;
|
||||
* `Address.City`: to be populated from the `addr.city` column, since `Address` is a named field of struct type
|
||||
and its fields will be prefixed with `addr.` according to the `db` tag.
|
||||
|
||||
Note that if a column in the result does not have a corresponding struct field, it will be ignored. Similarly,
|
||||
if a struct field does not have a corresponding column in the result, it will not be populated.
|
||||
|
||||
## Binding Parameters
|
||||
|
||||
A SQL statement is usually parameterized with dynamic values. For example, you may want to select the user record
|
||||
according to the user ID received from the client. Parameter binding should be used in this case, and it is almost
|
||||
always preferred to prevent from SQL injection attacks. Unlike `database/sql` which does anonymous parameter binding,
|
||||
`dbx` uses named parameter binding. *Anonymous parameter binding is not supported*, as it will mess up with named
|
||||
parameters. For example,
|
||||
|
||||
```go
|
||||
q := db.NewQuery("SELECT id, name FROM users WHERE id={:id}")
|
||||
q.Bind(dbx.Params{"id": 100})
|
||||
err := q.One(&user)
|
||||
```
|
||||
|
||||
The above example will select the user record whose `id` is 100. The method `Query.Bind()` binds a set
|
||||
of named parameters to a SQL statement which contains parameter placeholders in the format of `{:ParamName}`.
|
||||
|
||||
If a SQL statement needs to be executed multiple times with different parameter values, it may be prepared
|
||||
to improve the performance. For example,
|
||||
|
||||
```go
|
||||
q := db.NewQuery("SELECT id, name FROM users WHERE id={:id}")
|
||||
q.Prepare()
|
||||
defer q.Close()
|
||||
|
||||
q.Bind(dbx.Params{"id": 100})
|
||||
err := q.One(&user)
|
||||
|
||||
q.Bind(dbx.Params{"id": 200})
|
||||
err = q.One(&user)
|
||||
|
||||
// ...
|
||||
```
|
||||
|
||||
|
||||
## Cancelable Queries
|
||||
|
||||
Queries are cancelable when they are used with `context.Context`. In particular, by calling `Query.WithContext()` you
|
||||
can associate a context with a query and use the context to cancel the query while it is running. For example,
|
||||
|
||||
```go
|
||||
q := db.NewQuery("SELECT id, name FROM users")
|
||||
err := q.WithContext(ctx).All(&users)
|
||||
```
|
||||
|
||||
|
||||
## Building Queries
|
||||
|
||||
Instead of writing plain SQLs, `dbx` allows you to build SQLs programmatically, which often leads to cleaner,
|
||||
more secure, and DB-agnostic code. You can build three types of queries: the SELECT queries, the data manipulation
|
||||
queries, and the schema manipulation queries.
|
||||
|
||||
### Building SELECT Queries
|
||||
|
||||
Building a SELECT query starts by calling `DB.Select()`. You can build different clauses of a SELECT query using
|
||||
the corresponding query building methods. For example,
|
||||
|
||||
```go
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
err := db.Select("id", "name").
|
||||
From("users").
|
||||
Where(dbx.HashExp{"id": 100}).
|
||||
One(&user)
|
||||
```
|
||||
|
||||
The above code will generate and execute the following SQL statement:
|
||||
|
||||
```sql
|
||||
SELECT `id`, `name` FROM `users` WHERE `id`={:p0}
|
||||
```
|
||||
|
||||
Notice how the table and column names are properly quoted according to the currently using database type.
|
||||
And parameter binding is used to populate the value of `p0` in the `WHERE` clause.
|
||||
|
||||
Every SQL keyword has a corresponding query building method. For example, `SELECT` corresponds to `Select()`,
|
||||
`FROM` corresponds to `From()`, `WHERE` corresponds to `Where()`, and so on. You can chain these method calls
|
||||
together, just like you would do when writing a plain SQL. Each of these methods returns the query instance
|
||||
(of type `dbx.SelectQuery`) that is being built. Once you finish building a query, you may call methods such as
|
||||
`One()`, `All()` to execute the query and populate data into variables. You may also explicitly call `Build()`
|
||||
to build the query and turn it into a `dbx.Query` instance which may allow you to get the SQL statement and do
|
||||
other interesting work.
|
||||
|
||||
|
||||
### Building Query Conditions
|
||||
|
||||
`dbx` supports very flexible and powerful query condition building which can be used to build SQL clauses
|
||||
such as `WHERE`, `HAVING`, etc. For example,
|
||||
|
||||
```go
|
||||
// id=100
|
||||
dbx.NewExp("id={:id}", dbx.Params{"id": 100})
|
||||
|
||||
// id=100 AND status=1
|
||||
dbx.HashExp{"id": 100, "status": 1}
|
||||
|
||||
// status=1 OR age>30
|
||||
dbx.Or(dbx.HashExp{"status": 1}, dbx.NewExp("age>30"))
|
||||
|
||||
// name LIKE '%admin%' AND name LIKE '%example%'
|
||||
dbx.Like("name", "admin", "example")
|
||||
```
|
||||
|
||||
When building a query condition expression, its parameter values will be populated using parameter binding, which
|
||||
prevents SQL injection from happening. Also if an expression involves column names, they will be properly quoted.
|
||||
The following condition building functions are available:
|
||||
|
||||
* `dbx.NewExp()`: creating a condition using the given expression string and binding parameters. For example,
|
||||
`dbx.NewExp("id={:id}", dbx.Params{"id":100})` would create the expression `id=100`.
|
||||
* `dbx.HashExp`: a map type that represents name-value pairs concatenated by `AND` operators. For example,
|
||||
`dbx.HashExp{"id":100, "status":1}` would create `id=100 AND status=1`.
|
||||
* `dbx.Not()`: creating a `NOT` expression by prepending `NOT` to the given expression.
|
||||
* `dbx.And()`: creating an `AND` expression by concatenating the given expressions with the `AND` operators.
|
||||
* `dbx.Or()`: creating an `OR` expression by concatenating the given expressions with the `OR` operators.
|
||||
* `dbx.In()`: creating an `IN` expression for the specified column and the range of values.
|
||||
For example, `dbx.In("age", 30, 40, 50)` would create the expression `age IN (30, 40, 50)`.
|
||||
Note that if the value range is empty, it will generate an expression representing a false value.
|
||||
* `dbx.NotIn()`: creating an `NOT IN` expression. This is very similar to `dbx.In()`.
|
||||
* `dbx.Like()`: creating a `LIKE` expression for the specified column and the range of values. For example,
|
||||
`dbx.Like("title", "golang", "framework")` would create the expression `title LIKE "%golang%" AND title LIKE "%framework%"`.
|
||||
You can further customize a LIKE expression by calling `Escape()` and/or `Match()` functions of the resulting expression.
|
||||
Note that if the value range is empty, it will generate an empty expression.
|
||||
* `dbx.NotLike()`: creating a `NOT LIKE` expression. This is very similar to `dbx.Like()`.
|
||||
* `dbx.OrLike()`: creating a `LIKE` expression but concatenating different `LIKE` sub-expressions using `OR` instead of `AND`.
|
||||
* `dbx.OrNotLike()`: creating a `NOT LIKE` expression and concatenating different `NOT LIKE` sub-expressions using `OR` instead of `AND`.
|
||||
* `dbx.Exists()`: creating an `EXISTS` expression by prepending `EXISTS` to the given expression.
|
||||
* `dbx.NotExists()`: creating a `NOT EXISTS` expression by prepending `NOT EXISTS` to the given expression.
|
||||
* `dbx.Between()`: creating a `BETWEEN` expression. For example, `dbx.Between("age", 30, 40)` would create the
|
||||
expression `age BETWEEN 30 AND 40`.
|
||||
* `dbx.NotBetween()`: creating a `NOT BETWEEN` expression. For example
|
||||
|
||||
You may also create other convenient functions to help building query conditions, as long as the functions return
|
||||
an object implementing the `dbx.Expression` interface.
|
||||
|
||||
|
||||
### Building Data Manipulation Queries
|
||||
|
||||
Data manipulation queries are those changing the data in the database, such as INSERT, UPDATE, DELETE statements.
|
||||
Such queries can be built by calling the corresponding methods of `DB`. For example,
|
||||
|
||||
```go
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
// INSERT INTO `users` (`name`, `email`) VALUES ({:p0}, {:p1})
|
||||
err := db.Insert("users", dbx.Params{
|
||||
"name": "James",
|
||||
"email": "james@example.com",
|
||||
}).Execute()
|
||||
|
||||
// UPDATE `users` SET `status`={:p0} WHERE `id`={:p1}
|
||||
err = db.Update("users", dbx.Params{"status": 1}, dbx.HashExp{"id": 100}).Execute()
|
||||
|
||||
// DELETE FROM `users` WHERE `status`={:p0}
|
||||
err = db.Delete("users", dbx.HashExp{"status": 2}).Execute()
|
||||
```
|
||||
|
||||
When building data manipulation queries, remember to call `Execute()` at the end to execute the queries.
|
||||
|
||||
### Building Schema Manipulation Queries
|
||||
|
||||
Schema manipulation queries are those changing the database schema, such as creating a new table, adding a new column.
|
||||
These queries can be built by calling the corresponding methods of `DB`. For example,
|
||||
|
||||
```go
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
// CREATE TABLE `users` (`id` int primary key, `name` varchar(255))
|
||||
q := db.CreateTable("users", map[string]string{
|
||||
"id": "int primary key",
|
||||
"name": "varchar(255)",
|
||||
})
|
||||
err := q.Execute()
|
||||
```
|
||||
|
||||
## CRUD Operations
|
||||
|
||||
Although `dbx` is not an ORM, it does provide a very convenient way to do typical CRUD (Create, Read, Update, Delete)
|
||||
operations without the need of writing plain SQL statements.
|
||||
|
||||
To use the CRUD feature, first define a struct type for a table. By default, a struct is associated with a table
|
||||
whose name is the snake case version of the struct type name. For example, a struct named `MyCustomer`
|
||||
corresponds to the table name `my_customer`. You may explicitly specify the table name for a struct by implementing
|
||||
the `dbx.TableModel` interface. For example,
|
||||
|
||||
```go
|
||||
type MyCustomer struct{}
|
||||
|
||||
func (c MyCustomer) TableName() string {
|
||||
return "customer"
|
||||
}
|
||||
```
|
||||
|
||||
Note that the `TableName` method should be defined with a value receiver instead of a pointer receiver.
|
||||
|
||||
If the struct has a field named `ID` or `Id`, by default the field will be treated as the primary key field.
|
||||
If you want to use a different field as the primary key, tag it with `db:"pk"`. You may tag multiple fields
|
||||
for composite primary keys. Note that if you also want to explicitly specify the column name for a primary key field,
|
||||
you should use the tag format `db:"pk,col_name"`.
|
||||
|
||||
You can give a common prefix or suffix to your table names by defining your own table name mapping via
|
||||
`DB.TableMapFunc`. For example, the following code prefixes `tbl_` to all table names.
|
||||
|
||||
```go
|
||||
db.TableMapper = func(a interface{}) string {
|
||||
return "tbl_" + GetTableName(a)
|
||||
}
|
||||
```
|
||||
|
||||
### Create
|
||||
|
||||
To create (insert) a new row using a model, call the `ModelQuery.Insert()` method. For example,
|
||||
|
||||
```go
|
||||
type Customer struct {
|
||||
ID int
|
||||
Name string
|
||||
Email string
|
||||
Status int
|
||||
}
|
||||
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
customer := Customer{
|
||||
Name: "example",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
// INSERT INTO customer (name, email, status) VALUES ('example', 'test@example.com', 0)
|
||||
err := db.Model(&customer).Insert()
|
||||
```
|
||||
|
||||
This will insert a row using the values from *all* public fields (except the primary key field if it is empty) in the struct.
|
||||
If a primary key field is zero (a integer zero or a nil pointer), it is assumed to be auto-incremental and
|
||||
will be automatically filled with the last insertion ID after a successful insertion.
|
||||
|
||||
You can explicitly specify the fields that should be inserted by passing the list of the field names to the `Insert()` method.
|
||||
You can also exclude certain fields from being inserted by calling `Exclude()` before calling `Insert()`. For example,
|
||||
|
||||
```go
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
// insert only Name and Email fields
|
||||
err := db.Model(&customer).Insert("Name", "Email")
|
||||
// insert all public fields except Status
|
||||
err = db.Model(&customer).Exclude("Status").Insert()
|
||||
// insert only Name
|
||||
err = db.Model(&customer).Exclude("Status").Insert("Name", "Status")
|
||||
```
|
||||
|
||||
### Read
|
||||
|
||||
To read a model by a given primary key value, call `SelectQuery.Model()`.
|
||||
|
||||
```go
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
var customer Customer
|
||||
// SELECT * FROM customer WHERE id=100
|
||||
err := db.Select().Model(100, &customer)
|
||||
|
||||
// SELECT name, email FROM customer WHERE status=1 AND id=100
|
||||
err = db.Select("name", "email").Where(dbx.HashExp{"status": 1}).Model(100, &customer)
|
||||
```
|
||||
|
||||
Note that `SelectQuery.Model()` does not support composite primary keys. You should use `SelectQuery.One()` in this case.
|
||||
For example,
|
||||
|
||||
```go
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
var orderItem OrderItem
|
||||
|
||||
// SELECT * FROM order_item WHERE order_id=100 AND item_id=20
|
||||
err := db.Select().Where(dbx.HashExp{"order_id": 100, "item_id": 20}).One(&orderItem)
|
||||
```
|
||||
|
||||
In the above queries, we do not call `From()` to specify which table to select data from. This is because the select
|
||||
query automatically sets the table according to the model struct being populated. If the struct implements `TableModel`,
|
||||
the value returned by its `TableName()` method will be used as the table name. Otherwise, the snake case version
|
||||
of the struct type name will be the table name.
|
||||
|
||||
You may also call `SelectQuery.All()` to read a list of model structs. Similarly, you do not need to call `From()`
|
||||
if the table name can be inferred from the model structs.
|
||||
|
||||
|
||||
### Update
|
||||
|
||||
To update a model, call the `ModelQuery.Update()` method. Like `Insert()`, by default, the `Update()` method will
|
||||
update *all* public fields except primary key fields of the model. You can explicitly specify which fields can
|
||||
be updated and which cannot in the same way as described for the `Insert()` method. For example,
|
||||
|
||||
```go
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
// update all public fields of customer
|
||||
err := db.Model(&customer).Update()
|
||||
// update only Status
|
||||
err = db.Model(&customer).Update("Status")
|
||||
// update all public fields except Status
|
||||
err = db.Model(&customer).Exclude("Status").Update()
|
||||
```
|
||||
|
||||
Note that the `Update()` method assumes that the primary keys are immutable. It uses the primary key value of the model
|
||||
to look for the row that should be updated. An error will be returned if a model does not have a primary key.
|
||||
|
||||
|
||||
### Delete
|
||||
|
||||
To delete a model, call the `ModelQuery.Delete()` method. The method deletes the row using the primary key value
|
||||
specified by the model. If the model does not have a primary key, an error will be returned. For example,
|
||||
|
||||
```go
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
err := db.Model(&customer).Delete()
|
||||
```
|
||||
|
||||
### Null Handling
|
||||
|
||||
To represent a nullable database value, you can use a pointer type. If the pointer is nil, it means the corresponding
|
||||
database value is null.
|
||||
|
||||
Another option to represent a database null is to use `sql.NullXyz` types. For example, if a string column is nullable,
|
||||
you may use `sql.NullString`. The `NullString.Valid` field indicates whether the value is a null or not, and
|
||||
`NullString.String` returns the string value when it is not null. Because `sql.NulLXyz` types do not handle JSON
|
||||
marshalling, you may use the [null package](https://github.com/guregu/null), instead.
|
||||
|
||||
Below is an example of handling nulls:
|
||||
|
||||
```go
|
||||
type Customer struct {
|
||||
ID int
|
||||
Email string
|
||||
FirstName *string // use pointer to represent null
|
||||
LastName sql.NullString // use sql.NullString to represent null
|
||||
}
|
||||
```
|
||||
|
||||
## Quoting Table and Column Names
|
||||
|
||||
Databases vary in quoting table and column names. To allow writing DB-agnostic SQLs, `dbx` introduces a special
|
||||
syntax in quoting table and column names. A word enclosed within `{{` and `}}` is treated as a table name and will
|
||||
be quoted according to the particular DB driver. Similarly, a word enclosed within `[[` and `]]` is treated as a
|
||||
column name and will be quoted accordingly as well. For example, when working with a MySQL database, the following
|
||||
query will be properly quoted:
|
||||
|
||||
```go
|
||||
// SELECT * FROM `users` WHERE `status`=1
|
||||
q := db.NewQuery("SELECT * FROM {{users}} WHERE [[status]]=1")
|
||||
```
|
||||
|
||||
Note that if a table or column name contains a prefix, it will still be properly quoted. For example, `{{public.users}}`
|
||||
will be quoted as `"public"."users"` for PostgreSQL.
|
||||
|
||||
## Using Transactions
|
||||
|
||||
You can use all aforementioned query execution and building methods with transaction. For example,
|
||||
|
||||
```go
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
tx, _ := db.Begin()
|
||||
|
||||
_, err1 := tx.Insert("users", dbx.Params{
|
||||
"name": "user1",
|
||||
}).Execute()
|
||||
_, err2 := tx.Insert("users", dbx.Params{
|
||||
"name": "user2",
|
||||
}).Execute()
|
||||
|
||||
if err1 == nil && err2 == nil {
|
||||
tx.Commit()
|
||||
} else {
|
||||
tx.Rollback()
|
||||
}
|
||||
```
|
||||
|
||||
You may use `DB.Transactional()` to simplify your transactional code without explicitly committing or rolling back
|
||||
transactions. The method will start a transaction and automatically roll back the transaction if the callback
|
||||
returns an error. Otherwise it will
|
||||
automatically commit the transaction.
|
||||
|
||||
|
||||
```go
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
err := db.Transactional(func(tx *dbx.Tx) error {
|
||||
var err error
|
||||
_, err = tx.Insert("users", dbx.Params{
|
||||
"name": "user1",
|
||||
}).Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Insert("users", dbx.Params{
|
||||
"name": "user2",
|
||||
}).Execute()
|
||||
return err
|
||||
})
|
||||
|
||||
fmt.Println(err)
|
||||
```
|
||||
|
||||
## Logging Executed SQL Statements
|
||||
|
||||
You can log and instrument DB queries by installing loggers with a DB connection. There are three kinds of loggers you
|
||||
can install:
|
||||
* `DB.LogFunc`: this is called each time when a SQL statement is queried or executed. The function signature is the
|
||||
same as that of `fmt.Printf`, which makes it very easy to use.
|
||||
* `DB.QueryLogFunc`: this is called each time when querying with a SQL statement.
|
||||
* `DB.ExecLogFunc`: this is called when executing a SQL statement.
|
||||
|
||||
The following example shows how you can make use of these loggers.
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
// simple logging
|
||||
db.LogFunc = log.Printf
|
||||
|
||||
// or you can use the following more flexible logging
|
||||
db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
log.Printf("[%.2fms] Query SQL: %v", float64(t.Milliseconds()), sql)
|
||||
}
|
||||
db.ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
|
||||
log.Printf("[%.2fms] Execute SQL: %v", float64(t.Milliseconds()), sql)
|
||||
}
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
## Supporting New Databases
|
||||
|
||||
While `dbx` provides out-of-box query building support for most major relational databases, its open architecture
|
||||
allows you to add support for new databases. The effort of adding support for a new database involves:
|
||||
|
||||
* Create a struct that implements the `QueryBuilder` interface. You may use `BaseQueryBuilder` directly or extend it
|
||||
via composition.
|
||||
* Create a struct that implements the `Builder` interface. You may extend `BaseBuilder` via composition.
|
||||
* Write an `init()` function to register the new builder in `dbx.BuilderFuncMap`.
|
402
builder.go
Normal file
402
builder.go
Normal file
|
@ -0,0 +1,402 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Builder supports building SQL statements in a DB-agnostic way.
|
||||
// Builder mainly provides two sets of query building methods: those building SELECT statements
|
||||
// and those manipulating DB data or schema (e.g. INSERT statements, CREATE TABLE statements).
|
||||
type Builder interface {
|
||||
// NewQuery creates a new Query object with the given SQL statement.
|
||||
// The SQL statement may contain parameter placeholders which can be bound with actual parameter
|
||||
// values before the statement is executed.
|
||||
NewQuery(string) *Query
|
||||
// Select returns a new SelectQuery object that can be used to build a SELECT statement.
|
||||
// The parameters to this method should be the list column names to be selected.
|
||||
// A column name may have an optional alias name. For example, Select("id", "my_name AS name").
|
||||
Select(...string) *SelectQuery
|
||||
// ModelQuery returns a new ModelQuery object that can be used to perform model insertion, update, and deletion.
|
||||
// The parameter to this method should be a pointer to the model struct that needs to be inserted, updated, or deleted.
|
||||
Model(interface{}) *ModelQuery
|
||||
|
||||
// GeneratePlaceholder generates an anonymous parameter placeholder with the given parameter ID.
|
||||
GeneratePlaceholder(int) string
|
||||
|
||||
// Quote quotes a string so that it can be embedded in a SQL statement as a string value.
|
||||
Quote(string) string
|
||||
// QuoteSimpleTableName quotes a simple table name.
|
||||
// A simple table name does not contain any schema prefix.
|
||||
QuoteSimpleTableName(string) string
|
||||
// QuoteSimpleColumnName quotes a simple column name.
|
||||
// A simple column name does not contain any table prefix.
|
||||
QuoteSimpleColumnName(string) string
|
||||
|
||||
// QueryBuilder returns the query builder supporting the current DB.
|
||||
QueryBuilder() QueryBuilder
|
||||
|
||||
// Insert creates a Query that represents an INSERT SQL statement.
|
||||
// The keys of cols are the column names, while the values of cols are the corresponding column
|
||||
// values to be inserted.
|
||||
Insert(table string, cols Params) *Query
|
||||
// Upsert creates a Query that represents an UPSERT SQL statement.
|
||||
// Upsert inserts a row into the table if the primary key or unique index is not found.
|
||||
// Otherwise it will update the row with the new values.
|
||||
// The keys of cols are the column names, while the values of cols are the corresponding column
|
||||
// values to be inserted.
|
||||
Upsert(table string, cols Params, constraints ...string) *Query
|
||||
// Update creates a Query that represents an UPDATE SQL statement.
|
||||
// The keys of cols are the column names, while the values of cols are the corresponding new column
|
||||
// values. If the "where" expression is nil, the UPDATE SQL statement will have no WHERE clause
|
||||
// (be careful in this case as the SQL statement will update ALL rows in the table).
|
||||
Update(table string, cols Params, where Expression) *Query
|
||||
// Delete creates a Query that represents a DELETE SQL statement.
|
||||
// If the "where" expression is nil, the DELETE SQL statement will have no WHERE clause
|
||||
// (be careful in this case as the SQL statement will delete ALL rows in the table).
|
||||
Delete(table string, where Expression) *Query
|
||||
|
||||
// CreateTable creates a Query that represents a CREATE TABLE SQL statement.
|
||||
// The keys of cols are the column names, while the values of cols are the corresponding column types.
|
||||
// The optional "options" parameters will be appended to the generated SQL statement.
|
||||
CreateTable(table string, cols map[string]string, options ...string) *Query
|
||||
// RenameTable creates a Query that can be used to rename a table.
|
||||
RenameTable(oldName, newName string) *Query
|
||||
// DropTable creates a Query that can be used to drop a table.
|
||||
DropTable(table string) *Query
|
||||
// TruncateTable creates a Query that can be used to truncate a table.
|
||||
TruncateTable(table string) *Query
|
||||
|
||||
// AddColumn creates a Query that can be used to add a column to a table.
|
||||
AddColumn(table, col, typ string) *Query
|
||||
// DropColumn creates a Query that can be used to drop a column from a table.
|
||||
DropColumn(table, col string) *Query
|
||||
// RenameColumn creates a Query that can be used to rename a column in a table.
|
||||
RenameColumn(table, oldName, newName string) *Query
|
||||
// AlterColumn creates a Query that can be used to change the definition of a table column.
|
||||
AlterColumn(table, col, typ string) *Query
|
||||
|
||||
// AddPrimaryKey creates a Query that can be used to specify primary key(s) for a table.
|
||||
// The "name" parameter specifies the name of the primary key constraint.
|
||||
AddPrimaryKey(table, name string, cols ...string) *Query
|
||||
// DropPrimaryKey creates a Query that can be used to remove the named primary key constraint from a table.
|
||||
DropPrimaryKey(table, name string) *Query
|
||||
|
||||
// AddForeignKey creates a Query that can be used to add a foreign key constraint to a table.
|
||||
// The length of cols and refCols must be the same as they refer to the primary and referential columns.
|
||||
// The optional "options" parameters will be appended to the SQL statement. They can be used to
|
||||
// specify options such as "ON DELETE CASCADE".
|
||||
AddForeignKey(table, name string, cols, refCols []string, refTable string, options ...string) *Query
|
||||
// DropForeignKey creates a Query that can be used to remove the named foreign key constraint from a table.
|
||||
DropForeignKey(table, name string) *Query
|
||||
|
||||
// CreateIndex creates a Query that can be used to create an index for a table.
|
||||
CreateIndex(table, name string, cols ...string) *Query
|
||||
// CreateUniqueIndex creates a Query that can be used to create a unique index for a table.
|
||||
CreateUniqueIndex(table, name string, cols ...string) *Query
|
||||
// DropIndex creates a Query that can be used to remove the named index from a table.
|
||||
DropIndex(table, name string) *Query
|
||||
}
|
||||
|
||||
// BaseBuilder provides a basic implementation of the Builder interface.
|
||||
type BaseBuilder struct {
|
||||
db *DB
|
||||
executor Executor
|
||||
}
|
||||
|
||||
// NewBaseBuilder creates a new BaseBuilder instance.
|
||||
func NewBaseBuilder(db *DB, executor Executor) *BaseBuilder {
|
||||
return &BaseBuilder{db, executor}
|
||||
}
|
||||
|
||||
// DB returns the DB instance that this builder is associated with.
|
||||
func (b *BaseBuilder) DB() *DB {
|
||||
return b.db
|
||||
}
|
||||
|
||||
// Executor returns the executor object (a DB instance or a transaction) for executing SQL statements.
|
||||
func (b *BaseBuilder) Executor() Executor {
|
||||
return b.executor
|
||||
}
|
||||
|
||||
// NewQuery creates a new Query object with the given SQL statement.
|
||||
// The SQL statement may contain parameter placeholders which can be bound with actual parameter
|
||||
// values before the statement is executed.
|
||||
func (b *BaseBuilder) NewQuery(sql string) *Query {
|
||||
return NewQuery(b.db, b.executor, sql)
|
||||
}
|
||||
|
||||
// GeneratePlaceholder generates an anonymous parameter placeholder with the given parameter ID.
|
||||
func (b *BaseBuilder) GeneratePlaceholder(int) string {
|
||||
return "?"
|
||||
}
|
||||
|
||||
// Quote quotes a string so that it can be embedded in a SQL statement as a string value.
|
||||
func (b *BaseBuilder) Quote(s string) string {
|
||||
return "'" + strings.Replace(s, "'", "''", -1) + "'"
|
||||
}
|
||||
|
||||
// QuoteSimpleTableName quotes a simple table name.
|
||||
// A simple table name does not contain any schema prefix.
|
||||
func (b *BaseBuilder) QuoteSimpleTableName(s string) string {
|
||||
if strings.Contains(s, `"`) {
|
||||
return s
|
||||
}
|
||||
return `"` + s + `"`
|
||||
}
|
||||
|
||||
// QuoteSimpleColumnName quotes a simple column name.
|
||||
// A simple column name does not contain any table prefix.
|
||||
func (b *BaseBuilder) QuoteSimpleColumnName(s string) string {
|
||||
if strings.Contains(s, `"`) || s == "*" {
|
||||
return s
|
||||
}
|
||||
return `"` + s + `"`
|
||||
}
|
||||
|
||||
// Insert creates a Query that represents an INSERT SQL statement.
|
||||
// The keys of cols are the column names, while the values of cols are the corresponding column
|
||||
// values to be inserted.
|
||||
func (b *BaseBuilder) Insert(table string, cols Params) *Query {
|
||||
names := make([]string, 0, len(cols))
|
||||
for name := range cols {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
|
||||
params := Params{}
|
||||
columns := make([]string, 0, len(names))
|
||||
values := make([]string, 0, len(names))
|
||||
for _, name := range names {
|
||||
columns = append(columns, b.db.QuoteColumnName(name))
|
||||
value := cols[name]
|
||||
if e, ok := value.(Expression); ok {
|
||||
values = append(values, e.Build(b.db, params))
|
||||
} else {
|
||||
values = append(values, fmt.Sprintf("{:p%v}", len(params)))
|
||||
params[fmt.Sprintf("p%v", len(params))] = value
|
||||
}
|
||||
}
|
||||
|
||||
var sql string
|
||||
if len(names) == 0 {
|
||||
sql = fmt.Sprintf("INSERT INTO %v DEFAULT VALUES", b.db.QuoteTableName(table))
|
||||
} else {
|
||||
sql = fmt.Sprintf("INSERT INTO %v (%v) VALUES (%v)",
|
||||
b.db.QuoteTableName(table),
|
||||
strings.Join(columns, ", "),
|
||||
strings.Join(values, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
return b.NewQuery(sql).Bind(params)
|
||||
}
|
||||
|
||||
// Upsert creates a Query that represents an UPSERT SQL statement.
|
||||
// Upsert inserts a row into the table if the primary key or unique index is not found.
|
||||
// Otherwise it will update the row with the new values.
|
||||
// The keys of cols are the column names, while the values of cols are the corresponding column
|
||||
// values to be inserted.
|
||||
func (b *BaseBuilder) Upsert(table string, cols Params, constraints ...string) *Query {
|
||||
q := b.NewQuery("")
|
||||
q.LastError = errors.New("Upsert is not supported")
|
||||
return q
|
||||
}
|
||||
|
||||
// Update creates a Query that represents an UPDATE SQL statement.
|
||||
// The keys of cols are the column names, while the values of cols are the corresponding new column
|
||||
// values. If the "where" expression is nil, the UPDATE SQL statement will have no WHERE clause
|
||||
// (be careful in this case as the SQL statement will update ALL rows in the table).
|
||||
func (b *BaseBuilder) Update(table string, cols Params, where Expression) *Query {
|
||||
names := make([]string, 0, len(cols))
|
||||
for name := range cols {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
|
||||
params := Params{}
|
||||
lines := make([]string, 0, len(names))
|
||||
for _, name := range names {
|
||||
value := cols[name]
|
||||
name = b.db.QuoteColumnName(name)
|
||||
if e, ok := value.(Expression); ok {
|
||||
lines = append(lines, name+"="+e.Build(b.db, params))
|
||||
} else {
|
||||
lines = append(lines, fmt.Sprintf("%v={:p%v}", name, len(params)))
|
||||
params[fmt.Sprintf("p%v", len(params))] = value
|
||||
}
|
||||
}
|
||||
|
||||
sql := fmt.Sprintf("UPDATE %v SET %v", b.db.QuoteTableName(table), strings.Join(lines, ", "))
|
||||
if where != nil {
|
||||
w := where.Build(b.db, params)
|
||||
if w != "" {
|
||||
sql += " WHERE " + w
|
||||
}
|
||||
}
|
||||
|
||||
return b.NewQuery(sql).Bind(params)
|
||||
}
|
||||
|
||||
// Delete creates a Query that represents a DELETE SQL statement.
|
||||
// If the "where" expression is nil, the DELETE SQL statement will have no WHERE clause
|
||||
// (be careful in this case as the SQL statement will delete ALL rows in the table).
|
||||
func (b *BaseBuilder) Delete(table string, where Expression) *Query {
|
||||
sql := "DELETE FROM " + b.db.QuoteTableName(table)
|
||||
params := Params{}
|
||||
if where != nil {
|
||||
w := where.Build(b.db, params)
|
||||
if w != "" {
|
||||
sql += " WHERE " + w
|
||||
}
|
||||
}
|
||||
return b.NewQuery(sql).Bind(params)
|
||||
}
|
||||
|
||||
// CreateTable creates a Query that represents a CREATE TABLE SQL statement.
|
||||
// The keys of cols are the column names, while the values of cols are the corresponding column types.
|
||||
// The optional "options" parameters will be appended to the generated SQL statement.
|
||||
func (b *BaseBuilder) CreateTable(table string, cols map[string]string, options ...string) *Query {
|
||||
names := []string{}
|
||||
for name := range cols {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
|
||||
columns := []string{}
|
||||
for _, name := range names {
|
||||
columns = append(columns, b.db.QuoteColumnName(name)+" "+cols[name])
|
||||
}
|
||||
|
||||
sql := fmt.Sprintf("CREATE TABLE %v (%v)", b.db.QuoteTableName(table), strings.Join(columns, ", "))
|
||||
for _, opt := range options {
|
||||
sql += " " + opt
|
||||
}
|
||||
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// RenameTable creates a Query that can be used to rename a table.
|
||||
func (b *BaseBuilder) RenameTable(oldName, newName string) *Query {
|
||||
sql := fmt.Sprintf("RENAME TABLE %v TO %v", b.db.QuoteTableName(oldName), b.db.QuoteTableName(newName))
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// DropTable creates a Query that can be used to drop a table.
|
||||
func (b *BaseBuilder) DropTable(table string) *Query {
|
||||
sql := "DROP TABLE " + b.db.QuoteTableName(table)
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// TruncateTable creates a Query that can be used to truncate a table.
|
||||
func (b *BaseBuilder) TruncateTable(table string) *Query {
|
||||
sql := "TRUNCATE TABLE " + b.db.QuoteTableName(table)
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// AddColumn creates a Query that can be used to add a column to a table.
|
||||
func (b *BaseBuilder) AddColumn(table, col, typ string) *Query {
|
||||
sql := fmt.Sprintf("ALTER TABLE %v ADD %v %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(col), typ)
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// DropColumn creates a Query that can be used to drop a column from a table.
|
||||
func (b *BaseBuilder) DropColumn(table, col string) *Query {
|
||||
sql := fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(col))
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// RenameColumn creates a Query that can be used to rename a column in a table.
|
||||
func (b *BaseBuilder) RenameColumn(table, oldName, newName string) *Query {
|
||||
sql := fmt.Sprintf("ALTER TABLE %v RENAME COLUMN %v TO %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(oldName), b.db.QuoteColumnName(newName))
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// AlterColumn creates a Query that can be used to change the definition of a table column.
|
||||
func (b *BaseBuilder) AlterColumn(table, col, typ string) *Query {
|
||||
col = b.db.QuoteColumnName(col)
|
||||
sql := fmt.Sprintf("ALTER TABLE %v CHANGE %v %v %v", b.db.QuoteTableName(table), col, col, typ)
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// AddPrimaryKey creates a Query that can be used to specify primary key(s) for a table.
|
||||
// The "name" parameter specifies the name of the primary key constraint.
|
||||
func (b *BaseBuilder) AddPrimaryKey(table, name string, cols ...string) *Query {
|
||||
sql := fmt.Sprintf("ALTER TABLE %v ADD CONSTRAINT %v PRIMARY KEY (%v)",
|
||||
b.db.QuoteTableName(table),
|
||||
b.db.QuoteColumnName(name),
|
||||
b.quoteColumns(cols))
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// DropPrimaryKey creates a Query that can be used to remove the named primary key constraint from a table.
|
||||
func (b *BaseBuilder) DropPrimaryKey(table, name string) *Query {
|
||||
sql := fmt.Sprintf("ALTER TABLE %v DROP CONSTRAINT %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(name))
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// AddForeignKey creates a Query that can be used to add a foreign key constraint to a table.
|
||||
// The length of cols and refCols must be the same as they refer to the primary and referential columns.
|
||||
// The optional "options" parameters will be appended to the SQL statement. They can be used to
|
||||
// specify options such as "ON DELETE CASCADE".
|
||||
func (b *BaseBuilder) AddForeignKey(table, name string, cols, refCols []string, refTable string, options ...string) *Query {
|
||||
sql := fmt.Sprintf("ALTER TABLE %v ADD CONSTRAINT %v FOREIGN KEY (%v) REFERENCES %v (%v)",
|
||||
b.db.QuoteTableName(table),
|
||||
b.db.QuoteColumnName(name),
|
||||
b.quoteColumns(cols),
|
||||
b.db.QuoteTableName(refTable),
|
||||
b.quoteColumns(refCols))
|
||||
for _, opt := range options {
|
||||
sql += " " + opt
|
||||
}
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// DropForeignKey creates a Query that can be used to remove the named foreign key constraint from a table.
|
||||
func (b *BaseBuilder) DropForeignKey(table, name string) *Query {
|
||||
sql := fmt.Sprintf("ALTER TABLE %v DROP CONSTRAINT %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(name))
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// CreateIndex creates a Query that can be used to create an index for a table.
|
||||
func (b *BaseBuilder) CreateIndex(table, name string, cols ...string) *Query {
|
||||
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v)",
|
||||
b.db.QuoteColumnName(name),
|
||||
b.db.QuoteTableName(table),
|
||||
b.quoteColumns(cols))
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// CreateUniqueIndex creates a Query that can be used to create a unique index for a table.
|
||||
func (b *BaseBuilder) CreateUniqueIndex(table, name string, cols ...string) *Query {
|
||||
sql := fmt.Sprintf("CREATE UNIQUE INDEX %v ON %v (%v)",
|
||||
b.db.QuoteColumnName(name),
|
||||
b.db.QuoteTableName(table),
|
||||
b.quoteColumns(cols))
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// DropIndex creates a Query that can be used to remove the named index from a table.
|
||||
func (b *BaseBuilder) DropIndex(table, name string) *Query {
|
||||
sql := fmt.Sprintf("DROP INDEX %v ON %v", b.db.QuoteColumnName(name), b.db.QuoteTableName(table))
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// quoteColumns quotes a list of columns and concatenates them with commas.
|
||||
func (b *BaseBuilder) quoteColumns(cols []string) string {
|
||||
s := ""
|
||||
for i, col := range cols {
|
||||
if i == 0 {
|
||||
s = b.db.QuoteColumnName(col)
|
||||
} else {
|
||||
s += ", " + b.db.QuoteColumnName(col)
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
115
builder_mssql.go
Normal file
115
builder_mssql.go
Normal file
|
@ -0,0 +1,115 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// MssqlBuilder is the builder for SQL Server databases.
|
||||
type MssqlBuilder struct {
|
||||
*BaseBuilder
|
||||
qb *MssqlQueryBuilder
|
||||
}
|
||||
|
||||
var _ Builder = &MssqlBuilder{}
|
||||
|
||||
// MssqlQueryBuilder is the query builder for SQL Server databases.
|
||||
type MssqlQueryBuilder struct {
|
||||
*BaseQueryBuilder
|
||||
}
|
||||
|
||||
// NewMssqlBuilder creates a new MssqlBuilder instance.
|
||||
func NewMssqlBuilder(db *DB, executor Executor) Builder {
|
||||
return &MssqlBuilder{
|
||||
NewBaseBuilder(db, executor),
|
||||
&MssqlQueryBuilder{NewBaseQueryBuilder(db)},
|
||||
}
|
||||
}
|
||||
|
||||
// QueryBuilder returns the query builder supporting the current DB.
|
||||
func (b *MssqlBuilder) QueryBuilder() QueryBuilder {
|
||||
return b.qb
|
||||
}
|
||||
|
||||
// Select returns a new SelectQuery object that can be used to build a SELECT statement.
|
||||
// The parameters to this method should be the list column names to be selected.
|
||||
// A column name may have an optional alias name. For example, Select("id", "my_name AS name").
|
||||
func (b *MssqlBuilder) Select(cols ...string) *SelectQuery {
|
||||
return NewSelectQuery(b, b.db).Select(cols...)
|
||||
}
|
||||
|
||||
// Model returns a new ModelQuery object that can be used to perform model-based DB operations.
|
||||
// The model passed to this method should be a pointer to a model struct.
|
||||
func (b *MssqlBuilder) Model(model interface{}) *ModelQuery {
|
||||
return NewModelQuery(model, b.db.FieldMapper, b.db, b)
|
||||
}
|
||||
|
||||
// QuoteSimpleTableName quotes a simple table name.
|
||||
// A simple table name does not contain any schema prefix.
|
||||
func (b *MssqlBuilder) QuoteSimpleTableName(s string) string {
|
||||
if strings.Contains(s, `[`) {
|
||||
return s
|
||||
}
|
||||
return `[` + s + `]`
|
||||
}
|
||||
|
||||
// QuoteSimpleColumnName quotes a simple column name.
|
||||
// A simple column name does not contain any table prefix.
|
||||
func (b *MssqlBuilder) QuoteSimpleColumnName(s string) string {
|
||||
if strings.Contains(s, `[`) || s == "*" {
|
||||
return s
|
||||
}
|
||||
return `[` + s + `]`
|
||||
}
|
||||
|
||||
// RenameTable creates a Query that can be used to rename a table.
|
||||
func (b *MssqlBuilder) RenameTable(oldName, newName string) *Query {
|
||||
sql := fmt.Sprintf("sp_name '%v', '%v'", oldName, newName)
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// RenameColumn creates a Query that can be used to rename a column in a table.
|
||||
func (b *MssqlBuilder) RenameColumn(table, oldName, newName string) *Query {
|
||||
sql := fmt.Sprintf("sp_name '%v.%v', '%v', 'COLUMN'", table, oldName, newName)
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// AlterColumn creates a Query that can be used to change the definition of a table column.
|
||||
func (b *MssqlBuilder) AlterColumn(table, col, typ string) *Query {
|
||||
col = b.db.QuoteColumnName(col)
|
||||
sql := fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", b.db.QuoteTableName(table), col, typ)
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// BuildOrderByAndLimit generates the ORDER BY and LIMIT clauses.
|
||||
func (q *MssqlQueryBuilder) BuildOrderByAndLimit(sql string, cols []string, limit int64, offset int64) string {
|
||||
orderBy := q.BuildOrderBy(cols)
|
||||
if limit < 0 && offset < 0 {
|
||||
if orderBy == "" {
|
||||
return sql
|
||||
}
|
||||
return sql + "\n" + orderBy
|
||||
}
|
||||
|
||||
// only SQL SERVER 2012 or newer are supported by this method
|
||||
|
||||
if orderBy == "" {
|
||||
// ORDER BY clause is required when FETCH and OFFSET are in the SQL
|
||||
orderBy = "ORDER BY (SELECT NULL)"
|
||||
}
|
||||
sql += "\n" + orderBy
|
||||
|
||||
// http://technet.microsoft.com/en-us/library/gg699618.aspx
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
sql += "\n" + fmt.Sprintf("OFFSET %v ROWS", offset)
|
||||
if limit >= 0 {
|
||||
sql += "\n" + fmt.Sprintf("FETCH NEXT %v ROWS ONLY", limit)
|
||||
}
|
||||
return sql
|
||||
}
|
73
builder_mssql_test.go
Normal file
73
builder_mssql_test.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMssqlBuilder_QuoteSimpleTableName(t *testing.T) {
|
||||
b := getMssqlBuilder()
|
||||
assert.Equal(t, b.QuoteSimpleTableName(`abc`), "[abc]", "t1")
|
||||
assert.Equal(t, b.QuoteSimpleTableName("[abc]"), "[abc]", "t2")
|
||||
assert.Equal(t, b.QuoteSimpleTableName(`{{abc}}`), "[{{abc}}]", "t3")
|
||||
assert.Equal(t, b.QuoteSimpleTableName(`a.bc`), "[a.bc]", "t4")
|
||||
}
|
||||
|
||||
func TestMssqlBuilder_QuoteSimpleColumnName(t *testing.T) {
|
||||
b := getMssqlBuilder()
|
||||
assert.Equal(t, b.QuoteSimpleColumnName(`abc`), "[abc]", "t1")
|
||||
assert.Equal(t, b.QuoteSimpleColumnName("[abc]"), "[abc]", "t2")
|
||||
assert.Equal(t, b.QuoteSimpleColumnName(`{{abc}}`), "[{{abc}}]", "t3")
|
||||
assert.Equal(t, b.QuoteSimpleColumnName(`a.bc`), "[a.bc]", "t4")
|
||||
assert.Equal(t, b.QuoteSimpleColumnName(`*`), `*`, "t5")
|
||||
}
|
||||
|
||||
func TestMssqlBuilder_RenameTable(t *testing.T) {
|
||||
b := getMssqlBuilder()
|
||||
q := b.RenameTable("users", "user")
|
||||
assert.Equal(t, q.SQL(), `sp_name 'users', 'user'`, "t1")
|
||||
}
|
||||
|
||||
func TestMssqlBuilder_RenameColumn(t *testing.T) {
|
||||
b := getMssqlBuilder()
|
||||
q := b.RenameColumn("users", "name", "username")
|
||||
assert.Equal(t, q.SQL(), `sp_name 'users.name', 'username', 'COLUMN'`, "t1")
|
||||
}
|
||||
|
||||
func TestMssqlBuilder_AlterColumn(t *testing.T) {
|
||||
b := getMssqlBuilder()
|
||||
q := b.AlterColumn("users", "name", "int")
|
||||
assert.Equal(t, q.SQL(), `ALTER TABLE [users] ALTER COLUMN [name] int`, "t1")
|
||||
}
|
||||
|
||||
func TestMssqlQueryBuilder_BuildOrderByAndLimit(t *testing.T) {
|
||||
qb := getMssqlBuilder().QueryBuilder()
|
||||
|
||||
sql := qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, 10, 2)
|
||||
expected := "SELECT *\nORDER BY [name]\nOFFSET 2 ROWS\nFETCH NEXT 10 ROWS ONLY"
|
||||
assert.Equal(t, sql, expected, "t1")
|
||||
|
||||
sql = qb.BuildOrderByAndLimit("SELECT *", nil, -1, -1)
|
||||
expected = "SELECT *"
|
||||
assert.Equal(t, sql, expected, "t2")
|
||||
|
||||
sql = qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, -1, -1)
|
||||
expected = "SELECT *\nORDER BY [name]"
|
||||
assert.Equal(t, sql, expected, "t3")
|
||||
|
||||
sql = qb.BuildOrderByAndLimit("SELECT *", nil, 10, -1)
|
||||
expected = "SELECT *\nORDER BY (SELECT NULL)\nOFFSET 0 ROWS\nFETCH NEXT 10 ROWS ONLY"
|
||||
assert.Equal(t, sql, expected, "t4")
|
||||
}
|
||||
|
||||
func getMssqlBuilder() Builder {
|
||||
db := getDB()
|
||||
b := NewMssqlBuilder(db, db.sqlDB)
|
||||
db.Builder = b
|
||||
return b
|
||||
}
|
133
builder_mysql.go
Normal file
133
builder_mysql.go
Normal file
|
@ -0,0 +1,133 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// MysqlBuilder is the builder for MySQL databases.
|
||||
type MysqlBuilder struct {
|
||||
*BaseBuilder
|
||||
qb *BaseQueryBuilder
|
||||
}
|
||||
|
||||
var _ Builder = &MysqlBuilder{}
|
||||
|
||||
// NewMysqlBuilder creates a new MysqlBuilder instance.
|
||||
func NewMysqlBuilder(db *DB, executor Executor) Builder {
|
||||
return &MysqlBuilder{
|
||||
NewBaseBuilder(db, executor),
|
||||
NewBaseQueryBuilder(db),
|
||||
}
|
||||
}
|
||||
|
||||
// QueryBuilder returns the query builder supporting the current DB.
|
||||
func (b *MysqlBuilder) QueryBuilder() QueryBuilder {
|
||||
return b.qb
|
||||
}
|
||||
|
||||
// Select returns a new SelectQuery object that can be used to build a SELECT statement.
|
||||
// The parameters to this method should be the list column names to be selected.
|
||||
// A column name may have an optional alias name. For example, Select("id", "my_name AS name").
|
||||
func (b *MysqlBuilder) Select(cols ...string) *SelectQuery {
|
||||
return NewSelectQuery(b, b.db).Select(cols...)
|
||||
}
|
||||
|
||||
// Model returns a new ModelQuery object that can be used to perform model-based DB operations.
|
||||
// The model passed to this method should be a pointer to a model struct.
|
||||
func (b *MysqlBuilder) Model(model interface{}) *ModelQuery {
|
||||
return NewModelQuery(model, b.db.FieldMapper, b.db, b)
|
||||
}
|
||||
|
||||
// QuoteSimpleTableName quotes a simple table name.
|
||||
// A simple table name does not contain any schema prefix.
|
||||
func (b *MysqlBuilder) QuoteSimpleTableName(s string) string {
|
||||
if strings.ContainsAny(s, "`") {
|
||||
return s
|
||||
}
|
||||
return "`" + s + "`"
|
||||
}
|
||||
|
||||
// QuoteSimpleColumnName quotes a simple column name.
|
||||
// A simple column name does not contain any table prefix.
|
||||
func (b *MysqlBuilder) QuoteSimpleColumnName(s string) string {
|
||||
if strings.Contains(s, "`") || s == "*" {
|
||||
return s
|
||||
}
|
||||
return "`" + s + "`"
|
||||
}
|
||||
|
||||
// Upsert creates a Query that represents an UPSERT SQL statement.
|
||||
// Upsert inserts a row into the table if the primary key or unique index is not found.
|
||||
// Otherwise it will update the row with the new values.
|
||||
// The keys of cols are the column names, while the values of cols are the corresponding column
|
||||
// values to be inserted.
|
||||
func (b *MysqlBuilder) Upsert(table string, cols Params, constraints ...string) *Query {
|
||||
q := b.Insert(table, cols)
|
||||
|
||||
names := []string{}
|
||||
for name := range cols {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
|
||||
lines := []string{}
|
||||
for _, name := range names {
|
||||
value := cols[name]
|
||||
name = b.db.QuoteColumnName(name)
|
||||
if e, ok := value.(Expression); ok {
|
||||
lines = append(lines, name+"="+e.Build(b.db, q.params))
|
||||
} else {
|
||||
lines = append(lines, fmt.Sprintf("%v={:p%v}", name, len(q.params)))
|
||||
q.params[fmt.Sprintf("p%v", len(q.params))] = value
|
||||
}
|
||||
}
|
||||
|
||||
q.sql += " ON DUPLICATE KEY UPDATE " + strings.Join(lines, ", ")
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
var mysqlColumnRegexp = regexp.MustCompile("(?m)^\\s*[`\"](.*?)[`\"]\\s+(.*?),?$")
|
||||
|
||||
// RenameColumn creates a Query that can be used to rename a column in a table.
|
||||
func (b *MysqlBuilder) RenameColumn(table, oldName, newName string) *Query {
|
||||
qt := b.db.QuoteTableName(table)
|
||||
sql := fmt.Sprintf("ALTER TABLE %v CHANGE %v %v", qt, b.db.QuoteColumnName(oldName), b.db.QuoteColumnName(newName))
|
||||
|
||||
var info struct {
|
||||
SQL string `db:"Create Table"`
|
||||
}
|
||||
if err := b.db.NewQuery("SHOW CREATE TABLE " + qt).One(&info); err != nil {
|
||||
return b.db.NewQuery(sql)
|
||||
}
|
||||
|
||||
if matches := mysqlColumnRegexp.FindAllStringSubmatch(info.SQL, -1); matches != nil {
|
||||
for _, match := range matches {
|
||||
if match[1] == oldName {
|
||||
sql += " " + match[2]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return b.db.NewQuery(sql)
|
||||
}
|
||||
|
||||
// DropPrimaryKey creates a Query that can be used to remove the named primary key constraint from a table.
|
||||
func (b *MysqlBuilder) DropPrimaryKey(table, name string) *Query {
|
||||
sql := fmt.Sprintf("ALTER TABLE %v DROP PRIMARY KEY", b.db.QuoteTableName(table))
|
||||
return b.db.NewQuery(sql)
|
||||
}
|
||||
|
||||
// DropForeignKey creates a Query that can be used to remove the named foreign key constraint from a table.
|
||||
func (b *MysqlBuilder) DropForeignKey(table, name string) *Query {
|
||||
sql := fmt.Sprintf("ALTER TABLE %v DROP FOREIGN KEY %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(name))
|
||||
return b.db.NewQuery(sql)
|
||||
}
|
69
builder_mysql_test.go
Normal file
69
builder_mysql_test.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMysqlBuilder_QuoteSimpleTableName(t *testing.T) {
|
||||
b := getMysqlBuilder()
|
||||
assert.Equal(t, b.QuoteSimpleTableName(`abc`), "`abc`", "t1")
|
||||
assert.Equal(t, b.QuoteSimpleTableName("`abc`"), "`abc`", "t2")
|
||||
assert.Equal(t, b.QuoteSimpleTableName(`{{abc}}`), "`{{abc}}`", "t3")
|
||||
assert.Equal(t, b.QuoteSimpleTableName(`a.bc`), "`a.bc`", "t4")
|
||||
}
|
||||
|
||||
func TestMysqlBuilder_QuoteSimpleColumnName(t *testing.T) {
|
||||
b := getMysqlBuilder()
|
||||
assert.Equal(t, b.QuoteSimpleColumnName(`abc`), "`abc`", "t1")
|
||||
assert.Equal(t, b.QuoteSimpleColumnName("`abc`"), "`abc`", "t2")
|
||||
assert.Equal(t, b.QuoteSimpleColumnName(`{{abc}}`), "`{{abc}}`", "t3")
|
||||
assert.Equal(t, b.QuoteSimpleColumnName(`a.bc`), "`a.bc`", "t4")
|
||||
assert.Equal(t, b.QuoteSimpleColumnName(`*`), `*`, "t5")
|
||||
}
|
||||
|
||||
func TestMysqlBuilder_Upsert(t *testing.T) {
|
||||
getPreparedDB()
|
||||
b := getMysqlBuilder()
|
||||
q := b.Upsert("users", Params{
|
||||
"name": "James",
|
||||
"age": 30,
|
||||
})
|
||||
assert.Equal(t, q.SQL(), "INSERT INTO `users` (`age`, `name`) VALUES ({:p0}, {:p1}) ON DUPLICATE KEY UPDATE `age`={:p2}, `name`={:p3}", "t1")
|
||||
assert.Equal(t, q.Params()["p0"], 30, "t2")
|
||||
assert.Equal(t, q.Params()["p1"], "James", "t3")
|
||||
assert.Equal(t, q.Params()["p2"], 30, "t2")
|
||||
assert.Equal(t, q.Params()["p3"], "James", "t3")
|
||||
}
|
||||
|
||||
func TestMysqlBuilder_RenameColumn(t *testing.T) {
|
||||
b := getMysqlBuilder()
|
||||
q := b.RenameColumn("users", "name", "username")
|
||||
assert.Equal(t, q.SQL(), "ALTER TABLE `users` CHANGE `name` `username`")
|
||||
q = b.RenameColumn("customer", "email", "e")
|
||||
assert.Equal(t, q.SQL(), "ALTER TABLE `customer` CHANGE `email` `e` varchar(128) NOT NULL")
|
||||
}
|
||||
|
||||
func TestMysqlBuilder_DropPrimaryKey(t *testing.T) {
|
||||
b := getMysqlBuilder()
|
||||
q := b.DropPrimaryKey("users", "pk")
|
||||
assert.Equal(t, q.SQL(), "ALTER TABLE `users` DROP PRIMARY KEY", "t1")
|
||||
}
|
||||
|
||||
func TestMysqlBuilder_DropForeignKey(t *testing.T) {
|
||||
b := getMysqlBuilder()
|
||||
q := b.DropForeignKey("users", "fk")
|
||||
assert.Equal(t, q.SQL(), "ALTER TABLE `users` DROP FOREIGN KEY `fk`", "t1")
|
||||
}
|
||||
|
||||
func getMysqlBuilder() Builder {
|
||||
db := getDB()
|
||||
b := NewMysqlBuilder(db, db.sqlDB)
|
||||
db.Builder = b
|
||||
return b
|
||||
}
|
98
builder_oci.go
Normal file
98
builder_oci.go
Normal file
|
@ -0,0 +1,98 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// OciBuilder is the builder for Oracle databases.
|
||||
type OciBuilder struct {
|
||||
*BaseBuilder
|
||||
qb *OciQueryBuilder
|
||||
}
|
||||
|
||||
var _ Builder = &OciBuilder{}
|
||||
|
||||
// OciQueryBuilder is the query builder for Oracle databases.
|
||||
type OciQueryBuilder struct {
|
||||
*BaseQueryBuilder
|
||||
}
|
||||
|
||||
// NewOciBuilder creates a new OciBuilder instance.
|
||||
func NewOciBuilder(db *DB, executor Executor) Builder {
|
||||
return &OciBuilder{
|
||||
NewBaseBuilder(db, executor),
|
||||
&OciQueryBuilder{NewBaseQueryBuilder(db)},
|
||||
}
|
||||
}
|
||||
|
||||
// Select returns a new SelectQuery object that can be used to build a SELECT statement.
|
||||
// The parameters to this method should be the list column names to be selected.
|
||||
// A column name may have an optional alias name. For example, Select("id", "my_name AS name").
|
||||
func (b *OciBuilder) Select(cols ...string) *SelectQuery {
|
||||
return NewSelectQuery(b, b.db).Select(cols...)
|
||||
}
|
||||
|
||||
// Model returns a new ModelQuery object that can be used to perform model-based DB operations.
|
||||
// The model passed to this method should be a pointer to a model struct.
|
||||
func (b *OciBuilder) Model(model interface{}) *ModelQuery {
|
||||
return NewModelQuery(model, b.db.FieldMapper, b.db, b)
|
||||
}
|
||||
|
||||
// GeneratePlaceholder generates an anonymous parameter placeholder with the given parameter ID.
|
||||
func (b *OciBuilder) GeneratePlaceholder(i int) string {
|
||||
return fmt.Sprintf(":p%v", i)
|
||||
}
|
||||
|
||||
// QueryBuilder returns the query builder supporting the current DB.
|
||||
func (b *OciBuilder) QueryBuilder() QueryBuilder {
|
||||
return b.qb
|
||||
}
|
||||
|
||||
// DropIndex creates a Query that can be used to remove the named index from a table.
|
||||
func (b *OciBuilder) DropIndex(table, name string) *Query {
|
||||
sql := fmt.Sprintf("DROP INDEX %v", b.db.QuoteColumnName(name))
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// RenameTable creates a Query that can be used to rename a table.
|
||||
func (b *OciBuilder) RenameTable(oldName, newName string) *Query {
|
||||
sql := fmt.Sprintf("ALTER TABLE %v RENAME TO %v", b.db.QuoteTableName(oldName), b.db.QuoteTableName(newName))
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// AlterColumn creates a Query that can be used to change the definition of a table column.
|
||||
func (b *OciBuilder) AlterColumn(table, col, typ string) *Query {
|
||||
col = b.db.QuoteColumnName(col)
|
||||
sql := fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", b.db.QuoteTableName(table), col, typ)
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// BuildOrderByAndLimit generates the ORDER BY and LIMIT clauses.
|
||||
func (q *OciQueryBuilder) BuildOrderByAndLimit(sql string, cols []string, limit int64, offset int64) string {
|
||||
if orderBy := q.BuildOrderBy(cols); orderBy != "" {
|
||||
sql += "\n" + orderBy
|
||||
}
|
||||
|
||||
c := ""
|
||||
if offset > 0 {
|
||||
c = fmt.Sprintf("rowNumId > %v", offset)
|
||||
}
|
||||
if limit >= 0 {
|
||||
if c != "" {
|
||||
c += " AND "
|
||||
}
|
||||
c += fmt.Sprintf("rowNum <= %v", limit)
|
||||
}
|
||||
|
||||
if c == "" {
|
||||
return sql
|
||||
}
|
||||
|
||||
return `WITH USER_SQL AS (` + sql + `),
|
||||
PAGINATION AS (SELECT USER_SQL.*, rownum as rowNumId FROM USER_SQL)
|
||||
SELECT * FROM PAGINATION WHERE ` + c
|
||||
}
|
56
builder_oci_test.go
Normal file
56
builder_oci_test.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestOciBuilder_DropIndex(t *testing.T) {
|
||||
b := getOciBuilder()
|
||||
q := b.DropIndex("users", "idx")
|
||||
assert.Equal(t, q.SQL(), `DROP INDEX "idx"`, "t1")
|
||||
}
|
||||
|
||||
func TestOciBuilder_RenameTable(t *testing.T) {
|
||||
b := getOciBuilder()
|
||||
q := b.RenameTable("users", "user")
|
||||
assert.Equal(t, q.SQL(), `ALTER TABLE "users" RENAME TO "user"`, "t1")
|
||||
}
|
||||
|
||||
func TestOciBuilder_AlterColumn(t *testing.T) {
|
||||
b := getOciBuilder()
|
||||
q := b.AlterColumn("users", "name", "int")
|
||||
assert.Equal(t, q.SQL(), `ALTER TABLE "users" MODIFY "name" int`, "t1")
|
||||
}
|
||||
|
||||
func TestOciQueryBuilder_BuildOrderByAndLimit(t *testing.T) {
|
||||
qb := getOciBuilder().QueryBuilder()
|
||||
|
||||
sql := qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, 10, 2)
|
||||
expected := "WITH USER_SQL AS (SELECT *\nORDER BY \"name\"),\n\tPAGINATION AS (SELECT USER_SQL.*, rownum as rowNumId FROM USER_SQL)\nSELECT * FROM PAGINATION WHERE rowNumId > 2 AND rowNum <= 10"
|
||||
assert.Equal(t, sql, expected, "t1")
|
||||
|
||||
sql = qb.BuildOrderByAndLimit("SELECT *", nil, -1, -1)
|
||||
expected = "SELECT *"
|
||||
assert.Equal(t, sql, expected, "t2")
|
||||
|
||||
sql = qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, -1, -1)
|
||||
expected = "SELECT *\nORDER BY \"name\""
|
||||
assert.Equal(t, sql, expected, "t3")
|
||||
|
||||
sql = qb.BuildOrderByAndLimit("SELECT *", nil, 10, -1)
|
||||
expected = "WITH USER_SQL AS (SELECT *),\n\tPAGINATION AS (SELECT USER_SQL.*, rownum as rowNumId FROM USER_SQL)\nSELECT * FROM PAGINATION WHERE rowNum <= 10"
|
||||
assert.Equal(t, sql, expected, "t4")
|
||||
}
|
||||
|
||||
func getOciBuilder() Builder {
|
||||
db := getDB()
|
||||
b := NewOciBuilder(db, db.sqlDB)
|
||||
db.Builder = b
|
||||
return b
|
||||
}
|
105
builder_pgsql.go
Normal file
105
builder_pgsql.go
Normal file
|
@ -0,0 +1,105 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PgsqlBuilder is the builder for PostgreSQL databases.
|
||||
type PgsqlBuilder struct {
|
||||
*BaseBuilder
|
||||
qb *BaseQueryBuilder
|
||||
}
|
||||
|
||||
var _ Builder = &PgsqlBuilder{}
|
||||
|
||||
// NewPgsqlBuilder creates a new PgsqlBuilder instance.
|
||||
func NewPgsqlBuilder(db *DB, executor Executor) Builder {
|
||||
return &PgsqlBuilder{
|
||||
NewBaseBuilder(db, executor),
|
||||
NewBaseQueryBuilder(db),
|
||||
}
|
||||
}
|
||||
|
||||
// Select returns a new SelectQuery object that can be used to build a SELECT statement.
|
||||
// The parameters to this method should be the list column names to be selected.
|
||||
// A column name may have an optional alias name. For example, Select("id", "my_name AS name").
|
||||
func (b *PgsqlBuilder) Select(cols ...string) *SelectQuery {
|
||||
return NewSelectQuery(b, b.db).Select(cols...)
|
||||
}
|
||||
|
||||
// Model returns a new ModelQuery object that can be used to perform model-based DB operations.
|
||||
// The model passed to this method should be a pointer to a model struct.
|
||||
func (b *PgsqlBuilder) Model(model interface{}) *ModelQuery {
|
||||
return NewModelQuery(model, b.db.FieldMapper, b.db, b)
|
||||
}
|
||||
|
||||
// GeneratePlaceholder generates an anonymous parameter placeholder with the given parameter ID.
|
||||
func (b *PgsqlBuilder) GeneratePlaceholder(i int) string {
|
||||
return fmt.Sprintf("$%v", i)
|
||||
}
|
||||
|
||||
// QueryBuilder returns the query builder supporting the current DB.
|
||||
func (b *PgsqlBuilder) QueryBuilder() QueryBuilder {
|
||||
return b.qb
|
||||
}
|
||||
|
||||
// Upsert creates a Query that represents an UPSERT SQL statement.
|
||||
// Upsert inserts a row into the table if the primary key or unique index is not found.
|
||||
// Otherwise it will update the row with the new values.
|
||||
// The keys of cols are the column names, while the values of cols are the corresponding column
|
||||
// values to be inserted.
|
||||
func (b *PgsqlBuilder) Upsert(table string, cols Params, constraints ...string) *Query {
|
||||
q := b.Insert(table, cols)
|
||||
|
||||
names := []string{}
|
||||
for name := range cols {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
|
||||
lines := []string{}
|
||||
for _, name := range names {
|
||||
value := cols[name]
|
||||
name = b.db.QuoteColumnName(name)
|
||||
if e, ok := value.(Expression); ok {
|
||||
lines = append(lines, name+"="+e.Build(b.db, q.params))
|
||||
} else {
|
||||
lines = append(lines, fmt.Sprintf("%v={:p%v}", name, len(q.params)))
|
||||
q.params[fmt.Sprintf("p%v", len(q.params))] = value
|
||||
}
|
||||
}
|
||||
|
||||
if len(constraints) > 0 {
|
||||
c := b.quoteColumns(constraints)
|
||||
q.sql += " ON CONFLICT (" + c + ") DO UPDATE SET " + strings.Join(lines, ", ")
|
||||
} else {
|
||||
q.sql += " ON CONFLICT DO UPDATE SET " + strings.Join(lines, ", ")
|
||||
}
|
||||
|
||||
return b.NewQuery(q.sql).Bind(q.params)
|
||||
}
|
||||
|
||||
// DropIndex creates a Query that can be used to remove the named index from a table.
|
||||
func (b *PgsqlBuilder) DropIndex(table, name string) *Query {
|
||||
sql := fmt.Sprintf("DROP INDEX %v", b.db.QuoteColumnName(name))
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// RenameTable creates a Query that can be used to rename a table.
|
||||
func (b *PgsqlBuilder) RenameTable(oldName, newName string) *Query {
|
||||
sql := fmt.Sprintf("ALTER TABLE %v RENAME TO %v", b.db.QuoteTableName(oldName), b.db.QuoteTableName(newName))
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// AlterColumn creates a Query that can be used to change the definition of a table column.
|
||||
func (b *PgsqlBuilder) AlterColumn(table, col, typ string) *Query {
|
||||
col = b.db.QuoteColumnName(col)
|
||||
sql := fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", b.db.QuoteTableName(table), col, typ)
|
||||
return b.NewQuery(sql)
|
||||
}
|
49
builder_pgsql_test.go
Normal file
49
builder_pgsql_test.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPgsqlBuilder_Upsert(t *testing.T) {
|
||||
b := getPgsqlBuilder()
|
||||
q := b.Upsert("users", Params{
|
||||
"name": "James",
|
||||
"age": 30,
|
||||
}, "id")
|
||||
assert.Equal(t, q.sql, `INSERT INTO "users" ("age", "name") VALUES ({:p0}, {:p1}) ON CONFLICT ("id") DO UPDATE SET "age"={:p2}, "name"={:p3}`, "t1")
|
||||
assert.Equal(t, q.rawSQL, `INSERT INTO "users" ("age", "name") VALUES ($1, $2) ON CONFLICT ("id") DO UPDATE SET "age"=$3, "name"=$4`, "t2")
|
||||
assert.Equal(t, q.Params()["p0"], 30, "t3")
|
||||
assert.Equal(t, q.Params()["p1"], "James", "t4")
|
||||
assert.Equal(t, q.Params()["p2"], 30, "t5")
|
||||
assert.Equal(t, q.Params()["p3"], "James", "t6")
|
||||
}
|
||||
func TestPgsqlBuilder_DropIndex(t *testing.T) {
|
||||
b := getPgsqlBuilder()
|
||||
q := b.DropIndex("users", "idx")
|
||||
assert.Equal(t, q.SQL(), `DROP INDEX "idx"`, "t1")
|
||||
}
|
||||
|
||||
func TestPgsqlBuilder_RenameTable(t *testing.T) {
|
||||
b := getPgsqlBuilder()
|
||||
q := b.RenameTable("users", "user")
|
||||
assert.Equal(t, q.SQL(), `ALTER TABLE "users" RENAME TO "user"`, "t1")
|
||||
}
|
||||
|
||||
func TestPgsqlBuilder_AlterColumn(t *testing.T) {
|
||||
b := getPgsqlBuilder()
|
||||
q := b.AlterColumn("users", "name", "int")
|
||||
assert.Equal(t, q.SQL(), `ALTER TABLE "users" ALTER COLUMN "name" TYPE int`, "t1")
|
||||
}
|
||||
|
||||
func getPgsqlBuilder() Builder {
|
||||
db := getDB()
|
||||
b := NewPgsqlBuilder(db, db.sqlDB)
|
||||
db.Builder = b
|
||||
return b
|
||||
}
|
120
builder_sqlite.go
Normal file
120
builder_sqlite.go
Normal file
|
@ -0,0 +1,120 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SqliteBuilder is the builder for SQLite databases.
|
||||
type SqliteBuilder struct {
|
||||
*BaseBuilder
|
||||
qb *BaseQueryBuilder
|
||||
}
|
||||
|
||||
var _ Builder = &SqliteBuilder{}
|
||||
|
||||
// NewSqliteBuilder creates a new SqliteBuilder instance.
|
||||
func NewSqliteBuilder(db *DB, executor Executor) Builder {
|
||||
return &SqliteBuilder{
|
||||
NewBaseBuilder(db, executor),
|
||||
NewBaseQueryBuilder(db),
|
||||
}
|
||||
}
|
||||
|
||||
// QueryBuilder returns the query builder supporting the current DB.
|
||||
func (b *SqliteBuilder) QueryBuilder() QueryBuilder {
|
||||
return b.qb
|
||||
}
|
||||
|
||||
// Select returns a new SelectQuery object that can be used to build a SELECT statement.
|
||||
// The parameters to this method should be the list column names to be selected.
|
||||
// A column name may have an optional alias name. For example, Select("id", "my_name AS name").
|
||||
func (b *SqliteBuilder) Select(cols ...string) *SelectQuery {
|
||||
return NewSelectQuery(b, b.db).Select(cols...)
|
||||
}
|
||||
|
||||
// Model returns a new ModelQuery object that can be used to perform model-based DB operations.
|
||||
// The model passed to this method should be a pointer to a model struct.
|
||||
func (b *SqliteBuilder) Model(model interface{}) *ModelQuery {
|
||||
return NewModelQuery(model, b.db.FieldMapper, b.db, b)
|
||||
}
|
||||
|
||||
// QuoteSimpleTableName quotes a simple table name.
|
||||
// A simple table name does not contain any schema prefix.
|
||||
func (b *SqliteBuilder) QuoteSimpleTableName(s string) string {
|
||||
if strings.ContainsAny(s, "`") {
|
||||
return s
|
||||
}
|
||||
return "`" + s + "`"
|
||||
}
|
||||
|
||||
// QuoteSimpleColumnName quotes a simple column name.
|
||||
// A simple column name does not contain any table prefix.
|
||||
func (b *SqliteBuilder) QuoteSimpleColumnName(s string) string {
|
||||
if strings.Contains(s, "`") || s == "*" {
|
||||
return s
|
||||
}
|
||||
return "`" + s + "`"
|
||||
}
|
||||
|
||||
// DropIndex creates a Query that can be used to remove the named index from a table.
|
||||
func (b *SqliteBuilder) DropIndex(table, name string) *Query {
|
||||
sql := fmt.Sprintf("DROP INDEX %v", b.db.QuoteColumnName(name))
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// TruncateTable creates a Query that can be used to truncate a table.
|
||||
func (b *SqliteBuilder) TruncateTable(table string) *Query {
|
||||
sql := "DELETE FROM " + b.db.QuoteTableName(table)
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// RenameTable creates a Query that can be used to rename a table.
|
||||
func (b *SqliteBuilder) RenameTable(oldName, newName string) *Query {
|
||||
sql := fmt.Sprintf("ALTER TABLE %v RENAME TO %v", b.db.QuoteTableName(oldName), b.db.QuoteTableName(newName))
|
||||
return b.NewQuery(sql)
|
||||
}
|
||||
|
||||
// AlterColumn creates a Query that can be used to change the definition of a table column.
|
||||
func (b *SqliteBuilder) AlterColumn(table, col, typ string) *Query {
|
||||
q := b.NewQuery("")
|
||||
q.LastError = errors.New("SQLite does not support altering column")
|
||||
return q
|
||||
}
|
||||
|
||||
// AddPrimaryKey creates a Query that can be used to specify primary key(s) for a table.
|
||||
// The "name" parameter specifies the name of the primary key constraint.
|
||||
func (b *SqliteBuilder) AddPrimaryKey(table, name string, cols ...string) *Query {
|
||||
q := b.NewQuery("")
|
||||
q.LastError = errors.New("SQLite does not support adding primary key")
|
||||
return q
|
||||
}
|
||||
|
||||
// DropPrimaryKey creates a Query that can be used to remove the named primary key constraint from a table.
|
||||
func (b *SqliteBuilder) DropPrimaryKey(table, name string) *Query {
|
||||
q := b.NewQuery("")
|
||||
q.LastError = errors.New("SQLite does not support dropping primary key")
|
||||
return q
|
||||
}
|
||||
|
||||
// AddForeignKey creates a Query that can be used to add a foreign key constraint to a table.
|
||||
// The length of cols and refCols must be the same as they refer to the primary and referential columns.
|
||||
// The optional "options" parameters will be appended to the SQL statement. They can be used to
|
||||
// specify options such as "ON DELETE CASCADE".
|
||||
func (b *SqliteBuilder) AddForeignKey(table, name string, cols, refCols []string, refTable string, options ...string) *Query {
|
||||
q := b.NewQuery("")
|
||||
q.LastError = errors.New("SQLite does not support adding foreign keys")
|
||||
return q
|
||||
}
|
||||
|
||||
// DropForeignKey creates a Query that can be used to remove the named foreign key constraint from a table.
|
||||
func (b *SqliteBuilder) DropForeignKey(table, name string) *Query {
|
||||
q := b.NewQuery("")
|
||||
q.LastError = errors.New("SQLite does not support dropping foreign keys")
|
||||
return q
|
||||
}
|
83
builder_sqlite_test.go
Normal file
83
builder_sqlite_test.go
Normal file
|
@ -0,0 +1,83 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSqliteBuilder_QuoteSimpleTableName(t *testing.T) {
|
||||
b := getSqliteBuilder()
|
||||
assert.Equal(t, b.QuoteSimpleTableName(`abc`), "`abc`", "t1")
|
||||
assert.Equal(t, b.QuoteSimpleTableName("`abc`"), "`abc`", "t2")
|
||||
assert.Equal(t, b.QuoteSimpleTableName(`{{abc}}`), "`{{abc}}`", "t3")
|
||||
assert.Equal(t, b.QuoteSimpleTableName(`a.bc`), "`a.bc`", "t4")
|
||||
}
|
||||
|
||||
func TestSqliteBuilder_QuoteSimpleColumnName(t *testing.T) {
|
||||
b := getSqliteBuilder()
|
||||
assert.Equal(t, b.QuoteSimpleColumnName(`abc`), "`abc`", "t1")
|
||||
assert.Equal(t, b.QuoteSimpleColumnName("`abc`"), "`abc`", "t2")
|
||||
assert.Equal(t, b.QuoteSimpleColumnName(`{{abc}}`), "`{{abc}}`", "t3")
|
||||
assert.Equal(t, b.QuoteSimpleColumnName(`a.bc`), "`a.bc`", "t4")
|
||||
assert.Equal(t, b.QuoteSimpleColumnName(`*`), `*`, "t5")
|
||||
}
|
||||
|
||||
func TestSqliteBuilder_DropIndex(t *testing.T) {
|
||||
b := getSqliteBuilder()
|
||||
q := b.DropIndex("users", "idx")
|
||||
assert.Equal(t, q.SQL(), "DROP INDEX `idx`", "t1")
|
||||
}
|
||||
|
||||
func TestSqliteBuilder_TruncateTable(t *testing.T) {
|
||||
b := getSqliteBuilder()
|
||||
q := b.TruncateTable("users")
|
||||
assert.Equal(t, q.SQL(), "DELETE FROM `users`", "t1")
|
||||
}
|
||||
|
||||
func TestSqliteBuilder_RenameTable(t *testing.T) {
|
||||
b := getSqliteBuilder()
|
||||
q := b.RenameTable("usersOld", "usersNew")
|
||||
assert.Equal(t, q.SQL(), "ALTER TABLE `usersOld` RENAME TO `usersNew`", "t1")
|
||||
}
|
||||
|
||||
func TestSqliteBuilder_AlterColumn(t *testing.T) {
|
||||
b := getSqliteBuilder()
|
||||
q := b.AlterColumn("users", "name", "int")
|
||||
assert.NotEqual(t, q.LastError, nil, "t1")
|
||||
}
|
||||
|
||||
func TestSqliteBuilder_AddPrimaryKey(t *testing.T) {
|
||||
b := getSqliteBuilder()
|
||||
q := b.AddPrimaryKey("users", "pk", "id1", "id2")
|
||||
assert.NotEqual(t, q.LastError, nil, "t1")
|
||||
}
|
||||
|
||||
func TestSqliteBuilder_DropPrimaryKey(t *testing.T) {
|
||||
b := getSqliteBuilder()
|
||||
q := b.DropPrimaryKey("users", "pk")
|
||||
assert.NotEqual(t, q.LastError, nil, "t1")
|
||||
}
|
||||
|
||||
func TestSqliteBuilder_AddForeignKey(t *testing.T) {
|
||||
b := getSqliteBuilder()
|
||||
q := b.AddForeignKey("users", "fk", []string{"p1", "p2"}, []string{"f1", "f2"}, "profile", "opt")
|
||||
assert.NotEqual(t, q.LastError, nil, "t1")
|
||||
}
|
||||
|
||||
func TestSqliteBuilder_DropForeignKey(t *testing.T) {
|
||||
b := getSqliteBuilder()
|
||||
q := b.DropForeignKey("users", "fk")
|
||||
assert.NotEqual(t, q.LastError, nil, "t1")
|
||||
}
|
||||
|
||||
func getSqliteBuilder() Builder {
|
||||
db := getDB()
|
||||
b := NewSqliteBuilder(db, db.sqlDB)
|
||||
db.Builder = b
|
||||
return b
|
||||
}
|
39
builder_standard.go
Normal file
39
builder_standard.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
// StandardBuilder is the builder that is used by DB for an unknown driver.
|
||||
type StandardBuilder struct {
|
||||
*BaseBuilder
|
||||
qb *BaseQueryBuilder
|
||||
}
|
||||
|
||||
var _ Builder = &StandardBuilder{}
|
||||
|
||||
// NewStandardBuilder creates a new StandardBuilder instance.
|
||||
func NewStandardBuilder(db *DB, executor Executor) Builder {
|
||||
return &StandardBuilder{
|
||||
NewBaseBuilder(db, executor),
|
||||
NewBaseQueryBuilder(db),
|
||||
}
|
||||
}
|
||||
|
||||
// QueryBuilder returns the query builder supporting the current DB.
|
||||
func (b *StandardBuilder) QueryBuilder() QueryBuilder {
|
||||
return b.qb
|
||||
}
|
||||
|
||||
// Select returns a new SelectQuery object that can be used to build a SELECT statement.
|
||||
// The parameters to this method should be the list column names to be selected.
|
||||
// A column name may have an optional alias name. For example, Select("id", "my_name AS name").
|
||||
func (b *StandardBuilder) Select(cols ...string) *SelectQuery {
|
||||
return NewSelectQuery(b, b.db).Select(cols...)
|
||||
}
|
||||
|
||||
// Model returns a new ModelQuery object that can be used to perform model-based DB operations.
|
||||
// The model passed to this method should be a pointer to a model struct.
|
||||
func (b *StandardBuilder) Model(model interface{}) *ModelQuery {
|
||||
return NewModelQuery(model, b.db.FieldMapper, b.db, b)
|
||||
}
|
183
builder_standard_test.go
Normal file
183
builder_standard_test.go
Normal file
|
@ -0,0 +1,183 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStandardBuilder_Quote(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
assert.Equal(t, b.Quote(`abc`), `'abc'`, "t1")
|
||||
assert.Equal(t, b.Quote(`I'm`), `'I''m'`, "t2")
|
||||
assert.Equal(t, b.Quote(``), `''`, "t3")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_QuoteSimpleTableName(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
assert.Equal(t, b.QuoteSimpleTableName(`abc`), `"abc"`, "t1")
|
||||
assert.Equal(t, b.QuoteSimpleTableName(`"abc"`), `"abc"`, "t2")
|
||||
assert.Equal(t, b.QuoteSimpleTableName(`{{abc}}`), `"{{abc}}"`, "t3")
|
||||
assert.Equal(t, b.QuoteSimpleTableName(`a.bc`), `"a.bc"`, "t4")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_QuoteSimpleColumnName(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
assert.Equal(t, b.QuoteSimpleColumnName(`abc`), `"abc"`, "t1")
|
||||
assert.Equal(t, b.QuoteSimpleColumnName(`"abc"`), `"abc"`, "t2")
|
||||
assert.Equal(t, b.QuoteSimpleColumnName(`{{abc}}`), `"{{abc}}"`, "t3")
|
||||
assert.Equal(t, b.QuoteSimpleColumnName(`a.bc`), `"a.bc"`, "t4")
|
||||
assert.Equal(t, b.QuoteSimpleColumnName(`*`), `*`, "t5")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_Insert(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
q := b.Insert("users", Params{
|
||||
"name": "James",
|
||||
"age": 30,
|
||||
})
|
||||
assert.Equal(t, q.SQL(), `INSERT INTO "users" ("age", "name") VALUES ({:p0}, {:p1})`, "t1")
|
||||
assert.Equal(t, q.Params()["p0"], 30, "t2")
|
||||
assert.Equal(t, q.Params()["p1"], "James", "t3")
|
||||
|
||||
q = b.Insert("users", Params{})
|
||||
assert.Equal(t, q.SQL(), `INSERT INTO "users" DEFAULT VALUES`, "t2")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_Upsert(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
q := b.Upsert("users", Params{
|
||||
"name": "James",
|
||||
"age": 30,
|
||||
})
|
||||
assert.NotEqual(t, q.LastError, nil, "t1")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_Update(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
q := b.Update("users", Params{
|
||||
"name": "James",
|
||||
"age": 30,
|
||||
}, NewExp("id=10"))
|
||||
assert.Equal(t, q.SQL(), `UPDATE "users" SET "age"={:p0}, "name"={:p1} WHERE id=10`, "t1")
|
||||
assert.Equal(t, q.Params()["p0"], 30, "t2")
|
||||
assert.Equal(t, q.Params()["p1"], "James", "t3")
|
||||
|
||||
q = b.Update("users", Params{
|
||||
"name": "James",
|
||||
"age": 30,
|
||||
}, nil)
|
||||
assert.Equal(t, q.SQL(), `UPDATE "users" SET "age"={:p0}, "name"={:p1}`, "t2")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_Delete(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
q := b.Delete("users", NewExp("id=10"))
|
||||
assert.Equal(t, q.SQL(), `DELETE FROM "users" WHERE id=10`, "t1")
|
||||
q = b.Delete("users", nil)
|
||||
assert.Equal(t, q.SQL(), `DELETE FROM "users"`, "t2")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_CreateTable(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
q := b.CreateTable("users", map[string]string{
|
||||
"id": "int primary key",
|
||||
"name": "varchar(255)",
|
||||
}, "ON DELETE CASCADE")
|
||||
assert.Equal(t, q.SQL(), "CREATE TABLE \"users\" (\"id\" int primary key, \"name\" varchar(255)) ON DELETE CASCADE", "t1")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_RenameTable(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
q := b.RenameTable("users", "user")
|
||||
assert.Equal(t, q.SQL(), `RENAME TABLE "users" TO "user"`, "t1")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_DropTable(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
q := b.DropTable("users")
|
||||
assert.Equal(t, q.SQL(), `DROP TABLE "users"`, "t1")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_TruncateTable(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
q := b.TruncateTable("users")
|
||||
assert.Equal(t, q.SQL(), `TRUNCATE TABLE "users"`, "t1")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_AddColumn(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
q := b.AddColumn("users", "age", "int")
|
||||
assert.Equal(t, q.SQL(), `ALTER TABLE "users" ADD "age" int`, "t1")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_DropColumn(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
q := b.DropColumn("users", "age")
|
||||
assert.Equal(t, q.SQL(), `ALTER TABLE "users" DROP COLUMN "age"`, "t1")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_RenameColumn(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
q := b.RenameColumn("users", "name", "username")
|
||||
assert.Equal(t, q.SQL(), `ALTER TABLE "users" RENAME COLUMN "name" TO "username"`, "t1")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_AlterColumn(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
q := b.AlterColumn("users", "name", "int")
|
||||
assert.Equal(t, q.SQL(), `ALTER TABLE "users" CHANGE "name" "name" int`, "t1")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_AddPrimaryKey(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
q := b.AddPrimaryKey("users", "pk", "id1", "id2")
|
||||
assert.Equal(t, q.SQL(), `ALTER TABLE "users" ADD CONSTRAINT "pk" PRIMARY KEY ("id1", "id2")`, "t1")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_DropPrimaryKey(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
q := b.DropPrimaryKey("users", "pk")
|
||||
assert.Equal(t, q.SQL(), `ALTER TABLE "users" DROP CONSTRAINT "pk"`, "t1")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_AddForeignKey(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
q := b.AddForeignKey("users", "fk", []string{"p1", "p2"}, []string{"f1", "f2"}, "profile", "opt")
|
||||
assert.Equal(t, q.SQL(), `ALTER TABLE "users" ADD CONSTRAINT "fk" FOREIGN KEY ("p1", "p2") REFERENCES "profile" ("f1", "f2") opt`, "t1")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_DropForeignKey(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
q := b.DropForeignKey("users", "fk")
|
||||
assert.Equal(t, q.SQL(), `ALTER TABLE "users" DROP CONSTRAINT "fk"`, "t1")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_CreateIndex(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
q := b.CreateIndex("users", "idx", "id1", "id2")
|
||||
assert.Equal(t, q.SQL(), `CREATE INDEX "idx" ON "users" ("id1", "id2")`, "t1")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_CreateUniqueIndex(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
q := b.CreateUniqueIndex("users", "idx", "id1", "id2")
|
||||
assert.Equal(t, q.SQL(), `CREATE UNIQUE INDEX "idx" ON "users" ("id1", "id2")`, "t1")
|
||||
}
|
||||
|
||||
func TestStandardBuilder_DropIndex(t *testing.T) {
|
||||
b := getStandardBuilder()
|
||||
q := b.DropIndex("users", "idx")
|
||||
assert.Equal(t, q.SQL(), `DROP INDEX "idx" ON "users"`, "t1")
|
||||
}
|
||||
|
||||
func getStandardBuilder() Builder {
|
||||
db := getDB()
|
||||
b := NewStandardBuilder(db, db.sqlDB)
|
||||
db.Builder = b
|
||||
return b
|
||||
}
|
338
db.go
Normal file
338
db.go
Normal file
|
@ -0,0 +1,338 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package dbx provides a set of DB-agnostic and easy-to-use query building methods for relational databases.
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type (
|
||||
// LogFunc logs a message for each SQL statement being executed.
|
||||
// This method takes one or multiple parameters. If a single parameter
|
||||
// is provided, it will be treated as the log message. If multiple parameters
|
||||
// are provided, they will be passed to fmt.Sprintf() to generate the log message.
|
||||
LogFunc func(format string, a ...interface{})
|
||||
|
||||
// PerfFunc is called when a query finishes execution.
|
||||
// The query execution time is passed to this function so that the DB performance
|
||||
// can be profiled. The "ns" parameter gives the number of nanoseconds that the
|
||||
// SQL statement takes to execute, while the "execute" parameter indicates whether
|
||||
// the SQL statement is executed or queried (usually SELECT statements).
|
||||
PerfFunc func(ns int64, sql string, execute bool)
|
||||
|
||||
// QueryLogFunc is called each time when performing a SQL query.
|
||||
// The "t" parameter gives the time that the SQL statement takes to execute,
|
||||
// while rows and err are the result of the query.
|
||||
QueryLogFunc func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error)
|
||||
|
||||
// ExecLogFunc is called each time when a SQL statement is executed.
|
||||
// The "t" parameter gives the time that the SQL statement takes to execute,
|
||||
// while result and err refer to the result of the execution.
|
||||
ExecLogFunc func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error)
|
||||
|
||||
// BuilderFunc creates a Builder instance using the given DB instance and Executor.
|
||||
BuilderFunc func(*DB, Executor) Builder
|
||||
|
||||
// DB enhances sql.DB by providing a set of DB-agnostic query building methods.
|
||||
// DB allows easier query building and population of data into Go variables.
|
||||
DB struct {
|
||||
Builder
|
||||
|
||||
// FieldMapper maps struct fields to DB columns. Defaults to DefaultFieldMapFunc.
|
||||
FieldMapper FieldMapFunc
|
||||
// TableMapper maps structs to table names. Defaults to GetTableName.
|
||||
TableMapper TableMapFunc
|
||||
// LogFunc logs the SQL statements being executed. Defaults to nil, meaning no logging.
|
||||
LogFunc LogFunc
|
||||
// PerfFunc logs the SQL execution time. Defaults to nil, meaning no performance profiling.
|
||||
// Deprecated: Please use QueryLogFunc and ExecLogFunc instead.
|
||||
PerfFunc PerfFunc
|
||||
// QueryLogFunc is called each time when performing a SQL query that returns data.
|
||||
QueryLogFunc QueryLogFunc
|
||||
// ExecLogFunc is called each time when a SQL statement is executed.
|
||||
ExecLogFunc ExecLogFunc
|
||||
|
||||
sqlDB *sql.DB
|
||||
driverName string
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// Errors represents a list of errors.
|
||||
Errors []error
|
||||
)
|
||||
|
||||
// BuilderFuncMap lists supported BuilderFunc according to DB driver names.
|
||||
// You may modify this variable to add the builder support for a new DB driver.
|
||||
// If a DB driver is not listed here, the StandardBuilder will be used.
|
||||
var BuilderFuncMap = map[string]BuilderFunc{
|
||||
"sqlite": NewSqliteBuilder,
|
||||
"sqlite3": NewSqliteBuilder,
|
||||
"mysql": NewMysqlBuilder,
|
||||
"postgres": NewPgsqlBuilder,
|
||||
"pgx": NewPgsqlBuilder,
|
||||
"mssql": NewMssqlBuilder,
|
||||
"oci8": NewOciBuilder,
|
||||
}
|
||||
|
||||
// NewFromDB encapsulates an existing database connection.
|
||||
func NewFromDB(sqlDB *sql.DB, driverName string) *DB {
|
||||
db := &DB{
|
||||
driverName: driverName,
|
||||
sqlDB: sqlDB,
|
||||
FieldMapper: DefaultFieldMapFunc,
|
||||
TableMapper: GetTableName,
|
||||
}
|
||||
db.Builder = db.newBuilder(db.sqlDB)
|
||||
return db
|
||||
}
|
||||
|
||||
// Open opens a database specified by a driver name and data source name (DSN).
|
||||
// Note that Open does not check if DSN is specified correctly. It doesn't try to establish a DB connection either.
|
||||
// Please refer to sql.Open() for more information.
|
||||
func Open(driverName, dsn string) (*DB, error) {
|
||||
sqlDB, err := sql.Open(driverName, dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewFromDB(sqlDB, driverName), nil
|
||||
}
|
||||
|
||||
// MustOpen opens a database and establishes a connection to it.
|
||||
// Please refer to sql.Open() and sql.Ping() for more information.
|
||||
func MustOpen(driverName, dsn string) (*DB, error) {
|
||||
db, err := Open(driverName, dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := db.sqlDB.Ping(); err != nil {
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// Clone makes a shallow copy of DB.
|
||||
func (db *DB) Clone() *DB {
|
||||
db2 := &DB{
|
||||
driverName: db.driverName,
|
||||
sqlDB: db.sqlDB,
|
||||
FieldMapper: db.FieldMapper,
|
||||
TableMapper: db.TableMapper,
|
||||
PerfFunc: db.PerfFunc,
|
||||
LogFunc: db.LogFunc,
|
||||
QueryLogFunc: db.QueryLogFunc,
|
||||
ExecLogFunc: db.ExecLogFunc,
|
||||
}
|
||||
db2.Builder = db2.newBuilder(db.sqlDB)
|
||||
return db2
|
||||
}
|
||||
|
||||
// WithContext returns a new instance of DB associated with the given context.
|
||||
func (db *DB) WithContext(ctx context.Context) *DB {
|
||||
db2 := db.Clone()
|
||||
db2.ctx = ctx
|
||||
return db2
|
||||
}
|
||||
|
||||
// Context returns the context associated with the DB instance.
|
||||
// It returns nil if no context is associated.
|
||||
func (db *DB) Context() context.Context {
|
||||
return db.ctx
|
||||
}
|
||||
|
||||
// DB returns the sql.DB instance encapsulated by dbx.DB.
|
||||
func (db *DB) DB() *sql.DB {
|
||||
return db.sqlDB
|
||||
}
|
||||
|
||||
// Close closes the database, releasing any open resources.
|
||||
// It is rare to Close a DB, as the DB handle is meant to be
|
||||
// long-lived and shared between many goroutines.
|
||||
func (db *DB) Close() error {
|
||||
return db.sqlDB.Close()
|
||||
}
|
||||
|
||||
// Begin starts a transaction.
|
||||
func (db *DB) Begin() (*Tx, error) {
|
||||
var tx *sql.Tx
|
||||
var err error
|
||||
if db.ctx != nil {
|
||||
tx, err = db.sqlDB.BeginTx(db.ctx, nil)
|
||||
} else {
|
||||
tx, err = db.sqlDB.Begin()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Tx{db.newBuilder(tx), tx}, nil
|
||||
}
|
||||
|
||||
// BeginTx starts a transaction with the given context and transaction options.
|
||||
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
|
||||
tx, err := db.sqlDB.BeginTx(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Tx{db.newBuilder(tx), tx}, nil
|
||||
}
|
||||
|
||||
// Wrap encapsulates an existing transaction.
|
||||
func (db *DB) Wrap(sqlTx *sql.Tx) *Tx {
|
||||
return &Tx{db.newBuilder(sqlTx), sqlTx}
|
||||
}
|
||||
|
||||
// Transactional starts a transaction and executes the given function.
|
||||
// If the function returns an error, the transaction will be rolled back.
|
||||
// Otherwise, the transaction will be committed.
|
||||
func (db *DB) Transactional(f func(*Tx) error) (err error) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
tx.Rollback()
|
||||
panic(p)
|
||||
} else if err != nil {
|
||||
if err2 := tx.Rollback(); err2 != nil {
|
||||
if err2 == sql.ErrTxDone {
|
||||
return
|
||||
}
|
||||
err = Errors{err, err2}
|
||||
}
|
||||
} else {
|
||||
if err = tx.Commit(); err == sql.ErrTxDone {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err = f(tx)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// TransactionalContext starts a transaction and executes the given function with the given context and transaction options.
|
||||
// If the function returns an error, the transaction will be rolled back.
|
||||
// Otherwise, the transaction will be committed.
|
||||
func (db *DB) TransactionalContext(ctx context.Context, opts *sql.TxOptions, f func(*Tx) error) (err error) {
|
||||
tx, err := db.BeginTx(ctx, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
tx.Rollback()
|
||||
panic(p)
|
||||
} else if err != nil {
|
||||
if err2 := tx.Rollback(); err2 != nil {
|
||||
if err2 == sql.ErrTxDone {
|
||||
return
|
||||
}
|
||||
err = Errors{err, err2}
|
||||
}
|
||||
} else {
|
||||
if err = tx.Commit(); err == sql.ErrTxDone {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err = f(tx)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DriverName returns the name of the DB driver.
|
||||
func (db *DB) DriverName() string {
|
||||
return db.driverName
|
||||
}
|
||||
|
||||
// QuoteTableName quotes the given table name appropriately.
|
||||
// If the table name contains DB schema prefix, it will be handled accordingly.
|
||||
// This method will do nothing if the table name is already quoted or if it contains parenthesis.
|
||||
func (db *DB) QuoteTableName(s string) string {
|
||||
if strings.Contains(s, "(") || strings.Contains(s, "{{") {
|
||||
return s
|
||||
}
|
||||
if !strings.Contains(s, ".") {
|
||||
return db.QuoteSimpleTableName(s)
|
||||
}
|
||||
parts := strings.Split(s, ".")
|
||||
for i, part := range parts {
|
||||
parts[i] = db.QuoteSimpleTableName(part)
|
||||
}
|
||||
return strings.Join(parts, ".")
|
||||
}
|
||||
|
||||
// QuoteColumnName quotes the given column name appropriately.
|
||||
// If the table name contains table name prefix, it will be handled accordingly.
|
||||
// This method will do nothing if the column name is already quoted or if it contains parenthesis.
|
||||
func (db *DB) QuoteColumnName(s string) string {
|
||||
if strings.Contains(s, "(") || strings.Contains(s, "{{") || strings.Contains(s, "[[") {
|
||||
return s
|
||||
}
|
||||
prefix := ""
|
||||
if pos := strings.LastIndex(s, "."); pos != -1 {
|
||||
prefix = db.QuoteTableName(s[:pos]) + "."
|
||||
s = s[pos+1:]
|
||||
}
|
||||
return prefix + db.QuoteSimpleColumnName(s)
|
||||
}
|
||||
|
||||
var (
|
||||
plRegex = regexp.MustCompile(`\{:\w+\}`)
|
||||
quoteRegex = regexp.MustCompile(`(\{\{[\w\-\. ]+\}\}|\[\[[\w\-\. ]+\]\])`)
|
||||
)
|
||||
|
||||
// processSQL replaces the named param placeholders in the given SQL with anonymous ones.
|
||||
// It also quotes table names and column names found in the SQL if these names are enclosed
|
||||
// within double square/curly brackets. The method will return the updated SQL and the list of parameter names.
|
||||
func (db *DB) processSQL(s string) (string, []string) {
|
||||
var placeholders []string
|
||||
count := 0
|
||||
s = plRegex.ReplaceAllStringFunc(s, func(m string) string {
|
||||
count++
|
||||
placeholders = append(placeholders, m[2:len(m)-1])
|
||||
return db.GeneratePlaceholder(count)
|
||||
})
|
||||
s = quoteRegex.ReplaceAllStringFunc(s, func(m string) string {
|
||||
if m[0] == '{' {
|
||||
return db.QuoteTableName(m[2 : len(m)-2])
|
||||
}
|
||||
return db.QuoteColumnName(m[2 : len(m)-2])
|
||||
})
|
||||
return s, placeholders
|
||||
}
|
||||
|
||||
// newBuilder creates a query builder based on the current driver name.
|
||||
func (db *DB) newBuilder(executor Executor) Builder {
|
||||
builderFunc, ok := BuilderFuncMap[db.driverName]
|
||||
if !ok {
|
||||
builderFunc = NewStandardBuilder
|
||||
}
|
||||
return builderFunc(db, executor)
|
||||
}
|
||||
|
||||
// Error returns the error string of Errors.
|
||||
func (errs Errors) Error() string {
|
||||
var b bytes.Buffer
|
||||
for i, e := range errs {
|
||||
if i > 0 {
|
||||
b.WriteRune('\n')
|
||||
}
|
||||
b.WriteString(e.Error())
|
||||
}
|
||||
return b.String()
|
||||
}
|
380
db_test.go
Normal file
380
db_test.go
Normal file
|
@ -0,0 +1,380 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
// @todo change to sqlite
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
TestDSN = "travis:@/pocketbase_dbx_test?parseTime=true"
|
||||
FixtureFile = "testdata/mysql.sql"
|
||||
)
|
||||
|
||||
func TestDB_NewFromDB(t *testing.T) {
|
||||
sqlDB, err := sql.Open("mysql", TestDSN)
|
||||
if assert.Nil(t, err) {
|
||||
db := NewFromDB(sqlDB, "mysql")
|
||||
assert.NotNil(t, db.sqlDB)
|
||||
assert.NotNil(t, db.FieldMapper)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_Open(t *testing.T) {
|
||||
db, err := Open("mysql", TestDSN)
|
||||
assert.Nil(t, err)
|
||||
if assert.NotNil(t, db) {
|
||||
assert.NotNil(t, db.sqlDB)
|
||||
assert.NotNil(t, db.FieldMapper)
|
||||
db2 := db.Clone()
|
||||
assert.NotEqual(t, db, db2)
|
||||
assert.Equal(t, db.driverName, db2.driverName)
|
||||
ctx := context.Background()
|
||||
db3 := db.WithContext(ctx)
|
||||
assert.Equal(t, ctx, db3.ctx)
|
||||
assert.Equal(t, ctx, db3.Context())
|
||||
assert.NotEqual(t, db, db3)
|
||||
}
|
||||
|
||||
_, err = Open("xyz", TestDSN)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestDB_MustOpen(t *testing.T) {
|
||||
_, err := MustOpen("mysql", TestDSN)
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, err = MustOpen("mysql", "unknown:x@/test")
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestDB_Close(t *testing.T) {
|
||||
db := getDB()
|
||||
assert.Nil(t, db.Close())
|
||||
}
|
||||
|
||||
func TestDB_DriverName(t *testing.T) {
|
||||
db := getDB()
|
||||
assert.Equal(t, "mysql", db.DriverName())
|
||||
}
|
||||
|
||||
func TestDB_QuoteTableName(t *testing.T) {
|
||||
tests := []struct {
|
||||
input, output string
|
||||
}{
|
||||
{"users", "`users`"},
|
||||
{"`users`", "`users`"},
|
||||
{"(select)", "(select)"},
|
||||
{"{{users}}", "{{users}}"},
|
||||
{"public.db1.users", "`public`.`db1`.`users`"},
|
||||
}
|
||||
db := getDB()
|
||||
for _, test := range tests {
|
||||
result := db.QuoteTableName(test.input)
|
||||
assert.Equal(t, test.output, result, test.input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_QuoteColumnName(t *testing.T) {
|
||||
tests := []struct {
|
||||
input, output string
|
||||
}{
|
||||
{"*", "*"},
|
||||
{"users.*", "`users`.*"},
|
||||
{"name", "`name`"},
|
||||
{"`name`", "`name`"},
|
||||
{"(select)", "(select)"},
|
||||
{"{{name}}", "{{name}}"},
|
||||
{"[[name]]", "[[name]]"},
|
||||
{"public.db1.users", "`public`.`db1`.`users`"},
|
||||
}
|
||||
db := getDB()
|
||||
for _, test := range tests {
|
||||
result := db.QuoteColumnName(test.input)
|
||||
assert.Equal(t, test.output, result, test.input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_ProcessSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
tag string
|
||||
sql string // original SQL
|
||||
mysql string // expected MySQL version
|
||||
postgres string // expected PostgreSQL version
|
||||
oci8 string // expected OCI version
|
||||
params []string // expected params
|
||||
}{
|
||||
{
|
||||
"normal case",
|
||||
`INSERT INTO employee (id, name, age) VALUES ({:id}, {:name}, {:age})`,
|
||||
`INSERT INTO employee (id, name, age) VALUES (?, ?, ?)`,
|
||||
`INSERT INTO employee (id, name, age) VALUES ($1, $2, $3)`,
|
||||
`INSERT INTO employee (id, name, age) VALUES (:p1, :p2, :p3)`,
|
||||
[]string{"id", "name", "age"},
|
||||
},
|
||||
{
|
||||
"the same placeholder is used twice",
|
||||
`SELECT * FROM employee WHERE first_name LIKE {:keyword} OR last_name LIKE {:keyword}`,
|
||||
`SELECT * FROM employee WHERE first_name LIKE ? OR last_name LIKE ?`,
|
||||
`SELECT * FROM employee WHERE first_name LIKE $1 OR last_name LIKE $2`,
|
||||
`SELECT * FROM employee WHERE first_name LIKE :p1 OR last_name LIKE :p2`,
|
||||
[]string{"keyword", "keyword"},
|
||||
},
|
||||
{
|
||||
"non-matching placeholder",
|
||||
`SELECT * FROM employee WHERE first_name LIKE "{:key?word}" OR last_name LIKE {:keyword}`,
|
||||
`SELECT * FROM employee WHERE first_name LIKE "{:key?word}" OR last_name LIKE ?`,
|
||||
`SELECT * FROM employee WHERE first_name LIKE "{:key?word}" OR last_name LIKE $1`,
|
||||
`SELECT * FROM employee WHERE first_name LIKE "{:key?word}" OR last_name LIKE :p1`,
|
||||
[]string{"keyword"},
|
||||
},
|
||||
{
|
||||
"quote table/column names",
|
||||
`SELECT * FROM {{public.user}} WHERE [[user.id]]=1`,
|
||||
"SELECT * FROM `public`.`user` WHERE `user`.`id`=1",
|
||||
"SELECT * FROM \"public\".\"user\" WHERE \"user\".\"id\"=1",
|
||||
"SELECT * FROM \"public\".\"user\" WHERE \"user\".\"id\"=1",
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
mysqlDB := getDB()
|
||||
mysqlDB.Builder = NewMysqlBuilder(nil, nil)
|
||||
pgsqlDB := getDB()
|
||||
pgsqlDB.Builder = NewPgsqlBuilder(nil, nil)
|
||||
ociDB := getDB()
|
||||
ociDB.Builder = NewOciBuilder(nil, nil)
|
||||
|
||||
for _, test := range tests {
|
||||
s1, names := mysqlDB.processSQL(test.sql)
|
||||
assert.Equal(t, test.mysql, s1, test.tag)
|
||||
s2, _ := pgsqlDB.processSQL(test.sql)
|
||||
assert.Equal(t, test.postgres, s2, test.tag)
|
||||
s3, _ := ociDB.processSQL(test.sql)
|
||||
assert.Equal(t, test.oci8, s3, test.tag)
|
||||
|
||||
assert.Equal(t, test.params, names, test.tag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_Begin(t *testing.T) {
|
||||
tests := []struct {
|
||||
makeTx func(db *DB) *Tx
|
||||
desc string
|
||||
}{
|
||||
{
|
||||
makeTx: func(db *DB) *Tx {
|
||||
tx, _ := db.Begin()
|
||||
return tx
|
||||
},
|
||||
desc: "Begin",
|
||||
},
|
||||
{
|
||||
makeTx: func(db *DB) *Tx {
|
||||
sqlTx, _ := db.DB().Begin()
|
||||
return db.Wrap(sqlTx)
|
||||
},
|
||||
desc: "Wrap",
|
||||
},
|
||||
{
|
||||
makeTx: func(db *DB) *Tx {
|
||||
tx, _ := db.BeginTx(context.Background(), nil)
|
||||
return tx
|
||||
},
|
||||
desc: "BeginTx",
|
||||
},
|
||||
}
|
||||
|
||||
db := getPreparedDB()
|
||||
|
||||
var (
|
||||
lastID int
|
||||
name string
|
||||
tx *Tx
|
||||
)
|
||||
db.NewQuery("SELECT MAX(id) FROM item").Row(&lastID)
|
||||
|
||||
for _, test := range tests {
|
||||
t.Log(test.desc)
|
||||
|
||||
tx = test.makeTx(db)
|
||||
_, err1 := tx.Insert("item", Params{
|
||||
"name": "name1",
|
||||
}).Execute()
|
||||
_, err2 := tx.Insert("item", Params{
|
||||
"name": "name2",
|
||||
}).Execute()
|
||||
if err1 == nil && err2 == nil {
|
||||
tx.Commit()
|
||||
} else {
|
||||
t.Errorf("Unexpected TX rollback: %v, %v", err1, err2)
|
||||
tx.Rollback()
|
||||
}
|
||||
|
||||
q := db.NewQuery("SELECT name FROM item WHERE id={:id}")
|
||||
q.Bind(Params{"id": lastID + 1}).Row(&name)
|
||||
assert.Equal(t, "name1", name)
|
||||
q.Bind(Params{"id": lastID + 2}).Row(&name)
|
||||
assert.Equal(t, "name2", name)
|
||||
|
||||
tx = test.makeTx(db)
|
||||
_, err3 := tx.NewQuery("DELETE FROM item WHERE id=7").Execute()
|
||||
_, err4 := tx.NewQuery("DELETE FROM items WHERE id=7").Execute()
|
||||
if err3 == nil && err4 == nil {
|
||||
t.Error("Unexpected TX commit")
|
||||
tx.Commit()
|
||||
} else {
|
||||
tx.Rollback()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_Transactional(t *testing.T) {
|
||||
db := getPreparedDB()
|
||||
|
||||
var (
|
||||
lastID int
|
||||
name string
|
||||
)
|
||||
db.NewQuery("SELECT MAX(id) FROM item").Row(&lastID)
|
||||
|
||||
err := db.Transactional(func(tx *Tx) error {
|
||||
_, err := tx.Insert("item", Params{
|
||||
"name": "name1",
|
||||
}).Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Insert("item", Params{
|
||||
"name": "name2",
|
||||
}).Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if assert.Nil(t, err) {
|
||||
q := db.NewQuery("SELECT name FROM item WHERE id={:id}")
|
||||
q.Bind(Params{"id": lastID + 1}).Row(&name)
|
||||
assert.Equal(t, "name1", name)
|
||||
q.Bind(Params{"id": lastID + 2}).Row(&name)
|
||||
assert.Equal(t, "name2", name)
|
||||
}
|
||||
|
||||
err = db.Transactional(func(tx *Tx) error {
|
||||
_, err := tx.NewQuery("DELETE FROM item WHERE id=2").Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.NewQuery("DELETE FROM items WHERE id=2").Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if assert.NotNil(t, err) {
|
||||
db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name)
|
||||
assert.Equal(t, "Go in Action", name)
|
||||
}
|
||||
|
||||
// Rollback called within Transactional and return error
|
||||
err = db.Transactional(func(tx *Tx) error {
|
||||
_, err := tx.NewQuery("DELETE FROM item WHERE id=2").Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.NewQuery("DELETE FROM items WHERE id=2").Execute()
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if assert.NotNil(t, err) {
|
||||
db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name)
|
||||
assert.Equal(t, "Go in Action", name)
|
||||
}
|
||||
|
||||
// Rollback called within Transactional without returning error
|
||||
err = db.Transactional(func(tx *Tx) error {
|
||||
_, err := tx.NewQuery("DELETE FROM item WHERE id=2").Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.NewQuery("DELETE FROM items WHERE id=2").Execute()
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if assert.Nil(t, err) {
|
||||
db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name)
|
||||
assert.Equal(t, "Go in Action", name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrors_Error(t *testing.T) {
|
||||
errs := Errors{}
|
||||
assert.Equal(t, "", errs.Error())
|
||||
errs = Errors{errors.New("a")}
|
||||
assert.Equal(t, "a", errs.Error())
|
||||
errs = Errors{errors.New("a"), errors.New("b")}
|
||||
assert.Equal(t, "a\nb", errs.Error())
|
||||
}
|
||||
|
||||
func getDB() *DB {
|
||||
db, err := Open("mysql", TestDSN)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func getPreparedDB() *DB {
|
||||
db := getDB()
|
||||
s, err := ioutil.ReadFile(FixtureFile)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
lines := strings.Split(string(s), ";")
|
||||
for _, line := range lines {
|
||||
if strings.TrimSpace(line) == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := db.NewQuery(line).Execute(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// Naming according to issue 49 ( https://github.com/pocketbase/dbx/issues/49 )
|
||||
|
||||
type ArtistDAO struct {
|
||||
nickname string
|
||||
}
|
||||
|
||||
func (ArtistDAO) TableName() string {
|
||||
return "artists"
|
||||
}
|
||||
|
||||
func Test_TableNameWithPrefix(t *testing.T) {
|
||||
db := NewFromDB(nil, "mysql")
|
||||
db.TableMapper = func(a interface{}) string {
|
||||
return "tbl_" + GetTableName(a)
|
||||
}
|
||||
assert.Equal(t, "tbl_artists", db.TableMapper(ArtistDAO{}))
|
||||
}
|
281
example_test.go
Normal file
281
example_test.go
Normal file
|
@ -0,0 +1,281 @@
|
|||
package dbx_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
// This example shows how to populate DB data in different ways.
|
||||
func Example_dbQueries() {
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
// create a new query
|
||||
q := db.NewQuery("SELECT id, name FROM users LIMIT 10")
|
||||
|
||||
// fetch all rows into a struct array
|
||||
var users []struct {
|
||||
ID, Name string
|
||||
}
|
||||
q.All(&users)
|
||||
|
||||
// fetch a single row into a struct
|
||||
var user struct {
|
||||
ID, Name string
|
||||
}
|
||||
q.One(&user)
|
||||
|
||||
// fetch a single row into a string map
|
||||
data := dbx.NullStringMap{}
|
||||
q.One(data)
|
||||
|
||||
// fetch row by row
|
||||
rows2, _ := q.Rows()
|
||||
for rows2.Next() {
|
||||
rows2.ScanStruct(&user)
|
||||
// rows.ScanMap(data)
|
||||
// rows.Scan(&id, &name)
|
||||
}
|
||||
}
|
||||
|
||||
// This example shows how to use query builder to build DB queries.
|
||||
func Example_queryBuilder() {
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
// build a SELECT query
|
||||
// SELECT `id`, `name` FROM `users` WHERE `name` LIKE '%Charles%' ORDER BY `id`
|
||||
q := db.Select("id", "name").
|
||||
From("users").
|
||||
Where(dbx.Like("name", "Charles")).
|
||||
OrderBy("id")
|
||||
|
||||
// fetch all rows into a struct array
|
||||
var users []struct {
|
||||
ID, Name string
|
||||
}
|
||||
q.All(&users)
|
||||
|
||||
// build an INSERT query
|
||||
// INSERT INTO `users` (name) VALUES ('James')
|
||||
db.Insert("users", dbx.Params{
|
||||
"name": "James",
|
||||
}).Execute()
|
||||
}
|
||||
|
||||
// This example shows how to use query builder in transactions.
|
||||
func Example_transactions() {
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
db.Transactional(func(tx *dbx.Tx) error {
|
||||
_, err := tx.Insert("user", dbx.Params{
|
||||
"name": "user1",
|
||||
}).Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Insert("user", dbx.Params{
|
||||
"name": "user2",
|
||||
}).Execute()
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
type Customer struct {
|
||||
ID string
|
||||
Name string
|
||||
}
|
||||
|
||||
// This example shows how to do CRUD operations.
|
||||
func Example_crudOperations() {
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
var customer Customer
|
||||
|
||||
// read a customer: SELECT * FROM customer WHERE id=100
|
||||
db.Select().Model(100, &customer)
|
||||
|
||||
// create a customer: INSERT INTO customer (name) VALUES ('test')
|
||||
db.Model(&customer).Insert()
|
||||
|
||||
// update a customer: UPDATE customer SET name='test' WHERE id=100
|
||||
db.Model(&customer).Update()
|
||||
|
||||
// delete a customer: DELETE FROM customer WHERE id=100
|
||||
db.Model(&customer).Delete()
|
||||
}
|
||||
|
||||
func ExampleSchemaBuilder() {
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
db.Insert("users", dbx.Params{
|
||||
"name": "James",
|
||||
"age": 30,
|
||||
}).Execute()
|
||||
}
|
||||
|
||||
func ExampleRows_ScanMap() {
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
user := dbx.NullStringMap{}
|
||||
|
||||
sql := "SELECT id, name FROM users LIMIT 10"
|
||||
rows, _ := db.NewQuery(sql).Rows()
|
||||
for rows.Next() {
|
||||
rows.ScanMap(user)
|
||||
// ...
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleRows_ScanStruct() {
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
var user struct {
|
||||
ID, Name string
|
||||
}
|
||||
|
||||
sql := "SELECT id, name FROM users LIMIT 10"
|
||||
rows, _ := db.NewQuery(sql).Rows()
|
||||
for rows.Next() {
|
||||
rows.ScanStruct(&user)
|
||||
// ...
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleQuery_All() {
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
sql := "SELECT id, name FROM users LIMIT 10"
|
||||
|
||||
// fetches data into a slice of struct
|
||||
var users []struct {
|
||||
ID, Name string
|
||||
}
|
||||
db.NewQuery(sql).All(&users)
|
||||
|
||||
// fetches data into a slice of NullStringMap
|
||||
var users2 []dbx.NullStringMap
|
||||
db.NewQuery(sql).All(&users2)
|
||||
for _, user := range users2 {
|
||||
fmt.Println(user["id"].String, user["name"].String)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleQuery_One() {
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
sql := "SELECT id, name FROM users LIMIT 10"
|
||||
|
||||
// fetches data into a struct
|
||||
var user struct {
|
||||
ID, Name string
|
||||
}
|
||||
db.NewQuery(sql).One(&user)
|
||||
|
||||
// fetches data into a NullStringMap
|
||||
var user2 dbx.NullStringMap
|
||||
db.NewQuery(sql).All(user2)
|
||||
fmt.Println(user2["id"].String, user2["name"].String)
|
||||
}
|
||||
|
||||
func ExampleQuery_Row() {
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
sql := "SELECT id, name FROM users LIMIT 10"
|
||||
|
||||
// fetches data into a struct
|
||||
var (
|
||||
id int
|
||||
name string
|
||||
)
|
||||
db.NewQuery(sql).Row(&id, &name)
|
||||
}
|
||||
|
||||
func ExampleQuery_Rows() {
|
||||
var user struct {
|
||||
ID, Name string
|
||||
}
|
||||
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
sql := "SELECT id, name FROM users LIMIT 10"
|
||||
|
||||
rows, _ := db.NewQuery(sql).Rows()
|
||||
for rows.Next() {
|
||||
rows.ScanStruct(&user)
|
||||
// ...
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleQuery_Bind() {
|
||||
var user struct {
|
||||
ID, Name string
|
||||
}
|
||||
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
sql := "SELECT id, name FROM users WHERE age>{:age} AND status={:status}"
|
||||
|
||||
q := db.NewQuery(sql)
|
||||
q.Bind(dbx.Params{"age": 30, "status": 1}).One(&user)
|
||||
}
|
||||
|
||||
func ExampleQuery_Prepare() {
|
||||
var users1, users2, users3 []struct {
|
||||
ID, Name string
|
||||
}
|
||||
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
sql := "SELECT id, name FROM users WHERE age>{:age} AND status={:status}"
|
||||
|
||||
q := db.NewQuery(sql).Prepare()
|
||||
defer q.Close()
|
||||
|
||||
q.Bind(dbx.Params{"age": 30, "status": 1}).All(&users1)
|
||||
q.Bind(dbx.Params{"age": 20, "status": 1}).All(&users2)
|
||||
q.Bind(dbx.Params{"age": 10, "status": 1}).All(&users3)
|
||||
}
|
||||
|
||||
func ExampleDB() {
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
// queries data through a plain SQL
|
||||
var users []struct {
|
||||
ID, Name string
|
||||
}
|
||||
db.NewQuery("SELECT id, name FROM users WHERE age=30").All(&users)
|
||||
|
||||
// queries data using query builder
|
||||
db.Select("id", "name").From("users").Where(dbx.HashExp{"age": 30}).All(&users)
|
||||
|
||||
// executes a plain SQL
|
||||
db.NewQuery("INSERT INTO users (name) SET ({:name})").Bind(dbx.Params{"name": "James"}).Execute()
|
||||
|
||||
// executes a SQL using query builder
|
||||
db.Insert("users", dbx.Params{"name": "James"}).Execute()
|
||||
}
|
||||
|
||||
func ExampleDB_Open() {
|
||||
db, err := dbx.Open("mysql", "user:pass@/example")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var users []dbx.NullStringMap
|
||||
if err := db.NewQuery("SELECT * FROM users LIMIT 10").All(&users); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleDB_Begin() {
|
||||
db, _ := dbx.Open("mysql", "user:pass@/example")
|
||||
|
||||
tx, _ := db.Begin()
|
||||
|
||||
_, err1 := tx.Insert("user", dbx.Params{
|
||||
"name": "user1",
|
||||
}).Execute()
|
||||
_, err2 := tx.Insert("user", dbx.Params{
|
||||
"name": "user2",
|
||||
}).Execute()
|
||||
|
||||
if err1 == nil && err2 == nil {
|
||||
tx.Commit()
|
||||
} else {
|
||||
tx.Rollback()
|
||||
}
|
||||
}
|
421
expression.go
Normal file
421
expression.go
Normal file
|
@ -0,0 +1,421 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Expression represents a DB expression that can be embedded in a SQL statement.
|
||||
type Expression interface {
|
||||
// Build converts an expression into a SQL fragment.
|
||||
// If the expression contains binding parameters, they will be added to the given Params.
|
||||
Build(*DB, Params) string
|
||||
}
|
||||
|
||||
// HashExp represents a hash expression.
|
||||
//
|
||||
// A hash expression is a map whose keys are DB column names which need to be filtered according
|
||||
// to the corresponding values. For example, HashExp{"level": 2, "dept": 10} will generate
|
||||
// the SQL: "level"=2 AND "dept"=10.
|
||||
//
|
||||
// HashExp also handles nil values and slice values. For example, HashExp{"level": []interface{}{1, 2}, "dept": nil}
|
||||
// will generate: "level" IN (1, 2) AND "dept" IS NULL.
|
||||
type HashExp map[string]interface{}
|
||||
|
||||
// NewExp generates an expression with the specified SQL fragment and the optional binding parameters.
|
||||
func NewExp(e string, params ...Params) Expression {
|
||||
if len(params) > 0 {
|
||||
return &Exp{e, params[0]}
|
||||
}
|
||||
return &Exp{e, nil}
|
||||
}
|
||||
|
||||
// Not generates a NOT expression which prefixes "NOT" to the specified expression.
|
||||
func Not(e Expression) Expression {
|
||||
return &NotExp{e}
|
||||
}
|
||||
|
||||
// And generates an AND expression which concatenates the given expressions with "AND".
|
||||
func And(exps ...Expression) Expression {
|
||||
return &AndOrExp{exps, "AND"}
|
||||
}
|
||||
|
||||
// Or generates an OR expression which concatenates the given expressions with "OR".
|
||||
func Or(exps ...Expression) Expression {
|
||||
return &AndOrExp{exps, "OR"}
|
||||
}
|
||||
|
||||
// In generates an IN expression for the specified column and the list of allowed values.
|
||||
// If values is empty, a SQL "0=1" will be generated which represents a false expression.
|
||||
func In(col string, values ...interface{}) Expression {
|
||||
return &InExp{col, values, false}
|
||||
}
|
||||
|
||||
// NotIn generates an NOT IN expression for the specified column and the list of disallowed values.
|
||||
// If values is empty, an empty string will be returned indicating a true expression.
|
||||
func NotIn(col string, values ...interface{}) Expression {
|
||||
return &InExp{col, values, true}
|
||||
}
|
||||
|
||||
// DefaultLikeEscape specifies the default special character escaping for LIKE expressions
|
||||
// The strings at 2i positions are the special characters to be escaped while those at 2i+1 positions
|
||||
// are the corresponding escaped versions.
|
||||
var DefaultLikeEscape = []string{"\\", "\\\\", "%", "\\%", "_", "\\_"}
|
||||
|
||||
// Like generates a LIKE expression for the specified column and the possible strings that the column should be like.
|
||||
// If multiple values are present, the column should be like *all* of them. For example, Like("name", "key", "word")
|
||||
// will generate a SQL expression: "name" LIKE "%key%" AND "name" LIKE "%word%".
|
||||
//
|
||||
// By default, each value will be surrounded by "%" to enable partial matching. If a value contains special characters
|
||||
// such as "%", "\", "_", they will also be properly escaped.
|
||||
//
|
||||
// You may call Escape() and/or Match() to change the default behavior. For example, Like("name", "key").Match(false, true)
|
||||
// generates "name" LIKE "key%".
|
||||
func Like(col string, values ...string) *LikeExp {
|
||||
return &LikeExp{
|
||||
left: true,
|
||||
right: true,
|
||||
col: col,
|
||||
values: values,
|
||||
escape: DefaultLikeEscape,
|
||||
Like: "LIKE",
|
||||
}
|
||||
}
|
||||
|
||||
// NotLike generates a NOT LIKE expression.
|
||||
// For example, NotLike("name", "key", "word") will generate a SQL expression:
|
||||
// "name" NOT LIKE "%key%" AND "name" NOT LIKE "%word%". Please see Like() for more details.
|
||||
func NotLike(col string, values ...string) *LikeExp {
|
||||
return &LikeExp{
|
||||
left: true,
|
||||
right: true,
|
||||
col: col,
|
||||
values: values,
|
||||
escape: DefaultLikeEscape,
|
||||
Like: "NOT LIKE",
|
||||
}
|
||||
}
|
||||
|
||||
// OrLike generates an OR LIKE expression.
|
||||
// This is similar to Like() except that the column should be like one of the possible values.
|
||||
// For example, OrLike("name", "key", "word") will generate a SQL expression:
|
||||
// "name" LIKE "%key%" OR "name" LIKE "%word%". Please see Like() for more details.
|
||||
func OrLike(col string, values ...string) *LikeExp {
|
||||
return &LikeExp{
|
||||
or: true,
|
||||
left: true,
|
||||
right: true,
|
||||
col: col,
|
||||
values: values,
|
||||
escape: DefaultLikeEscape,
|
||||
Like: "LIKE",
|
||||
}
|
||||
}
|
||||
|
||||
// OrNotLike generates an OR NOT LIKE expression.
|
||||
// For example, OrNotLike("name", "key", "word") will generate a SQL expression:
|
||||
// "name" NOT LIKE "%key%" OR "name" NOT LIKE "%word%". Please see Like() for more details.
|
||||
func OrNotLike(col string, values ...string) *LikeExp {
|
||||
return &LikeExp{
|
||||
or: true,
|
||||
left: true,
|
||||
right: true,
|
||||
col: col,
|
||||
values: values,
|
||||
escape: DefaultLikeEscape,
|
||||
Like: "NOT LIKE",
|
||||
}
|
||||
}
|
||||
|
||||
// Exists generates an EXISTS expression by prefixing "EXISTS" to the given expression.
|
||||
func Exists(exp Expression) Expression {
|
||||
return &ExistsExp{exp, false}
|
||||
}
|
||||
|
||||
// NotExists generates an EXISTS expression by prefixing "NOT EXISTS" to the given expression.
|
||||
func NotExists(exp Expression) Expression {
|
||||
return &ExistsExp{exp, true}
|
||||
}
|
||||
|
||||
// Between generates a BETWEEN expression.
|
||||
// For example, Between("age", 10, 30) generates: "age" BETWEEN 10 AND 30
|
||||
func Between(col string, from, to interface{}) Expression {
|
||||
return &BetweenExp{col, from, to, false}
|
||||
}
|
||||
|
||||
// NotBetween generates a NOT BETWEEN expression.
|
||||
// For example, NotBetween("age", 10, 30) generates: "age" NOT BETWEEN 10 AND 30
|
||||
func NotBetween(col string, from, to interface{}) Expression {
|
||||
return &BetweenExp{col, from, to, true}
|
||||
}
|
||||
|
||||
// Exp represents an expression with a SQL fragment and a list of optional binding parameters.
|
||||
type Exp struct {
|
||||
e string
|
||||
params Params
|
||||
}
|
||||
|
||||
// Build converts an expression into a SQL fragment.
|
||||
func (e *Exp) Build(db *DB, params Params) string {
|
||||
if len(e.params) == 0 {
|
||||
return e.e
|
||||
}
|
||||
for k, v := range e.params {
|
||||
params[k] = v
|
||||
}
|
||||
return e.e
|
||||
}
|
||||
|
||||
// Build converts an expression into a SQL fragment.
|
||||
func (e HashExp) Build(db *DB, params Params) string {
|
||||
if len(e) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// ensure the hash exp generates the same SQL for different runs
|
||||
names := []string{}
|
||||
for name := range e {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
|
||||
var parts []string
|
||||
for _, name := range names {
|
||||
value := e[name]
|
||||
switch value.(type) {
|
||||
case nil:
|
||||
name = db.QuoteColumnName(name)
|
||||
parts = append(parts, name+" IS NULL")
|
||||
case Expression:
|
||||
if sql := value.(Expression).Build(db, params); sql != "" {
|
||||
parts = append(parts, "("+sql+")")
|
||||
}
|
||||
case []interface{}:
|
||||
in := In(name, value.([]interface{})...)
|
||||
if sql := in.Build(db, params); sql != "" {
|
||||
parts = append(parts, sql)
|
||||
}
|
||||
default:
|
||||
pn := fmt.Sprintf("p%v", len(params))
|
||||
name = db.QuoteColumnName(name)
|
||||
parts = append(parts, name+"={:"+pn+"}")
|
||||
params[pn] = value
|
||||
}
|
||||
}
|
||||
if len(parts) == 1 {
|
||||
return parts[0]
|
||||
}
|
||||
return strings.Join(parts, " AND ")
|
||||
}
|
||||
|
||||
// NotExp represents an expression that should prefix "NOT" to a specified expression.
|
||||
type NotExp struct {
|
||||
e Expression
|
||||
}
|
||||
|
||||
// Build converts an expression into a SQL fragment.
|
||||
func (e *NotExp) Build(db *DB, params Params) string {
|
||||
if sql := e.e.Build(db, params); sql != "" {
|
||||
return "NOT (" + sql + ")"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// AndOrExp represents an expression that concatenates multiple expressions using either "AND" or "OR".
|
||||
type AndOrExp struct {
|
||||
exps []Expression
|
||||
op string
|
||||
}
|
||||
|
||||
// Build converts an expression into a SQL fragment.
|
||||
func (e *AndOrExp) Build(db *DB, params Params) string {
|
||||
if len(e.exps) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var parts []string
|
||||
for _, a := range e.exps {
|
||||
if a == nil {
|
||||
continue
|
||||
}
|
||||
if sql := a.Build(db, params); sql != "" {
|
||||
parts = append(parts, sql)
|
||||
}
|
||||
}
|
||||
if len(parts) == 1 {
|
||||
return parts[0]
|
||||
}
|
||||
return "(" + strings.Join(parts, ") "+e.op+" (") + ")"
|
||||
}
|
||||
|
||||
// InExp represents an "IN" or "NOT IN" expression.
|
||||
type InExp struct {
|
||||
col string
|
||||
values []interface{}
|
||||
not bool
|
||||
}
|
||||
|
||||
// Build converts an expression into a SQL fragment.
|
||||
func (e *InExp) Build(db *DB, params Params) string {
|
||||
if len(e.values) == 0 {
|
||||
if e.not {
|
||||
return ""
|
||||
}
|
||||
return "0=1"
|
||||
}
|
||||
|
||||
var values []string
|
||||
for _, value := range e.values {
|
||||
switch value.(type) {
|
||||
case nil:
|
||||
values = append(values, "NULL")
|
||||
case Expression:
|
||||
sql := value.(Expression).Build(db, params)
|
||||
values = append(values, sql)
|
||||
default:
|
||||
name := fmt.Sprintf("p%v", len(params))
|
||||
params[name] = value
|
||||
values = append(values, "{:"+name+"}")
|
||||
}
|
||||
}
|
||||
col := db.QuoteColumnName(e.col)
|
||||
if len(values) == 1 {
|
||||
if e.not {
|
||||
return col + "<>" + values[0]
|
||||
}
|
||||
return col + "=" + values[0]
|
||||
}
|
||||
in := "IN"
|
||||
if e.not {
|
||||
in = "NOT IN"
|
||||
}
|
||||
return fmt.Sprintf("%v %v (%v)", col, in, strings.Join(values, ", "))
|
||||
}
|
||||
|
||||
// LikeExp represents a variant of LIKE expressions.
|
||||
type LikeExp struct {
|
||||
or bool
|
||||
left, right bool
|
||||
col string
|
||||
values []string
|
||||
escape []string
|
||||
|
||||
// Like stores the LIKE operator. It can be "LIKE", "NOT LIKE".
|
||||
// It may also be customized as something like "ILIKE".
|
||||
Like string
|
||||
}
|
||||
|
||||
// Escape specifies how a LIKE expression should be escaped.
|
||||
// Each string at position 2i represents a special character and the string at position 2i+1 is
|
||||
// the corresponding escaped version.
|
||||
func (e *LikeExp) Escape(chars ...string) *LikeExp {
|
||||
e.escape = chars
|
||||
return e
|
||||
}
|
||||
|
||||
// Match specifies whether to do wildcard matching on the left and/or right of given strings.
|
||||
func (e *LikeExp) Match(left, right bool) *LikeExp {
|
||||
e.left, e.right = left, right
|
||||
return e
|
||||
}
|
||||
|
||||
// Build converts an expression into a SQL fragment.
|
||||
func (e *LikeExp) Build(db *DB, params Params) string {
|
||||
if len(e.values) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if len(e.escape)%2 != 0 {
|
||||
panic("LikeExp.Escape must be a slice of even number of strings")
|
||||
}
|
||||
|
||||
var parts []string
|
||||
col := db.QuoteColumnName(e.col)
|
||||
for _, value := range e.values {
|
||||
name := fmt.Sprintf("p%v", len(params))
|
||||
for i := 0; i < len(e.escape); i += 2 {
|
||||
value = strings.Replace(value, e.escape[i], e.escape[i+1], -1)
|
||||
}
|
||||
if e.left {
|
||||
value = "%" + value
|
||||
}
|
||||
if e.right {
|
||||
value += "%"
|
||||
}
|
||||
params[name] = value
|
||||
parts = append(parts, fmt.Sprintf("%v %v {:%v}", col, e.Like, name))
|
||||
}
|
||||
|
||||
if e.or {
|
||||
return strings.Join(parts, " OR ")
|
||||
}
|
||||
return strings.Join(parts, " AND ")
|
||||
}
|
||||
|
||||
// ExistsExp represents an EXISTS or NOT EXISTS expression.
|
||||
type ExistsExp struct {
|
||||
exp Expression
|
||||
not bool
|
||||
}
|
||||
|
||||
// Build converts an expression into a SQL fragment.
|
||||
func (e *ExistsExp) Build(db *DB, params Params) string {
|
||||
sql := e.exp.Build(db, params)
|
||||
if sql == "" {
|
||||
if e.not {
|
||||
return ""
|
||||
}
|
||||
return "0=1"
|
||||
}
|
||||
if e.not {
|
||||
return "NOT EXISTS (" + sql + ")"
|
||||
}
|
||||
return "EXISTS (" + sql + ")"
|
||||
}
|
||||
|
||||
// BetweenExp represents a BETWEEN or a NOT BETWEEN expression.
|
||||
type BetweenExp struct {
|
||||
col string
|
||||
from, to interface{}
|
||||
not bool
|
||||
}
|
||||
|
||||
// Build converts an expression into a SQL fragment.
|
||||
func (e *BetweenExp) Build(db *DB, params Params) string {
|
||||
between := "BETWEEN"
|
||||
if e.not {
|
||||
between = "NOT BETWEEN"
|
||||
}
|
||||
name1 := fmt.Sprintf("p%v", len(params))
|
||||
name2 := fmt.Sprintf("p%v", len(params)+1)
|
||||
params[name1] = e.from
|
||||
params[name2] = e.to
|
||||
col := db.QuoteColumnName(e.col)
|
||||
return fmt.Sprintf("%v %v {:%v} AND {:%v}", col, between, name1, name2)
|
||||
}
|
||||
|
||||
// Enclose surrounds the provided nonempty expression with parenthesis "()".
|
||||
func Enclose(exp Expression) Expression {
|
||||
return &EncloseExp{exp}
|
||||
}
|
||||
|
||||
// EncloseExp represents a parenthesis enclosed expression.
|
||||
type EncloseExp struct {
|
||||
exp Expression
|
||||
}
|
||||
|
||||
// Build converts an expression into a SQL fragment.
|
||||
func (e *EncloseExp) Build(db *DB, params Params) string {
|
||||
str := e.exp.Build(db, params)
|
||||
|
||||
if str == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return "(" + str + ")"
|
||||
}
|
196
expression_test.go
Normal file
196
expression_test.go
Normal file
|
@ -0,0 +1,196 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestExp(t *testing.T) {
|
||||
params := Params{"k2": "v2"}
|
||||
|
||||
e1 := NewExp("s1").(*Exp)
|
||||
assert.Equal(t, e1.Build(nil, params), "s1", "e1.Build()")
|
||||
assert.Equal(t, len(params), 1, `len(params)@1`)
|
||||
|
||||
e2 := NewExp("s2", Params{"k1": "v1"}).(*Exp)
|
||||
assert.Equal(t, e2.Build(nil, params), "s2", "e2.Build()")
|
||||
assert.Equal(t, len(params), 2, `len(params)@2`)
|
||||
}
|
||||
|
||||
func TestHashExp(t *testing.T) {
|
||||
e1 := HashExp{}
|
||||
assert.Equal(t, e1.Build(nil, nil), "", `e1.Build()`)
|
||||
|
||||
e2 := HashExp{
|
||||
"k1": nil,
|
||||
"k2": NewExp("s1", Params{"ka": "va"}),
|
||||
"k3": 1.1,
|
||||
"k4": "abc",
|
||||
"k5": []interface{}{1, 2},
|
||||
}
|
||||
db := getDB()
|
||||
params := Params{"k0": "v0"}
|
||||
expected := "`k1` IS NULL AND (s1) AND `k3`={:p2} AND `k4`={:p3} AND `k5` IN ({:p4}, {:p5})"
|
||||
|
||||
assert.Equal(t, e2.Build(db, params), expected, `e2.Build()`)
|
||||
assert.Equal(t, len(params), 6, `len(params)`)
|
||||
assert.Equal(t, params["p5"].(int), 2, `params["p5"]`)
|
||||
}
|
||||
|
||||
func TestNotExp(t *testing.T) {
|
||||
e1 := Not(NewExp("s1"))
|
||||
assert.Equal(t, e1.Build(nil, nil), "NOT (s1)", `e1.Build()`)
|
||||
|
||||
e2 := Not(NewExp(""))
|
||||
assert.Equal(t, e2.Build(nil, nil), "", `e2.Build()`)
|
||||
}
|
||||
|
||||
func TestAndOrExp(t *testing.T) {
|
||||
e1 := And(NewExp("s1", Params{"k1": "v1"}), NewExp(""), NewExp("s2", Params{"k2": "v2"}))
|
||||
params := Params{}
|
||||
assert.Equal(t, e1.Build(nil, params), "(s1) AND (s2)", `e1.Build()`)
|
||||
assert.Equal(t, len(params), 2, `len(params)`)
|
||||
|
||||
e2 := Or(NewExp("s1"), NewExp("s2"))
|
||||
assert.Equal(t, e2.Build(nil, nil), "(s1) OR (s2)", `e2.Build()`)
|
||||
|
||||
e3 := And()
|
||||
assert.Equal(t, e3.Build(nil, nil), "", `e3.Build()`)
|
||||
|
||||
e4 := And(NewExp("s1"))
|
||||
assert.Equal(t, e4.Build(nil, nil), "s1", `e4.Build()`)
|
||||
|
||||
e5 := And(NewExp("s1"), nil)
|
||||
assert.Equal(t, e5.Build(nil, nil), "s1", `e5.Build()`)
|
||||
}
|
||||
|
||||
func TestInExp(t *testing.T) {
|
||||
db := getDB()
|
||||
|
||||
e1 := In("age", 1, 2, 3)
|
||||
params := Params{}
|
||||
assert.Equal(t, e1.Build(db, params), "`age` IN ({:p0}, {:p1}, {:p2})", `e1.Build()`)
|
||||
assert.Equal(t, len(params), 3, `len(params)@1`)
|
||||
|
||||
e2 := In("age", 1)
|
||||
params = Params{}
|
||||
assert.Equal(t, e2.Build(db, params), "`age`={:p0}", `e2.Build()`)
|
||||
assert.Equal(t, len(params), 1, `len(params)@2`)
|
||||
|
||||
e3 := NotIn("age", 1, 2, 3)
|
||||
params = Params{}
|
||||
assert.Equal(t, e3.Build(db, params), "`age` NOT IN ({:p0}, {:p1}, {:p2})", `e3.Build()`)
|
||||
assert.Equal(t, len(params), 3, `len(params)@3`)
|
||||
|
||||
e4 := NotIn("age", 1)
|
||||
params = Params{}
|
||||
assert.Equal(t, e4.Build(db, params), "`age`<>{:p0}", `e4.Build()`)
|
||||
assert.Equal(t, len(params), 1, `len(params)@4`)
|
||||
|
||||
e5 := In("age")
|
||||
assert.Equal(t, e5.Build(db, nil), "0=1", `e5.Build()`)
|
||||
|
||||
e6 := NotIn("age")
|
||||
assert.Equal(t, e6.Build(db, nil), "", `e6.Build()`)
|
||||
}
|
||||
|
||||
func TestLikeExp(t *testing.T) {
|
||||
db := getDB()
|
||||
|
||||
e1 := Like("name", "a", "b", "c")
|
||||
params := Params{}
|
||||
assert.Equal(t, e1.Build(db, params), "`name` LIKE {:p0} AND `name` LIKE {:p1} AND `name` LIKE {:p2}", `e1.Build()`)
|
||||
assert.Equal(t, len(params), 3, `len(params)@1`)
|
||||
|
||||
e2 := Like("name", "a")
|
||||
params = Params{}
|
||||
assert.Equal(t, e2.Build(db, params), "`name` LIKE {:p0}", `e2.Build()`)
|
||||
assert.Equal(t, len(params), 1, `len(params)@2`)
|
||||
|
||||
e3 := Like("name")
|
||||
assert.Equal(t, e3.Build(db, nil), "", `e3.Build()`)
|
||||
|
||||
e4 := NotLike("name", "a", "b", "c")
|
||||
params = Params{}
|
||||
assert.Equal(t, e4.Build(db, params), "`name` NOT LIKE {:p0} AND `name` NOT LIKE {:p1} AND `name` NOT LIKE {:p2}", `e4.Build()`)
|
||||
assert.Equal(t, len(params), 3, `len(params)@4`)
|
||||
|
||||
e5 := OrLike("name", "a", "b", "c")
|
||||
params = Params{}
|
||||
assert.Equal(t, e5.Build(db, params), "`name` LIKE {:p0} OR `name` LIKE {:p1} OR `name` LIKE {:p2}", `e5.Build()`)
|
||||
assert.Equal(t, len(params), 3, `len(params)@5`)
|
||||
|
||||
e6 := OrNotLike("name", "a", "b", "c")
|
||||
params = Params{}
|
||||
assert.Equal(t, e6.Build(db, params), "`name` NOT LIKE {:p0} OR `name` NOT LIKE {:p1} OR `name` NOT LIKE {:p2}", `e6.Build()`)
|
||||
assert.Equal(t, len(params), 3, `len(params)@6`)
|
||||
|
||||
e7 := Like("name", "a_\\%")
|
||||
params = Params{}
|
||||
e7.Build(db, params)
|
||||
assert.Equal(t, params["p0"], "%a\\_\\\\\\%%", `params["p0"]@1`)
|
||||
|
||||
e8 := Like("name", "a").Match(false, true)
|
||||
params = Params{}
|
||||
e8.Build(db, params)
|
||||
assert.Equal(t, params["p0"], "a%", `params["p0"]@2`)
|
||||
|
||||
e9 := Like("name", "a").Match(true, false)
|
||||
params = Params{}
|
||||
e9.Build(db, params)
|
||||
assert.Equal(t, params["p0"], "%a", `params["p0"]@3`)
|
||||
|
||||
e10 := Like("name", "a").Match(false, false)
|
||||
params = Params{}
|
||||
e10.Build(db, params)
|
||||
assert.Equal(t, params["p0"], "a", `params["p0"]@4`)
|
||||
|
||||
e11 := Like("name", "%a").Match(false, false).Escape()
|
||||
params = Params{}
|
||||
e11.Build(db, params)
|
||||
assert.Equal(t, params["p0"], "%a", `params["p0"]@5`)
|
||||
}
|
||||
|
||||
func TestBetweenExp(t *testing.T) {
|
||||
db := getDB()
|
||||
|
||||
e1 := Between("age", 30, 40)
|
||||
params := Params{}
|
||||
assert.Equal(t, e1.Build(db, params), "`age` BETWEEN {:p0} AND {:p1}", `e1.Build()`)
|
||||
assert.Equal(t, len(params), 2, `len(params)@1`)
|
||||
|
||||
e2 := NotBetween("age", 30, 40)
|
||||
params = Params{}
|
||||
assert.Equal(t, e2.Build(db, params), "`age` NOT BETWEEN {:p0} AND {:p1}", `e2.Build()`)
|
||||
assert.Equal(t, len(params), 2, `len(params)@2`)
|
||||
}
|
||||
|
||||
func TestExistsExp(t *testing.T) {
|
||||
e1 := Exists(NewExp("s1"))
|
||||
assert.Equal(t, e1.Build(nil, nil), "EXISTS (s1)", `e1.Build()`)
|
||||
|
||||
e2 := NotExists(NewExp("s1"))
|
||||
assert.Equal(t, e2.Build(nil, nil), "NOT EXISTS (s1)", `e2.Build()`)
|
||||
|
||||
e3 := Exists(NewExp(""))
|
||||
assert.Equal(t, e3.Build(nil, nil), "0=1", `e3.Build()`)
|
||||
|
||||
e4 := NotExists(NewExp(""))
|
||||
assert.Equal(t, e4.Build(nil, nil), "", `e4.Build()`)
|
||||
}
|
||||
|
||||
func TestEncloseExp(t *testing.T) {
|
||||
e1 := Enclose(NewExp(""))
|
||||
assert.Equal(t, e1.Build(nil, nil), "", `e1.Build()`)
|
||||
|
||||
e2 := Enclose(NewExp("s1"))
|
||||
assert.Equal(t, e2.Build(nil, nil), "(s1)", `e2.Build()`)
|
||||
|
||||
e3 := Enclose(NewExp("(s1)"))
|
||||
assert.Equal(t, e3.Build(nil, nil), "((s1))", `e3.Build()`)
|
||||
}
|
9
go.mod
Normal file
9
go.mod
Normal file
|
@ -0,0 +1,9 @@
|
|||
module github.com/pocketbase/dbx
|
||||
|
||||
go 1.13
|
||||
|
||||
require (
|
||||
github.com/go-sql-driver/mysql v1.4.1
|
||||
github.com/stretchr/testify v1.4.0
|
||||
google.golang.org/appengine v1.6.5 // indirect
|
||||
)
|
22
go.sum
Normal file
22
go.sum
Normal file
|
@ -0,0 +1,22 @@
|
|||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA=
|
||||
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
|
||||
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
|
||||
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
174
model_query.go
Normal file
174
model_query.go
Normal file
|
@ -0,0 +1,174 @@
|
|||
package dbx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type (
|
||||
// TableModel is the interface that should be implemented by models which have unconventional table names.
|
||||
TableModel interface {
|
||||
TableName() string
|
||||
}
|
||||
|
||||
// ModelQuery represents a query associated with a struct model.
|
||||
ModelQuery struct {
|
||||
db *DB
|
||||
ctx context.Context
|
||||
builder Builder
|
||||
model *structValue
|
||||
exclude []string
|
||||
lastError error
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
MissingPKError = errors.New("missing primary key declaration")
|
||||
CompositePKError = errors.New("composite primary key is not supported")
|
||||
)
|
||||
|
||||
func NewModelQuery(model interface{}, fieldMapFunc FieldMapFunc, db *DB, builder Builder) *ModelQuery {
|
||||
q := &ModelQuery{
|
||||
db: db,
|
||||
ctx: db.ctx,
|
||||
builder: builder,
|
||||
model: newStructValue(model, fieldMapFunc, db.TableMapper),
|
||||
}
|
||||
if q.model == nil {
|
||||
q.lastError = VarTypeError("must be a pointer to a struct representing the model")
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// Context returns the context associated with the query.
|
||||
func (q *ModelQuery) Context() context.Context {
|
||||
return q.ctx
|
||||
}
|
||||
|
||||
// WithContext associates a context with the query.
|
||||
func (q *ModelQuery) WithContext(ctx context.Context) *ModelQuery {
|
||||
q.ctx = ctx
|
||||
return q
|
||||
}
|
||||
|
||||
// Exclude excludes the specified struct fields from being inserted/updated into the DB table.
|
||||
func (q *ModelQuery) Exclude(attrs ...string) *ModelQuery {
|
||||
q.exclude = attrs
|
||||
return q
|
||||
}
|
||||
|
||||
// Insert inserts a row in the table using the struct model associated with this query.
|
||||
//
|
||||
// By default, it inserts *all* public fields into the table, including those nil or empty ones.
|
||||
// You may pass a list of the fields to this method to indicate that only those fields should be inserted.
|
||||
// You may also call Exclude to exclude some fields from being inserted.
|
||||
//
|
||||
// If a model has an empty primary key, it is considered auto-incremental and the corresponding struct
|
||||
// field will be filled with the generated primary key value after a successful insertion.
|
||||
func (q *ModelQuery) Insert(attrs ...string) error {
|
||||
if q.lastError != nil {
|
||||
return q.lastError
|
||||
}
|
||||
cols := q.model.columns(attrs, q.exclude)
|
||||
pkName := ""
|
||||
for name, value := range q.model.pk() {
|
||||
if isAutoInc(value) {
|
||||
delete(cols, name)
|
||||
pkName = name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if pkName == "" {
|
||||
_, err := q.builder.Insert(q.model.tableName, Params(cols)).WithContext(q.ctx).Execute()
|
||||
return err
|
||||
}
|
||||
|
||||
// handle auto-incremental PK
|
||||
query := q.builder.Insert(q.model.tableName, Params(cols)).WithContext(q.ctx)
|
||||
pkValue, err := insertAndReturnPK(q.db, query, pkName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pkField := indirect(q.model.dbNameMap[pkName].getField(q.model.value))
|
||||
switch pkField.Kind() {
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
pkField.SetUint(uint64(pkValue))
|
||||
default:
|
||||
pkField.SetInt(pkValue)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func insertAndReturnPK(db *DB, query *Query, pkName string) (int64, error) {
|
||||
if db.DriverName() != "postgres" && db.DriverName() != "pgx" {
|
||||
result, err := query.Execute()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.LastInsertId()
|
||||
}
|
||||
|
||||
// specially handle postgres (lib/pq) as it doesn't support LastInsertId
|
||||
returning := fmt.Sprintf(" RETURNING %s", db.QuoteColumnName(pkName))
|
||||
query.sql += returning
|
||||
query.rawSQL += returning
|
||||
var pkValue int64
|
||||
err := query.Row(&pkValue)
|
||||
return pkValue, err
|
||||
}
|
||||
|
||||
func isAutoInc(value interface{}) bool {
|
||||
v := reflect.ValueOf(value)
|
||||
switch v.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return v.Int() == 0
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return v.Uint() == 0
|
||||
case reflect.Ptr:
|
||||
return v.IsNil() || isAutoInc(v.Elem())
|
||||
case reflect.Invalid:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Update updates a row in the table using the struct model associated with this query.
|
||||
// The row being updated has the same primary key as specified by the model.
|
||||
//
|
||||
// By default, it updates *all* public fields in the table, including those nil or empty ones.
|
||||
// You may pass a list of the fields to this method to indicate that only those fields should be updated.
|
||||
// You may also call Exclude to exclude some fields from being updated.
|
||||
func (q *ModelQuery) Update(attrs ...string) error {
|
||||
if q.lastError != nil {
|
||||
return q.lastError
|
||||
}
|
||||
pk := q.model.pk()
|
||||
if len(pk) == 0 {
|
||||
return MissingPKError
|
||||
}
|
||||
|
||||
cols := q.model.columns(attrs, q.exclude)
|
||||
for name := range pk {
|
||||
delete(cols, name)
|
||||
}
|
||||
_, err := q.builder.Update(q.model.tableName, Params(cols), HashExp(pk)).WithContext(q.ctx).Execute()
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete deletes a row in the table using the primary key specified by the struct model associated with this query.
|
||||
func (q *ModelQuery) Delete() error {
|
||||
if q.lastError != nil {
|
||||
return q.lastError
|
||||
}
|
||||
pk := q.model.pk()
|
||||
if len(pk) == 0 {
|
||||
return MissingPKError
|
||||
}
|
||||
_, err := q.builder.Delete(q.model.tableName, HashExp(pk)).WithContext(q.ctx).Execute()
|
||||
return err
|
||||
}
|
238
model_query_test.go
Normal file
238
model_query_test.go
Normal file
|
@ -0,0 +1,238 @@
|
|||
package dbx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type Item struct {
|
||||
ID2 int
|
||||
Name string
|
||||
}
|
||||
|
||||
func TestModelQuery_Insert(t *testing.T) {
|
||||
db := getPreparedDB()
|
||||
defer db.Close()
|
||||
|
||||
name := "test"
|
||||
email := "test@example.com"
|
||||
|
||||
{
|
||||
// inserting normally
|
||||
customer := Customer{
|
||||
Name: name,
|
||||
Email: email,
|
||||
}
|
||||
err := db.Model(&customer).Insert()
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, 4, customer.ID)
|
||||
var c Customer
|
||||
db.Select().From("customer").Where(HashExp{"ID": 4}).One(&c)
|
||||
assert.Equal(t, name, c.Name)
|
||||
assert.Equal(t, email, c.Email)
|
||||
assert.Equal(t, 0, c.Status)
|
||||
assert.False(t, c.Address.Valid)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// inserting with pointer-typed fields
|
||||
customer := CustomerPtr{
|
||||
Name: name,
|
||||
Email: &email,
|
||||
}
|
||||
err := db.Model(&customer).Insert()
|
||||
if assert.Nil(t, err) && assert.NotNil(t, customer.ID) {
|
||||
assert.Equal(t, 5, *customer.ID)
|
||||
var c CustomerPtr
|
||||
db.Select().From("customer").Where(HashExp{"ID": 4}).One(&c)
|
||||
assert.Equal(t, name, c.Name)
|
||||
if assert.NotNil(t, c.Email) {
|
||||
assert.Equal(t, email, *c.Email)
|
||||
}
|
||||
if assert.NotNil(t, c.Status) {
|
||||
assert.Equal(t, 0, *c.Status)
|
||||
}
|
||||
assert.Nil(t, c.Address)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// inserting with null-typed fields
|
||||
customer := CustomerNull{
|
||||
Name: name,
|
||||
Email: sql.NullString{email, true},
|
||||
}
|
||||
err := db.Model(&customer).Insert()
|
||||
if assert.Nil(t, err) {
|
||||
// potential todo: need to check if the field implements sql.Scanner
|
||||
// assert.Equal(t, int64(6), customer.ID.Int64)
|
||||
var c CustomerNull
|
||||
db.Select().From("customer").Where(HashExp{"ID": 4}).One(&c)
|
||||
assert.Equal(t, name, c.Name)
|
||||
assert.Equal(t, email, c.Email.String)
|
||||
if assert.NotNil(t, c.Status) {
|
||||
assert.Equal(t, int64(0), c.Status.Int64)
|
||||
}
|
||||
assert.False(t, c.Address.Valid)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// inserting with embedded structures
|
||||
customer := CustomerEmbedded{
|
||||
Id: 100,
|
||||
Email: &email,
|
||||
InnerCustomer: InnerCustomer{
|
||||
Name: &name,
|
||||
Status: sql.NullInt64{1, true},
|
||||
},
|
||||
}
|
||||
err := db.Model(&customer).Insert()
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, 100, customer.Id)
|
||||
var c CustomerEmbedded
|
||||
db.Select().From("customer").Where(HashExp{"ID": 100}).One(&c)
|
||||
assert.Equal(t, name, *c.Name)
|
||||
assert.Equal(t, email, *c.Email)
|
||||
if assert.NotNil(t, c.Status) {
|
||||
assert.Equal(t, int64(1), c.Status.Int64)
|
||||
}
|
||||
assert.False(t, c.Address.Valid)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// inserting with include/exclude fields
|
||||
customer := Customer{
|
||||
Name: name,
|
||||
Email: email,
|
||||
Status: 1,
|
||||
}
|
||||
err := db.Model(&customer).Exclude("Name").Insert("Name", "Email")
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, 101, customer.ID)
|
||||
var c Customer
|
||||
db.Select().From("customer").Where(HashExp{"ID": 101}).One(&c)
|
||||
assert.Equal(t, "", c.Name)
|
||||
assert.Equal(t, email, c.Email)
|
||||
assert.Equal(t, 0, c.Status)
|
||||
assert.False(t, c.Address.Valid)
|
||||
}
|
||||
}
|
||||
|
||||
var a int
|
||||
assert.NotNil(t, db.Model(&a).Insert())
|
||||
}
|
||||
|
||||
func TestModelQuery_Update(t *testing.T) {
|
||||
db := getPreparedDB()
|
||||
defer db.Close()
|
||||
|
||||
id := 2
|
||||
name := "test"
|
||||
email := "test@example.com"
|
||||
{
|
||||
// updating normally
|
||||
customer := Customer{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Email: email,
|
||||
}
|
||||
err := db.Model(&customer).Update()
|
||||
if assert.Nil(t, err) {
|
||||
var c Customer
|
||||
db.Select().From("customer").Where(HashExp{"ID": id}).One(&c)
|
||||
assert.Equal(t, name, c.Name)
|
||||
assert.Equal(t, email, c.Email)
|
||||
assert.Equal(t, 0, c.Status)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// updating without primary keys
|
||||
item2 := Item{
|
||||
Name: name,
|
||||
}
|
||||
err := db.Model(&item2).Update()
|
||||
assert.Equal(t, MissingPKError, err)
|
||||
}
|
||||
|
||||
{
|
||||
// updating all fields
|
||||
customer := CustomerPtr{
|
||||
ID: &id,
|
||||
Name: name,
|
||||
Email: &email,
|
||||
}
|
||||
err := db.Model(&customer).Update()
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, id, *customer.ID)
|
||||
var c CustomerPtr
|
||||
db.Select().From("customer").Where(HashExp{"ID": id}).One(&c)
|
||||
assert.Equal(t, name, c.Name)
|
||||
if assert.NotNil(t, c.Email) {
|
||||
assert.Equal(t, email, *c.Email)
|
||||
}
|
||||
assert.Nil(t, c.Status)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// updating selected fields only
|
||||
id = 3
|
||||
customer := CustomerPtr{
|
||||
ID: &id,
|
||||
Name: name,
|
||||
Email: &email,
|
||||
}
|
||||
err := db.Model(&customer).Update("Name", "Email")
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, id, *customer.ID)
|
||||
var c CustomerPtr
|
||||
db.Select().From("customer").Where(HashExp{"ID": id}).One(&c)
|
||||
assert.Equal(t, name, c.Name)
|
||||
if assert.NotNil(t, c.Email) {
|
||||
assert.Equal(t, email, *c.Email)
|
||||
}
|
||||
if assert.NotNil(t, c.Status) {
|
||||
assert.Equal(t, 2, *c.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// updating non-struct
|
||||
var a int
|
||||
assert.NotNil(t, db.Model(&a).Update())
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelQuery_Delete(t *testing.T) {
|
||||
db := getPreparedDB()
|
||||
defer db.Close()
|
||||
|
||||
customer := Customer{
|
||||
ID: 2,
|
||||
}
|
||||
err := db.Model(&customer).Delete()
|
||||
if assert.Nil(t, err) {
|
||||
var m Customer
|
||||
err := db.Select().From("customer").Where(HashExp{"ID": 2}).One(&m)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
{
|
||||
// deleting without primary keys
|
||||
item2 := Item{
|
||||
Name: "",
|
||||
}
|
||||
err := db.Model(&item2).Delete()
|
||||
assert.Equal(t, MissingPKError, err)
|
||||
}
|
||||
|
||||
var a int
|
||||
assert.NotNil(t, db.Model(&a).Delete())
|
||||
}
|
384
query.go
Normal file
384
query.go
Normal file
|
@ -0,0 +1,384 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ExecHookFunc executes before op allowing custom handling like auto fail/retry.
|
||||
type ExecHookFunc func(q *Query, op func() error) error
|
||||
|
||||
// OneHookFunc executes right before the query populate the row result from One() call (aka. op).
|
||||
type OneHookFunc func(q *Query, a interface{}, op func(b interface{}) error) error
|
||||
|
||||
// AllHookFunc executes right before the query populate the row result from All() call (aka. op).
|
||||
type AllHookFunc func(q *Query, sliceA interface{}, op func(sliceB interface{}) error) error
|
||||
|
||||
// Params represents a list of parameter values to be bound to a SQL statement.
|
||||
// The map keys are the parameter names while the map values are the corresponding parameter values.
|
||||
type Params map[string]interface{}
|
||||
|
||||
// Executor prepares, executes, or queries a SQL statement.
|
||||
type Executor interface {
|
||||
// Exec executes a SQL statement
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
// ExecContext executes a SQL statement with the given context
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
// Query queries a SQL statement
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
// QueryContext queries a SQL statement with the given context
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
// Prepare creates a prepared statement
|
||||
Prepare(query string) (*sql.Stmt, error)
|
||||
}
|
||||
|
||||
// Query represents a SQL statement to be executed.
|
||||
type Query struct {
|
||||
executor Executor
|
||||
|
||||
sql, rawSQL string
|
||||
placeholders []string
|
||||
params Params
|
||||
|
||||
stmt *sql.Stmt
|
||||
ctx context.Context
|
||||
|
||||
// hooks
|
||||
execHook ExecHookFunc
|
||||
oneHook OneHookFunc
|
||||
allHook AllHookFunc
|
||||
|
||||
// FieldMapper maps struct field names to DB column names.
|
||||
FieldMapper FieldMapFunc
|
||||
// LastError contains the last error (if any) of the query.
|
||||
// LastError is cleared by Execute(), Row(), Rows(), One(), and All().
|
||||
LastError error
|
||||
// LogFunc is used to log the SQL statement being executed.
|
||||
LogFunc LogFunc
|
||||
// PerfFunc is used to log the SQL execution time. It is ignored if nil.
|
||||
// Deprecated: Please use QueryLogFunc and ExecLogFunc instead.
|
||||
PerfFunc PerfFunc
|
||||
// QueryLogFunc is called each time when performing a SQL query that returns data.
|
||||
QueryLogFunc QueryLogFunc
|
||||
// ExecLogFunc is called each time when a SQL statement is executed.
|
||||
ExecLogFunc ExecLogFunc
|
||||
}
|
||||
|
||||
// NewQuery creates a new Query with the given SQL statement.
|
||||
func NewQuery(db *DB, executor Executor, sql string) *Query {
|
||||
rawSQL, placeholders := db.processSQL(sql)
|
||||
return &Query{
|
||||
executor: executor,
|
||||
sql: sql,
|
||||
rawSQL: rawSQL,
|
||||
placeholders: placeholders,
|
||||
params: Params{},
|
||||
ctx: db.ctx,
|
||||
FieldMapper: db.FieldMapper,
|
||||
LogFunc: db.LogFunc,
|
||||
PerfFunc: db.PerfFunc,
|
||||
QueryLogFunc: db.QueryLogFunc,
|
||||
ExecLogFunc: db.ExecLogFunc,
|
||||
}
|
||||
}
|
||||
|
||||
// SQL returns the original SQL used to create the query.
|
||||
// The actual SQL (RawSQL) being executed is obtained by replacing the named
|
||||
// parameter placeholders with anonymous ones.
|
||||
func (q *Query) SQL() string {
|
||||
return q.sql
|
||||
}
|
||||
|
||||
// Context returns the context associated with the query.
|
||||
func (q *Query) Context() context.Context {
|
||||
return q.ctx
|
||||
}
|
||||
|
||||
// WithContext associates a context with the query.
|
||||
func (q *Query) WithContext(ctx context.Context) *Query {
|
||||
q.ctx = ctx
|
||||
return q
|
||||
}
|
||||
|
||||
// WithExecHook associates the provided exec hook function with the query.
|
||||
//
|
||||
// It is called for every Query resolver (Execute(), One(), All(), Row(), Column()),
|
||||
// allowing you to implement auto fail/retry or any other additional handling.
|
||||
func (q *Query) WithExecHook(fn ExecHookFunc) *Query {
|
||||
q.execHook = fn
|
||||
return q
|
||||
}
|
||||
|
||||
// WithOneHook associates the provided hook function with the query,
|
||||
// called on q.One(), allowing you to implement custom struct scan based
|
||||
// on the One() argument and/or result.
|
||||
func (q *Query) WithOneHook(fn OneHookFunc) *Query {
|
||||
q.oneHook = fn
|
||||
return q
|
||||
}
|
||||
|
||||
// WithOneHook associates the provided hook function with the query,
|
||||
// called on q.All(), allowing you to implement custom slice scan based
|
||||
// on the All() argument and/or result.
|
||||
func (q *Query) WithAllHook(fn AllHookFunc) *Query {
|
||||
q.allHook = fn
|
||||
return q
|
||||
}
|
||||
|
||||
// logSQL returns the SQL statement with parameters being replaced with the actual values.
|
||||
// The result is only for logging purpose and should not be used to execute.
|
||||
func (q *Query) logSQL() string {
|
||||
s := q.sql
|
||||
for k, v := range q.params {
|
||||
if valuer, ok := v.(driver.Valuer); ok && valuer != nil {
|
||||
v, _ = valuer.Value()
|
||||
}
|
||||
var sv string
|
||||
if str, ok := v.(string); ok {
|
||||
sv = "'" + strings.Replace(str, "'", "''", -1) + "'"
|
||||
} else if bs, ok := v.([]byte); ok {
|
||||
sv = "0x" + hex.EncodeToString(bs)
|
||||
} else {
|
||||
sv = fmt.Sprintf("%v", v)
|
||||
}
|
||||
s = strings.Replace(s, "{:"+k+"}", sv, -1)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Params returns the parameters to be bound to the SQL statement represented by this query.
|
||||
func (q *Query) Params() Params {
|
||||
return q.params
|
||||
}
|
||||
|
||||
// Prepare creates a prepared statement for later queries or executions.
|
||||
// Close() should be called after finishing all queries.
|
||||
func (q *Query) Prepare() *Query {
|
||||
stmt, err := q.executor.Prepare(q.rawSQL)
|
||||
if err != nil {
|
||||
q.LastError = err
|
||||
return q
|
||||
}
|
||||
q.stmt = stmt
|
||||
return q
|
||||
}
|
||||
|
||||
// Close closes the underlying prepared statement.
|
||||
// Close does nothing if the query has not been prepared before.
|
||||
func (q *Query) Close() error {
|
||||
if q.stmt == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := q.stmt.Close()
|
||||
q.stmt = nil
|
||||
return err
|
||||
}
|
||||
|
||||
// Bind sets the parameters that should be bound to the SQL statement.
|
||||
// The parameter placeholders in the SQL statement are in the format of "{:ParamName}".
|
||||
func (q *Query) Bind(params Params) *Query {
|
||||
if len(q.params) == 0 {
|
||||
q.params = params
|
||||
} else {
|
||||
for k, v := range params {
|
||||
q.params[k] = v
|
||||
}
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// Execute executes the SQL statement without retrieving data.
|
||||
func (q *Query) Execute() (sql.Result, error) {
|
||||
var result sql.Result
|
||||
|
||||
execErr := q.execWrap(func() error {
|
||||
var err error
|
||||
result, err = q.execute()
|
||||
return err
|
||||
})
|
||||
|
||||
return result, execErr
|
||||
}
|
||||
|
||||
func (q *Query) execute() (result sql.Result, err error) {
|
||||
err = q.LastError
|
||||
q.LastError = nil
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var params []interface{}
|
||||
params, err = replacePlaceholders(q.placeholders, q.params)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
if q.ctx == nil {
|
||||
if q.stmt == nil {
|
||||
result, err = q.executor.Exec(q.rawSQL, params...)
|
||||
} else {
|
||||
result, err = q.stmt.Exec(params...)
|
||||
}
|
||||
} else {
|
||||
if q.stmt == nil {
|
||||
result, err = q.executor.ExecContext(q.ctx, q.rawSQL, params...)
|
||||
} else {
|
||||
result, err = q.stmt.ExecContext(q.ctx, params...)
|
||||
}
|
||||
}
|
||||
|
||||
if q.ExecLogFunc != nil {
|
||||
q.ExecLogFunc(q.ctx, time.Now().Sub(start), q.logSQL(), result, err)
|
||||
}
|
||||
if q.LogFunc != nil {
|
||||
q.LogFunc("[%.2fms] Execute SQL: %v", float64(time.Now().Sub(start).Milliseconds()), q.logSQL())
|
||||
}
|
||||
if q.PerfFunc != nil {
|
||||
q.PerfFunc(time.Now().Sub(start).Nanoseconds(), q.logSQL(), true)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// One executes the SQL statement and populates the first row of the result into a struct or NullStringMap.
|
||||
// Refer to Rows.ScanStruct() and Rows.ScanMap() for more details on how to specify
|
||||
// the variable to be populated.
|
||||
// Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned.
|
||||
func (q *Query) One(a interface{}) error {
|
||||
return q.execWrap(func() error {
|
||||
rows, err := q.Rows()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if q.oneHook != nil {
|
||||
return q.oneHook(q, a, rows.one)
|
||||
}
|
||||
|
||||
return rows.one(a)
|
||||
})
|
||||
}
|
||||
|
||||
// All executes the SQL statement and populates all the resulting rows into a slice of struct or NullStringMap.
|
||||
// The slice must be given as a pointer. Each slice element must be either a struct or a NullStringMap.
|
||||
// Refer to Rows.ScanStruct() and Rows.ScanMap() for more details on how each slice element can be.
|
||||
// If the query returns no row, the slice will be an empty slice (not nil).
|
||||
func (q *Query) All(slice interface{}) error {
|
||||
return q.execWrap(func() error {
|
||||
rows, err := q.Rows()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if q.allHook != nil {
|
||||
return q.allHook(q, slice, rows.all)
|
||||
}
|
||||
|
||||
return rows.all(slice)
|
||||
})
|
||||
}
|
||||
|
||||
// Row executes the SQL statement and populates the first row of the result into a list of variables.
|
||||
// Note that the number of the variables should match to that of the columns in the query result.
|
||||
// Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned.
|
||||
func (q *Query) Row(a ...interface{}) error {
|
||||
return q.execWrap(func() error {
|
||||
rows, err := q.Rows()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return rows.row(a...)
|
||||
})
|
||||
}
|
||||
|
||||
// Column executes the SQL statement and populates the first column of the result into a slice.
|
||||
// Note that the parameter must be a pointer to a slice.
|
||||
func (q *Query) Column(a interface{}) error {
|
||||
return q.execWrap(func() error {
|
||||
rows, err := q.Rows()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return rows.column(a)
|
||||
})
|
||||
}
|
||||
|
||||
// Rows executes the SQL statement and returns a Rows object to allow retrieving data row by row.
|
||||
func (q *Query) Rows() (rows *Rows, err error) {
|
||||
err = q.LastError
|
||||
q.LastError = nil
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var params []interface{}
|
||||
params, err = replacePlaceholders(q.placeholders, q.params)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
var rr *sql.Rows
|
||||
if q.ctx == nil {
|
||||
if q.stmt == nil {
|
||||
rr, err = q.executor.Query(q.rawSQL, params...)
|
||||
} else {
|
||||
rr, err = q.stmt.Query(params...)
|
||||
}
|
||||
} else {
|
||||
if q.stmt == nil {
|
||||
rr, err = q.executor.QueryContext(q.ctx, q.rawSQL, params...)
|
||||
} else {
|
||||
rr, err = q.stmt.QueryContext(q.ctx, params...)
|
||||
}
|
||||
}
|
||||
rows = &Rows{rr, q.FieldMapper}
|
||||
|
||||
if q.QueryLogFunc != nil {
|
||||
q.QueryLogFunc(q.ctx, time.Now().Sub(start), q.logSQL(), rr, err)
|
||||
}
|
||||
if q.LogFunc != nil {
|
||||
q.LogFunc("[%.2fms] Query SQL: %v", float64(time.Now().Sub(start).Milliseconds()), q.logSQL())
|
||||
}
|
||||
if q.PerfFunc != nil {
|
||||
q.PerfFunc(time.Now().Sub(start).Nanoseconds(), q.logSQL(), false)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (q *Query) execWrap(op func() error) error {
|
||||
if q.execHook != nil {
|
||||
return q.execHook(q, op)
|
||||
}
|
||||
return op()
|
||||
}
|
||||
|
||||
// replacePlaceholders converts a list of named parameters into a list of anonymous parameters.
|
||||
func replacePlaceholders(placeholders []string, params Params) ([]interface{}, error) {
|
||||
if len(placeholders) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result []interface{}
|
||||
for _, name := range placeholders {
|
||||
if value, ok := params[name]; ok {
|
||||
result = append(result, value)
|
||||
} else {
|
||||
return nil, errors.New("Named parameter not found: " + name)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
244
query_builder.go
Normal file
244
query_builder.go
Normal file
|
@ -0,0 +1,244 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// QueryBuilder builds different clauses for a SELECT SQL statement.
|
||||
type QueryBuilder interface {
|
||||
// BuildSelect generates a SELECT clause from the given selected column names.
|
||||
BuildSelect(cols []string, distinct bool, option string) string
|
||||
// BuildFrom generates a FROM clause from the given tables.
|
||||
BuildFrom(tables []string) string
|
||||
// BuildGroupBy generates a GROUP BY clause from the given group-by columns.
|
||||
BuildGroupBy(cols []string) string
|
||||
// BuildJoin generates a JOIN clause from the given join information.
|
||||
BuildJoin([]JoinInfo, Params) string
|
||||
// BuildWhere generates a WHERE clause from the given expression.
|
||||
BuildWhere(Expression, Params) string
|
||||
// BuildHaving generates a HAVING clause from the given expression.
|
||||
BuildHaving(Expression, Params) string
|
||||
// BuildOrderByAndLimit generates the ORDER BY and LIMIT clauses.
|
||||
BuildOrderByAndLimit(string, []string, int64, int64) string
|
||||
// BuildUnion generates a UNION clause from the given union information.
|
||||
BuildUnion([]UnionInfo, Params) string
|
||||
}
|
||||
|
||||
// BaseQueryBuilder provides a basic implementation of QueryBuilder.
|
||||
type BaseQueryBuilder struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
var _ QueryBuilder = &BaseQueryBuilder{}
|
||||
|
||||
// NewBaseQueryBuilder creates a new BaseQueryBuilder instance.
|
||||
func NewBaseQueryBuilder(db *DB) *BaseQueryBuilder {
|
||||
return &BaseQueryBuilder{db}
|
||||
}
|
||||
|
||||
// DB returns the DB instance associated with the query builder.
|
||||
func (q *BaseQueryBuilder) DB() *DB {
|
||||
return q.db
|
||||
}
|
||||
|
||||
// the regexp for columns and tables.
|
||||
var selectRegex = regexp.MustCompile(`(?i:\s+as\s+|\s+)([\w\-_\.]+)$`)
|
||||
|
||||
// BuildSelect generates a SELECT clause from the given selected column names.
|
||||
func (q *BaseQueryBuilder) BuildSelect(cols []string, distinct bool, option string) string {
|
||||
var s bytes.Buffer
|
||||
s.WriteString("SELECT ")
|
||||
if distinct {
|
||||
s.WriteString("DISTINCT ")
|
||||
}
|
||||
if option != "" {
|
||||
s.WriteString(option)
|
||||
s.WriteString(" ")
|
||||
}
|
||||
if len(cols) == 0 {
|
||||
s.WriteString("*")
|
||||
return s.String()
|
||||
}
|
||||
|
||||
for i, col := range cols {
|
||||
if i > 0 {
|
||||
s.WriteString(", ")
|
||||
}
|
||||
matches := selectRegex.FindStringSubmatch(col)
|
||||
if len(matches) == 0 {
|
||||
s.WriteString(q.db.QuoteColumnName(col))
|
||||
} else {
|
||||
col := col[:len(col)-len(matches[0])]
|
||||
alias := matches[1]
|
||||
s.WriteString(q.db.QuoteColumnName(col) + " AS " + q.db.QuoteSimpleColumnName(alias))
|
||||
}
|
||||
}
|
||||
|
||||
return s.String()
|
||||
}
|
||||
|
||||
// BuildFrom generates a FROM clause from the given tables.
|
||||
func (q *BaseQueryBuilder) BuildFrom(tables []string) string {
|
||||
if len(tables) == 0 {
|
||||
return ""
|
||||
}
|
||||
s := ""
|
||||
for _, table := range tables {
|
||||
table = q.quoteTableNameAndAlias(table)
|
||||
if s == "" {
|
||||
s = table
|
||||
} else {
|
||||
s += ", " + table
|
||||
}
|
||||
}
|
||||
return "FROM " + s
|
||||
}
|
||||
|
||||
// BuildJoin generates a JOIN clause from the given join information.
|
||||
func (q *BaseQueryBuilder) BuildJoin(joins []JoinInfo, params Params) string {
|
||||
if len(joins) == 0 {
|
||||
return ""
|
||||
}
|
||||
parts := []string{}
|
||||
for _, join := range joins {
|
||||
sql := join.Join + " " + q.quoteTableNameAndAlias(join.Table)
|
||||
on := ""
|
||||
if join.On != nil {
|
||||
on = join.On.Build(q.db, params)
|
||||
}
|
||||
if on != "" {
|
||||
sql += " ON " + on
|
||||
}
|
||||
parts = append(parts, sql)
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// BuildWhere generates a WHERE clause from the given expression.
|
||||
func (q *BaseQueryBuilder) BuildWhere(e Expression, params Params) string {
|
||||
if e != nil {
|
||||
if c := e.Build(q.db, params); c != "" {
|
||||
return "WHERE " + c
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// BuildHaving generates a HAVING clause from the given expression.
|
||||
func (q *BaseQueryBuilder) BuildHaving(e Expression, params Params) string {
|
||||
if e != nil {
|
||||
if c := e.Build(q.db, params); c != "" {
|
||||
return "HAVING " + c
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// BuildGroupBy generates a GROUP BY clause from the given group-by columns.
|
||||
func (q *BaseQueryBuilder) BuildGroupBy(cols []string) string {
|
||||
if len(cols) == 0 {
|
||||
return ""
|
||||
}
|
||||
s := ""
|
||||
for i, col := range cols {
|
||||
if i == 0 {
|
||||
s = q.db.QuoteColumnName(col)
|
||||
} else {
|
||||
s += ", " + q.db.QuoteColumnName(col)
|
||||
}
|
||||
}
|
||||
return "GROUP BY " + s
|
||||
}
|
||||
|
||||
// BuildOrderByAndLimit generates the ORDER BY and LIMIT clauses.
|
||||
func (q *BaseQueryBuilder) BuildOrderByAndLimit(sql string, cols []string, limit int64, offset int64) string {
|
||||
if orderBy := q.BuildOrderBy(cols); orderBy != "" {
|
||||
sql += " " + orderBy
|
||||
}
|
||||
if limit := q.BuildLimit(limit, offset); limit != "" {
|
||||
return sql + " " + limit
|
||||
}
|
||||
return sql
|
||||
}
|
||||
|
||||
// BuildUnion generates a UNION clause from the given union information.
|
||||
func (q *BaseQueryBuilder) BuildUnion(unions []UnionInfo, params Params) string {
|
||||
if len(unions) == 0 {
|
||||
return ""
|
||||
}
|
||||
sql := ""
|
||||
for i, union := range unions {
|
||||
if i > 0 {
|
||||
sql += " "
|
||||
}
|
||||
for k, v := range union.Query.params {
|
||||
params[k] = v
|
||||
}
|
||||
u := "UNION"
|
||||
if union.All {
|
||||
u = "UNION ALL"
|
||||
}
|
||||
sql += fmt.Sprintf("%v (%v)", u, union.Query.sql)
|
||||
}
|
||||
return sql
|
||||
}
|
||||
|
||||
var orderRegex = regexp.MustCompile(`\s+((?i)ASC|DESC)$`)
|
||||
|
||||
// BuildOrderBy generates the ORDER BY clause.
|
||||
func (q *BaseQueryBuilder) BuildOrderBy(cols []string) string {
|
||||
if len(cols) == 0 {
|
||||
return ""
|
||||
}
|
||||
s := ""
|
||||
for i, col := range cols {
|
||||
if i > 0 {
|
||||
s += ", "
|
||||
}
|
||||
matches := orderRegex.FindStringSubmatch(col)
|
||||
if len(matches) == 0 {
|
||||
s += q.db.QuoteColumnName(col)
|
||||
} else {
|
||||
col := col[:len(col)-len(matches[0])]
|
||||
dir := matches[1]
|
||||
s += q.db.QuoteColumnName(col) + " " + dir
|
||||
}
|
||||
}
|
||||
return "ORDER BY " + s
|
||||
}
|
||||
|
||||
// BuildLimit generates the LIMIT clause.
|
||||
func (q *BaseQueryBuilder) BuildLimit(limit int64, offset int64) string {
|
||||
if limit < 0 && offset > 0 {
|
||||
// most DBMS requires LIMIT when OFFSET is present
|
||||
limit = 9223372036854775807 // 2^63 - 1
|
||||
}
|
||||
|
||||
sql := ""
|
||||
if limit >= 0 {
|
||||
sql = fmt.Sprintf("LIMIT %v", limit)
|
||||
}
|
||||
if offset <= 0 {
|
||||
return sql
|
||||
}
|
||||
if sql != "" {
|
||||
sql += " "
|
||||
}
|
||||
return sql + fmt.Sprintf("OFFSET %v", offset)
|
||||
}
|
||||
|
||||
func (q *BaseQueryBuilder) quoteTableNameAndAlias(table string) string {
|
||||
matches := selectRegex.FindStringSubmatch(table)
|
||||
if len(matches) == 0 {
|
||||
return q.db.QuoteTableName(table)
|
||||
}
|
||||
table = table[:len(table)-len(matches[0])]
|
||||
return q.db.QuoteTableName(table) + " " + q.db.QuoteSimpleTableName(matches[1])
|
||||
}
|
228
query_builder_test.go
Normal file
228
query_builder_test.go
Normal file
|
@ -0,0 +1,228 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestQB_BuildSelect(t *testing.T) {
|
||||
tests := []struct {
|
||||
tag string
|
||||
cols []string
|
||||
distinct bool
|
||||
option string
|
||||
expected string
|
||||
}{
|
||||
{"empty", []string{}, false, "", "SELECT *"},
|
||||
{"empty distinct", []string{}, true, "CALC_ROWS", "SELECT DISTINCT CALC_ROWS *"},
|
||||
{"multi-columns", []string{"name", "DOB1"}, false, "", "SELECT `name`, `DOB1`"},
|
||||
{"aliased columns", []string{"name As Name", "users.last_name", "u.first1 first"}, false, "", "SELECT `name` AS `Name`, `users`.`last_name`, `u`.`first1` AS `first`"},
|
||||
}
|
||||
|
||||
db := getDB()
|
||||
qb := db.QueryBuilder()
|
||||
for _, test := range tests {
|
||||
s := qb.BuildSelect(test.cols, test.distinct, test.option)
|
||||
assert.Equal(t, test.expected, s, test.tag)
|
||||
}
|
||||
assert.Equal(t, qb.(*BaseQueryBuilder).DB(), db)
|
||||
}
|
||||
|
||||
func TestQB_BuildFrom(t *testing.T) {
|
||||
tests := []struct {
|
||||
tag string
|
||||
tables []string
|
||||
expected string
|
||||
}{
|
||||
{"empty", []string{}, ""},
|
||||
{"single table", []string{"users"}, "FROM `users`"},
|
||||
{"multiple tables", []string{"users", "posts"}, "FROM `users`, `posts`"},
|
||||
{"table alias", []string{"users u", "posts as p"}, "FROM `users` `u`, `posts` `p`"},
|
||||
{"table prefix and alias", []string{"pub.users p.u", "posts AS p1"}, "FROM `pub`.`users` `p.u`, `posts` `p1`"},
|
||||
}
|
||||
|
||||
qb := getDB().QueryBuilder()
|
||||
for _, test := range tests {
|
||||
s := qb.BuildFrom(test.tables)
|
||||
assert.Equal(t, test.expected, s, test.tag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQB_BuildGroupBy(t *testing.T) {
|
||||
tests := []struct {
|
||||
tag string
|
||||
cols []string
|
||||
expected string
|
||||
}{
|
||||
{"empty", []string{}, ""},
|
||||
{"single column", []string{"name"}, "GROUP BY `name`"},
|
||||
{"multiple columns", []string{"name", "age"}, "GROUP BY `name`, `age`"},
|
||||
}
|
||||
|
||||
qb := getDB().QueryBuilder()
|
||||
for _, test := range tests {
|
||||
s := qb.BuildGroupBy(test.cols)
|
||||
assert.Equal(t, test.expected, s, test.tag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQB_BuildWhere(t *testing.T) {
|
||||
tests := []struct {
|
||||
exp Expression
|
||||
expected string
|
||||
count int
|
||||
tag string
|
||||
}{
|
||||
{HashExp{"age": 30, "dept": "marketing"}, "WHERE `age`={:p0} AND `dept`={:p1}", 2, "t1"},
|
||||
{nil, "", 0, "t2"},
|
||||
{NewExp(""), "", 0, "t3"},
|
||||
}
|
||||
|
||||
qb := getDB().QueryBuilder()
|
||||
for _, test := range tests {
|
||||
params := Params{}
|
||||
s := qb.BuildWhere(test.exp, params)
|
||||
assert.Equal(t, test.expected, s, test.tag)
|
||||
assert.Equal(t, test.count, len(params), test.tag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQB_BuildHaving(t *testing.T) {
|
||||
tests := []struct {
|
||||
exp Expression
|
||||
expected string
|
||||
count int
|
||||
tag string
|
||||
}{
|
||||
{HashExp{"age": 30, "dept": "marketing"}, "HAVING `age`={:p0} AND `dept`={:p1}", 2, "t1"},
|
||||
{nil, "", 0, "t2"},
|
||||
{NewExp(""), "", 0, "t3"},
|
||||
}
|
||||
|
||||
qb := getDB().QueryBuilder()
|
||||
for _, test := range tests {
|
||||
params := Params{}
|
||||
s := qb.BuildHaving(test.exp, params)
|
||||
assert.Equal(t, test.expected, s, test.tag)
|
||||
assert.Equal(t, test.count, len(params), test.tag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQB_BuildOrderBy(t *testing.T) {
|
||||
tests := []struct {
|
||||
tag string
|
||||
cols []string
|
||||
expected string
|
||||
}{
|
||||
{"empty", []string{}, ""},
|
||||
{"single column", []string{"name"}, "ORDER BY `name`"},
|
||||
{"multiple columns", []string{"name ASC", "age DESC", "id desc"}, "ORDER BY `name` ASC, `age` DESC, `id` desc"},
|
||||
}
|
||||
qb := getDB().QueryBuilder().(*BaseQueryBuilder)
|
||||
for _, test := range tests {
|
||||
s := qb.BuildOrderBy(test.cols)
|
||||
assert.Equal(t, test.expected, s, test.tag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQB_BuildLimit(t *testing.T) {
|
||||
tests := []struct {
|
||||
tag string
|
||||
limit, offset int64
|
||||
expected string
|
||||
}{
|
||||
{"t1", 10, -1, "LIMIT 10"},
|
||||
{"t2", 10, 0, "LIMIT 10"},
|
||||
{"t3", 10, 2, "LIMIT 10 OFFSET 2"},
|
||||
{"t4", 0, 2, "LIMIT 0 OFFSET 2"},
|
||||
{"t5", -1, 2, "LIMIT 9223372036854775807 OFFSET 2"},
|
||||
{"t6", -1, 0, ""},
|
||||
}
|
||||
qb := getDB().QueryBuilder().(*BaseQueryBuilder)
|
||||
for _, test := range tests {
|
||||
s := qb.BuildLimit(test.limit, test.offset)
|
||||
assert.Equal(t, test.expected, s, test.tag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQB_BuildOrderByAndLimit(t *testing.T) {
|
||||
qb := getDB().QueryBuilder()
|
||||
|
||||
sql := qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, 10, 2)
|
||||
expected := "SELECT * ORDER BY `name` LIMIT 10 OFFSET 2"
|
||||
assert.Equal(t, sql, expected, "t1")
|
||||
|
||||
sql = qb.BuildOrderByAndLimit("SELECT *", nil, -1, -1)
|
||||
expected = "SELECT *"
|
||||
assert.Equal(t, sql, expected, "t2")
|
||||
|
||||
sql = qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, -1, -1)
|
||||
expected = "SELECT * ORDER BY `name`"
|
||||
assert.Equal(t, sql, expected, "t3")
|
||||
|
||||
sql = qb.BuildOrderByAndLimit("SELECT *", nil, 10, -1)
|
||||
expected = "SELECT * LIMIT 10"
|
||||
assert.Equal(t, sql, expected, "t4")
|
||||
}
|
||||
|
||||
func TestQB_BuildJoin(t *testing.T) {
|
||||
qb := getDB().QueryBuilder()
|
||||
|
||||
params := Params{}
|
||||
ji := JoinInfo{"LEFT JOIN", "users u", NewExp("id=u.id", Params{"id": 1})}
|
||||
sql := qb.BuildJoin([]JoinInfo{ji}, params)
|
||||
expected := "LEFT JOIN `users` `u` ON id=u.id"
|
||||
assert.Equal(t, sql, expected, "BuildJoin@1")
|
||||
assert.Equal(t, len(params), 1, "len(params)@1")
|
||||
|
||||
params = Params{}
|
||||
ji = JoinInfo{"INNER JOIN", "users", nil}
|
||||
sql = qb.BuildJoin([]JoinInfo{ji}, params)
|
||||
expected = "INNER JOIN `users`"
|
||||
assert.Equal(t, sql, expected, "BuildJoin@2")
|
||||
assert.Equal(t, len(params), 0, "len(params)@2")
|
||||
|
||||
sql = qb.BuildJoin([]JoinInfo{}, nil)
|
||||
expected = ""
|
||||
assert.Equal(t, sql, expected, "BuildJoin@3")
|
||||
|
||||
ji = JoinInfo{"INNER JOIN", "users", nil}
|
||||
ji2 := JoinInfo{"LEFT JOIN", "posts", nil}
|
||||
sql = qb.BuildJoin([]JoinInfo{ji, ji2}, nil)
|
||||
expected = "INNER JOIN `users` LEFT JOIN `posts`"
|
||||
assert.Equal(t, sql, expected, "BuildJoin@3")
|
||||
}
|
||||
|
||||
func TestQB_BuildUnion(t *testing.T) {
|
||||
db := getDB()
|
||||
qb := db.QueryBuilder()
|
||||
|
||||
params := Params{}
|
||||
ui := UnionInfo{false, db.NewQuery("SELECT names").Bind(Params{"id": 1})}
|
||||
sql := qb.BuildUnion([]UnionInfo{ui}, params)
|
||||
expected := "UNION (SELECT names)"
|
||||
assert.Equal(t, sql, expected, "BuildUnion@1")
|
||||
assert.Equal(t, len(params), 1, "len(params)@1")
|
||||
|
||||
params = Params{}
|
||||
ui = UnionInfo{true, db.NewQuery("SELECT names")}
|
||||
sql = qb.BuildUnion([]UnionInfo{ui}, params)
|
||||
expected = "UNION ALL (SELECT names)"
|
||||
assert.Equal(t, sql, expected, "BuildUnion@2")
|
||||
assert.Equal(t, len(params), 0, "len(params)@2")
|
||||
|
||||
sql = qb.BuildUnion([]UnionInfo{}, nil)
|
||||
expected = ""
|
||||
assert.Equal(t, sql, expected, "BuildUnion@3")
|
||||
|
||||
ui = UnionInfo{true, db.NewQuery("SELECT names")}
|
||||
ui2 := UnionInfo{false, db.NewQuery("SELECT ages")}
|
||||
sql = qb.BuildUnion([]UnionInfo{ui, ui2}, nil)
|
||||
expected = "UNION ALL (SELECT names) UNION (SELECT ages)"
|
||||
assert.Equal(t, sql, expected, "BuildUnion@4")
|
||||
}
|
600
query_test.go
Normal file
600
query_test.go
Normal file
|
@ -0,0 +1,600 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
ss "database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type City struct {
|
||||
ID int
|
||||
Name string
|
||||
}
|
||||
|
||||
func TestNewQuery(t *testing.T) {
|
||||
db := getDB()
|
||||
sql := "SELECT * FROM users WHERE id={:id}"
|
||||
q := NewQuery(db, db.sqlDB, sql)
|
||||
assert.Equal(t, q.SQL(), sql, "q.SQL()")
|
||||
assert.Equal(t, q.rawSQL, "SELECT * FROM users WHERE id=?", "q.RawSQL()")
|
||||
|
||||
assert.Equal(t, len(q.Params()), 0, "len(q.Params())@1")
|
||||
q.Bind(Params{"id": 1})
|
||||
assert.Equal(t, len(q.Params()), 1, "len(q.Params())@2")
|
||||
}
|
||||
|
||||
func TestQuery_Execute(t *testing.T) {
|
||||
db := getPreparedDB()
|
||||
defer db.Close()
|
||||
|
||||
result, err := db.NewQuery("INSERT INTO item (name) VALUES ('test')").Execute()
|
||||
if assert.Nil(t, err) {
|
||||
rows, _ := result.RowsAffected()
|
||||
assert.Equal(t, rows, int64(1), "Result.RowsAffected()")
|
||||
lastID, _ := result.LastInsertId()
|
||||
assert.Equal(t, lastID, int64(6), "Result.LastInsertId()")
|
||||
}
|
||||
}
|
||||
|
||||
type Customer struct {
|
||||
scanned bool
|
||||
|
||||
ID int
|
||||
Email string
|
||||
Status int
|
||||
Name string
|
||||
Address ss.NullString
|
||||
}
|
||||
|
||||
func (m Customer) TableName() string {
|
||||
return "customer"
|
||||
}
|
||||
|
||||
func (m *Customer) PostScan() error {
|
||||
m.scanned = true
|
||||
return nil
|
||||
}
|
||||
|
||||
type CustomerPtr struct {
|
||||
ID *int `db:"pk"`
|
||||
Email *string
|
||||
Status *int
|
||||
Name string
|
||||
Address *string
|
||||
}
|
||||
|
||||
func (m CustomerPtr) TableName() string {
|
||||
return "customer"
|
||||
}
|
||||
|
||||
type CustomerNull struct {
|
||||
ID ss.NullInt64 `db:"pk,id"`
|
||||
Email ss.NullString
|
||||
Status *ss.NullInt64
|
||||
Name string
|
||||
Address ss.NullString
|
||||
}
|
||||
|
||||
func (m CustomerNull) TableName() string {
|
||||
return "customer"
|
||||
}
|
||||
|
||||
type CustomerEmbedded struct {
|
||||
Id int
|
||||
Email *string
|
||||
InnerCustomer
|
||||
}
|
||||
|
||||
func (m CustomerEmbedded) TableName() string {
|
||||
return "customer"
|
||||
}
|
||||
|
||||
type CustomerEmbedded2 struct {
|
||||
ID int
|
||||
Email *string
|
||||
Inner InnerCustomer
|
||||
}
|
||||
|
||||
type InnerCustomer struct {
|
||||
Status ss.NullInt64
|
||||
Name *string
|
||||
Address ss.NullString
|
||||
}
|
||||
|
||||
func TestQuery_Rows(t *testing.T) {
|
||||
db := getPreparedDB()
|
||||
defer db.Close()
|
||||
|
||||
var (
|
||||
sql string
|
||||
err error
|
||||
)
|
||||
|
||||
// Query.All()
|
||||
var customers []Customer
|
||||
sql = `SELECT * FROM customer ORDER BY id`
|
||||
err = db.NewQuery(sql).All(&customers)
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, len(customers), 3, "len(customers)")
|
||||
assert.Equal(t, customers[2].ID, 3, "customers[2].ID")
|
||||
assert.Equal(t, customers[2].Email, `user3@example.com`, "customers[2].Email")
|
||||
assert.Equal(t, customers[2].Status, 2, "customers[2].Status")
|
||||
assert.Equal(t, customers[0].scanned, true, "customers[0].scanned")
|
||||
assert.Equal(t, customers[1].scanned, true, "customers[1].scanned")
|
||||
assert.Equal(t, customers[2].scanned, true, "customers[2].scanned")
|
||||
}
|
||||
|
||||
// Query.All() with slice of pointers
|
||||
var customersPtrSlice []*Customer
|
||||
sql = `SELECT * FROM customer ORDER BY id`
|
||||
err = db.NewQuery(sql).All(&customersPtrSlice)
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, len(customersPtrSlice), 3, "len(customersPtrSlice)")
|
||||
assert.Equal(t, customersPtrSlice[2].ID, 3, "customersPtrSlice[2].ID")
|
||||
assert.Equal(t, customersPtrSlice[2].Email, `user3@example.com`, "customersPtrSlice[2].Email")
|
||||
assert.Equal(t, customersPtrSlice[2].Status, 2, "customersPtrSlice[2].Status")
|
||||
assert.Equal(t, customersPtrSlice[0].scanned, true, "customersPtrSlice[0].scanned")
|
||||
assert.Equal(t, customersPtrSlice[1].scanned, true, "customersPtrSlice[1].scanned")
|
||||
assert.Equal(t, customersPtrSlice[2].scanned, true, "customersPtrSlice[2].scanned")
|
||||
}
|
||||
|
||||
var customers2 []NullStringMap
|
||||
err = db.NewQuery(sql).All(&customers2)
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, len(customers2), 3, "len(customers2)")
|
||||
assert.Equal(t, customers2[1]["id"].String, "2", "customers2[1][id]")
|
||||
assert.Equal(t, customers2[1]["email"].String, `user2@example.com`, "customers2[1][email]")
|
||||
assert.Equal(t, customers2[1]["status"].String, "1", "customers2[1][status]")
|
||||
}
|
||||
err = db.NewQuery(sql).All(customers)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
var customers3 []string
|
||||
err = db.NewQuery(sql).All(&customers3)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
var customers4 string
|
||||
err = db.NewQuery(sql).All(&customers4)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
var customers5 []Customer
|
||||
err = db.NewQuery(`SELECT * FROM customer WHERE id=999`).All(&customers5)
|
||||
if assert.Nil(t, err) {
|
||||
assert.NotNil(t, customers5)
|
||||
assert.Zero(t, len(customers5))
|
||||
}
|
||||
|
||||
// One
|
||||
var customer Customer
|
||||
sql = `SELECT * FROM customer WHERE id={:id}`
|
||||
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customer)
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, customer.ID, 2, "customer.ID")
|
||||
assert.Equal(t, customer.Email, `user2@example.com`, "customer.Email")
|
||||
assert.Equal(t, customer.Status, 1, "customer.Status")
|
||||
}
|
||||
|
||||
var customerPtr2 CustomerPtr
|
||||
sql = `SELECT id, email, address FROM customer WHERE id=2`
|
||||
rows2, err := db.sqlDB.Query(sql)
|
||||
defer rows2.Close()
|
||||
assert.Nil(t, err)
|
||||
rows2.Next()
|
||||
err = rows2.Scan(&customerPtr2.ID, &customerPtr2.Email, &customerPtr2.Address)
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, *customerPtr2.ID, 2, "customer.ID")
|
||||
assert.Equal(t, *customerPtr2.Email, `user2@example.com`)
|
||||
assert.Nil(t, customerPtr2.Address)
|
||||
}
|
||||
|
||||
// struct fields are pointers
|
||||
var customerPtr CustomerPtr
|
||||
sql = `SELECT * FROM customer WHERE id={:id}`
|
||||
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customerPtr)
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, *customerPtr.ID, 2, "customer.ID")
|
||||
assert.Equal(t, *customerPtr.Email, `user2@example.com`, "customer.Email")
|
||||
assert.Equal(t, *customerPtr.Status, 1, "customer.Status")
|
||||
}
|
||||
|
||||
// struct fields are null types
|
||||
var customerNull CustomerNull
|
||||
sql = `SELECT * FROM customer WHERE id={:id}`
|
||||
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customerNull)
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, customerNull.ID.Int64, int64(2), "customer.ID")
|
||||
assert.Equal(t, customerNull.Email.String, `user2@example.com`, "customer.Email")
|
||||
assert.Equal(t, customerNull.Status.Int64, int64(1), "customer.Status")
|
||||
}
|
||||
|
||||
// embedded with anonymous struct
|
||||
var customerEmbedded CustomerEmbedded
|
||||
sql = `SELECT * FROM customer WHERE id={:id}`
|
||||
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customerEmbedded)
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, customerEmbedded.Id, 2, "customer.ID")
|
||||
assert.Equal(t, *customerEmbedded.Email, `user2@example.com`, "customer.Email")
|
||||
assert.Equal(t, customerEmbedded.Status.Int64, int64(1), "customer.Status")
|
||||
}
|
||||
|
||||
// embedded with named struct
|
||||
var customerEmbedded2 CustomerEmbedded2
|
||||
sql = `SELECT id, email, status as "inner.status" FROM customer WHERE id={:id}`
|
||||
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customerEmbedded2)
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, customerEmbedded2.ID, 2, "customer.ID")
|
||||
assert.Equal(t, *customerEmbedded2.Email, `user2@example.com`, "customer.Email")
|
||||
assert.Equal(t, customerEmbedded2.Inner.Status.Int64, int64(1), "customer.Status")
|
||||
}
|
||||
|
||||
customer2 := NullStringMap{}
|
||||
sql = `SELECT * FROM customer WHERE id={:id}`
|
||||
err = db.NewQuery(sql).Bind(Params{"id": 1}).One(customer2)
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, customer2["id"].String, "1", "customer2[id]")
|
||||
assert.Equal(t, customer2["email"].String, `user1@example.com`, "customer2[email]")
|
||||
assert.Equal(t, customer2["status"].String, "1", "customer2[status]")
|
||||
}
|
||||
|
||||
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(customer)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
var customer3 NullStringMap
|
||||
err = db.NewQuery(sql).Bind(Params{"id": 2}).One(customer3)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
err = db.NewQuery(sql).Bind(Params{"id": 1}).One(&customer3)
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, customer3["id"].String, "1", "customer3[id]")
|
||||
}
|
||||
|
||||
// Rows
|
||||
sql = `SELECT * FROM customer ORDER BY id DESC`
|
||||
rows, err := db.NewQuery(sql).Rows()
|
||||
if assert.Nil(t, err) {
|
||||
s := ""
|
||||
for rows.Next() {
|
||||
rows.ScanStruct(&customer)
|
||||
s += customer.Email + ","
|
||||
}
|
||||
assert.Equal(t, s, "user3@example.com,user2@example.com,user1@example.com,", "Rows().Next()")
|
||||
}
|
||||
|
||||
// FieldMapper
|
||||
var a struct {
|
||||
MyID string `db:"id"`
|
||||
name string
|
||||
}
|
||||
sql = `SELECT * FROM customer WHERE id=2`
|
||||
err = db.NewQuery(sql).One(&a)
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, a.MyID, "2", "a.MyID")
|
||||
// unexported field is not populated
|
||||
assert.Equal(t, a.name, "", "a.name")
|
||||
}
|
||||
|
||||
// prepared statement
|
||||
sql = `SELECT * FROM customer WHERE id={:id}`
|
||||
q := db.NewQuery(sql).Prepare()
|
||||
q.Bind(Params{"id": 1}).One(&customer)
|
||||
assert.Equal(t, customer.ID, 1, "prepared@1")
|
||||
err = q.Bind(Params{"id": 20}).One(&customer)
|
||||
assert.Equal(t, err, ss.ErrNoRows, "prepared@2")
|
||||
q.Bind(Params{"id": 3}).One(&customer)
|
||||
assert.Equal(t, customer.ID, 3, "prepared@3")
|
||||
|
||||
sql = `SELECT name FROM customer WHERE id={:id}`
|
||||
var name string
|
||||
q = db.NewQuery(sql).Prepare()
|
||||
q.Bind(Params{"id": 1}).Row(&name)
|
||||
assert.Equal(t, name, "user1", "prepared2@1")
|
||||
err = q.Bind(Params{"id": 20}).Row(&name)
|
||||
assert.Equal(t, err, ss.ErrNoRows, "prepared2@2")
|
||||
q.Bind(Params{"id": 3}).Row(&name)
|
||||
assert.Equal(t, name, "user3", "prepared2@3")
|
||||
|
||||
// Query.LastError
|
||||
sql = `SELECT * FROM a`
|
||||
q = db.NewQuery(sql).Prepare()
|
||||
customer.ID = 100
|
||||
err = q.Bind(Params{"id": 1}).One(&customer)
|
||||
assert.NotEqual(t, err, nil, "LastError@0")
|
||||
assert.Equal(t, customer.ID, 100, "LastError@1")
|
||||
assert.Equal(t, q.LastError, nil, "LastError@2")
|
||||
|
||||
// Query.Column
|
||||
sql = `SELECT name, id FROM customer ORDER BY id`
|
||||
var names []string
|
||||
err = db.NewQuery(sql).Column(&names)
|
||||
if assert.Nil(t, err) && assert.Equal(t, 3, len(names)) {
|
||||
assert.Equal(t, "user1", names[0])
|
||||
assert.Equal(t, "user2", names[1])
|
||||
assert.Equal(t, "user3", names[2])
|
||||
}
|
||||
err = db.NewQuery(sql).Column(names)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestQuery_logSQL(t *testing.T) {
|
||||
db := getDB()
|
||||
q := db.NewQuery("SELECT * FROM users WHERE type={:type} AND id={:id} AND bytes={:bytes}").Bind(Params{
|
||||
"id": 1,
|
||||
"type": "a",
|
||||
"bytes": []byte("test"),
|
||||
})
|
||||
expected := "SELECT * FROM users WHERE type='a' AND id=1 AND bytes=0x74657374"
|
||||
assert.Equal(t, q.logSQL(), expected, "logSQL()")
|
||||
}
|
||||
|
||||
func TestReplacePlaceholders(t *testing.T) {
|
||||
tests := []struct {
|
||||
ID string
|
||||
Placeholders []string
|
||||
Params Params
|
||||
ExpectedParams string
|
||||
HasError bool
|
||||
}{
|
||||
{"t1", nil, nil, "null", false},
|
||||
{"t2", []string{"id", "name"}, Params{"id": 1, "name": "xyz"}, `[1,"xyz"]`, false},
|
||||
{"t3", []string{"id", "name"}, Params{"id": 1}, `null`, true},
|
||||
{"t4", []string{"id", "name"}, Params{"id": 1, "name": "xyz", "age": 30}, `[1,"xyz"]`, false},
|
||||
}
|
||||
for _, test := range tests {
|
||||
params, err := replacePlaceholders(test.Placeholders, test.Params)
|
||||
result, _ := json.Marshal(params)
|
||||
assert.Equal(t, string(result), test.ExpectedParams, "params@"+test.ID)
|
||||
assert.Equal(t, err != nil, test.HasError, "error@"+test.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIssue6(t *testing.T) {
|
||||
db := getPreparedDB()
|
||||
q := db.Select("*").From("customer").Where(HashExp{"id": 1})
|
||||
var customer Customer
|
||||
assert.Equal(t, q.One(&customer), nil)
|
||||
assert.Equal(t, 1, customer.ID)
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int64
|
||||
Email string
|
||||
Created time.Time
|
||||
Updated *time.Time
|
||||
}
|
||||
|
||||
func TestIssue13(t *testing.T) {
|
||||
db := getPreparedDB()
|
||||
var user User
|
||||
err := db.Select().From("user").Where(HashExp{"id": 1}).One(&user)
|
||||
if assert.Nil(t, err) {
|
||||
assert.NotZero(t, user.Created)
|
||||
assert.Nil(t, user.Updated)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
user2 := User{
|
||||
Email: "now@example.com",
|
||||
Created: now,
|
||||
}
|
||||
err = db.Model(&user2).Insert()
|
||||
if assert.Nil(t, err) {
|
||||
assert.NotZero(t, user2.ID)
|
||||
}
|
||||
|
||||
user3 := User{
|
||||
Email: "now@example.com",
|
||||
Created: now,
|
||||
Updated: &now,
|
||||
}
|
||||
err = db.Model(&user3).Insert()
|
||||
if assert.Nil(t, err) {
|
||||
assert.NotZero(t, user2.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryWithExecHook(t *testing.T) {
|
||||
db := getPreparedDB()
|
||||
defer db.Close()
|
||||
|
||||
// error return
|
||||
{
|
||||
err := db.NewQuery("select * from user").
|
||||
WithExecHook(func(q *Query, op func() error) error {
|
||||
return errors.New("test")
|
||||
}).
|
||||
Row()
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// Row()
|
||||
{
|
||||
calls := 0
|
||||
err := db.NewQuery("select * from user").
|
||||
WithExecHook(func(q *Query, op func() error) error {
|
||||
calls++
|
||||
return nil
|
||||
}).
|
||||
Row()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, calls, "Row()")
|
||||
}
|
||||
|
||||
// One()
|
||||
{
|
||||
calls := 0
|
||||
err := db.NewQuery("select * from user").
|
||||
WithExecHook(func(q *Query, op func() error) error {
|
||||
calls++
|
||||
return nil
|
||||
}).
|
||||
One(nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, calls, "One()")
|
||||
}
|
||||
|
||||
// All()
|
||||
{
|
||||
calls := 0
|
||||
err := db.NewQuery("select * from user").
|
||||
WithExecHook(func(q *Query, op func() error) error {
|
||||
calls++
|
||||
return nil
|
||||
}).
|
||||
All(nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, calls, "All()")
|
||||
}
|
||||
|
||||
// Column()
|
||||
{
|
||||
calls := 0
|
||||
err := db.NewQuery("select * from user").
|
||||
WithExecHook(func(q *Query, op func() error) error {
|
||||
calls++
|
||||
return nil
|
||||
}).
|
||||
Column(nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, calls, "Column()")
|
||||
}
|
||||
|
||||
// Execute()
|
||||
{
|
||||
calls := 0
|
||||
_, err := db.NewQuery("select * from user").
|
||||
WithExecHook(func(q *Query, op func() error) error {
|
||||
calls++
|
||||
return nil
|
||||
}).
|
||||
Execute()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, calls, "Execute()")
|
||||
}
|
||||
|
||||
// op call
|
||||
{
|
||||
calls := 0
|
||||
var id int
|
||||
err := db.NewQuery("select id from user where id = 2").
|
||||
WithExecHook(func(q *Query, op func() error) error {
|
||||
calls++
|
||||
return op()
|
||||
}).
|
||||
Row(&id)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, calls, "op hook calls")
|
||||
assert.Equal(t, 2, id, "id mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryWithOneHook(t *testing.T) {
|
||||
db := getPreparedDB()
|
||||
defer db.Close()
|
||||
|
||||
// error return
|
||||
{
|
||||
err := db.NewQuery("select * from user").
|
||||
WithOneHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
|
||||
return errors.New("test")
|
||||
}).
|
||||
One(nil)
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// hooks call order
|
||||
{
|
||||
hookCalls := []string{}
|
||||
err := db.NewQuery("select * from user").
|
||||
WithExecHook(func(q *Query, op func() error) error {
|
||||
hookCalls = append(hookCalls, "exec")
|
||||
return op()
|
||||
}).
|
||||
WithOneHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
|
||||
hookCalls = append(hookCalls, "one")
|
||||
return nil
|
||||
}).
|
||||
One(nil)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, hookCalls, []string{"exec", "one"})
|
||||
}
|
||||
|
||||
// op call
|
||||
{
|
||||
calls := 0
|
||||
other := User{}
|
||||
err := db.NewQuery("select id from user where id = 2").
|
||||
WithOneHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
|
||||
calls++
|
||||
return op(&other)
|
||||
}).
|
||||
One(nil)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, calls, "hook calls")
|
||||
assert.Equal(t, int64(2), other.ID, "replaced scan struct")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryWithAllHook(t *testing.T) {
|
||||
db := getPreparedDB()
|
||||
defer db.Close()
|
||||
|
||||
// error return
|
||||
{
|
||||
err := db.NewQuery("select * from user").
|
||||
WithAllHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
|
||||
return errors.New("test")
|
||||
}).
|
||||
All(nil)
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// hooks call order
|
||||
{
|
||||
hookCalls := []string{}
|
||||
err := db.NewQuery("select * from user").
|
||||
WithExecHook(func(q *Query, op func() error) error {
|
||||
hookCalls = append(hookCalls, "exec")
|
||||
return op()
|
||||
}).
|
||||
WithAllHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
|
||||
hookCalls = append(hookCalls, "all")
|
||||
return nil
|
||||
}).
|
||||
All(nil)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, hookCalls, []string{"exec", "all"})
|
||||
}
|
||||
|
||||
// op call
|
||||
{
|
||||
calls := 0
|
||||
other := []User{}
|
||||
err := db.NewQuery("select id from user order by id asc").
|
||||
WithAllHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
|
||||
calls++
|
||||
return op(&other)
|
||||
}).
|
||||
All(nil)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, calls, "hook calls")
|
||||
assert.Equal(t, 2, len(other), "users length")
|
||||
assert.Equal(t, int64(1), other[0].ID, "user 1 id check")
|
||||
assert.Equal(t, int64(2), other[1].ID, "user 2 id check")
|
||||
}
|
||||
}
|
301
rows.go
Normal file
301
rows.go
Normal file
|
@ -0,0 +1,301 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// VarTypeError indicates a variable type error when trying to populating a variable with DB result.
|
||||
type VarTypeError string
|
||||
|
||||
// Error returns the error message.
|
||||
func (s VarTypeError) Error() string {
|
||||
return "Invalid variable type: " + string(s)
|
||||
}
|
||||
|
||||
// NullStringMap is a map of sql.NullString that can be used to hold DB query result.
|
||||
// The map keys correspond to the DB column names, while the map values are their corresponding column values.
|
||||
type NullStringMap map[string]sql.NullString
|
||||
|
||||
// Rows enhances sql.Rows by providing additional data query methods.
|
||||
// Rows can be obtained by calling Query.Rows(). It is mainly used to populate data row by row.
|
||||
type Rows struct {
|
||||
*sql.Rows
|
||||
fieldMapFunc FieldMapFunc
|
||||
}
|
||||
|
||||
// ScanMap populates the current row of data into a NullStringMap.
|
||||
// Note that the NullStringMap must not be nil, or it will panic.
|
||||
// The NullStringMap will be populated using column names as keys and their values as
|
||||
// the corresponding element values.
|
||||
func (r *Rows) ScanMap(a NullStringMap) error {
|
||||
cols, _ := r.Columns()
|
||||
var refs []interface{}
|
||||
for i := 0; i < len(cols); i++ {
|
||||
var t sql.NullString
|
||||
refs = append(refs, &t)
|
||||
}
|
||||
if err := r.Scan(refs...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, col := range cols {
|
||||
a[col] = *refs[i].(*sql.NullString)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ScanStruct populates the current row of data into a struct.
|
||||
// The struct must be given as a pointer.
|
||||
//
|
||||
// ScanStruct associates struct fields with DB table columns through a field mapping function.
|
||||
// It populates a struct field with the data of its associated column.
|
||||
// Note that only exported struct fields will be populated.
|
||||
//
|
||||
// By default, DefaultFieldMapFunc() is used to map struct fields to table columns.
|
||||
// This function separates each word in a field name with a underscore and turns every letter into lower case.
|
||||
// For example, "LastName" is mapped to "last_name", "MyID" is mapped to "my_id", and so on.
|
||||
// To change the default behavior, set DB.FieldMapper with your custom mapping function.
|
||||
// You may also set Query.FieldMapper to change the behavior for particular queries.
|
||||
func (r *Rows) ScanStruct(a interface{}) error {
|
||||
rv := reflect.ValueOf(a)
|
||||
if rv.Kind() != reflect.Ptr || rv.IsNil() {
|
||||
return VarTypeError("must be a pointer")
|
||||
}
|
||||
rv = indirect(rv)
|
||||
if rv.Kind() != reflect.Struct {
|
||||
return VarTypeError("must be a pointer to a struct")
|
||||
}
|
||||
|
||||
si := getStructInfo(rv.Type(), r.fieldMapFunc)
|
||||
|
||||
cols, _ := r.Columns()
|
||||
refs := make([]interface{}, len(cols))
|
||||
|
||||
for i, col := range cols {
|
||||
if fi, ok := si.dbNameMap[col]; ok {
|
||||
refs[i] = fi.getField(rv).Addr().Interface()
|
||||
} else {
|
||||
refs[i] = &sql.NullString{}
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.Scan(refs...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check for PostScanner
|
||||
if rv.CanAddr() {
|
||||
addr := rv.Addr()
|
||||
if addr.CanInterface() {
|
||||
if ps, ok := addr.Interface().(PostScanner); ok {
|
||||
if err := ps.PostScan(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// all populates all rows of query result into a slice of struct or NullStringMap.
|
||||
// Note that the slice must be given as a pointer.
|
||||
func (r *Rows) all(slice interface{}) error {
|
||||
defer r.Close()
|
||||
|
||||
v := reflect.ValueOf(slice)
|
||||
if v.Kind() != reflect.Ptr || v.IsNil() {
|
||||
return VarTypeError("must be a pointer")
|
||||
}
|
||||
v = indirect(v)
|
||||
|
||||
if v.Kind() != reflect.Slice {
|
||||
return VarTypeError("must be a slice of struct or NullStringMap")
|
||||
}
|
||||
|
||||
if v.IsNil() {
|
||||
// create an empty slice
|
||||
v.Set(reflect.MakeSlice(v.Type(), 0, 0))
|
||||
}
|
||||
|
||||
et := v.Type().Elem()
|
||||
|
||||
if et.Kind() == reflect.Map {
|
||||
for r.Next() {
|
||||
ev, ok := reflect.MakeMap(et).Interface().(NullStringMap)
|
||||
if !ok {
|
||||
return VarTypeError("must be a slice of struct or NullStringMap")
|
||||
}
|
||||
if err := r.ScanMap(ev); err != nil {
|
||||
return err
|
||||
}
|
||||
v.Set(reflect.Append(v, reflect.ValueOf(ev)))
|
||||
}
|
||||
return r.Close()
|
||||
}
|
||||
|
||||
var isSliceOfPointers bool
|
||||
if et.Kind() == reflect.Ptr {
|
||||
isSliceOfPointers = true
|
||||
et = et.Elem()
|
||||
}
|
||||
|
||||
if et.Kind() != reflect.Struct {
|
||||
return VarTypeError("must be a slice of struct or NullStringMap")
|
||||
}
|
||||
|
||||
etPtr := reflect.PtrTo(et)
|
||||
implementsPostScanner := etPtr.Implements(postScannerType)
|
||||
|
||||
si := getStructInfo(et, r.fieldMapFunc)
|
||||
|
||||
cols, _ := r.Columns()
|
||||
for r.Next() {
|
||||
ev := reflect.New(et).Elem()
|
||||
refs := make([]interface{}, len(cols))
|
||||
for i, col := range cols {
|
||||
if fi, ok := si.dbNameMap[col]; ok {
|
||||
refs[i] = fi.getField(ev).Addr().Interface()
|
||||
} else {
|
||||
refs[i] = &sql.NullString{}
|
||||
}
|
||||
}
|
||||
if err := r.Scan(refs...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isSliceOfPointers {
|
||||
ev = ev.Addr()
|
||||
}
|
||||
|
||||
// check for PostScanner
|
||||
if implementsPostScanner {
|
||||
evAddr := ev
|
||||
if ev.CanAddr() {
|
||||
evAddr = ev.Addr()
|
||||
}
|
||||
if evAddr.CanInterface() {
|
||||
if ps, ok := evAddr.Interface().(PostScanner); ok {
|
||||
if err := ps.PostScan(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
v.Set(reflect.Append(v, ev))
|
||||
}
|
||||
|
||||
return r.Close()
|
||||
}
|
||||
|
||||
// column populates the given slice with the first column of the query result.
|
||||
// Note that the slice must be given as a pointer.
|
||||
func (r *Rows) column(slice interface{}) error {
|
||||
defer r.Close()
|
||||
|
||||
v := reflect.ValueOf(slice)
|
||||
if v.Kind() != reflect.Ptr || v.IsNil() {
|
||||
return VarTypeError("must be a pointer to a slice")
|
||||
}
|
||||
v = indirect(v)
|
||||
|
||||
if v.Kind() != reflect.Slice {
|
||||
return VarTypeError("must be a pointer to a slice")
|
||||
}
|
||||
|
||||
et := v.Type().Elem()
|
||||
|
||||
cols, _ := r.Columns()
|
||||
for r.Next() {
|
||||
ev := reflect.New(et)
|
||||
refs := make([]interface{}, len(cols))
|
||||
for i := range cols {
|
||||
if i == 0 {
|
||||
refs[i] = ev.Interface()
|
||||
} else {
|
||||
refs[i] = &sql.NullString{}
|
||||
}
|
||||
}
|
||||
if err := r.Scan(refs...); err != nil {
|
||||
return err
|
||||
}
|
||||
v.Set(reflect.Append(v, ev.Elem()))
|
||||
}
|
||||
|
||||
return r.Close()
|
||||
}
|
||||
|
||||
// one populates a single row of query result into a struct or a NullStringMap.
|
||||
// Note that if a struct is given, it should be a pointer.
|
||||
func (r *Rows) one(a interface{}) error {
|
||||
defer r.Close()
|
||||
|
||||
if !r.Next() {
|
||||
if err := r.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
rt := reflect.TypeOf(a)
|
||||
if rt.Kind() == reflect.Ptr && rt.Elem().Kind() == reflect.Map {
|
||||
// pointer to map
|
||||
v := indirect(reflect.ValueOf(a))
|
||||
if v.IsNil() {
|
||||
v.Set(reflect.MakeMap(v.Type()))
|
||||
}
|
||||
a = v.Interface()
|
||||
rt = reflect.TypeOf(a)
|
||||
}
|
||||
|
||||
if rt.Kind() == reflect.Map {
|
||||
v, ok := a.(NullStringMap)
|
||||
if !ok {
|
||||
return VarTypeError("must be a NullStringMap")
|
||||
}
|
||||
if v == nil {
|
||||
return VarTypeError("NullStringMap is nil")
|
||||
}
|
||||
err = r.ScanMap(v)
|
||||
} else {
|
||||
err = r.ScanStruct(a)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return r.Close()
|
||||
}
|
||||
|
||||
// row populates a single row of query result into a list of variables.
|
||||
func (r *Rows) row(a ...interface{}) error {
|
||||
defer r.Close()
|
||||
|
||||
for _, dp := range a {
|
||||
if _, ok := dp.(*sql.RawBytes); ok {
|
||||
return VarTypeError("RawBytes isn't allowed on Row()")
|
||||
}
|
||||
}
|
||||
|
||||
if !r.Next() {
|
||||
if err := r.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
if err := r.Scan(a...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return r.Close()
|
||||
}
|
445
select.go
Normal file
445
select.go
Normal file
|
@ -0,0 +1,445 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// BuildHookFunc defines a callback function that is executed on Query creation.
|
||||
type BuildHookFunc func(q *Query)
|
||||
|
||||
// SelectQuery represents a DB-agnostic SELECT query.
|
||||
// It can be built into a DB-specific query by calling the Build() method.
|
||||
type SelectQuery struct {
|
||||
// FieldMapper maps struct field names to DB column names.
|
||||
FieldMapper FieldMapFunc
|
||||
// TableMapper maps structs to DB table names.
|
||||
TableMapper TableMapFunc
|
||||
|
||||
builder Builder
|
||||
ctx context.Context
|
||||
buildHook BuildHookFunc
|
||||
|
||||
preFragment string
|
||||
postFragment string
|
||||
selects []string
|
||||
distinct bool
|
||||
selectOption string
|
||||
from []string
|
||||
where Expression
|
||||
join []JoinInfo
|
||||
orderBy []string
|
||||
groupBy []string
|
||||
having Expression
|
||||
union []UnionInfo
|
||||
limit int64
|
||||
offset int64
|
||||
params Params
|
||||
}
|
||||
|
||||
// JoinInfo contains the specification for a JOIN clause.
|
||||
type JoinInfo struct {
|
||||
Join string
|
||||
Table string
|
||||
On Expression
|
||||
}
|
||||
|
||||
// UnionInfo contains the specification for a UNION clause.
|
||||
type UnionInfo struct {
|
||||
All bool
|
||||
Query *Query
|
||||
}
|
||||
|
||||
// NewSelectQuery creates a new SelectQuery instance.
|
||||
func NewSelectQuery(builder Builder, db *DB) *SelectQuery {
|
||||
return &SelectQuery{
|
||||
builder: builder,
|
||||
selects: []string{},
|
||||
from: []string{},
|
||||
join: []JoinInfo{},
|
||||
orderBy: []string{},
|
||||
groupBy: []string{},
|
||||
union: []UnionInfo{},
|
||||
limit: -1,
|
||||
params: Params{},
|
||||
ctx: db.ctx,
|
||||
FieldMapper: db.FieldMapper,
|
||||
TableMapper: db.TableMapper,
|
||||
}
|
||||
}
|
||||
|
||||
// WithBuildHook runs the provided hook function with the query created on Build().
|
||||
func (q *SelectQuery) WithBuildHook(fn BuildHookFunc) *SelectQuery {
|
||||
q.buildHook = fn
|
||||
return q
|
||||
}
|
||||
|
||||
// Context returns the context associated with the query.
|
||||
func (q *SelectQuery) Context() context.Context {
|
||||
return q.ctx
|
||||
}
|
||||
|
||||
// WithContext associates a context with the query.
|
||||
func (q *SelectQuery) WithContext(ctx context.Context) *SelectQuery {
|
||||
q.ctx = ctx
|
||||
return q
|
||||
}
|
||||
|
||||
// PreFragment sets SQL fragment that should be prepended before the select query (e.g. WITH clause).
|
||||
func (s *SelectQuery) PreFragment(fragment string) *SelectQuery {
|
||||
s.preFragment = fragment
|
||||
return s
|
||||
}
|
||||
|
||||
// PostFragment sets SQL fragment that should be appended at the end of the select query.
|
||||
func (s *SelectQuery) PostFragment(fragment string) *SelectQuery {
|
||||
s.postFragment = fragment
|
||||
return s
|
||||
}
|
||||
|
||||
// Select specifies the columns to be selected.
|
||||
// Column names will be automatically quoted.
|
||||
func (s *SelectQuery) Select(cols ...string) *SelectQuery {
|
||||
s.selects = cols
|
||||
return s
|
||||
}
|
||||
|
||||
// AndSelect adds additional columns to be selected.
|
||||
// Column names will be automatically quoted.
|
||||
func (s *SelectQuery) AndSelect(cols ...string) *SelectQuery {
|
||||
s.selects = append(s.selects, cols...)
|
||||
return s
|
||||
}
|
||||
|
||||
// Distinct specifies whether to select columns distinctively.
|
||||
// By default, distinct is false.
|
||||
func (s *SelectQuery) Distinct(v bool) *SelectQuery {
|
||||
s.distinct = v
|
||||
return s
|
||||
}
|
||||
|
||||
// SelectOption specifies additional option that should be append to "SELECT".
|
||||
func (s *SelectQuery) SelectOption(option string) *SelectQuery {
|
||||
s.selectOption = option
|
||||
return s
|
||||
}
|
||||
|
||||
// From specifies which tables to select from.
|
||||
// Table names will be automatically quoted.
|
||||
func (s *SelectQuery) From(tables ...string) *SelectQuery {
|
||||
s.from = tables
|
||||
return s
|
||||
}
|
||||
|
||||
// Where specifies the WHERE condition.
|
||||
func (s *SelectQuery) Where(e Expression) *SelectQuery {
|
||||
s.where = e
|
||||
return s
|
||||
}
|
||||
|
||||
// AndWhere concatenates a new WHERE condition with the existing one (if any) using "AND".
|
||||
func (s *SelectQuery) AndWhere(e Expression) *SelectQuery {
|
||||
s.where = And(s.where, e)
|
||||
return s
|
||||
}
|
||||
|
||||
// OrWhere concatenates a new WHERE condition with the existing one (if any) using "OR".
|
||||
func (s *SelectQuery) OrWhere(e Expression) *SelectQuery {
|
||||
s.where = Or(s.where, e)
|
||||
return s
|
||||
}
|
||||
|
||||
// Join specifies a JOIN clause.
|
||||
// The "typ" parameter specifies the JOIN type (e.g. "INNER JOIN", "LEFT JOIN").
|
||||
func (s *SelectQuery) Join(typ string, table string, on Expression) *SelectQuery {
|
||||
s.join = append(s.join, JoinInfo{typ, table, on})
|
||||
return s
|
||||
}
|
||||
|
||||
// InnerJoin specifies an INNER JOIN clause.
|
||||
// This is a shortcut method for Join.
|
||||
func (s *SelectQuery) InnerJoin(table string, on Expression) *SelectQuery {
|
||||
return s.Join("INNER JOIN", table, on)
|
||||
}
|
||||
|
||||
// LeftJoin specifies a LEFT JOIN clause.
|
||||
// This is a shortcut method for Join.
|
||||
func (s *SelectQuery) LeftJoin(table string, on Expression) *SelectQuery {
|
||||
return s.Join("LEFT JOIN", table, on)
|
||||
}
|
||||
|
||||
// RightJoin specifies a RIGHT JOIN clause.
|
||||
// This is a shortcut method for Join.
|
||||
func (s *SelectQuery) RightJoin(table string, on Expression) *SelectQuery {
|
||||
return s.Join("RIGHT JOIN", table, on)
|
||||
}
|
||||
|
||||
// OrderBy specifies the ORDER BY clause.
|
||||
// Column names will be properly quoted. A column name can contain "ASC" or "DESC" to indicate its ordering direction.
|
||||
func (s *SelectQuery) OrderBy(cols ...string) *SelectQuery {
|
||||
s.orderBy = cols
|
||||
return s
|
||||
}
|
||||
|
||||
// AndOrderBy appends additional columns to the existing ORDER BY clause.
|
||||
// Column names will be properly quoted. A column name can contain "ASC" or "DESC" to indicate its ordering direction.
|
||||
func (s *SelectQuery) AndOrderBy(cols ...string) *SelectQuery {
|
||||
s.orderBy = append(s.orderBy, cols...)
|
||||
return s
|
||||
}
|
||||
|
||||
// GroupBy specifies the GROUP BY clause.
|
||||
// Column names will be properly quoted.
|
||||
func (s *SelectQuery) GroupBy(cols ...string) *SelectQuery {
|
||||
s.groupBy = cols
|
||||
return s
|
||||
}
|
||||
|
||||
// AndGroupBy appends additional columns to the existing GROUP BY clause.
|
||||
// Column names will be properly quoted.
|
||||
func (s *SelectQuery) AndGroupBy(cols ...string) *SelectQuery {
|
||||
s.groupBy = append(s.groupBy, cols...)
|
||||
return s
|
||||
}
|
||||
|
||||
// Having specifies the HAVING clause.
|
||||
func (s *SelectQuery) Having(e Expression) *SelectQuery {
|
||||
s.having = e
|
||||
return s
|
||||
}
|
||||
|
||||
// AndHaving concatenates a new HAVING condition with the existing one (if any) using "AND".
|
||||
func (s *SelectQuery) AndHaving(e Expression) *SelectQuery {
|
||||
s.having = And(s.having, e)
|
||||
return s
|
||||
}
|
||||
|
||||
// OrHaving concatenates a new HAVING condition with the existing one (if any) using "OR".
|
||||
func (s *SelectQuery) OrHaving(e Expression) *SelectQuery {
|
||||
s.having = Or(s.having, e)
|
||||
return s
|
||||
}
|
||||
|
||||
// Union specifies a UNION clause.
|
||||
func (s *SelectQuery) Union(q *Query) *SelectQuery {
|
||||
s.union = append(s.union, UnionInfo{false, q})
|
||||
return s
|
||||
}
|
||||
|
||||
// UnionAll specifies a UNION ALL clause.
|
||||
func (s *SelectQuery) UnionAll(q *Query) *SelectQuery {
|
||||
s.union = append(s.union, UnionInfo{true, q})
|
||||
return s
|
||||
}
|
||||
|
||||
// Limit specifies the LIMIT clause.
|
||||
// A negative limit means no limit.
|
||||
func (s *SelectQuery) Limit(limit int64) *SelectQuery {
|
||||
s.limit = limit
|
||||
return s
|
||||
}
|
||||
|
||||
// Offset specifies the OFFSET clause.
|
||||
// A negative offset means no offset.
|
||||
func (s *SelectQuery) Offset(offset int64) *SelectQuery {
|
||||
s.offset = offset
|
||||
return s
|
||||
}
|
||||
|
||||
// Bind specifies the parameter values to be bound to the query.
|
||||
func (s *SelectQuery) Bind(params Params) *SelectQuery {
|
||||
s.params = params
|
||||
return s
|
||||
}
|
||||
|
||||
// AndBind appends additional parameters to be bound to the query.
|
||||
func (s *SelectQuery) AndBind(params Params) *SelectQuery {
|
||||
if len(s.params) == 0 {
|
||||
s.params = params
|
||||
} else {
|
||||
for k, v := range params {
|
||||
s.params[k] = v
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Build builds the SELECT query and returns an executable Query object.
|
||||
func (s *SelectQuery) Build() *Query {
|
||||
params := Params{}
|
||||
for k, v := range s.params {
|
||||
params[k] = v
|
||||
}
|
||||
|
||||
qb := s.builder.QueryBuilder()
|
||||
|
||||
clauses := []string{
|
||||
s.preFragment,
|
||||
qb.BuildSelect(s.selects, s.distinct, s.selectOption),
|
||||
qb.BuildFrom(s.from),
|
||||
qb.BuildJoin(s.join, params),
|
||||
qb.BuildWhere(s.where, params),
|
||||
qb.BuildGroupBy(s.groupBy),
|
||||
qb.BuildHaving(s.having, params),
|
||||
}
|
||||
|
||||
sql := ""
|
||||
for _, clause := range clauses {
|
||||
if clause != "" {
|
||||
if sql == "" {
|
||||
sql = clause
|
||||
} else {
|
||||
sql += " " + clause
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sql = qb.BuildOrderByAndLimit(sql, s.orderBy, s.limit, s.offset)
|
||||
|
||||
if s.postFragment != "" {
|
||||
sql += " " + s.postFragment
|
||||
}
|
||||
|
||||
if union := qb.BuildUnion(s.union, params); union != "" {
|
||||
sql = fmt.Sprintf("(%v) %v", sql, union)
|
||||
}
|
||||
|
||||
query := s.builder.NewQuery(sql).Bind(params).WithContext(s.ctx)
|
||||
|
||||
if s.buildHook != nil {
|
||||
s.buildHook(query)
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
// One executes the SELECT query and populates the first row of the result into the specified variable.
|
||||
//
|
||||
// If the query does not specify a "from" clause, the method will try to infer the name of the table
|
||||
// to be selected from by calling getTableName() which will return either the variable type name
|
||||
// or the TableName() method if the variable implements the TableModel interface.
|
||||
//
|
||||
// Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned.
|
||||
func (s *SelectQuery) One(a interface{}) error {
|
||||
if len(s.from) == 0 {
|
||||
if tableName := s.TableMapper(a); tableName != "" {
|
||||
s.from = []string{tableName}
|
||||
}
|
||||
}
|
||||
|
||||
return s.Build().One(a)
|
||||
}
|
||||
|
||||
// Model selects the row with the specified primary key and populates the model with the row data.
|
||||
//
|
||||
// The model variable should be a pointer to a struct. If the query does not specify a "from" clause,
|
||||
// it will use the model struct to determine which table to select data from. It will also use the model
|
||||
// to infer the name of the primary key column. Only simple primary key is supported. For composite primary keys,
|
||||
// please use Where() to specify the filtering condition.
|
||||
func (s *SelectQuery) Model(pk, model interface{}) error {
|
||||
t := reflect.TypeOf(model)
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
if t.Kind() != reflect.Struct {
|
||||
return VarTypeError("must be a pointer to a struct")
|
||||
}
|
||||
|
||||
si := getStructInfo(t, s.FieldMapper)
|
||||
if len(si.pkNames) == 1 {
|
||||
return s.AndWhere(HashExp{si.nameMap[si.pkNames[0]].dbName: pk}).One(model)
|
||||
}
|
||||
if len(si.pkNames) == 0 {
|
||||
return MissingPKError
|
||||
}
|
||||
|
||||
return CompositePKError
|
||||
}
|
||||
|
||||
// All executes the SELECT query and populates all rows of the result into a slice.
|
||||
//
|
||||
// Note that the slice must be passed in as a pointer.
|
||||
//
|
||||
// If the query does not specify a "from" clause, the method will try to infer the name of the table
|
||||
// to be selected from by calling getTableName() which will return either the type name of the slice elements
|
||||
// or the TableName() method if the slice element implements the TableModel interface.
|
||||
func (s *SelectQuery) All(slice interface{}) error {
|
||||
if len(s.from) == 0 {
|
||||
if tableName := s.TableMapper(slice); tableName != "" {
|
||||
s.from = []string{tableName}
|
||||
}
|
||||
}
|
||||
|
||||
return s.Build().All(slice)
|
||||
}
|
||||
|
||||
// Rows builds and executes the SELECT query and returns a Rows object for data retrieval purpose.
|
||||
// This is a shortcut to SelectQuery.Build().Rows()
|
||||
func (s *SelectQuery) Rows() (*Rows, error) {
|
||||
return s.Build().Rows()
|
||||
}
|
||||
|
||||
// Row builds and executes the SELECT query and populates the first row of the result into the specified variables.
|
||||
// This is a shortcut to SelectQuery.Build().Row()
|
||||
func (s *SelectQuery) Row(a ...interface{}) error {
|
||||
return s.Build().Row(a...)
|
||||
}
|
||||
|
||||
// Column builds and executes the SELECT statement and populates the first column of the result into a slice.
|
||||
// Note that the parameter must be a pointer to a slice.
|
||||
// This is a shortcut to SelectQuery.Build().Column()
|
||||
func (s *SelectQuery) Column(a interface{}) error {
|
||||
return s.Build().Column(a)
|
||||
}
|
||||
|
||||
// QueryInfo represents a debug/info struct with exported SelectQuery fields.
|
||||
type QueryInfo struct {
|
||||
PreFragment string
|
||||
PostFragment string
|
||||
Builder Builder
|
||||
Selects []string
|
||||
Distinct bool
|
||||
SelectOption string
|
||||
From []string
|
||||
Where Expression
|
||||
Join []JoinInfo
|
||||
OrderBy []string
|
||||
GroupBy []string
|
||||
Having Expression
|
||||
Union []UnionInfo
|
||||
Limit int64
|
||||
Offset int64
|
||||
Params Params
|
||||
Context context.Context
|
||||
BuildHook BuildHookFunc
|
||||
}
|
||||
|
||||
// Info exports common SelectQuery fields allowing to inspect the
|
||||
// current select query options.
|
||||
func (s *SelectQuery) Info() *QueryInfo {
|
||||
return &QueryInfo{
|
||||
Builder: s.builder,
|
||||
PreFragment: s.preFragment,
|
||||
PostFragment: s.postFragment,
|
||||
Selects: s.selects,
|
||||
Distinct: s.distinct,
|
||||
SelectOption: s.selectOption,
|
||||
From: s.from,
|
||||
Where: s.where,
|
||||
Join: s.join,
|
||||
OrderBy: s.orderBy,
|
||||
GroupBy: s.groupBy,
|
||||
Having: s.having,
|
||||
Union: s.union,
|
||||
Limit: s.limit,
|
||||
Offset: s.offset,
|
||||
Params: s.params,
|
||||
Context: s.ctx,
|
||||
BuildHook: s.buildHook,
|
||||
}
|
||||
}
|
179
select_test.go
Normal file
179
select_test.go
Normal file
|
@ -0,0 +1,179 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"database/sql"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSelectQuery(t *testing.T) {
|
||||
db := getDB()
|
||||
|
||||
// minimal select query
|
||||
q := db.Select().From("users").Build()
|
||||
expected := "SELECT * FROM `users`"
|
||||
assert.Equal(t, q.SQL(), expected, "t1")
|
||||
assert.Equal(t, len(q.Params()), 0, "t2")
|
||||
|
||||
// a full select query
|
||||
q = db.Select("id", "name").
|
||||
PreFragment("pre").
|
||||
PostFragment("post").
|
||||
AndSelect("age").
|
||||
Distinct(true).
|
||||
SelectOption("CALC").
|
||||
From("users").
|
||||
Where(NewExp("age>30")).
|
||||
AndWhere(NewExp("status=1")).
|
||||
OrWhere(NewExp("type=2")).
|
||||
InnerJoin("profile", NewExp("user.id=profile.id")).
|
||||
LeftJoin("team", nil).
|
||||
RightJoin("dept", nil).
|
||||
OrderBy("age DESC", "type").
|
||||
AndOrderBy("id").
|
||||
GroupBy("id").
|
||||
AndGroupBy("age").
|
||||
Having(NewExp("id>10")).
|
||||
AndHaving(NewExp("id<20")).
|
||||
OrHaving(NewExp("type=3")).
|
||||
Limit(10).
|
||||
Offset(20).
|
||||
Bind(Params{"id": 1}).
|
||||
AndBind(Params{"age": 30}).
|
||||
Build()
|
||||
|
||||
expected = "pre SELECT DISTINCT CALC `id`, `name`, `age` FROM `users` INNER JOIN `profile` ON user.id=profile.id LEFT JOIN `team` RIGHT JOIN `dept` WHERE ((age>30) AND (status=1)) OR (type=2) GROUP BY `id`, `age` HAVING ((id>10) AND (id<20)) OR (type=3) ORDER BY `age` DESC, `type`, `id` LIMIT 10 OFFSET 20 post"
|
||||
assert.Equal(t, q.SQL(), expected, "t3")
|
||||
assert.Equal(t, len(q.Params()), 2, "t4")
|
||||
|
||||
q3 := db.Select().AndBind(Params{"id": 1}).Build()
|
||||
assert.Equal(t, len(q3.Params()), 1)
|
||||
|
||||
// union
|
||||
q1 := db.Select().From("users").PreFragment("pre_q1").Build()
|
||||
q2 := db.Select().From("posts").PostFragment("post_q2").Build()
|
||||
q = db.Select().From("profiles").Union(q1).UnionAll(q2).PreFragment("pre").PostFragment("post").Build()
|
||||
expected = "(pre SELECT * FROM `profiles` post) UNION (pre_q1 SELECT * FROM `users`) UNION ALL (SELECT * FROM `posts` post_q2)"
|
||||
assert.Equal(t, q.SQL(), expected, "t5")
|
||||
}
|
||||
|
||||
func TestSelectQuery_Data(t *testing.T) {
|
||||
db := getPreparedDB()
|
||||
defer db.Close()
|
||||
|
||||
q := db.Select("id", "email").From("customer").OrderBy("id")
|
||||
|
||||
var customer Customer
|
||||
q.One(&customer)
|
||||
assert.Equal(t, customer.Email, "user1@example.com", "customer.Email")
|
||||
|
||||
var customers []Customer
|
||||
q.All(&customers)
|
||||
assert.Equal(t, len(customers), 3, "len(customers)")
|
||||
|
||||
rows, _ := q.Rows()
|
||||
customer.Email = ""
|
||||
rows.one(&customer)
|
||||
assert.Equal(t, customer.Email, "user1@example.com", "customer.Email")
|
||||
|
||||
var id, email string
|
||||
q.Row(&id, &email)
|
||||
assert.Equal(t, id, "1", "id")
|
||||
assert.Equal(t, email, "user1@example.com", "email")
|
||||
|
||||
var emails []string
|
||||
err := db.Select("email").From("customer").Column(&emails)
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, 3, len(emails))
|
||||
}
|
||||
|
||||
var e int
|
||||
err = db.Select().From("customer").One(&e)
|
||||
assert.NotNil(t, err)
|
||||
err = db.Select().From("customer").All(&e)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestSelectQuery_Model(t *testing.T) {
|
||||
db := getPreparedDB()
|
||||
defer db.Close()
|
||||
|
||||
{
|
||||
// One without specifying FROM
|
||||
var customer CustomerPtr
|
||||
err := db.Select().OrderBy("id").One(&customer)
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, "user1@example.com", *customer.Email)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// All without specifying FROM
|
||||
var customers []CustomerPtr
|
||||
err := db.Select().OrderBy("id").All(&customers)
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, 3, len(customers))
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// Model without specifying FROM
|
||||
var customer CustomerPtr
|
||||
err := db.Select().Model(2, &customer)
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, "user2@example.com", *customer.Email)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// Model with WHERE
|
||||
var customer CustomerPtr
|
||||
err := db.Select().Where(HashExp{"id": 1}).Model(2, &customer)
|
||||
assert.Equal(t, sql.ErrNoRows, err)
|
||||
|
||||
err = db.Select().Where(HashExp{"id": 2}).Model(2, &customer)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
{
|
||||
// errors
|
||||
var i int
|
||||
err := db.Select().Model(1, &i)
|
||||
assert.Equal(t, VarTypeError("must be a pointer to a struct"), err)
|
||||
|
||||
var a struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
err = db.Select().Model(1, &a)
|
||||
assert.Equal(t, MissingPKError, err)
|
||||
var b struct {
|
||||
ID1 string `db:"pk"`
|
||||
ID2 string `db:"pk"`
|
||||
}
|
||||
err = db.Select().Model(1, &b)
|
||||
assert.Equal(t, CompositePKError, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectWithBuildHook(t *testing.T) {
|
||||
db := getPreparedDB()
|
||||
defer db.Close()
|
||||
|
||||
var buildSQL string
|
||||
|
||||
db.Select("id").
|
||||
From("user").
|
||||
WithBuildHook(func(q *Query) {
|
||||
buildSQL = q.SQL()
|
||||
}).
|
||||
Build()
|
||||
|
||||
assert.Equal(t, "SELECT `id` FROM `user`", buildSQL)
|
||||
}
|
281
struct.go
Normal file
281
struct.go
Normal file
|
@ -0,0 +1,281 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type (
|
||||
// FieldMapFunc converts a struct field name into a DB column name.
|
||||
FieldMapFunc func(string) string
|
||||
|
||||
// TableMapFunc converts a sample struct into a DB table name.
|
||||
TableMapFunc func(a interface{}) string
|
||||
|
||||
structInfo struct {
|
||||
nameMap map[string]*fieldInfo // mapping from struct field names to field infos
|
||||
dbNameMap map[string]*fieldInfo // mapping from db column names to field infos
|
||||
pkNames []string // struct field names representing PKs
|
||||
}
|
||||
|
||||
structValue struct {
|
||||
*structInfo
|
||||
value reflect.Value // the struct value
|
||||
tableName string // the db table name for the struct
|
||||
}
|
||||
|
||||
fieldInfo struct {
|
||||
name string // field name
|
||||
dbName string // db column name
|
||||
path []int // index path to the struct field reflection
|
||||
}
|
||||
|
||||
structInfoMapKey struct {
|
||||
t reflect.Type
|
||||
m reflect.Value
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
// DbTag is the name of the struct tag used to specify the column name for the associated struct field
|
||||
DbTag = "db"
|
||||
|
||||
fieldRegex = regexp.MustCompile(`([^A-Z_])([A-Z])`)
|
||||
scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
|
||||
postScannerType = reflect.TypeOf((*PostScanner)(nil)).Elem()
|
||||
structInfoMap = make(map[structInfoMapKey]*structInfo)
|
||||
muStructInfoMap sync.Mutex
|
||||
)
|
||||
|
||||
// PostScanner is an optional interface used by ScanStruct.
|
||||
type PostScanner interface {
|
||||
// PostScan executes right after the struct has been populated
|
||||
// with the DB values, allowing you to further normalize or validate
|
||||
// the loaded data.
|
||||
PostScan() error
|
||||
}
|
||||
|
||||
// DefaultFieldMapFunc maps a field name to a DB column name.
|
||||
// The mapping rule set by this method is that words in a field name will be separated by underscores
|
||||
// and the name will be turned into lower case. For example, "FirstName" maps to "first_name", and "MyID" becomes "my_id".
|
||||
// See DB.FieldMapper for more details.
|
||||
func DefaultFieldMapFunc(f string) string {
|
||||
return strings.ToLower(fieldRegex.ReplaceAllString(f, "${1}_$2"))
|
||||
}
|
||||
|
||||
func getStructInfo(a reflect.Type, mapper FieldMapFunc) *structInfo {
|
||||
muStructInfoMap.Lock()
|
||||
defer muStructInfoMap.Unlock()
|
||||
|
||||
key := structInfoMapKey{a, reflect.ValueOf(mapper)}
|
||||
if si, ok := structInfoMap[key]; ok {
|
||||
return si
|
||||
}
|
||||
|
||||
si := &structInfo{
|
||||
nameMap: map[string]*fieldInfo{},
|
||||
dbNameMap: map[string]*fieldInfo{},
|
||||
}
|
||||
si.build(a, make([]int, 0), "", "", mapper)
|
||||
structInfoMap[key] = si
|
||||
|
||||
return si
|
||||
}
|
||||
|
||||
func newStructValue(model interface{}, fieldMapFunc FieldMapFunc, tableMapFunc TableMapFunc) *structValue {
|
||||
value := reflect.ValueOf(model)
|
||||
if value.Kind() != reflect.Ptr || value.Elem().Kind() != reflect.Struct || value.IsNil() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &structValue{
|
||||
structInfo: getStructInfo(reflect.TypeOf(model).Elem(), fieldMapFunc),
|
||||
value: value.Elem(),
|
||||
tableName: tableMapFunc(model),
|
||||
}
|
||||
}
|
||||
|
||||
// pk returns the primary key values indexed by the corresponding primary key column names.
|
||||
func (s *structValue) pk() map[string]interface{} {
|
||||
if len(s.pkNames) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.columns(s.pkNames, nil)
|
||||
}
|
||||
|
||||
// columns returns the struct field values indexed by their corresponding DB column names.
|
||||
func (s *structValue) columns(include, exclude []string) map[string]interface{} {
|
||||
v := make(map[string]interface{}, len(s.nameMap))
|
||||
if len(include) == 0 {
|
||||
for _, fi := range s.nameMap {
|
||||
v[fi.dbName] = fi.getValue(s.value)
|
||||
}
|
||||
} else {
|
||||
for _, attr := range include {
|
||||
if fi, ok := s.nameMap[attr]; ok {
|
||||
v[fi.dbName] = fi.getValue(s.value)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(exclude) > 0 {
|
||||
for _, name := range exclude {
|
||||
if fi, ok := s.nameMap[name]; ok {
|
||||
delete(v, fi.dbName)
|
||||
}
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// getValue returns the field value for the given struct value.
|
||||
func (fi *fieldInfo) getValue(a reflect.Value) interface{} {
|
||||
for _, i := range fi.path {
|
||||
a = a.Field(i)
|
||||
if a.Kind() == reflect.Ptr {
|
||||
if a.IsNil() {
|
||||
return nil
|
||||
}
|
||||
a = a.Elem()
|
||||
}
|
||||
}
|
||||
return a.Interface()
|
||||
}
|
||||
|
||||
// getField returns the reflection value of the field for the given struct value.
|
||||
func (fi *fieldInfo) getField(a reflect.Value) reflect.Value {
|
||||
i := 0
|
||||
for ; i < len(fi.path)-1; i++ {
|
||||
a = indirect(a.Field(fi.path[i]))
|
||||
}
|
||||
return a.Field(fi.path[i])
|
||||
}
|
||||
|
||||
func (si *structInfo) build(a reflect.Type, path []int, namePrefix, dbNamePrefix string, mapper FieldMapFunc) {
|
||||
n := a.NumField()
|
||||
for i := 0; i < n; i++ {
|
||||
field := a.Field(i)
|
||||
tag := field.Tag.Get(DbTag)
|
||||
|
||||
// only handle anonymous or exported fields
|
||||
if !field.Anonymous && field.PkgPath != "" || tag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
path2 := make([]int, len(path), len(path)+1)
|
||||
copy(path2, path)
|
||||
path2 = append(path2, i)
|
||||
|
||||
ft := field.Type
|
||||
if ft.Kind() == reflect.Ptr {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
|
||||
name := field.Name
|
||||
dbName, isPK := parseTag(tag)
|
||||
if dbName == "" && !field.Anonymous {
|
||||
if mapper != nil {
|
||||
dbName = mapper(field.Name)
|
||||
} else {
|
||||
dbName = field.Name
|
||||
}
|
||||
}
|
||||
if field.Anonymous {
|
||||
name = ""
|
||||
}
|
||||
|
||||
if isNestedStruct(ft) {
|
||||
// dive into non-scanner struct
|
||||
si.build(ft, path2, concat(namePrefix, name), concat(dbNamePrefix, dbName), mapper)
|
||||
} else if dbName != "" {
|
||||
// non-anonymous scanner or struct field
|
||||
fi := &fieldInfo{
|
||||
name: concat(namePrefix, name),
|
||||
dbName: concat(dbNamePrefix, dbName),
|
||||
path: path2,
|
||||
}
|
||||
// a field in an anonymous struct may be shadowed
|
||||
if _, ok := si.nameMap[fi.name]; !ok || len(path2) < len(si.nameMap[fi.name].path) {
|
||||
si.nameMap[fi.name] = fi
|
||||
si.dbNameMap[fi.dbName] = fi
|
||||
if isPK {
|
||||
si.pkNames = append(si.pkNames, fi.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(si.pkNames) == 0 {
|
||||
if _, ok := si.nameMap["ID"]; ok {
|
||||
si.pkNames = append(si.pkNames, "ID")
|
||||
} else if _, ok := si.nameMap["Id"]; ok {
|
||||
si.pkNames = append(si.pkNames, "Id")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isNestedStruct(t reflect.Type) bool {
|
||||
if t.PkgPath() == "time" && t.Name() == "Time" {
|
||||
return false
|
||||
}
|
||||
return t.Kind() == reflect.Struct && !reflect.PtrTo(t).Implements(scannerType)
|
||||
}
|
||||
|
||||
func parseTag(tag string) (string, bool) {
|
||||
if tag == "pk" {
|
||||
return "", true
|
||||
}
|
||||
if strings.HasPrefix(tag, "pk,") {
|
||||
return tag[3:], true
|
||||
}
|
||||
return tag, false
|
||||
}
|
||||
|
||||
func concat(s1, s2 string) string {
|
||||
if s1 == "" {
|
||||
return s2
|
||||
} else if s2 == "" {
|
||||
return s1
|
||||
} else {
|
||||
return s1 + "." + s2
|
||||
}
|
||||
}
|
||||
|
||||
// indirect dereferences pointers and returns the actual value it points to.
|
||||
// If a pointer is nil, it will be initialized with a new value.
|
||||
func indirect(v reflect.Value) reflect.Value {
|
||||
for v.Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
v.Set(reflect.New(v.Type().Elem()))
|
||||
}
|
||||
v = v.Elem()
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// GetTableName implements the default way of determining the table name corresponding to the given model struct
|
||||
// or slice of structs. To get the actual table name for a model, you should use DB.TableMapFunc() instead.
|
||||
// Do not call this method in a model's TableName() method because it will cause infinite loop.
|
||||
func GetTableName(a interface{}) string {
|
||||
if tm, ok := a.(TableModel); ok {
|
||||
v := reflect.ValueOf(a)
|
||||
if v.Kind() == reflect.Ptr && v.IsNil() {
|
||||
a = reflect.New(v.Type().Elem()).Interface()
|
||||
return a.(TableModel).TableName()
|
||||
}
|
||||
return tm.TableName()
|
||||
}
|
||||
t := reflect.TypeOf(a)
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
if t.Kind() == reflect.Slice {
|
||||
return GetTableName(reflect.Zero(t.Elem()).Interface())
|
||||
}
|
||||
return DefaultFieldMapFunc(t.Name())
|
||||
}
|
154
struct_test.go
Normal file
154
struct_test.go
Normal file
|
@ -0,0 +1,154 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDefaultFieldMapFunc(t *testing.T) {
|
||||
tests := []struct {
|
||||
input, output string
|
||||
}{
|
||||
{"Name", "name"},
|
||||
{"FirstName", "first_name"},
|
||||
{"Name0", "name0"},
|
||||
{"ID", "id"},
|
||||
{"UserID", "user_id"},
|
||||
{"User0ID", "user0_id"},
|
||||
{"MyURL", "my_url"},
|
||||
{"URLPath", "urlpath"},
|
||||
{"MyURLPath", "my_urlpath"},
|
||||
{"First_Name", "first_name"},
|
||||
{"first_name", "first_name"},
|
||||
{"_FirstName", "_first_name"},
|
||||
{"_First_Name", "_first_name"},
|
||||
}
|
||||
for _, test := range tests {
|
||||
r := DefaultFieldMapFunc(test.input)
|
||||
assert.Equal(t, test.output, r, test.input)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_concat(t *testing.T) {
|
||||
assert.Equal(t, "a.b", concat("a", "b"))
|
||||
assert.Equal(t, "a", concat("a", ""))
|
||||
assert.Equal(t, "b", concat("", "b"))
|
||||
}
|
||||
|
||||
func Test_parseTag(t *testing.T) {
|
||||
name, pk := parseTag("abc")
|
||||
assert.Equal(t, "abc", name)
|
||||
assert.False(t, pk)
|
||||
|
||||
name, pk = parseTag("pk,abc")
|
||||
assert.Equal(t, "abc", name)
|
||||
assert.True(t, pk)
|
||||
|
||||
name, pk = parseTag("pk")
|
||||
assert.Equal(t, "", name)
|
||||
assert.True(t, pk)
|
||||
}
|
||||
|
||||
func Test_indirect(t *testing.T) {
|
||||
var a int
|
||||
assert.Equal(t, reflect.ValueOf(a).Kind(), indirect(reflect.ValueOf(a)).Kind())
|
||||
var b *int
|
||||
bi := indirect(reflect.ValueOf(&b))
|
||||
assert.Equal(t, reflect.ValueOf(a).Kind(), bi.Kind())
|
||||
if assert.NotNil(t, b) {
|
||||
assert.Equal(t, 0, *b)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_structValue_columns(t *testing.T) {
|
||||
customer := Customer{
|
||||
ID: 1,
|
||||
Name: "abc",
|
||||
Status: 2,
|
||||
Email: "abc@example.com",
|
||||
}
|
||||
sv := newStructValue(&customer, DefaultFieldMapFunc, GetTableName)
|
||||
cols := sv.columns(nil, nil)
|
||||
assert.Equal(t, map[string]interface{}{"id": 1, "name": "abc", "status": 2, "email": "abc@example.com", "address": sql.NullString{}}, cols)
|
||||
|
||||
cols = sv.columns([]string{"ID", "name"}, nil)
|
||||
assert.Equal(t, map[string]interface{}{"id": 1}, cols)
|
||||
|
||||
cols = sv.columns([]string{"ID", "Name"}, []string{"ID"})
|
||||
assert.Equal(t, map[string]interface{}{"name": "abc"}, cols)
|
||||
|
||||
cols = sv.columns(nil, []string{"ID", "Address"})
|
||||
assert.Equal(t, map[string]interface{}{"name": "abc", "status": 2, "email": "abc@example.com"}, cols)
|
||||
|
||||
sv = newStructValue(&customer, nil, GetTableName)
|
||||
cols = sv.columns([]string{"ID", "Name"}, []string{"ID"})
|
||||
assert.Equal(t, map[string]interface{}{"Name": "abc"}, cols)
|
||||
}
|
||||
|
||||
func TestIssue37(t *testing.T) {
|
||||
customer := Customer{
|
||||
ID: 1,
|
||||
Name: "abc",
|
||||
Status: 2,
|
||||
Email: "abc@example.com",
|
||||
}
|
||||
ev := struct {
|
||||
Customer
|
||||
Status string
|
||||
}{customer, "20"}
|
||||
sv := newStructValue(&ev, nil, GetTableName)
|
||||
cols := sv.columns([]string{"ID", "Status"}, nil)
|
||||
assert.Equal(t, map[string]interface{}{"ID": 1, "Status": "20"}, cols)
|
||||
|
||||
ev2 := struct {
|
||||
Status string
|
||||
Customer
|
||||
}{"20", customer}
|
||||
sv = newStructValue(&ev2, nil, GetTableName)
|
||||
cols = sv.columns([]string{"ID", "Status"}, nil)
|
||||
assert.Equal(t, map[string]interface{}{"ID": 1, "Status": "20"}, cols)
|
||||
}
|
||||
|
||||
type MyCustomer struct{}
|
||||
|
||||
func TestGetTableName(t *testing.T) {
|
||||
var c1 Customer
|
||||
assert.Equal(t, "customer", GetTableName(c1))
|
||||
|
||||
var c2 *Customer
|
||||
assert.Equal(t, "customer", GetTableName(c2))
|
||||
|
||||
var c3 MyCustomer
|
||||
assert.Equal(t, "my_customer", GetTableName(c3))
|
||||
|
||||
var c4 []Customer
|
||||
assert.Equal(t, "customer", GetTableName(c4))
|
||||
|
||||
var c5 *[]Customer
|
||||
assert.Equal(t, "customer", GetTableName(c5))
|
||||
|
||||
var c6 []MyCustomer
|
||||
assert.Equal(t, "my_customer", GetTableName(c6))
|
||||
|
||||
var c7 []CustomerPtr
|
||||
assert.Equal(t, "customer", GetTableName(c7))
|
||||
|
||||
var c8 **int
|
||||
assert.Equal(t, "", GetTableName(c8))
|
||||
}
|
||||
|
||||
type FA struct {
|
||||
A1 string
|
||||
A2 int
|
||||
}
|
||||
|
||||
type FB struct {
|
||||
B1 string
|
||||
}
|
81
testdata/mysql.sql
vendored
Normal file
81
testdata/mysql.sql
vendored
Normal file
|
@ -0,0 +1,81 @@
|
|||
/**
|
||||
* This is the database schema for testing MySQL support of ozzo-dbx.
|
||||
* The following database setup is required in order to run the test:
|
||||
* - host: 127.0.0.1
|
||||
* - user: travis
|
||||
* - pass: <none>
|
||||
* - database: pocketbase_dbx_test
|
||||
*/
|
||||
|
||||
DROP TABLE IF EXISTS `order_item` CASCADE;
|
||||
DROP TABLE IF EXISTS `item` CASCADE;
|
||||
DROP TABLE IF EXISTS `order` CASCADE;
|
||||
DROP TABLE IF EXISTS `customer` CASCADE;
|
||||
DROP TABLE IF EXISTS `user` CASCADE;
|
||||
|
||||
CREATE TABLE `customer` (
|
||||
`id` int(11) NOT NULL AUTO_INCREMENT,
|
||||
`email` varchar(128) NOT NULL,
|
||||
`name` varchar(128),
|
||||
`address` text,
|
||||
`status` int (11) DEFAULT 0,
|
||||
PRIMARY KEY (`id`)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
|
||||
|
||||
CREATE TABLE `user` (
|
||||
`id` int(11) NOT NULL AUTO_INCREMENT,
|
||||
`email` varchar(128) NOT NULL,
|
||||
`created` date,
|
||||
`updated` date,
|
||||
PRIMARY KEY (`id`)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
|
||||
|
||||
CREATE TABLE `item` (
|
||||
`id` int(11) NOT NULL AUTO_INCREMENT,
|
||||
`name` varchar(128) NOT NULL,
|
||||
PRIMARY KEY (`id`)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
|
||||
|
||||
CREATE TABLE `order` (
|
||||
`id` int(11) NOT NULL AUTO_INCREMENT,
|
||||
`customer_id` int(11) NOT NULL,
|
||||
`created_at` int(11) NOT NULL,
|
||||
`total` decimal(10,0) NOT NULL,
|
||||
PRIMARY KEY (`id`),
|
||||
CONSTRAINT `FK_order_customer_id` FOREIGN KEY (`customer_id`) REFERENCES `customer` (`id`) ON DELETE CASCADE
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
|
||||
|
||||
CREATE TABLE `order_item` (
|
||||
`order_id` int(11) NOT NULL,
|
||||
`item_id` int(11) NOT NULL,
|
||||
`quantity` int(11) NOT NULL,
|
||||
`subtotal` decimal(10,0) NOT NULL,
|
||||
PRIMARY KEY (`order_id`,`item_id`),
|
||||
KEY `FK_order_item_item_id` (`item_id`),
|
||||
CONSTRAINT `FK_order_item_order_id` FOREIGN KEY (`order_id`) REFERENCES `order` (`id`) ON DELETE CASCADE,
|
||||
CONSTRAINT `FK_order_item_item_id` FOREIGN KEY (`item_id`) REFERENCES `item` (`id`) ON DELETE CASCADE
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
|
||||
|
||||
INSERT INTO `customer` (email, name, address, status) VALUES ('user1@example.com', 'user1', 'address1', 1);
|
||||
INSERT INTO `customer` (email, name, address, status) VALUES ('user2@example.com', 'user2', NULL, 1);
|
||||
INSERT INTO `customer` (email, name, address, status) VALUES ('user3@example.com', 'user3', 'address3', 2);
|
||||
|
||||
INSERT INTO `user` (email, created) VALUES ('user1@example.com', '2015-01-02');
|
||||
INSERT INTO `user` (email, created) VALUES ('user2@example.com', now());
|
||||
|
||||
INSERT INTO `item` (name) VALUES ('The Go Programming Language');
|
||||
INSERT INTO `item` (name) VALUES ('Go in Action');
|
||||
INSERT INTO `item` (name) VALUES ('Go Programming Blueprints');
|
||||
INSERT INTO `item` (name) VALUES ('Building Microservices');
|
||||
INSERT INTO `item` (name) VALUES ('Go Web Programming');
|
||||
|
||||
INSERT INTO `order` (customer_id, created_at, total) VALUES (1, 1325282384, 110.0);
|
||||
INSERT INTO `order` (customer_id, created_at, total) VALUES (2, 1325334482, 33.0);
|
||||
INSERT INTO `order` (customer_id, created_at, total) VALUES (2, 1325502201, 40.0);
|
||||
|
||||
INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (1, 1, 1, 30.0);
|
||||
INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (1, 2, 2, 40.0);
|
||||
INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (2, 4, 1, 10.0);
|
||||
INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (2, 5, 1, 15.0);
|
||||
INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (2, 3, 1, 8.0);
|
||||
INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (3, 2, 1, 40.0);
|
23
tx.go
Normal file
23
tx.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
// Copyright 2016 Qiang Xue. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dbx
|
||||
|
||||
import "database/sql"
|
||||
|
||||
// Tx enhances sql.Tx with additional querying methods.
|
||||
type Tx struct {
|
||||
Builder
|
||||
tx *sql.Tx
|
||||
}
|
||||
|
||||
// Commit commits the transaction.
|
||||
func (t *Tx) Commit() error {
|
||||
return t.tx.Commit()
|
||||
}
|
||||
|
||||
// Rollback aborts the transaction.
|
||||
func (t *Tx) Rollback() error {
|
||||
return t.tx.Rollback()
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue