diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fb6c1ae --- /dev/null +++ b/.gitignore @@ -0,0 +1,26 @@ +.idea + +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +*.prof diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..ee38f9a --- /dev/null +++ b/.travis.yml @@ -0,0 +1,22 @@ +dist: bionic + +language: go + +go: + - 1.13.x + +services: + - mysql + +install: + - go get golang.org/x/tools/cmd/cover + - go get github.com/mattn/goveralls + - go get golang.org/x/lint/golint + +before_script: + - mysql -e 'CREATE DATABASE pocketbase_dbx_test;'; + +script: + - test -z "`gofmt -l -d .`" + - go test -v -covermode=count -coverprofile=coverage.out + - $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..d235be9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,17 @@ +The MIT License (MIT) +Copyright (c) 2016, Qiang Xue + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software +and associated documentation files (the "Software"), to deal in the Software without restriction, +including without limitation the rights to use, copy, modify, merge, publish, distribute, +sublicense, and/or sell copies of the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or +substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..ac665eb --- /dev/null +++ b/README.md @@ -0,0 +1,748 @@ +dbx +[![Go Report Card](https://goreportcard.com/badge/github.com/pocketbase/dbx)](https://goreportcard.com/report/github.com/pocketbase/dbx) +[![GoDoc](https://godoc.org/github.com/pocketbase/dbx?status.svg)](https://pkg.go.dev/github.com/pocketbase/dbx) +================================================================================ + +> ⚠️ This is a maintained fork of [go-ozzo/ozzo-dbx](https://github.com/go-ozzo/ozzo-dbx) (see [#103](https://github.com/go-ozzo/ozzo-dbx/issues/103)). +> +> Currently, the changes are primarily related to better SQLite support and some other minor improvements, implementing [#99](https://github.com/go-ozzo/ozzo-dbx/pull/99), [#100](https://github.com/go-ozzo/ozzo-dbx/pull/100) and [#102](https://github.com/go-ozzo/ozzo-dbx/pull/102). + + +## Summary + +- [Description](#description) +- [Requirements](#requirements) +- [Installation](#installation) +- [Supported Databases](#supported-databases) +- [Getting Started](#getting-started) +- [Connecting to Database](#connecting-to-database) +- [Executing Queries](#executing-queries) +- [Binding Parameters](#binding-parameters) +- [Building Queries](#building-queries) + - [Building SELECT Queries](#building-select-queries) + - [Building Query Conditions](#building-query-conditions) + - [Building Data Manipulation Queries](#building-data-manipulation-queries) + - [Building Schema Manipulation Queries](#building-schema-manipulation-queries) +- [CRUD Operations](#crud-operations) + - [Create](#create) + - [Read](#read) + - [Update](#update) + - [Delete](#delete) + - [Null Handling](#null-handling) +- [Quoting Table and Column Names](#quoting-table-and-column-names) +- [Using Transactions](#using-transactions) +- [Logging Executed SQL Statements](#logging-executed-sql-statements) +- [Supporting New Databases](#supporting-new-databases) + + +## Description + +`dbx` is a Go package that enhances the standard `database/sql` package by providing powerful data retrieval methods +as well as DB-agnostic query building capabilities. `dbx` is not an ORM. It has the following features: + +* Populating data into structs and NullString maps +* Named parameter binding +* DB-agnostic query building methods, including SELECT queries, data manipulation queries, and schema manipulation queries +* Inserting, updating, and deleting model structs +* Powerful query condition building +* Open architecture allowing addition of new database support or customization of existing support +* Logging executed SQL statements +* Supporting major relational databases + +For an example on how this library is used in an application, please refer to [go-rest-api](https://github.com/qiangxue/go-rest-api) which is a starter kit for building RESTful APIs in Go. + +## Requirements + +Go 1.13 or above. + +## Installation + +Run the following command to install the package: + +``` +go get github.com/pocketbase/dbx +``` + +In addition, install the specific DB driver package for the kind of database to be used. Please refer to +[SQL database drivers](https://github.com/golang/go/wiki/SQLDrivers) for a complete list. For example, if you are +using MySQL, you may install the following package: + +```sh +go get github.com/go-sql-driver/mysql +``` + +and import it in your main code like the following: + +```go +import _ "github.com/go-sql-driver/mysql" +``` + +## Supported Databases + +The following databases are fully supported out of box: + +* SQLite +* MySQL +* PostgreSQL +* MS SQL Server (2012 or above) +* Oracle + +For other databases, the query building feature may not work as expected. You can create a custom builder to +solve the problem. Please see the last section for more details. + +## Getting Started + +The following code snippet shows how you can use this package in order to access data from a MySQL database. + +```go +package main + +import ( + "github.com/pocketbase/dbx" + _ "github.com/go-sql-driver/mysql" +) + +func main() { + db, _ := dbx.Open("mysql", "user:pass@/example") + + // create a new query + q := db.NewQuery("SELECT id, name FROM users LIMIT 10") + + // fetch all rows into a struct array + var users []struct { + ID, Name string + } + q.All(&users) + + // fetch a single row into a struct + var user struct { + ID, Name string + } + q.One(&user) + + // fetch a single row into a string map + data := dbx.NullStringMap{} + q.One(data) + + // fetch row by row + rows2, _ := q.Rows() + for rows2.Next() { + rows2.ScanStruct(&user) + // rows.ScanMap(data) + // rows.Scan(&id, &name) + } +} +``` + +And the following example shows how to use the query building capability of this package. + +```go +package main + +import ( + "github.com/pocketbase/dbx" + _ "github.com/go-sql-driver/mysql" +) + +func main() { + db, _ := dbx.Open("mysql", "user:pass@/example") + + // build a SELECT query + // SELECT `id`, `name` FROM `users` WHERE `name` LIKE '%Charles%' ORDER BY `id` + q := db.Select("id", "name"). + From("users"). + Where(dbx.Like("name", "Charles")). + OrderBy("id") + + // fetch all rows into a struct array + var users []struct { + ID, Name string + } + q.All(&users) + + // build an INSERT query + // INSERT INTO `users` (`name`) VALUES ('James') + db.Insert("users", dbx.Params{ + "name": "James", + }).Execute() +} +``` + +## Connecting to Database + +To connect to a database, call `dbx.Open()` in the same way as you would do with the `Open()` method in `database/sql`. + +```go +db, err := dbx.Open("mysql", "user:pass@hostname/db_name") +``` + +The method returns a `dbx.DB` instance which can be used to create and execute DB queries. Note that the method +does not really establish a connection until a query is made using the returned `dbx.DB` instance. It also +does not check the correctness of the data source name either. Call `dbx.MustOpen()` to make sure the data +source name is correct. + +## Executing Queries + +To execute a SQL statement, first create a `dbx.Query` instance by calling `DB.NewQuery()` with the SQL statement +to be executed. And then call `Query.Execute()` to execute the query if the query is not meant to retrieving data. +For example, + +```go +q := db.NewQuery("UPDATE users SET status=1 WHERE id=100") +result, err := q.Execute() +``` + +If the SQL statement does retrieve data (e.g. a SELECT statement), one of the following methods should be called, +which will execute the query and populate the result into the specified variable(s). + +* `Query.All()`: populate all rows of the result into a slice of structs or `NullString` maps. +* `Query.One()`: populate the first row of the result into a struct or a `NullString` map. +* `Query.Column()`: populate the first column of the result into a slice. +* `Query.Row()`: populate the first row of the result into a list of variables, one for each returning column. +* `Query.Rows()`: returns a `dbx.Rows` instance to allow retrieving data row by row. + +For example, + +```go +type User struct { + ID int + Name string +} + +var ( + users []User + user User + + row dbx.NullStringMap + + id int + name string + + err error +) + +q := db.NewQuery("SELECT id, name FROM users LIMIT 10") + +// populate all rows into a User slice +err = q.All(&users) +fmt.Println(users[0].ID, users[0].Name) + +// populate the first row into a User struct +err = q.One(&user) +fmt.Println(user.ID, user.Name) + +// populate the first row into a NullString map +err = q.One(&row) +fmt.Println(row["id"], row["name"]) + +var ids []int +err = q.Column(&ids) +fmt.Println(ids) + +// populate the first row into id and name +err = q.Row(&id, &name) + +// populate data row by row +rows, _ := q.Rows() +for rows.Next() { + _ = rows.ScanMap(&row) +} +``` + +When populating a struct, the following rules are used to determine which columns should go into which struct fields: + +* Only exported struct fields can be populated. +* A field receives data if its name is mapped to a column according to the field mapping function `Query.FieldMapper`. + The default field mapping function separates words in a field name by underscores and turns them into lower case. + For example, a field name `FirstName` will be mapped to the column name `first_name`, and `MyID` to `my_id`. +* If a field has a `db` tag, the tag value will be used as the corresponding column name. If the `db` tag is a dash `-`, + it means the field should NOT be populated. +* For anonymous fields that are of struct type, they will be expanded and their component fields will be populated + according to the rules described above. +* For named fields that are of struct type, they will also be expanded. But their component fields will be prefixed + with the struct names when being populated. + +An exception to the above struct expansion is that when a struct type implements `sql.Scanner` or when it is `time.Time`. +In this case, the field will be populated as a whole by the DB driver. Also, if a field is a pointer to some type, +the field will be allocated memory and populated with the query result if it is not null. + +The following example shows how fields are populated according to the rules above: + +```go +type User struct { + id int + Type int `db:"-"` + MyName string `db:"name"` + Profile + Address Address `db:"addr"` +} + +type Profile struct { + Age int +} + +type Address struct { + City string +} +``` + +* `User.id`: not populated because the field is not exported; +* `User.Type`: not populated because the `db` tag is `-`; +* `User.MyName`: to be populated from the `name` column, according to the `db` tag; +* `Profile.Age`: to be populated from the `age` column, since `Profile` is an anonymous field; +* `Address.City`: to be populated from the `addr.city` column, since `Address` is a named field of struct type + and its fields will be prefixed with `addr.` according to the `db` tag. + +Note that if a column in the result does not have a corresponding struct field, it will be ignored. Similarly, +if a struct field does not have a corresponding column in the result, it will not be populated. + +## Binding Parameters + +A SQL statement is usually parameterized with dynamic values. For example, you may want to select the user record +according to the user ID received from the client. Parameter binding should be used in this case, and it is almost +always preferred to prevent from SQL injection attacks. Unlike `database/sql` which does anonymous parameter binding, +`dbx` uses named parameter binding. *Anonymous parameter binding is not supported*, as it will mess up with named +parameters. For example, + +```go +q := db.NewQuery("SELECT id, name FROM users WHERE id={:id}") +q.Bind(dbx.Params{"id": 100}) +err := q.One(&user) +``` + +The above example will select the user record whose `id` is 100. The method `Query.Bind()` binds a set +of named parameters to a SQL statement which contains parameter placeholders in the format of `{:ParamName}`. + +If a SQL statement needs to be executed multiple times with different parameter values, it may be prepared +to improve the performance. For example, + +```go +q := db.NewQuery("SELECT id, name FROM users WHERE id={:id}") +q.Prepare() +defer q.Close() + +q.Bind(dbx.Params{"id": 100}) +err := q.One(&user) + +q.Bind(dbx.Params{"id": 200}) +err = q.One(&user) + +// ... +``` + + +## Cancelable Queries + +Queries are cancelable when they are used with `context.Context`. In particular, by calling `Query.WithContext()` you +can associate a context with a query and use the context to cancel the query while it is running. For example, + +```go +q := db.NewQuery("SELECT id, name FROM users") +err := q.WithContext(ctx).All(&users) +``` + + +## Building Queries + +Instead of writing plain SQLs, `dbx` allows you to build SQLs programmatically, which often leads to cleaner, +more secure, and DB-agnostic code. You can build three types of queries: the SELECT queries, the data manipulation +queries, and the schema manipulation queries. + +### Building SELECT Queries + +Building a SELECT query starts by calling `DB.Select()`. You can build different clauses of a SELECT query using +the corresponding query building methods. For example, + +```go +db, _ := dbx.Open("mysql", "user:pass@/example") +err := db.Select("id", "name"). + From("users"). + Where(dbx.HashExp{"id": 100}). + One(&user) +``` + +The above code will generate and execute the following SQL statement: + +```sql +SELECT `id`, `name` FROM `users` WHERE `id`={:p0} +``` + +Notice how the table and column names are properly quoted according to the currently using database type. +And parameter binding is used to populate the value of `p0` in the `WHERE` clause. + +Every SQL keyword has a corresponding query building method. For example, `SELECT` corresponds to `Select()`, +`FROM` corresponds to `From()`, `WHERE` corresponds to `Where()`, and so on. You can chain these method calls +together, just like you would do when writing a plain SQL. Each of these methods returns the query instance +(of type `dbx.SelectQuery`) that is being built. Once you finish building a query, you may call methods such as +`One()`, `All()` to execute the query and populate data into variables. You may also explicitly call `Build()` +to build the query and turn it into a `dbx.Query` instance which may allow you to get the SQL statement and do +other interesting work. + + +### Building Query Conditions + +`dbx` supports very flexible and powerful query condition building which can be used to build SQL clauses +such as `WHERE`, `HAVING`, etc. For example, + +```go +// id=100 +dbx.NewExp("id={:id}", dbx.Params{"id": 100}) + +// id=100 AND status=1 +dbx.HashExp{"id": 100, "status": 1} + +// status=1 OR age>30 +dbx.Or(dbx.HashExp{"status": 1}, dbx.NewExp("age>30")) + +// name LIKE '%admin%' AND name LIKE '%example%' +dbx.Like("name", "admin", "example") +``` + +When building a query condition expression, its parameter values will be populated using parameter binding, which +prevents SQL injection from happening. Also if an expression involves column names, they will be properly quoted. +The following condition building functions are available: + +* `dbx.NewExp()`: creating a condition using the given expression string and binding parameters. For example, +`dbx.NewExp("id={:id}", dbx.Params{"id":100})` would create the expression `id=100`. +* `dbx.HashExp`: a map type that represents name-value pairs concatenated by `AND` operators. For example, +`dbx.HashExp{"id":100, "status":1}` would create `id=100 AND status=1`. +* `dbx.Not()`: creating a `NOT` expression by prepending `NOT` to the given expression. +* `dbx.And()`: creating an `AND` expression by concatenating the given expressions with the `AND` operators. +* `dbx.Or()`: creating an `OR` expression by concatenating the given expressions with the `OR` operators. +* `dbx.In()`: creating an `IN` expression for the specified column and the range of values. +For example, `dbx.In("age", 30, 40, 50)` would create the expression `age IN (30, 40, 50)`. +Note that if the value range is empty, it will generate an expression representing a false value. +* `dbx.NotIn()`: creating an `NOT IN` expression. This is very similar to `dbx.In()`. +* `dbx.Like()`: creating a `LIKE` expression for the specified column and the range of values. For example, +`dbx.Like("title", "golang", "framework")` would create the expression `title LIKE "%golang%" AND title LIKE "%framework%"`. +You can further customize a LIKE expression by calling `Escape()` and/or `Match()` functions of the resulting expression. +Note that if the value range is empty, it will generate an empty expression. +* `dbx.NotLike()`: creating a `NOT LIKE` expression. This is very similar to `dbx.Like()`. +* `dbx.OrLike()`: creating a `LIKE` expression but concatenating different `LIKE` sub-expressions using `OR` instead of `AND`. +* `dbx.OrNotLike()`: creating a `NOT LIKE` expression and concatenating different `NOT LIKE` sub-expressions using `OR` instead of `AND`. +* `dbx.Exists()`: creating an `EXISTS` expression by prepending `EXISTS` to the given expression. +* `dbx.NotExists()`: creating a `NOT EXISTS` expression by prepending `NOT EXISTS` to the given expression. +* `dbx.Between()`: creating a `BETWEEN` expression. For example, `dbx.Between("age", 30, 40)` would create the +expression `age BETWEEN 30 AND 40`. +* `dbx.NotBetween()`: creating a `NOT BETWEEN` expression. For example + +You may also create other convenient functions to help building query conditions, as long as the functions return +an object implementing the `dbx.Expression` interface. + + +### Building Data Manipulation Queries + +Data manipulation queries are those changing the data in the database, such as INSERT, UPDATE, DELETE statements. +Such queries can be built by calling the corresponding methods of `DB`. For example, + +```go +db, _ := dbx.Open("mysql", "user:pass@/example") + +// INSERT INTO `users` (`name`, `email`) VALUES ({:p0}, {:p1}) +err := db.Insert("users", dbx.Params{ + "name": "James", + "email": "james@example.com", +}).Execute() + +// UPDATE `users` SET `status`={:p0} WHERE `id`={:p1} +err = db.Update("users", dbx.Params{"status": 1}, dbx.HashExp{"id": 100}).Execute() + +// DELETE FROM `users` WHERE `status`={:p0} +err = db.Delete("users", dbx.HashExp{"status": 2}).Execute() +``` + +When building data manipulation queries, remember to call `Execute()` at the end to execute the queries. + +### Building Schema Manipulation Queries + +Schema manipulation queries are those changing the database schema, such as creating a new table, adding a new column. +These queries can be built by calling the corresponding methods of `DB`. For example, + +```go +db, _ := dbx.Open("mysql", "user:pass@/example") + +// CREATE TABLE `users` (`id` int primary key, `name` varchar(255)) +q := db.CreateTable("users", map[string]string{ + "id": "int primary key", + "name": "varchar(255)", +}) +err := q.Execute() +``` + +## CRUD Operations + +Although `dbx` is not an ORM, it does provide a very convenient way to do typical CRUD (Create, Read, Update, Delete) +operations without the need of writing plain SQL statements. + +To use the CRUD feature, first define a struct type for a table. By default, a struct is associated with a table +whose name is the snake case version of the struct type name. For example, a struct named `MyCustomer` +corresponds to the table name `my_customer`. You may explicitly specify the table name for a struct by implementing +the `dbx.TableModel` interface. For example, + +```go +type MyCustomer struct{} + +func (c MyCustomer) TableName() string { + return "customer" +} +``` + +Note that the `TableName` method should be defined with a value receiver instead of a pointer receiver. + +If the struct has a field named `ID` or `Id`, by default the field will be treated as the primary key field. +If you want to use a different field as the primary key, tag it with `db:"pk"`. You may tag multiple fields +for composite primary keys. Note that if you also want to explicitly specify the column name for a primary key field, +you should use the tag format `db:"pk,col_name"`. + +You can give a common prefix or suffix to your table names by defining your own table name mapping via +`DB.TableMapFunc`. For example, the following code prefixes `tbl_` to all table names. + +```go +db.TableMapper = func(a interface{}) string { + return "tbl_" + GetTableName(a) +} +``` + +### Create + +To create (insert) a new row using a model, call the `ModelQuery.Insert()` method. For example, + +```go +type Customer struct { + ID int + Name string + Email string + Status int +} + +db, _ := dbx.Open("mysql", "user:pass@/example") + +customer := Customer{ + Name: "example", + Email: "test@example.com", +} +// INSERT INTO customer (name, email, status) VALUES ('example', 'test@example.com', 0) +err := db.Model(&customer).Insert() +``` + +This will insert a row using the values from *all* public fields (except the primary key field if it is empty) in the struct. +If a primary key field is zero (a integer zero or a nil pointer), it is assumed to be auto-incremental and +will be automatically filled with the last insertion ID after a successful insertion. + +You can explicitly specify the fields that should be inserted by passing the list of the field names to the `Insert()` method. +You can also exclude certain fields from being inserted by calling `Exclude()` before calling `Insert()`. For example, + +```go +db, _ := dbx.Open("mysql", "user:pass@/example") + +// insert only Name and Email fields +err := db.Model(&customer).Insert("Name", "Email") +// insert all public fields except Status +err = db.Model(&customer).Exclude("Status").Insert() +// insert only Name +err = db.Model(&customer).Exclude("Status").Insert("Name", "Status") +``` + +### Read + +To read a model by a given primary key value, call `SelectQuery.Model()`. + +```go +db, _ := dbx.Open("mysql", "user:pass@/example") + +var customer Customer +// SELECT * FROM customer WHERE id=100 +err := db.Select().Model(100, &customer) + +// SELECT name, email FROM customer WHERE status=1 AND id=100 +err = db.Select("name", "email").Where(dbx.HashExp{"status": 1}).Model(100, &customer) +``` + +Note that `SelectQuery.Model()` does not support composite primary keys. You should use `SelectQuery.One()` in this case. +For example, + +```go +db, _ := dbx.Open("mysql", "user:pass@/example") + +var orderItem OrderItem + +// SELECT * FROM order_item WHERE order_id=100 AND item_id=20 +err := db.Select().Where(dbx.HashExp{"order_id": 100, "item_id": 20}).One(&orderItem) +``` + +In the above queries, we do not call `From()` to specify which table to select data from. This is because the select +query automatically sets the table according to the model struct being populated. If the struct implements `TableModel`, +the value returned by its `TableName()` method will be used as the table name. Otherwise, the snake case version +of the struct type name will be the table name. + +You may also call `SelectQuery.All()` to read a list of model structs. Similarly, you do not need to call `From()` +if the table name can be inferred from the model structs. + + +### Update + +To update a model, call the `ModelQuery.Update()` method. Like `Insert()`, by default, the `Update()` method will +update *all* public fields except primary key fields of the model. You can explicitly specify which fields can +be updated and which cannot in the same way as described for the `Insert()` method. For example, + +```go +db, _ := dbx.Open("mysql", "user:pass@/example") + +// update all public fields of customer +err := db.Model(&customer).Update() +// update only Status +err = db.Model(&customer).Update("Status") +// update all public fields except Status +err = db.Model(&customer).Exclude("Status").Update() +``` + +Note that the `Update()` method assumes that the primary keys are immutable. It uses the primary key value of the model +to look for the row that should be updated. An error will be returned if a model does not have a primary key. + + +### Delete + +To delete a model, call the `ModelQuery.Delete()` method. The method deletes the row using the primary key value +specified by the model. If the model does not have a primary key, an error will be returned. For example, + +```go +db, _ := dbx.Open("mysql", "user:pass@/example") + +err := db.Model(&customer).Delete() +``` + +### Null Handling + +To represent a nullable database value, you can use a pointer type. If the pointer is nil, it means the corresponding +database value is null. + +Another option to represent a database null is to use `sql.NullXyz` types. For example, if a string column is nullable, +you may use `sql.NullString`. The `NullString.Valid` field indicates whether the value is a null or not, and +`NullString.String` returns the string value when it is not null. Because `sql.NulLXyz` types do not handle JSON +marshalling, you may use the [null package](https://github.com/guregu/null), instead. + +Below is an example of handling nulls: + +```go +type Customer struct { + ID int + Email string + FirstName *string // use pointer to represent null + LastName sql.NullString // use sql.NullString to represent null +} +``` + +## Quoting Table and Column Names + +Databases vary in quoting table and column names. To allow writing DB-agnostic SQLs, `dbx` introduces a special +syntax in quoting table and column names. A word enclosed within `{{` and `}}` is treated as a table name and will +be quoted according to the particular DB driver. Similarly, a word enclosed within `[[` and `]]` is treated as a +column name and will be quoted accordingly as well. For example, when working with a MySQL database, the following +query will be properly quoted: + +```go +// SELECT * FROM `users` WHERE `status`=1 +q := db.NewQuery("SELECT * FROM {{users}} WHERE [[status]]=1") +``` + +Note that if a table or column name contains a prefix, it will still be properly quoted. For example, `{{public.users}}` +will be quoted as `"public"."users"` for PostgreSQL. + +## Using Transactions + +You can use all aforementioned query execution and building methods with transaction. For example, + +```go +db, _ := dbx.Open("mysql", "user:pass@/example") + +tx, _ := db.Begin() + +_, err1 := tx.Insert("users", dbx.Params{ + "name": "user1", +}).Execute() +_, err2 := tx.Insert("users", dbx.Params{ + "name": "user2", +}).Execute() + +if err1 == nil && err2 == nil { + tx.Commit() +} else { + tx.Rollback() +} +``` + +You may use `DB.Transactional()` to simplify your transactional code without explicitly committing or rolling back +transactions. The method will start a transaction and automatically roll back the transaction if the callback +returns an error. Otherwise it will +automatically commit the transaction. + + +```go +db, _ := dbx.Open("mysql", "user:pass@/example") + +err := db.Transactional(func(tx *dbx.Tx) error { + var err error + _, err = tx.Insert("users", dbx.Params{ + "name": "user1", + }).Execute() + if err != nil { + return err + } + _, err = tx.Insert("users", dbx.Params{ + "name": "user2", + }).Execute() + return err +}) + +fmt.Println(err) +``` + +## Logging Executed SQL Statements + +You can log and instrument DB queries by installing loggers with a DB connection. There are three kinds of loggers you +can install: +* `DB.LogFunc`: this is called each time when a SQL statement is queried or executed. The function signature is the + same as that of `fmt.Printf`, which makes it very easy to use. +* `DB.QueryLogFunc`: this is called each time when querying with a SQL statement. +* `DB.ExecLogFunc`: this is called when executing a SQL statement. + +The following example shows how you can make use of these loggers. + +```go +package main + +import ( + "context" + "database/sql" + "log" + "time" + + "github.com/pocketbase/dbx" +) + +func main() { + db, _ := dbx.Open("mysql", "user:pass@/example") + + // simple logging + db.LogFunc = log.Printf + + // or you can use the following more flexible logging + db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) { + log.Printf("[%.2fms] Query SQL: %v", float64(t.Milliseconds()), sql) + } + db.ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) { + log.Printf("[%.2fms] Execute SQL: %v", float64(t.Milliseconds()), sql) + } + // ... +} +``` + +## Supporting New Databases + +While `dbx` provides out-of-box query building support for most major relational databases, its open architecture +allows you to add support for new databases. The effort of adding support for a new database involves: + +* Create a struct that implements the `QueryBuilder` interface. You may use `BaseQueryBuilder` directly or extend it + via composition. +* Create a struct that implements the `Builder` interface. You may extend `BaseBuilder` via composition. +* Write an `init()` function to register the new builder in `dbx.BuilderFuncMap`. diff --git a/builder.go b/builder.go new file mode 100644 index 0000000..04f46c3 --- /dev/null +++ b/builder.go @@ -0,0 +1,402 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "errors" + "fmt" + "sort" + "strings" +) + +// Builder supports building SQL statements in a DB-agnostic way. +// Builder mainly provides two sets of query building methods: those building SELECT statements +// and those manipulating DB data or schema (e.g. INSERT statements, CREATE TABLE statements). +type Builder interface { + // NewQuery creates a new Query object with the given SQL statement. + // The SQL statement may contain parameter placeholders which can be bound with actual parameter + // values before the statement is executed. + NewQuery(string) *Query + // Select returns a new SelectQuery object that can be used to build a SELECT statement. + // The parameters to this method should be the list column names to be selected. + // A column name may have an optional alias name. For example, Select("id", "my_name AS name"). + Select(...string) *SelectQuery + // ModelQuery returns a new ModelQuery object that can be used to perform model insertion, update, and deletion. + // The parameter to this method should be a pointer to the model struct that needs to be inserted, updated, or deleted. + Model(interface{}) *ModelQuery + + // GeneratePlaceholder generates an anonymous parameter placeholder with the given parameter ID. + GeneratePlaceholder(int) string + + // Quote quotes a string so that it can be embedded in a SQL statement as a string value. + Quote(string) string + // QuoteSimpleTableName quotes a simple table name. + // A simple table name does not contain any schema prefix. + QuoteSimpleTableName(string) string + // QuoteSimpleColumnName quotes a simple column name. + // A simple column name does not contain any table prefix. + QuoteSimpleColumnName(string) string + + // QueryBuilder returns the query builder supporting the current DB. + QueryBuilder() QueryBuilder + + // Insert creates a Query that represents an INSERT SQL statement. + // The keys of cols are the column names, while the values of cols are the corresponding column + // values to be inserted. + Insert(table string, cols Params) *Query + // Upsert creates a Query that represents an UPSERT SQL statement. + // Upsert inserts a row into the table if the primary key or unique index is not found. + // Otherwise it will update the row with the new values. + // The keys of cols are the column names, while the values of cols are the corresponding column + // values to be inserted. + Upsert(table string, cols Params, constraints ...string) *Query + // Update creates a Query that represents an UPDATE SQL statement. + // The keys of cols are the column names, while the values of cols are the corresponding new column + // values. If the "where" expression is nil, the UPDATE SQL statement will have no WHERE clause + // (be careful in this case as the SQL statement will update ALL rows in the table). + Update(table string, cols Params, where Expression) *Query + // Delete creates a Query that represents a DELETE SQL statement. + // If the "where" expression is nil, the DELETE SQL statement will have no WHERE clause + // (be careful in this case as the SQL statement will delete ALL rows in the table). + Delete(table string, where Expression) *Query + + // CreateTable creates a Query that represents a CREATE TABLE SQL statement. + // The keys of cols are the column names, while the values of cols are the corresponding column types. + // The optional "options" parameters will be appended to the generated SQL statement. + CreateTable(table string, cols map[string]string, options ...string) *Query + // RenameTable creates a Query that can be used to rename a table. + RenameTable(oldName, newName string) *Query + // DropTable creates a Query that can be used to drop a table. + DropTable(table string) *Query + // TruncateTable creates a Query that can be used to truncate a table. + TruncateTable(table string) *Query + + // AddColumn creates a Query that can be used to add a column to a table. + AddColumn(table, col, typ string) *Query + // DropColumn creates a Query that can be used to drop a column from a table. + DropColumn(table, col string) *Query + // RenameColumn creates a Query that can be used to rename a column in a table. + RenameColumn(table, oldName, newName string) *Query + // AlterColumn creates a Query that can be used to change the definition of a table column. + AlterColumn(table, col, typ string) *Query + + // AddPrimaryKey creates a Query that can be used to specify primary key(s) for a table. + // The "name" parameter specifies the name of the primary key constraint. + AddPrimaryKey(table, name string, cols ...string) *Query + // DropPrimaryKey creates a Query that can be used to remove the named primary key constraint from a table. + DropPrimaryKey(table, name string) *Query + + // AddForeignKey creates a Query that can be used to add a foreign key constraint to a table. + // The length of cols and refCols must be the same as they refer to the primary and referential columns. + // The optional "options" parameters will be appended to the SQL statement. They can be used to + // specify options such as "ON DELETE CASCADE". + AddForeignKey(table, name string, cols, refCols []string, refTable string, options ...string) *Query + // DropForeignKey creates a Query that can be used to remove the named foreign key constraint from a table. + DropForeignKey(table, name string) *Query + + // CreateIndex creates a Query that can be used to create an index for a table. + CreateIndex(table, name string, cols ...string) *Query + // CreateUniqueIndex creates a Query that can be used to create a unique index for a table. + CreateUniqueIndex(table, name string, cols ...string) *Query + // DropIndex creates a Query that can be used to remove the named index from a table. + DropIndex(table, name string) *Query +} + +// BaseBuilder provides a basic implementation of the Builder interface. +type BaseBuilder struct { + db *DB + executor Executor +} + +// NewBaseBuilder creates a new BaseBuilder instance. +func NewBaseBuilder(db *DB, executor Executor) *BaseBuilder { + return &BaseBuilder{db, executor} +} + +// DB returns the DB instance that this builder is associated with. +func (b *BaseBuilder) DB() *DB { + return b.db +} + +// Executor returns the executor object (a DB instance or a transaction) for executing SQL statements. +func (b *BaseBuilder) Executor() Executor { + return b.executor +} + +// NewQuery creates a new Query object with the given SQL statement. +// The SQL statement may contain parameter placeholders which can be bound with actual parameter +// values before the statement is executed. +func (b *BaseBuilder) NewQuery(sql string) *Query { + return NewQuery(b.db, b.executor, sql) +} + +// GeneratePlaceholder generates an anonymous parameter placeholder with the given parameter ID. +func (b *BaseBuilder) GeneratePlaceholder(int) string { + return "?" +} + +// Quote quotes a string so that it can be embedded in a SQL statement as a string value. +func (b *BaseBuilder) Quote(s string) string { + return "'" + strings.Replace(s, "'", "''", -1) + "'" +} + +// QuoteSimpleTableName quotes a simple table name. +// A simple table name does not contain any schema prefix. +func (b *BaseBuilder) QuoteSimpleTableName(s string) string { + if strings.Contains(s, `"`) { + return s + } + return `"` + s + `"` +} + +// QuoteSimpleColumnName quotes a simple column name. +// A simple column name does not contain any table prefix. +func (b *BaseBuilder) QuoteSimpleColumnName(s string) string { + if strings.Contains(s, `"`) || s == "*" { + return s + } + return `"` + s + `"` +} + +// Insert creates a Query that represents an INSERT SQL statement. +// The keys of cols are the column names, while the values of cols are the corresponding column +// values to be inserted. +func (b *BaseBuilder) Insert(table string, cols Params) *Query { + names := make([]string, 0, len(cols)) + for name := range cols { + names = append(names, name) + } + sort.Strings(names) + + params := Params{} + columns := make([]string, 0, len(names)) + values := make([]string, 0, len(names)) + for _, name := range names { + columns = append(columns, b.db.QuoteColumnName(name)) + value := cols[name] + if e, ok := value.(Expression); ok { + values = append(values, e.Build(b.db, params)) + } else { + values = append(values, fmt.Sprintf("{:p%v}", len(params))) + params[fmt.Sprintf("p%v", len(params))] = value + } + } + + var sql string + if len(names) == 0 { + sql = fmt.Sprintf("INSERT INTO %v DEFAULT VALUES", b.db.QuoteTableName(table)) + } else { + sql = fmt.Sprintf("INSERT INTO %v (%v) VALUES (%v)", + b.db.QuoteTableName(table), + strings.Join(columns, ", "), + strings.Join(values, ", "), + ) + } + + return b.NewQuery(sql).Bind(params) +} + +// Upsert creates a Query that represents an UPSERT SQL statement. +// Upsert inserts a row into the table if the primary key or unique index is not found. +// Otherwise it will update the row with the new values. +// The keys of cols are the column names, while the values of cols are the corresponding column +// values to be inserted. +func (b *BaseBuilder) Upsert(table string, cols Params, constraints ...string) *Query { + q := b.NewQuery("") + q.LastError = errors.New("Upsert is not supported") + return q +} + +// Update creates a Query that represents an UPDATE SQL statement. +// The keys of cols are the column names, while the values of cols are the corresponding new column +// values. If the "where" expression is nil, the UPDATE SQL statement will have no WHERE clause +// (be careful in this case as the SQL statement will update ALL rows in the table). +func (b *BaseBuilder) Update(table string, cols Params, where Expression) *Query { + names := make([]string, 0, len(cols)) + for name := range cols { + names = append(names, name) + } + sort.Strings(names) + + params := Params{} + lines := make([]string, 0, len(names)) + for _, name := range names { + value := cols[name] + name = b.db.QuoteColumnName(name) + if e, ok := value.(Expression); ok { + lines = append(lines, name+"="+e.Build(b.db, params)) + } else { + lines = append(lines, fmt.Sprintf("%v={:p%v}", name, len(params))) + params[fmt.Sprintf("p%v", len(params))] = value + } + } + + sql := fmt.Sprintf("UPDATE %v SET %v", b.db.QuoteTableName(table), strings.Join(lines, ", ")) + if where != nil { + w := where.Build(b.db, params) + if w != "" { + sql += " WHERE " + w + } + } + + return b.NewQuery(sql).Bind(params) +} + +// Delete creates a Query that represents a DELETE SQL statement. +// If the "where" expression is nil, the DELETE SQL statement will have no WHERE clause +// (be careful in this case as the SQL statement will delete ALL rows in the table). +func (b *BaseBuilder) Delete(table string, where Expression) *Query { + sql := "DELETE FROM " + b.db.QuoteTableName(table) + params := Params{} + if where != nil { + w := where.Build(b.db, params) + if w != "" { + sql += " WHERE " + w + } + } + return b.NewQuery(sql).Bind(params) +} + +// CreateTable creates a Query that represents a CREATE TABLE SQL statement. +// The keys of cols are the column names, while the values of cols are the corresponding column types. +// The optional "options" parameters will be appended to the generated SQL statement. +func (b *BaseBuilder) CreateTable(table string, cols map[string]string, options ...string) *Query { + names := []string{} + for name := range cols { + names = append(names, name) + } + sort.Strings(names) + + columns := []string{} + for _, name := range names { + columns = append(columns, b.db.QuoteColumnName(name)+" "+cols[name]) + } + + sql := fmt.Sprintf("CREATE TABLE %v (%v)", b.db.QuoteTableName(table), strings.Join(columns, ", ")) + for _, opt := range options { + sql += " " + opt + } + + return b.NewQuery(sql) +} + +// RenameTable creates a Query that can be used to rename a table. +func (b *BaseBuilder) RenameTable(oldName, newName string) *Query { + sql := fmt.Sprintf("RENAME TABLE %v TO %v", b.db.QuoteTableName(oldName), b.db.QuoteTableName(newName)) + return b.NewQuery(sql) +} + +// DropTable creates a Query that can be used to drop a table. +func (b *BaseBuilder) DropTable(table string) *Query { + sql := "DROP TABLE " + b.db.QuoteTableName(table) + return b.NewQuery(sql) +} + +// TruncateTable creates a Query that can be used to truncate a table. +func (b *BaseBuilder) TruncateTable(table string) *Query { + sql := "TRUNCATE TABLE " + b.db.QuoteTableName(table) + return b.NewQuery(sql) +} + +// AddColumn creates a Query that can be used to add a column to a table. +func (b *BaseBuilder) AddColumn(table, col, typ string) *Query { + sql := fmt.Sprintf("ALTER TABLE %v ADD %v %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(col), typ) + return b.NewQuery(sql) +} + +// DropColumn creates a Query that can be used to drop a column from a table. +func (b *BaseBuilder) DropColumn(table, col string) *Query { + sql := fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(col)) + return b.NewQuery(sql) +} + +// RenameColumn creates a Query that can be used to rename a column in a table. +func (b *BaseBuilder) RenameColumn(table, oldName, newName string) *Query { + sql := fmt.Sprintf("ALTER TABLE %v RENAME COLUMN %v TO %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(oldName), b.db.QuoteColumnName(newName)) + return b.NewQuery(sql) +} + +// AlterColumn creates a Query that can be used to change the definition of a table column. +func (b *BaseBuilder) AlterColumn(table, col, typ string) *Query { + col = b.db.QuoteColumnName(col) + sql := fmt.Sprintf("ALTER TABLE %v CHANGE %v %v %v", b.db.QuoteTableName(table), col, col, typ) + return b.NewQuery(sql) +} + +// AddPrimaryKey creates a Query that can be used to specify primary key(s) for a table. +// The "name" parameter specifies the name of the primary key constraint. +func (b *BaseBuilder) AddPrimaryKey(table, name string, cols ...string) *Query { + sql := fmt.Sprintf("ALTER TABLE %v ADD CONSTRAINT %v PRIMARY KEY (%v)", + b.db.QuoteTableName(table), + b.db.QuoteColumnName(name), + b.quoteColumns(cols)) + return b.NewQuery(sql) +} + +// DropPrimaryKey creates a Query that can be used to remove the named primary key constraint from a table. +func (b *BaseBuilder) DropPrimaryKey(table, name string) *Query { + sql := fmt.Sprintf("ALTER TABLE %v DROP CONSTRAINT %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(name)) + return b.NewQuery(sql) +} + +// AddForeignKey creates a Query that can be used to add a foreign key constraint to a table. +// The length of cols and refCols must be the same as they refer to the primary and referential columns. +// The optional "options" parameters will be appended to the SQL statement. They can be used to +// specify options such as "ON DELETE CASCADE". +func (b *BaseBuilder) AddForeignKey(table, name string, cols, refCols []string, refTable string, options ...string) *Query { + sql := fmt.Sprintf("ALTER TABLE %v ADD CONSTRAINT %v FOREIGN KEY (%v) REFERENCES %v (%v)", + b.db.QuoteTableName(table), + b.db.QuoteColumnName(name), + b.quoteColumns(cols), + b.db.QuoteTableName(refTable), + b.quoteColumns(refCols)) + for _, opt := range options { + sql += " " + opt + } + return b.NewQuery(sql) +} + +// DropForeignKey creates a Query that can be used to remove the named foreign key constraint from a table. +func (b *BaseBuilder) DropForeignKey(table, name string) *Query { + sql := fmt.Sprintf("ALTER TABLE %v DROP CONSTRAINT %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(name)) + return b.NewQuery(sql) +} + +// CreateIndex creates a Query that can be used to create an index for a table. +func (b *BaseBuilder) CreateIndex(table, name string, cols ...string) *Query { + sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v)", + b.db.QuoteColumnName(name), + b.db.QuoteTableName(table), + b.quoteColumns(cols)) + return b.NewQuery(sql) +} + +// CreateUniqueIndex creates a Query that can be used to create a unique index for a table. +func (b *BaseBuilder) CreateUniqueIndex(table, name string, cols ...string) *Query { + sql := fmt.Sprintf("CREATE UNIQUE INDEX %v ON %v (%v)", + b.db.QuoteColumnName(name), + b.db.QuoteTableName(table), + b.quoteColumns(cols)) + return b.NewQuery(sql) +} + +// DropIndex creates a Query that can be used to remove the named index from a table. +func (b *BaseBuilder) DropIndex(table, name string) *Query { + sql := fmt.Sprintf("DROP INDEX %v ON %v", b.db.QuoteColumnName(name), b.db.QuoteTableName(table)) + return b.NewQuery(sql) +} + +// quoteColumns quotes a list of columns and concatenates them with commas. +func (b *BaseBuilder) quoteColumns(cols []string) string { + s := "" + for i, col := range cols { + if i == 0 { + s = b.db.QuoteColumnName(col) + } else { + s += ", " + b.db.QuoteColumnName(col) + } + } + return s +} diff --git a/builder_mssql.go b/builder_mssql.go new file mode 100644 index 0000000..ccb6e59 --- /dev/null +++ b/builder_mssql.go @@ -0,0 +1,115 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "fmt" + "strings" +) + +// MssqlBuilder is the builder for SQL Server databases. +type MssqlBuilder struct { + *BaseBuilder + qb *MssqlQueryBuilder +} + +var _ Builder = &MssqlBuilder{} + +// MssqlQueryBuilder is the query builder for SQL Server databases. +type MssqlQueryBuilder struct { + *BaseQueryBuilder +} + +// NewMssqlBuilder creates a new MssqlBuilder instance. +func NewMssqlBuilder(db *DB, executor Executor) Builder { + return &MssqlBuilder{ + NewBaseBuilder(db, executor), + &MssqlQueryBuilder{NewBaseQueryBuilder(db)}, + } +} + +// QueryBuilder returns the query builder supporting the current DB. +func (b *MssqlBuilder) QueryBuilder() QueryBuilder { + return b.qb +} + +// Select returns a new SelectQuery object that can be used to build a SELECT statement. +// The parameters to this method should be the list column names to be selected. +// A column name may have an optional alias name. For example, Select("id", "my_name AS name"). +func (b *MssqlBuilder) Select(cols ...string) *SelectQuery { + return NewSelectQuery(b, b.db).Select(cols...) +} + +// Model returns a new ModelQuery object that can be used to perform model-based DB operations. +// The model passed to this method should be a pointer to a model struct. +func (b *MssqlBuilder) Model(model interface{}) *ModelQuery { + return NewModelQuery(model, b.db.FieldMapper, b.db, b) +} + +// QuoteSimpleTableName quotes a simple table name. +// A simple table name does not contain any schema prefix. +func (b *MssqlBuilder) QuoteSimpleTableName(s string) string { + if strings.Contains(s, `[`) { + return s + } + return `[` + s + `]` +} + +// QuoteSimpleColumnName quotes a simple column name. +// A simple column name does not contain any table prefix. +func (b *MssqlBuilder) QuoteSimpleColumnName(s string) string { + if strings.Contains(s, `[`) || s == "*" { + return s + } + return `[` + s + `]` +} + +// RenameTable creates a Query that can be used to rename a table. +func (b *MssqlBuilder) RenameTable(oldName, newName string) *Query { + sql := fmt.Sprintf("sp_name '%v', '%v'", oldName, newName) + return b.NewQuery(sql) +} + +// RenameColumn creates a Query that can be used to rename a column in a table. +func (b *MssqlBuilder) RenameColumn(table, oldName, newName string) *Query { + sql := fmt.Sprintf("sp_name '%v.%v', '%v', 'COLUMN'", table, oldName, newName) + return b.NewQuery(sql) +} + +// AlterColumn creates a Query that can be used to change the definition of a table column. +func (b *MssqlBuilder) AlterColumn(table, col, typ string) *Query { + col = b.db.QuoteColumnName(col) + sql := fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", b.db.QuoteTableName(table), col, typ) + return b.NewQuery(sql) +} + +// BuildOrderByAndLimit generates the ORDER BY and LIMIT clauses. +func (q *MssqlQueryBuilder) BuildOrderByAndLimit(sql string, cols []string, limit int64, offset int64) string { + orderBy := q.BuildOrderBy(cols) + if limit < 0 && offset < 0 { + if orderBy == "" { + return sql + } + return sql + "\n" + orderBy + } + + // only SQL SERVER 2012 or newer are supported by this method + + if orderBy == "" { + // ORDER BY clause is required when FETCH and OFFSET are in the SQL + orderBy = "ORDER BY (SELECT NULL)" + } + sql += "\n" + orderBy + + // http://technet.microsoft.com/en-us/library/gg699618.aspx + if offset < 0 { + offset = 0 + } + sql += "\n" + fmt.Sprintf("OFFSET %v ROWS", offset) + if limit >= 0 { + sql += "\n" + fmt.Sprintf("FETCH NEXT %v ROWS ONLY", limit) + } + return sql +} diff --git a/builder_mssql_test.go b/builder_mssql_test.go new file mode 100644 index 0000000..b440ec7 --- /dev/null +++ b/builder_mssql_test.go @@ -0,0 +1,73 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMssqlBuilder_QuoteSimpleTableName(t *testing.T) { + b := getMssqlBuilder() + assert.Equal(t, b.QuoteSimpleTableName(`abc`), "[abc]", "t1") + assert.Equal(t, b.QuoteSimpleTableName("[abc]"), "[abc]", "t2") + assert.Equal(t, b.QuoteSimpleTableName(`{{abc}}`), "[{{abc}}]", "t3") + assert.Equal(t, b.QuoteSimpleTableName(`a.bc`), "[a.bc]", "t4") +} + +func TestMssqlBuilder_QuoteSimpleColumnName(t *testing.T) { + b := getMssqlBuilder() + assert.Equal(t, b.QuoteSimpleColumnName(`abc`), "[abc]", "t1") + assert.Equal(t, b.QuoteSimpleColumnName("[abc]"), "[abc]", "t2") + assert.Equal(t, b.QuoteSimpleColumnName(`{{abc}}`), "[{{abc}}]", "t3") + assert.Equal(t, b.QuoteSimpleColumnName(`a.bc`), "[a.bc]", "t4") + assert.Equal(t, b.QuoteSimpleColumnName(`*`), `*`, "t5") +} + +func TestMssqlBuilder_RenameTable(t *testing.T) { + b := getMssqlBuilder() + q := b.RenameTable("users", "user") + assert.Equal(t, q.SQL(), `sp_name 'users', 'user'`, "t1") +} + +func TestMssqlBuilder_RenameColumn(t *testing.T) { + b := getMssqlBuilder() + q := b.RenameColumn("users", "name", "username") + assert.Equal(t, q.SQL(), `sp_name 'users.name', 'username', 'COLUMN'`, "t1") +} + +func TestMssqlBuilder_AlterColumn(t *testing.T) { + b := getMssqlBuilder() + q := b.AlterColumn("users", "name", "int") + assert.Equal(t, q.SQL(), `ALTER TABLE [users] ALTER COLUMN [name] int`, "t1") +} + +func TestMssqlQueryBuilder_BuildOrderByAndLimit(t *testing.T) { + qb := getMssqlBuilder().QueryBuilder() + + sql := qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, 10, 2) + expected := "SELECT *\nORDER BY [name]\nOFFSET 2 ROWS\nFETCH NEXT 10 ROWS ONLY" + assert.Equal(t, sql, expected, "t1") + + sql = qb.BuildOrderByAndLimit("SELECT *", nil, -1, -1) + expected = "SELECT *" + assert.Equal(t, sql, expected, "t2") + + sql = qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, -1, -1) + expected = "SELECT *\nORDER BY [name]" + assert.Equal(t, sql, expected, "t3") + + sql = qb.BuildOrderByAndLimit("SELECT *", nil, 10, -1) + expected = "SELECT *\nORDER BY (SELECT NULL)\nOFFSET 0 ROWS\nFETCH NEXT 10 ROWS ONLY" + assert.Equal(t, sql, expected, "t4") +} + +func getMssqlBuilder() Builder { + db := getDB() + b := NewMssqlBuilder(db, db.sqlDB) + db.Builder = b + return b +} diff --git a/builder_mysql.go b/builder_mysql.go new file mode 100644 index 0000000..174d5f9 --- /dev/null +++ b/builder_mysql.go @@ -0,0 +1,133 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "fmt" + "regexp" + "sort" + "strings" +) + +// MysqlBuilder is the builder for MySQL databases. +type MysqlBuilder struct { + *BaseBuilder + qb *BaseQueryBuilder +} + +var _ Builder = &MysqlBuilder{} + +// NewMysqlBuilder creates a new MysqlBuilder instance. +func NewMysqlBuilder(db *DB, executor Executor) Builder { + return &MysqlBuilder{ + NewBaseBuilder(db, executor), + NewBaseQueryBuilder(db), + } +} + +// QueryBuilder returns the query builder supporting the current DB. +func (b *MysqlBuilder) QueryBuilder() QueryBuilder { + return b.qb +} + +// Select returns a new SelectQuery object that can be used to build a SELECT statement. +// The parameters to this method should be the list column names to be selected. +// A column name may have an optional alias name. For example, Select("id", "my_name AS name"). +func (b *MysqlBuilder) Select(cols ...string) *SelectQuery { + return NewSelectQuery(b, b.db).Select(cols...) +} + +// Model returns a new ModelQuery object that can be used to perform model-based DB operations. +// The model passed to this method should be a pointer to a model struct. +func (b *MysqlBuilder) Model(model interface{}) *ModelQuery { + return NewModelQuery(model, b.db.FieldMapper, b.db, b) +} + +// QuoteSimpleTableName quotes a simple table name. +// A simple table name does not contain any schema prefix. +func (b *MysqlBuilder) QuoteSimpleTableName(s string) string { + if strings.ContainsAny(s, "`") { + return s + } + return "`" + s + "`" +} + +// QuoteSimpleColumnName quotes a simple column name. +// A simple column name does not contain any table prefix. +func (b *MysqlBuilder) QuoteSimpleColumnName(s string) string { + if strings.Contains(s, "`") || s == "*" { + return s + } + return "`" + s + "`" +} + +// Upsert creates a Query that represents an UPSERT SQL statement. +// Upsert inserts a row into the table if the primary key or unique index is not found. +// Otherwise it will update the row with the new values. +// The keys of cols are the column names, while the values of cols are the corresponding column +// values to be inserted. +func (b *MysqlBuilder) Upsert(table string, cols Params, constraints ...string) *Query { + q := b.Insert(table, cols) + + names := []string{} + for name := range cols { + names = append(names, name) + } + sort.Strings(names) + + lines := []string{} + for _, name := range names { + value := cols[name] + name = b.db.QuoteColumnName(name) + if e, ok := value.(Expression); ok { + lines = append(lines, name+"="+e.Build(b.db, q.params)) + } else { + lines = append(lines, fmt.Sprintf("%v={:p%v}", name, len(q.params))) + q.params[fmt.Sprintf("p%v", len(q.params))] = value + } + } + + q.sql += " ON DUPLICATE KEY UPDATE " + strings.Join(lines, ", ") + + return q +} + +var mysqlColumnRegexp = regexp.MustCompile("(?m)^\\s*[`\"](.*?)[`\"]\\s+(.*?),?$") + +// RenameColumn creates a Query that can be used to rename a column in a table. +func (b *MysqlBuilder) RenameColumn(table, oldName, newName string) *Query { + qt := b.db.QuoteTableName(table) + sql := fmt.Sprintf("ALTER TABLE %v CHANGE %v %v", qt, b.db.QuoteColumnName(oldName), b.db.QuoteColumnName(newName)) + + var info struct { + SQL string `db:"Create Table"` + } + if err := b.db.NewQuery("SHOW CREATE TABLE " + qt).One(&info); err != nil { + return b.db.NewQuery(sql) + } + + if matches := mysqlColumnRegexp.FindAllStringSubmatch(info.SQL, -1); matches != nil { + for _, match := range matches { + if match[1] == oldName { + sql += " " + match[2] + break + } + } + } + + return b.db.NewQuery(sql) +} + +// DropPrimaryKey creates a Query that can be used to remove the named primary key constraint from a table. +func (b *MysqlBuilder) DropPrimaryKey(table, name string) *Query { + sql := fmt.Sprintf("ALTER TABLE %v DROP PRIMARY KEY", b.db.QuoteTableName(table)) + return b.db.NewQuery(sql) +} + +// DropForeignKey creates a Query that can be used to remove the named foreign key constraint from a table. +func (b *MysqlBuilder) DropForeignKey(table, name string) *Query { + sql := fmt.Sprintf("ALTER TABLE %v DROP FOREIGN KEY %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(name)) + return b.db.NewQuery(sql) +} diff --git a/builder_mysql_test.go b/builder_mysql_test.go new file mode 100644 index 0000000..298d3ff --- /dev/null +++ b/builder_mysql_test.go @@ -0,0 +1,69 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMysqlBuilder_QuoteSimpleTableName(t *testing.T) { + b := getMysqlBuilder() + assert.Equal(t, b.QuoteSimpleTableName(`abc`), "`abc`", "t1") + assert.Equal(t, b.QuoteSimpleTableName("`abc`"), "`abc`", "t2") + assert.Equal(t, b.QuoteSimpleTableName(`{{abc}}`), "`{{abc}}`", "t3") + assert.Equal(t, b.QuoteSimpleTableName(`a.bc`), "`a.bc`", "t4") +} + +func TestMysqlBuilder_QuoteSimpleColumnName(t *testing.T) { + b := getMysqlBuilder() + assert.Equal(t, b.QuoteSimpleColumnName(`abc`), "`abc`", "t1") + assert.Equal(t, b.QuoteSimpleColumnName("`abc`"), "`abc`", "t2") + assert.Equal(t, b.QuoteSimpleColumnName(`{{abc}}`), "`{{abc}}`", "t3") + assert.Equal(t, b.QuoteSimpleColumnName(`a.bc`), "`a.bc`", "t4") + assert.Equal(t, b.QuoteSimpleColumnName(`*`), `*`, "t5") +} + +func TestMysqlBuilder_Upsert(t *testing.T) { + getPreparedDB() + b := getMysqlBuilder() + q := b.Upsert("users", Params{ + "name": "James", + "age": 30, + }) + assert.Equal(t, q.SQL(), "INSERT INTO `users` (`age`, `name`) VALUES ({:p0}, {:p1}) ON DUPLICATE KEY UPDATE `age`={:p2}, `name`={:p3}", "t1") + assert.Equal(t, q.Params()["p0"], 30, "t2") + assert.Equal(t, q.Params()["p1"], "James", "t3") + assert.Equal(t, q.Params()["p2"], 30, "t2") + assert.Equal(t, q.Params()["p3"], "James", "t3") +} + +func TestMysqlBuilder_RenameColumn(t *testing.T) { + b := getMysqlBuilder() + q := b.RenameColumn("users", "name", "username") + assert.Equal(t, q.SQL(), "ALTER TABLE `users` CHANGE `name` `username`") + q = b.RenameColumn("customer", "email", "e") + assert.Equal(t, q.SQL(), "ALTER TABLE `customer` CHANGE `email` `e` varchar(128) NOT NULL") +} + +func TestMysqlBuilder_DropPrimaryKey(t *testing.T) { + b := getMysqlBuilder() + q := b.DropPrimaryKey("users", "pk") + assert.Equal(t, q.SQL(), "ALTER TABLE `users` DROP PRIMARY KEY", "t1") +} + +func TestMysqlBuilder_DropForeignKey(t *testing.T) { + b := getMysqlBuilder() + q := b.DropForeignKey("users", "fk") + assert.Equal(t, q.SQL(), "ALTER TABLE `users` DROP FOREIGN KEY `fk`", "t1") +} + +func getMysqlBuilder() Builder { + db := getDB() + b := NewMysqlBuilder(db, db.sqlDB) + db.Builder = b + return b +} diff --git a/builder_oci.go b/builder_oci.go new file mode 100644 index 0000000..19fa894 --- /dev/null +++ b/builder_oci.go @@ -0,0 +1,98 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "fmt" +) + +// OciBuilder is the builder for Oracle databases. +type OciBuilder struct { + *BaseBuilder + qb *OciQueryBuilder +} + +var _ Builder = &OciBuilder{} + +// OciQueryBuilder is the query builder for Oracle databases. +type OciQueryBuilder struct { + *BaseQueryBuilder +} + +// NewOciBuilder creates a new OciBuilder instance. +func NewOciBuilder(db *DB, executor Executor) Builder { + return &OciBuilder{ + NewBaseBuilder(db, executor), + &OciQueryBuilder{NewBaseQueryBuilder(db)}, + } +} + +// Select returns a new SelectQuery object that can be used to build a SELECT statement. +// The parameters to this method should be the list column names to be selected. +// A column name may have an optional alias name. For example, Select("id", "my_name AS name"). +func (b *OciBuilder) Select(cols ...string) *SelectQuery { + return NewSelectQuery(b, b.db).Select(cols...) +} + +// Model returns a new ModelQuery object that can be used to perform model-based DB operations. +// The model passed to this method should be a pointer to a model struct. +func (b *OciBuilder) Model(model interface{}) *ModelQuery { + return NewModelQuery(model, b.db.FieldMapper, b.db, b) +} + +// GeneratePlaceholder generates an anonymous parameter placeholder with the given parameter ID. +func (b *OciBuilder) GeneratePlaceholder(i int) string { + return fmt.Sprintf(":p%v", i) +} + +// QueryBuilder returns the query builder supporting the current DB. +func (b *OciBuilder) QueryBuilder() QueryBuilder { + return b.qb +} + +// DropIndex creates a Query that can be used to remove the named index from a table. +func (b *OciBuilder) DropIndex(table, name string) *Query { + sql := fmt.Sprintf("DROP INDEX %v", b.db.QuoteColumnName(name)) + return b.NewQuery(sql) +} + +// RenameTable creates a Query that can be used to rename a table. +func (b *OciBuilder) RenameTable(oldName, newName string) *Query { + sql := fmt.Sprintf("ALTER TABLE %v RENAME TO %v", b.db.QuoteTableName(oldName), b.db.QuoteTableName(newName)) + return b.NewQuery(sql) +} + +// AlterColumn creates a Query that can be used to change the definition of a table column. +func (b *OciBuilder) AlterColumn(table, col, typ string) *Query { + col = b.db.QuoteColumnName(col) + sql := fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", b.db.QuoteTableName(table), col, typ) + return b.NewQuery(sql) +} + +// BuildOrderByAndLimit generates the ORDER BY and LIMIT clauses. +func (q *OciQueryBuilder) BuildOrderByAndLimit(sql string, cols []string, limit int64, offset int64) string { + if orderBy := q.BuildOrderBy(cols); orderBy != "" { + sql += "\n" + orderBy + } + + c := "" + if offset > 0 { + c = fmt.Sprintf("rowNumId > %v", offset) + } + if limit >= 0 { + if c != "" { + c += " AND " + } + c += fmt.Sprintf("rowNum <= %v", limit) + } + + if c == "" { + return sql + } + + return `WITH USER_SQL AS (` + sql + `), + PAGINATION AS (SELECT USER_SQL.*, rownum as rowNumId FROM USER_SQL) +SELECT * FROM PAGINATION WHERE ` + c +} diff --git a/builder_oci_test.go b/builder_oci_test.go new file mode 100644 index 0000000..e112fc1 --- /dev/null +++ b/builder_oci_test.go @@ -0,0 +1,56 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestOciBuilder_DropIndex(t *testing.T) { + b := getOciBuilder() + q := b.DropIndex("users", "idx") + assert.Equal(t, q.SQL(), `DROP INDEX "idx"`, "t1") +} + +func TestOciBuilder_RenameTable(t *testing.T) { + b := getOciBuilder() + q := b.RenameTable("users", "user") + assert.Equal(t, q.SQL(), `ALTER TABLE "users" RENAME TO "user"`, "t1") +} + +func TestOciBuilder_AlterColumn(t *testing.T) { + b := getOciBuilder() + q := b.AlterColumn("users", "name", "int") + assert.Equal(t, q.SQL(), `ALTER TABLE "users" MODIFY "name" int`, "t1") +} + +func TestOciQueryBuilder_BuildOrderByAndLimit(t *testing.T) { + qb := getOciBuilder().QueryBuilder() + + sql := qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, 10, 2) + expected := "WITH USER_SQL AS (SELECT *\nORDER BY \"name\"),\n\tPAGINATION AS (SELECT USER_SQL.*, rownum as rowNumId FROM USER_SQL)\nSELECT * FROM PAGINATION WHERE rowNumId > 2 AND rowNum <= 10" + assert.Equal(t, sql, expected, "t1") + + sql = qb.BuildOrderByAndLimit("SELECT *", nil, -1, -1) + expected = "SELECT *" + assert.Equal(t, sql, expected, "t2") + + sql = qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, -1, -1) + expected = "SELECT *\nORDER BY \"name\"" + assert.Equal(t, sql, expected, "t3") + + sql = qb.BuildOrderByAndLimit("SELECT *", nil, 10, -1) + expected = "WITH USER_SQL AS (SELECT *),\n\tPAGINATION AS (SELECT USER_SQL.*, rownum as rowNumId FROM USER_SQL)\nSELECT * FROM PAGINATION WHERE rowNum <= 10" + assert.Equal(t, sql, expected, "t4") +} + +func getOciBuilder() Builder { + db := getDB() + b := NewOciBuilder(db, db.sqlDB) + db.Builder = b + return b +} diff --git a/builder_pgsql.go b/builder_pgsql.go new file mode 100644 index 0000000..44190bc --- /dev/null +++ b/builder_pgsql.go @@ -0,0 +1,105 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "fmt" + "sort" + "strings" +) + +// PgsqlBuilder is the builder for PostgreSQL databases. +type PgsqlBuilder struct { + *BaseBuilder + qb *BaseQueryBuilder +} + +var _ Builder = &PgsqlBuilder{} + +// NewPgsqlBuilder creates a new PgsqlBuilder instance. +func NewPgsqlBuilder(db *DB, executor Executor) Builder { + return &PgsqlBuilder{ + NewBaseBuilder(db, executor), + NewBaseQueryBuilder(db), + } +} + +// Select returns a new SelectQuery object that can be used to build a SELECT statement. +// The parameters to this method should be the list column names to be selected. +// A column name may have an optional alias name. For example, Select("id", "my_name AS name"). +func (b *PgsqlBuilder) Select(cols ...string) *SelectQuery { + return NewSelectQuery(b, b.db).Select(cols...) +} + +// Model returns a new ModelQuery object that can be used to perform model-based DB operations. +// The model passed to this method should be a pointer to a model struct. +func (b *PgsqlBuilder) Model(model interface{}) *ModelQuery { + return NewModelQuery(model, b.db.FieldMapper, b.db, b) +} + +// GeneratePlaceholder generates an anonymous parameter placeholder with the given parameter ID. +func (b *PgsqlBuilder) GeneratePlaceholder(i int) string { + return fmt.Sprintf("$%v", i) +} + +// QueryBuilder returns the query builder supporting the current DB. +func (b *PgsqlBuilder) QueryBuilder() QueryBuilder { + return b.qb +} + +// Upsert creates a Query that represents an UPSERT SQL statement. +// Upsert inserts a row into the table if the primary key or unique index is not found. +// Otherwise it will update the row with the new values. +// The keys of cols are the column names, while the values of cols are the corresponding column +// values to be inserted. +func (b *PgsqlBuilder) Upsert(table string, cols Params, constraints ...string) *Query { + q := b.Insert(table, cols) + + names := []string{} + for name := range cols { + names = append(names, name) + } + sort.Strings(names) + + lines := []string{} + for _, name := range names { + value := cols[name] + name = b.db.QuoteColumnName(name) + if e, ok := value.(Expression); ok { + lines = append(lines, name+"="+e.Build(b.db, q.params)) + } else { + lines = append(lines, fmt.Sprintf("%v={:p%v}", name, len(q.params))) + q.params[fmt.Sprintf("p%v", len(q.params))] = value + } + } + + if len(constraints) > 0 { + c := b.quoteColumns(constraints) + q.sql += " ON CONFLICT (" + c + ") DO UPDATE SET " + strings.Join(lines, ", ") + } else { + q.sql += " ON CONFLICT DO UPDATE SET " + strings.Join(lines, ", ") + } + + return b.NewQuery(q.sql).Bind(q.params) +} + +// DropIndex creates a Query that can be used to remove the named index from a table. +func (b *PgsqlBuilder) DropIndex(table, name string) *Query { + sql := fmt.Sprintf("DROP INDEX %v", b.db.QuoteColumnName(name)) + return b.NewQuery(sql) +} + +// RenameTable creates a Query that can be used to rename a table. +func (b *PgsqlBuilder) RenameTable(oldName, newName string) *Query { + sql := fmt.Sprintf("ALTER TABLE %v RENAME TO %v", b.db.QuoteTableName(oldName), b.db.QuoteTableName(newName)) + return b.NewQuery(sql) +} + +// AlterColumn creates a Query that can be used to change the definition of a table column. +func (b *PgsqlBuilder) AlterColumn(table, col, typ string) *Query { + col = b.db.QuoteColumnName(col) + sql := fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", b.db.QuoteTableName(table), col, typ) + return b.NewQuery(sql) +} diff --git a/builder_pgsql_test.go b/builder_pgsql_test.go new file mode 100644 index 0000000..6f46152 --- /dev/null +++ b/builder_pgsql_test.go @@ -0,0 +1,49 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPgsqlBuilder_Upsert(t *testing.T) { + b := getPgsqlBuilder() + q := b.Upsert("users", Params{ + "name": "James", + "age": 30, + }, "id") + assert.Equal(t, q.sql, `INSERT INTO "users" ("age", "name") VALUES ({:p0}, {:p1}) ON CONFLICT ("id") DO UPDATE SET "age"={:p2}, "name"={:p3}`, "t1") + assert.Equal(t, q.rawSQL, `INSERT INTO "users" ("age", "name") VALUES ($1, $2) ON CONFLICT ("id") DO UPDATE SET "age"=$3, "name"=$4`, "t2") + assert.Equal(t, q.Params()["p0"], 30, "t3") + assert.Equal(t, q.Params()["p1"], "James", "t4") + assert.Equal(t, q.Params()["p2"], 30, "t5") + assert.Equal(t, q.Params()["p3"], "James", "t6") +} +func TestPgsqlBuilder_DropIndex(t *testing.T) { + b := getPgsqlBuilder() + q := b.DropIndex("users", "idx") + assert.Equal(t, q.SQL(), `DROP INDEX "idx"`, "t1") +} + +func TestPgsqlBuilder_RenameTable(t *testing.T) { + b := getPgsqlBuilder() + q := b.RenameTable("users", "user") + assert.Equal(t, q.SQL(), `ALTER TABLE "users" RENAME TO "user"`, "t1") +} + +func TestPgsqlBuilder_AlterColumn(t *testing.T) { + b := getPgsqlBuilder() + q := b.AlterColumn("users", "name", "int") + assert.Equal(t, q.SQL(), `ALTER TABLE "users" ALTER COLUMN "name" TYPE int`, "t1") +} + +func getPgsqlBuilder() Builder { + db := getDB() + b := NewPgsqlBuilder(db, db.sqlDB) + db.Builder = b + return b +} diff --git a/builder_sqlite.go b/builder_sqlite.go new file mode 100644 index 0000000..b69b3f3 --- /dev/null +++ b/builder_sqlite.go @@ -0,0 +1,120 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "errors" + "fmt" + "strings" +) + +// SqliteBuilder is the builder for SQLite databases. +type SqliteBuilder struct { + *BaseBuilder + qb *BaseQueryBuilder +} + +var _ Builder = &SqliteBuilder{} + +// NewSqliteBuilder creates a new SqliteBuilder instance. +func NewSqliteBuilder(db *DB, executor Executor) Builder { + return &SqliteBuilder{ + NewBaseBuilder(db, executor), + NewBaseQueryBuilder(db), + } +} + +// QueryBuilder returns the query builder supporting the current DB. +func (b *SqliteBuilder) QueryBuilder() QueryBuilder { + return b.qb +} + +// Select returns a new SelectQuery object that can be used to build a SELECT statement. +// The parameters to this method should be the list column names to be selected. +// A column name may have an optional alias name. For example, Select("id", "my_name AS name"). +func (b *SqliteBuilder) Select(cols ...string) *SelectQuery { + return NewSelectQuery(b, b.db).Select(cols...) +} + +// Model returns a new ModelQuery object that can be used to perform model-based DB operations. +// The model passed to this method should be a pointer to a model struct. +func (b *SqliteBuilder) Model(model interface{}) *ModelQuery { + return NewModelQuery(model, b.db.FieldMapper, b.db, b) +} + +// QuoteSimpleTableName quotes a simple table name. +// A simple table name does not contain any schema prefix. +func (b *SqliteBuilder) QuoteSimpleTableName(s string) string { + if strings.ContainsAny(s, "`") { + return s + } + return "`" + s + "`" +} + +// QuoteSimpleColumnName quotes a simple column name. +// A simple column name does not contain any table prefix. +func (b *SqliteBuilder) QuoteSimpleColumnName(s string) string { + if strings.Contains(s, "`") || s == "*" { + return s + } + return "`" + s + "`" +} + +// DropIndex creates a Query that can be used to remove the named index from a table. +func (b *SqliteBuilder) DropIndex(table, name string) *Query { + sql := fmt.Sprintf("DROP INDEX %v", b.db.QuoteColumnName(name)) + return b.NewQuery(sql) +} + +// TruncateTable creates a Query that can be used to truncate a table. +func (b *SqliteBuilder) TruncateTable(table string) *Query { + sql := "DELETE FROM " + b.db.QuoteTableName(table) + return b.NewQuery(sql) +} + +// RenameTable creates a Query that can be used to rename a table. +func (b *SqliteBuilder) RenameTable(oldName, newName string) *Query { + sql := fmt.Sprintf("ALTER TABLE %v RENAME TO %v", b.db.QuoteTableName(oldName), b.db.QuoteTableName(newName)) + return b.NewQuery(sql) +} + +// AlterColumn creates a Query that can be used to change the definition of a table column. +func (b *SqliteBuilder) AlterColumn(table, col, typ string) *Query { + q := b.NewQuery("") + q.LastError = errors.New("SQLite does not support altering column") + return q +} + +// AddPrimaryKey creates a Query that can be used to specify primary key(s) for a table. +// The "name" parameter specifies the name of the primary key constraint. +func (b *SqliteBuilder) AddPrimaryKey(table, name string, cols ...string) *Query { + q := b.NewQuery("") + q.LastError = errors.New("SQLite does not support adding primary key") + return q +} + +// DropPrimaryKey creates a Query that can be used to remove the named primary key constraint from a table. +func (b *SqliteBuilder) DropPrimaryKey(table, name string) *Query { + q := b.NewQuery("") + q.LastError = errors.New("SQLite does not support dropping primary key") + return q +} + +// AddForeignKey creates a Query that can be used to add a foreign key constraint to a table. +// The length of cols and refCols must be the same as they refer to the primary and referential columns. +// The optional "options" parameters will be appended to the SQL statement. They can be used to +// specify options such as "ON DELETE CASCADE". +func (b *SqliteBuilder) AddForeignKey(table, name string, cols, refCols []string, refTable string, options ...string) *Query { + q := b.NewQuery("") + q.LastError = errors.New("SQLite does not support adding foreign keys") + return q +} + +// DropForeignKey creates a Query that can be used to remove the named foreign key constraint from a table. +func (b *SqliteBuilder) DropForeignKey(table, name string) *Query { + q := b.NewQuery("") + q.LastError = errors.New("SQLite does not support dropping foreign keys") + return q +} diff --git a/builder_sqlite_test.go b/builder_sqlite_test.go new file mode 100644 index 0000000..e8f3202 --- /dev/null +++ b/builder_sqlite_test.go @@ -0,0 +1,83 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSqliteBuilder_QuoteSimpleTableName(t *testing.T) { + b := getSqliteBuilder() + assert.Equal(t, b.QuoteSimpleTableName(`abc`), "`abc`", "t1") + assert.Equal(t, b.QuoteSimpleTableName("`abc`"), "`abc`", "t2") + assert.Equal(t, b.QuoteSimpleTableName(`{{abc}}`), "`{{abc}}`", "t3") + assert.Equal(t, b.QuoteSimpleTableName(`a.bc`), "`a.bc`", "t4") +} + +func TestSqliteBuilder_QuoteSimpleColumnName(t *testing.T) { + b := getSqliteBuilder() + assert.Equal(t, b.QuoteSimpleColumnName(`abc`), "`abc`", "t1") + assert.Equal(t, b.QuoteSimpleColumnName("`abc`"), "`abc`", "t2") + assert.Equal(t, b.QuoteSimpleColumnName(`{{abc}}`), "`{{abc}}`", "t3") + assert.Equal(t, b.QuoteSimpleColumnName(`a.bc`), "`a.bc`", "t4") + assert.Equal(t, b.QuoteSimpleColumnName(`*`), `*`, "t5") +} + +func TestSqliteBuilder_DropIndex(t *testing.T) { + b := getSqliteBuilder() + q := b.DropIndex("users", "idx") + assert.Equal(t, q.SQL(), "DROP INDEX `idx`", "t1") +} + +func TestSqliteBuilder_TruncateTable(t *testing.T) { + b := getSqliteBuilder() + q := b.TruncateTable("users") + assert.Equal(t, q.SQL(), "DELETE FROM `users`", "t1") +} + +func TestSqliteBuilder_RenameTable(t *testing.T) { + b := getSqliteBuilder() + q := b.RenameTable("usersOld", "usersNew") + assert.Equal(t, q.SQL(), "ALTER TABLE `usersOld` RENAME TO `usersNew`", "t1") +} + +func TestSqliteBuilder_AlterColumn(t *testing.T) { + b := getSqliteBuilder() + q := b.AlterColumn("users", "name", "int") + assert.NotEqual(t, q.LastError, nil, "t1") +} + +func TestSqliteBuilder_AddPrimaryKey(t *testing.T) { + b := getSqliteBuilder() + q := b.AddPrimaryKey("users", "pk", "id1", "id2") + assert.NotEqual(t, q.LastError, nil, "t1") +} + +func TestSqliteBuilder_DropPrimaryKey(t *testing.T) { + b := getSqliteBuilder() + q := b.DropPrimaryKey("users", "pk") + assert.NotEqual(t, q.LastError, nil, "t1") +} + +func TestSqliteBuilder_AddForeignKey(t *testing.T) { + b := getSqliteBuilder() + q := b.AddForeignKey("users", "fk", []string{"p1", "p2"}, []string{"f1", "f2"}, "profile", "opt") + assert.NotEqual(t, q.LastError, nil, "t1") +} + +func TestSqliteBuilder_DropForeignKey(t *testing.T) { + b := getSqliteBuilder() + q := b.DropForeignKey("users", "fk") + assert.NotEqual(t, q.LastError, nil, "t1") +} + +func getSqliteBuilder() Builder { + db := getDB() + b := NewSqliteBuilder(db, db.sqlDB) + db.Builder = b + return b +} diff --git a/builder_standard.go b/builder_standard.go new file mode 100644 index 0000000..9af6724 --- /dev/null +++ b/builder_standard.go @@ -0,0 +1,39 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +// StandardBuilder is the builder that is used by DB for an unknown driver. +type StandardBuilder struct { + *BaseBuilder + qb *BaseQueryBuilder +} + +var _ Builder = &StandardBuilder{} + +// NewStandardBuilder creates a new StandardBuilder instance. +func NewStandardBuilder(db *DB, executor Executor) Builder { + return &StandardBuilder{ + NewBaseBuilder(db, executor), + NewBaseQueryBuilder(db), + } +} + +// QueryBuilder returns the query builder supporting the current DB. +func (b *StandardBuilder) QueryBuilder() QueryBuilder { + return b.qb +} + +// Select returns a new SelectQuery object that can be used to build a SELECT statement. +// The parameters to this method should be the list column names to be selected. +// A column name may have an optional alias name. For example, Select("id", "my_name AS name"). +func (b *StandardBuilder) Select(cols ...string) *SelectQuery { + return NewSelectQuery(b, b.db).Select(cols...) +} + +// Model returns a new ModelQuery object that can be used to perform model-based DB operations. +// The model passed to this method should be a pointer to a model struct. +func (b *StandardBuilder) Model(model interface{}) *ModelQuery { + return NewModelQuery(model, b.db.FieldMapper, b.db, b) +} diff --git a/builder_standard_test.go b/builder_standard_test.go new file mode 100644 index 0000000..e0593a6 --- /dev/null +++ b/builder_standard_test.go @@ -0,0 +1,183 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStandardBuilder_Quote(t *testing.T) { + b := getStandardBuilder() + assert.Equal(t, b.Quote(`abc`), `'abc'`, "t1") + assert.Equal(t, b.Quote(`I'm`), `'I''m'`, "t2") + assert.Equal(t, b.Quote(``), `''`, "t3") +} + +func TestStandardBuilder_QuoteSimpleTableName(t *testing.T) { + b := getStandardBuilder() + assert.Equal(t, b.QuoteSimpleTableName(`abc`), `"abc"`, "t1") + assert.Equal(t, b.QuoteSimpleTableName(`"abc"`), `"abc"`, "t2") + assert.Equal(t, b.QuoteSimpleTableName(`{{abc}}`), `"{{abc}}"`, "t3") + assert.Equal(t, b.QuoteSimpleTableName(`a.bc`), `"a.bc"`, "t4") +} + +func TestStandardBuilder_QuoteSimpleColumnName(t *testing.T) { + b := getStandardBuilder() + assert.Equal(t, b.QuoteSimpleColumnName(`abc`), `"abc"`, "t1") + assert.Equal(t, b.QuoteSimpleColumnName(`"abc"`), `"abc"`, "t2") + assert.Equal(t, b.QuoteSimpleColumnName(`{{abc}}`), `"{{abc}}"`, "t3") + assert.Equal(t, b.QuoteSimpleColumnName(`a.bc`), `"a.bc"`, "t4") + assert.Equal(t, b.QuoteSimpleColumnName(`*`), `*`, "t5") +} + +func TestStandardBuilder_Insert(t *testing.T) { + b := getStandardBuilder() + q := b.Insert("users", Params{ + "name": "James", + "age": 30, + }) + assert.Equal(t, q.SQL(), `INSERT INTO "users" ("age", "name") VALUES ({:p0}, {:p1})`, "t1") + assert.Equal(t, q.Params()["p0"], 30, "t2") + assert.Equal(t, q.Params()["p1"], "James", "t3") + + q = b.Insert("users", Params{}) + assert.Equal(t, q.SQL(), `INSERT INTO "users" DEFAULT VALUES`, "t2") +} + +func TestStandardBuilder_Upsert(t *testing.T) { + b := getStandardBuilder() + q := b.Upsert("users", Params{ + "name": "James", + "age": 30, + }) + assert.NotEqual(t, q.LastError, nil, "t1") +} + +func TestStandardBuilder_Update(t *testing.T) { + b := getStandardBuilder() + q := b.Update("users", Params{ + "name": "James", + "age": 30, + }, NewExp("id=10")) + assert.Equal(t, q.SQL(), `UPDATE "users" SET "age"={:p0}, "name"={:p1} WHERE id=10`, "t1") + assert.Equal(t, q.Params()["p0"], 30, "t2") + assert.Equal(t, q.Params()["p1"], "James", "t3") + + q = b.Update("users", Params{ + "name": "James", + "age": 30, + }, nil) + assert.Equal(t, q.SQL(), `UPDATE "users" SET "age"={:p0}, "name"={:p1}`, "t2") +} + +func TestStandardBuilder_Delete(t *testing.T) { + b := getStandardBuilder() + q := b.Delete("users", NewExp("id=10")) + assert.Equal(t, q.SQL(), `DELETE FROM "users" WHERE id=10`, "t1") + q = b.Delete("users", nil) + assert.Equal(t, q.SQL(), `DELETE FROM "users"`, "t2") +} + +func TestStandardBuilder_CreateTable(t *testing.T) { + b := getStandardBuilder() + q := b.CreateTable("users", map[string]string{ + "id": "int primary key", + "name": "varchar(255)", + }, "ON DELETE CASCADE") + assert.Equal(t, q.SQL(), "CREATE TABLE \"users\" (\"id\" int primary key, \"name\" varchar(255)) ON DELETE CASCADE", "t1") +} + +func TestStandardBuilder_RenameTable(t *testing.T) { + b := getStandardBuilder() + q := b.RenameTable("users", "user") + assert.Equal(t, q.SQL(), `RENAME TABLE "users" TO "user"`, "t1") +} + +func TestStandardBuilder_DropTable(t *testing.T) { + b := getStandardBuilder() + q := b.DropTable("users") + assert.Equal(t, q.SQL(), `DROP TABLE "users"`, "t1") +} + +func TestStandardBuilder_TruncateTable(t *testing.T) { + b := getStandardBuilder() + q := b.TruncateTable("users") + assert.Equal(t, q.SQL(), `TRUNCATE TABLE "users"`, "t1") +} + +func TestStandardBuilder_AddColumn(t *testing.T) { + b := getStandardBuilder() + q := b.AddColumn("users", "age", "int") + assert.Equal(t, q.SQL(), `ALTER TABLE "users" ADD "age" int`, "t1") +} + +func TestStandardBuilder_DropColumn(t *testing.T) { + b := getStandardBuilder() + q := b.DropColumn("users", "age") + assert.Equal(t, q.SQL(), `ALTER TABLE "users" DROP COLUMN "age"`, "t1") +} + +func TestStandardBuilder_RenameColumn(t *testing.T) { + b := getStandardBuilder() + q := b.RenameColumn("users", "name", "username") + assert.Equal(t, q.SQL(), `ALTER TABLE "users" RENAME COLUMN "name" TO "username"`, "t1") +} + +func TestStandardBuilder_AlterColumn(t *testing.T) { + b := getStandardBuilder() + q := b.AlterColumn("users", "name", "int") + assert.Equal(t, q.SQL(), `ALTER TABLE "users" CHANGE "name" "name" int`, "t1") +} + +func TestStandardBuilder_AddPrimaryKey(t *testing.T) { + b := getStandardBuilder() + q := b.AddPrimaryKey("users", "pk", "id1", "id2") + assert.Equal(t, q.SQL(), `ALTER TABLE "users" ADD CONSTRAINT "pk" PRIMARY KEY ("id1", "id2")`, "t1") +} + +func TestStandardBuilder_DropPrimaryKey(t *testing.T) { + b := getStandardBuilder() + q := b.DropPrimaryKey("users", "pk") + assert.Equal(t, q.SQL(), `ALTER TABLE "users" DROP CONSTRAINT "pk"`, "t1") +} + +func TestStandardBuilder_AddForeignKey(t *testing.T) { + b := getStandardBuilder() + q := b.AddForeignKey("users", "fk", []string{"p1", "p2"}, []string{"f1", "f2"}, "profile", "opt") + assert.Equal(t, q.SQL(), `ALTER TABLE "users" ADD CONSTRAINT "fk" FOREIGN KEY ("p1", "p2") REFERENCES "profile" ("f1", "f2") opt`, "t1") +} + +func TestStandardBuilder_DropForeignKey(t *testing.T) { + b := getStandardBuilder() + q := b.DropForeignKey("users", "fk") + assert.Equal(t, q.SQL(), `ALTER TABLE "users" DROP CONSTRAINT "fk"`, "t1") +} + +func TestStandardBuilder_CreateIndex(t *testing.T) { + b := getStandardBuilder() + q := b.CreateIndex("users", "idx", "id1", "id2") + assert.Equal(t, q.SQL(), `CREATE INDEX "idx" ON "users" ("id1", "id2")`, "t1") +} + +func TestStandardBuilder_CreateUniqueIndex(t *testing.T) { + b := getStandardBuilder() + q := b.CreateUniqueIndex("users", "idx", "id1", "id2") + assert.Equal(t, q.SQL(), `CREATE UNIQUE INDEX "idx" ON "users" ("id1", "id2")`, "t1") +} + +func TestStandardBuilder_DropIndex(t *testing.T) { + b := getStandardBuilder() + q := b.DropIndex("users", "idx") + assert.Equal(t, q.SQL(), `DROP INDEX "idx" ON "users"`, "t1") +} + +func getStandardBuilder() Builder { + db := getDB() + b := NewStandardBuilder(db, db.sqlDB) + db.Builder = b + return b +} diff --git a/db.go b/db.go new file mode 100644 index 0000000..b46f1d5 --- /dev/null +++ b/db.go @@ -0,0 +1,338 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +// Package dbx provides a set of DB-agnostic and easy-to-use query building methods for relational databases. +package dbx + +import ( + "bytes" + "context" + "database/sql" + "regexp" + "strings" + "time" +) + +type ( + // LogFunc logs a message for each SQL statement being executed. + // This method takes one or multiple parameters. If a single parameter + // is provided, it will be treated as the log message. If multiple parameters + // are provided, they will be passed to fmt.Sprintf() to generate the log message. + LogFunc func(format string, a ...interface{}) + + // PerfFunc is called when a query finishes execution. + // The query execution time is passed to this function so that the DB performance + // can be profiled. The "ns" parameter gives the number of nanoseconds that the + // SQL statement takes to execute, while the "execute" parameter indicates whether + // the SQL statement is executed or queried (usually SELECT statements). + PerfFunc func(ns int64, sql string, execute bool) + + // QueryLogFunc is called each time when performing a SQL query. + // The "t" parameter gives the time that the SQL statement takes to execute, + // while rows and err are the result of the query. + QueryLogFunc func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) + + // ExecLogFunc is called each time when a SQL statement is executed. + // The "t" parameter gives the time that the SQL statement takes to execute, + // while result and err refer to the result of the execution. + ExecLogFunc func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) + + // BuilderFunc creates a Builder instance using the given DB instance and Executor. + BuilderFunc func(*DB, Executor) Builder + + // DB enhances sql.DB by providing a set of DB-agnostic query building methods. + // DB allows easier query building and population of data into Go variables. + DB struct { + Builder + + // FieldMapper maps struct fields to DB columns. Defaults to DefaultFieldMapFunc. + FieldMapper FieldMapFunc + // TableMapper maps structs to table names. Defaults to GetTableName. + TableMapper TableMapFunc + // LogFunc logs the SQL statements being executed. Defaults to nil, meaning no logging. + LogFunc LogFunc + // PerfFunc logs the SQL execution time. Defaults to nil, meaning no performance profiling. + // Deprecated: Please use QueryLogFunc and ExecLogFunc instead. + PerfFunc PerfFunc + // QueryLogFunc is called each time when performing a SQL query that returns data. + QueryLogFunc QueryLogFunc + // ExecLogFunc is called each time when a SQL statement is executed. + ExecLogFunc ExecLogFunc + + sqlDB *sql.DB + driverName string + ctx context.Context + } + + // Errors represents a list of errors. + Errors []error +) + +// BuilderFuncMap lists supported BuilderFunc according to DB driver names. +// You may modify this variable to add the builder support for a new DB driver. +// If a DB driver is not listed here, the StandardBuilder will be used. +var BuilderFuncMap = map[string]BuilderFunc{ + "sqlite": NewSqliteBuilder, + "sqlite3": NewSqliteBuilder, + "mysql": NewMysqlBuilder, + "postgres": NewPgsqlBuilder, + "pgx": NewPgsqlBuilder, + "mssql": NewMssqlBuilder, + "oci8": NewOciBuilder, +} + +// NewFromDB encapsulates an existing database connection. +func NewFromDB(sqlDB *sql.DB, driverName string) *DB { + db := &DB{ + driverName: driverName, + sqlDB: sqlDB, + FieldMapper: DefaultFieldMapFunc, + TableMapper: GetTableName, + } + db.Builder = db.newBuilder(db.sqlDB) + return db +} + +// Open opens a database specified by a driver name and data source name (DSN). +// Note that Open does not check if DSN is specified correctly. It doesn't try to establish a DB connection either. +// Please refer to sql.Open() for more information. +func Open(driverName, dsn string) (*DB, error) { + sqlDB, err := sql.Open(driverName, dsn) + if err != nil { + return nil, err + } + + return NewFromDB(sqlDB, driverName), nil +} + +// MustOpen opens a database and establishes a connection to it. +// Please refer to sql.Open() and sql.Ping() for more information. +func MustOpen(driverName, dsn string) (*DB, error) { + db, err := Open(driverName, dsn) + if err != nil { + return nil, err + } + if err := db.sqlDB.Ping(); err != nil { + db.Close() + return nil, err + } + return db, nil +} + +// Clone makes a shallow copy of DB. +func (db *DB) Clone() *DB { + db2 := &DB{ + driverName: db.driverName, + sqlDB: db.sqlDB, + FieldMapper: db.FieldMapper, + TableMapper: db.TableMapper, + PerfFunc: db.PerfFunc, + LogFunc: db.LogFunc, + QueryLogFunc: db.QueryLogFunc, + ExecLogFunc: db.ExecLogFunc, + } + db2.Builder = db2.newBuilder(db.sqlDB) + return db2 +} + +// WithContext returns a new instance of DB associated with the given context. +func (db *DB) WithContext(ctx context.Context) *DB { + db2 := db.Clone() + db2.ctx = ctx + return db2 +} + +// Context returns the context associated with the DB instance. +// It returns nil if no context is associated. +func (db *DB) Context() context.Context { + return db.ctx +} + +// DB returns the sql.DB instance encapsulated by dbx.DB. +func (db *DB) DB() *sql.DB { + return db.sqlDB +} + +// Close closes the database, releasing any open resources. +// It is rare to Close a DB, as the DB handle is meant to be +// long-lived and shared between many goroutines. +func (db *DB) Close() error { + return db.sqlDB.Close() +} + +// Begin starts a transaction. +func (db *DB) Begin() (*Tx, error) { + var tx *sql.Tx + var err error + if db.ctx != nil { + tx, err = db.sqlDB.BeginTx(db.ctx, nil) + } else { + tx, err = db.sqlDB.Begin() + } + if err != nil { + return nil, err + } + return &Tx{db.newBuilder(tx), tx}, nil +} + +// BeginTx starts a transaction with the given context and transaction options. +func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { + tx, err := db.sqlDB.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &Tx{db.newBuilder(tx), tx}, nil +} + +// Wrap encapsulates an existing transaction. +func (db *DB) Wrap(sqlTx *sql.Tx) *Tx { + return &Tx{db.newBuilder(sqlTx), sqlTx} +} + +// Transactional starts a transaction and executes the given function. +// If the function returns an error, the transaction will be rolled back. +// Otherwise, the transaction will be committed. +func (db *DB) Transactional(f func(*Tx) error) (err error) { + tx, err := db.Begin() + if err != nil { + return err + } + + defer func() { + if p := recover(); p != nil { + tx.Rollback() + panic(p) + } else if err != nil { + if err2 := tx.Rollback(); err2 != nil { + if err2 == sql.ErrTxDone { + return + } + err = Errors{err, err2} + } + } else { + if err = tx.Commit(); err == sql.ErrTxDone { + err = nil + } + } + }() + + err = f(tx) + + return err +} + +// TransactionalContext starts a transaction and executes the given function with the given context and transaction options. +// If the function returns an error, the transaction will be rolled back. +// Otherwise, the transaction will be committed. +func (db *DB) TransactionalContext(ctx context.Context, opts *sql.TxOptions, f func(*Tx) error) (err error) { + tx, err := db.BeginTx(ctx, opts) + if err != nil { + return err + } + + defer func() { + if p := recover(); p != nil { + tx.Rollback() + panic(p) + } else if err != nil { + if err2 := tx.Rollback(); err2 != nil { + if err2 == sql.ErrTxDone { + return + } + err = Errors{err, err2} + } + } else { + if err = tx.Commit(); err == sql.ErrTxDone { + err = nil + } + } + }() + + err = f(tx) + + return err +} + +// DriverName returns the name of the DB driver. +func (db *DB) DriverName() string { + return db.driverName +} + +// QuoteTableName quotes the given table name appropriately. +// If the table name contains DB schema prefix, it will be handled accordingly. +// This method will do nothing if the table name is already quoted or if it contains parenthesis. +func (db *DB) QuoteTableName(s string) string { + if strings.Contains(s, "(") || strings.Contains(s, "{{") { + return s + } + if !strings.Contains(s, ".") { + return db.QuoteSimpleTableName(s) + } + parts := strings.Split(s, ".") + for i, part := range parts { + parts[i] = db.QuoteSimpleTableName(part) + } + return strings.Join(parts, ".") +} + +// QuoteColumnName quotes the given column name appropriately. +// If the table name contains table name prefix, it will be handled accordingly. +// This method will do nothing if the column name is already quoted or if it contains parenthesis. +func (db *DB) QuoteColumnName(s string) string { + if strings.Contains(s, "(") || strings.Contains(s, "{{") || strings.Contains(s, "[[") { + return s + } + prefix := "" + if pos := strings.LastIndex(s, "."); pos != -1 { + prefix = db.QuoteTableName(s[:pos]) + "." + s = s[pos+1:] + } + return prefix + db.QuoteSimpleColumnName(s) +} + +var ( + plRegex = regexp.MustCompile(`\{:\w+\}`) + quoteRegex = regexp.MustCompile(`(\{\{[\w\-\. ]+\}\}|\[\[[\w\-\. ]+\]\])`) +) + +// processSQL replaces the named param placeholders in the given SQL with anonymous ones. +// It also quotes table names and column names found in the SQL if these names are enclosed +// within double square/curly brackets. The method will return the updated SQL and the list of parameter names. +func (db *DB) processSQL(s string) (string, []string) { + var placeholders []string + count := 0 + s = plRegex.ReplaceAllStringFunc(s, func(m string) string { + count++ + placeholders = append(placeholders, m[2:len(m)-1]) + return db.GeneratePlaceholder(count) + }) + s = quoteRegex.ReplaceAllStringFunc(s, func(m string) string { + if m[0] == '{' { + return db.QuoteTableName(m[2 : len(m)-2]) + } + return db.QuoteColumnName(m[2 : len(m)-2]) + }) + return s, placeholders +} + +// newBuilder creates a query builder based on the current driver name. +func (db *DB) newBuilder(executor Executor) Builder { + builderFunc, ok := BuilderFuncMap[db.driverName] + if !ok { + builderFunc = NewStandardBuilder + } + return builderFunc(db, executor) +} + +// Error returns the error string of Errors. +func (errs Errors) Error() string { + var b bytes.Buffer + for i, e := range errs { + if i > 0 { + b.WriteRune('\n') + } + b.WriteString(e.Error()) + } + return b.String() +} diff --git a/db_test.go b/db_test.go new file mode 100644 index 0000000..e8e68cd --- /dev/null +++ b/db_test.go @@ -0,0 +1,380 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "context" + "database/sql" + "errors" + "io/ioutil" + "strings" + "testing" + + // @todo change to sqlite + _ "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/assert" +) + +const ( + TestDSN = "travis:@/pocketbase_dbx_test?parseTime=true" + FixtureFile = "testdata/mysql.sql" +) + +func TestDB_NewFromDB(t *testing.T) { + sqlDB, err := sql.Open("mysql", TestDSN) + if assert.Nil(t, err) { + db := NewFromDB(sqlDB, "mysql") + assert.NotNil(t, db.sqlDB) + assert.NotNil(t, db.FieldMapper) + } +} + +func TestDB_Open(t *testing.T) { + db, err := Open("mysql", TestDSN) + assert.Nil(t, err) + if assert.NotNil(t, db) { + assert.NotNil(t, db.sqlDB) + assert.NotNil(t, db.FieldMapper) + db2 := db.Clone() + assert.NotEqual(t, db, db2) + assert.Equal(t, db.driverName, db2.driverName) + ctx := context.Background() + db3 := db.WithContext(ctx) + assert.Equal(t, ctx, db3.ctx) + assert.Equal(t, ctx, db3.Context()) + assert.NotEqual(t, db, db3) + } + + _, err = Open("xyz", TestDSN) + assert.NotNil(t, err) +} + +func TestDB_MustOpen(t *testing.T) { + _, err := MustOpen("mysql", TestDSN) + assert.Nil(t, err) + + _, err = MustOpen("mysql", "unknown:x@/test") + assert.NotNil(t, err) +} + +func TestDB_Close(t *testing.T) { + db := getDB() + assert.Nil(t, db.Close()) +} + +func TestDB_DriverName(t *testing.T) { + db := getDB() + assert.Equal(t, "mysql", db.DriverName()) +} + +func TestDB_QuoteTableName(t *testing.T) { + tests := []struct { + input, output string + }{ + {"users", "`users`"}, + {"`users`", "`users`"}, + {"(select)", "(select)"}, + {"{{users}}", "{{users}}"}, + {"public.db1.users", "`public`.`db1`.`users`"}, + } + db := getDB() + for _, test := range tests { + result := db.QuoteTableName(test.input) + assert.Equal(t, test.output, result, test.input) + } +} + +func TestDB_QuoteColumnName(t *testing.T) { + tests := []struct { + input, output string + }{ + {"*", "*"}, + {"users.*", "`users`.*"}, + {"name", "`name`"}, + {"`name`", "`name`"}, + {"(select)", "(select)"}, + {"{{name}}", "{{name}}"}, + {"[[name]]", "[[name]]"}, + {"public.db1.users", "`public`.`db1`.`users`"}, + } + db := getDB() + for _, test := range tests { + result := db.QuoteColumnName(test.input) + assert.Equal(t, test.output, result, test.input) + } +} + +func TestDB_ProcessSQL(t *testing.T) { + tests := []struct { + tag string + sql string // original SQL + mysql string // expected MySQL version + postgres string // expected PostgreSQL version + oci8 string // expected OCI version + params []string // expected params + }{ + { + "normal case", + `INSERT INTO employee (id, name, age) VALUES ({:id}, {:name}, {:age})`, + `INSERT INTO employee (id, name, age) VALUES (?, ?, ?)`, + `INSERT INTO employee (id, name, age) VALUES ($1, $2, $3)`, + `INSERT INTO employee (id, name, age) VALUES (:p1, :p2, :p3)`, + []string{"id", "name", "age"}, + }, + { + "the same placeholder is used twice", + `SELECT * FROM employee WHERE first_name LIKE {:keyword} OR last_name LIKE {:keyword}`, + `SELECT * FROM employee WHERE first_name LIKE ? OR last_name LIKE ?`, + `SELECT * FROM employee WHERE first_name LIKE $1 OR last_name LIKE $2`, + `SELECT * FROM employee WHERE first_name LIKE :p1 OR last_name LIKE :p2`, + []string{"keyword", "keyword"}, + }, + { + "non-matching placeholder", + `SELECT * FROM employee WHERE first_name LIKE "{:key?word}" OR last_name LIKE {:keyword}`, + `SELECT * FROM employee WHERE first_name LIKE "{:key?word}" OR last_name LIKE ?`, + `SELECT * FROM employee WHERE first_name LIKE "{:key?word}" OR last_name LIKE $1`, + `SELECT * FROM employee WHERE first_name LIKE "{:key?word}" OR last_name LIKE :p1`, + []string{"keyword"}, + }, + { + "quote table/column names", + `SELECT * FROM {{public.user}} WHERE [[user.id]]=1`, + "SELECT * FROM `public`.`user` WHERE `user`.`id`=1", + "SELECT * FROM \"public\".\"user\" WHERE \"user\".\"id\"=1", + "SELECT * FROM \"public\".\"user\" WHERE \"user\".\"id\"=1", + nil, + }, + } + + mysqlDB := getDB() + mysqlDB.Builder = NewMysqlBuilder(nil, nil) + pgsqlDB := getDB() + pgsqlDB.Builder = NewPgsqlBuilder(nil, nil) + ociDB := getDB() + ociDB.Builder = NewOciBuilder(nil, nil) + + for _, test := range tests { + s1, names := mysqlDB.processSQL(test.sql) + assert.Equal(t, test.mysql, s1, test.tag) + s2, _ := pgsqlDB.processSQL(test.sql) + assert.Equal(t, test.postgres, s2, test.tag) + s3, _ := ociDB.processSQL(test.sql) + assert.Equal(t, test.oci8, s3, test.tag) + + assert.Equal(t, test.params, names, test.tag) + } +} + +func TestDB_Begin(t *testing.T) { + tests := []struct { + makeTx func(db *DB) *Tx + desc string + }{ + { + makeTx: func(db *DB) *Tx { + tx, _ := db.Begin() + return tx + }, + desc: "Begin", + }, + { + makeTx: func(db *DB) *Tx { + sqlTx, _ := db.DB().Begin() + return db.Wrap(sqlTx) + }, + desc: "Wrap", + }, + { + makeTx: func(db *DB) *Tx { + tx, _ := db.BeginTx(context.Background(), nil) + return tx + }, + desc: "BeginTx", + }, + } + + db := getPreparedDB() + + var ( + lastID int + name string + tx *Tx + ) + db.NewQuery("SELECT MAX(id) FROM item").Row(&lastID) + + for _, test := range tests { + t.Log(test.desc) + + tx = test.makeTx(db) + _, err1 := tx.Insert("item", Params{ + "name": "name1", + }).Execute() + _, err2 := tx.Insert("item", Params{ + "name": "name2", + }).Execute() + if err1 == nil && err2 == nil { + tx.Commit() + } else { + t.Errorf("Unexpected TX rollback: %v, %v", err1, err2) + tx.Rollback() + } + + q := db.NewQuery("SELECT name FROM item WHERE id={:id}") + q.Bind(Params{"id": lastID + 1}).Row(&name) + assert.Equal(t, "name1", name) + q.Bind(Params{"id": lastID + 2}).Row(&name) + assert.Equal(t, "name2", name) + + tx = test.makeTx(db) + _, err3 := tx.NewQuery("DELETE FROM item WHERE id=7").Execute() + _, err4 := tx.NewQuery("DELETE FROM items WHERE id=7").Execute() + if err3 == nil && err4 == nil { + t.Error("Unexpected TX commit") + tx.Commit() + } else { + tx.Rollback() + } + } +} + +func TestDB_Transactional(t *testing.T) { + db := getPreparedDB() + + var ( + lastID int + name string + ) + db.NewQuery("SELECT MAX(id) FROM item").Row(&lastID) + + err := db.Transactional(func(tx *Tx) error { + _, err := tx.Insert("item", Params{ + "name": "name1", + }).Execute() + if err != nil { + return err + } + _, err = tx.Insert("item", Params{ + "name": "name2", + }).Execute() + if err != nil { + return err + } + return nil + }) + + if assert.Nil(t, err) { + q := db.NewQuery("SELECT name FROM item WHERE id={:id}") + q.Bind(Params{"id": lastID + 1}).Row(&name) + assert.Equal(t, "name1", name) + q.Bind(Params{"id": lastID + 2}).Row(&name) + assert.Equal(t, "name2", name) + } + + err = db.Transactional(func(tx *Tx) error { + _, err := tx.NewQuery("DELETE FROM item WHERE id=2").Execute() + if err != nil { + return err + } + _, err = tx.NewQuery("DELETE FROM items WHERE id=2").Execute() + if err != nil { + return err + } + return nil + }) + if assert.NotNil(t, err) { + db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name) + assert.Equal(t, "Go in Action", name) + } + + // Rollback called within Transactional and return error + err = db.Transactional(func(tx *Tx) error { + _, err := tx.NewQuery("DELETE FROM item WHERE id=2").Execute() + if err != nil { + return err + } + _, err = tx.NewQuery("DELETE FROM items WHERE id=2").Execute() + if err != nil { + tx.Rollback() + return err + } + return nil + }) + if assert.NotNil(t, err) { + db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name) + assert.Equal(t, "Go in Action", name) + } + + // Rollback called within Transactional without returning error + err = db.Transactional(func(tx *Tx) error { + _, err := tx.NewQuery("DELETE FROM item WHERE id=2").Execute() + if err != nil { + return err + } + _, err = tx.NewQuery("DELETE FROM items WHERE id=2").Execute() + if err != nil { + tx.Rollback() + return nil + } + return nil + }) + if assert.Nil(t, err) { + db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name) + assert.Equal(t, "Go in Action", name) + } +} + +func TestErrors_Error(t *testing.T) { + errs := Errors{} + assert.Equal(t, "", errs.Error()) + errs = Errors{errors.New("a")} + assert.Equal(t, "a", errs.Error()) + errs = Errors{errors.New("a"), errors.New("b")} + assert.Equal(t, "a\nb", errs.Error()) +} + +func getDB() *DB { + db, err := Open("mysql", TestDSN) + if err != nil { + panic(err) + } + return db +} + +func getPreparedDB() *DB { + db := getDB() + s, err := ioutil.ReadFile(FixtureFile) + if err != nil { + panic(err) + } + lines := strings.Split(string(s), ";") + for _, line := range lines { + if strings.TrimSpace(line) == "" { + continue + } + if _, err := db.NewQuery(line).Execute(); err != nil { + panic(err) + } + } + return db +} + +// Naming according to issue 49 ( https://github.com/pocketbase/dbx/issues/49 ) + +type ArtistDAO struct { + nickname string +} + +func (ArtistDAO) TableName() string { + return "artists" +} + +func Test_TableNameWithPrefix(t *testing.T) { + db := NewFromDB(nil, "mysql") + db.TableMapper = func(a interface{}) string { + return "tbl_" + GetTableName(a) + } + assert.Equal(t, "tbl_artists", db.TableMapper(ArtistDAO{})) +} diff --git a/example_test.go b/example_test.go new file mode 100644 index 0000000..c0909cd --- /dev/null +++ b/example_test.go @@ -0,0 +1,281 @@ +package dbx_test + +import ( + "fmt" + + "github.com/pocketbase/dbx" +) + +// This example shows how to populate DB data in different ways. +func Example_dbQueries() { + db, _ := dbx.Open("mysql", "user:pass@/example") + + // create a new query + q := db.NewQuery("SELECT id, name FROM users LIMIT 10") + + // fetch all rows into a struct array + var users []struct { + ID, Name string + } + q.All(&users) + + // fetch a single row into a struct + var user struct { + ID, Name string + } + q.One(&user) + + // fetch a single row into a string map + data := dbx.NullStringMap{} + q.One(data) + + // fetch row by row + rows2, _ := q.Rows() + for rows2.Next() { + rows2.ScanStruct(&user) + // rows.ScanMap(data) + // rows.Scan(&id, &name) + } +} + +// This example shows how to use query builder to build DB queries. +func Example_queryBuilder() { + db, _ := dbx.Open("mysql", "user:pass@/example") + + // build a SELECT query + // SELECT `id`, `name` FROM `users` WHERE `name` LIKE '%Charles%' ORDER BY `id` + q := db.Select("id", "name"). + From("users"). + Where(dbx.Like("name", "Charles")). + OrderBy("id") + + // fetch all rows into a struct array + var users []struct { + ID, Name string + } + q.All(&users) + + // build an INSERT query + // INSERT INTO `users` (name) VALUES ('James') + db.Insert("users", dbx.Params{ + "name": "James", + }).Execute() +} + +// This example shows how to use query builder in transactions. +func Example_transactions() { + db, _ := dbx.Open("mysql", "user:pass@/example") + + db.Transactional(func(tx *dbx.Tx) error { + _, err := tx.Insert("user", dbx.Params{ + "name": "user1", + }).Execute() + if err != nil { + return err + } + _, err = tx.Insert("user", dbx.Params{ + "name": "user2", + }).Execute() + return err + }) +} + +type Customer struct { + ID string + Name string +} + +// This example shows how to do CRUD operations. +func Example_crudOperations() { + db, _ := dbx.Open("mysql", "user:pass@/example") + + var customer Customer + + // read a customer: SELECT * FROM customer WHERE id=100 + db.Select().Model(100, &customer) + + // create a customer: INSERT INTO customer (name) VALUES ('test') + db.Model(&customer).Insert() + + // update a customer: UPDATE customer SET name='test' WHERE id=100 + db.Model(&customer).Update() + + // delete a customer: DELETE FROM customer WHERE id=100 + db.Model(&customer).Delete() +} + +func ExampleSchemaBuilder() { + db, _ := dbx.Open("mysql", "user:pass@/example") + + db.Insert("users", dbx.Params{ + "name": "James", + "age": 30, + }).Execute() +} + +func ExampleRows_ScanMap() { + db, _ := dbx.Open("mysql", "user:pass@/example") + + user := dbx.NullStringMap{} + + sql := "SELECT id, name FROM users LIMIT 10" + rows, _ := db.NewQuery(sql).Rows() + for rows.Next() { + rows.ScanMap(user) + // ... + } +} + +func ExampleRows_ScanStruct() { + db, _ := dbx.Open("mysql", "user:pass@/example") + + var user struct { + ID, Name string + } + + sql := "SELECT id, name FROM users LIMIT 10" + rows, _ := db.NewQuery(sql).Rows() + for rows.Next() { + rows.ScanStruct(&user) + // ... + } +} + +func ExampleQuery_All() { + db, _ := dbx.Open("mysql", "user:pass@/example") + sql := "SELECT id, name FROM users LIMIT 10" + + // fetches data into a slice of struct + var users []struct { + ID, Name string + } + db.NewQuery(sql).All(&users) + + // fetches data into a slice of NullStringMap + var users2 []dbx.NullStringMap + db.NewQuery(sql).All(&users2) + for _, user := range users2 { + fmt.Println(user["id"].String, user["name"].String) + } +} + +func ExampleQuery_One() { + db, _ := dbx.Open("mysql", "user:pass@/example") + sql := "SELECT id, name FROM users LIMIT 10" + + // fetches data into a struct + var user struct { + ID, Name string + } + db.NewQuery(sql).One(&user) + + // fetches data into a NullStringMap + var user2 dbx.NullStringMap + db.NewQuery(sql).All(user2) + fmt.Println(user2["id"].String, user2["name"].String) +} + +func ExampleQuery_Row() { + db, _ := dbx.Open("mysql", "user:pass@/example") + sql := "SELECT id, name FROM users LIMIT 10" + + // fetches data into a struct + var ( + id int + name string + ) + db.NewQuery(sql).Row(&id, &name) +} + +func ExampleQuery_Rows() { + var user struct { + ID, Name string + } + + db, _ := dbx.Open("mysql", "user:pass@/example") + sql := "SELECT id, name FROM users LIMIT 10" + + rows, _ := db.NewQuery(sql).Rows() + for rows.Next() { + rows.ScanStruct(&user) + // ... + } +} + +func ExampleQuery_Bind() { + var user struct { + ID, Name string + } + + db, _ := dbx.Open("mysql", "user:pass@/example") + sql := "SELECT id, name FROM users WHERE age>{:age} AND status={:status}" + + q := db.NewQuery(sql) + q.Bind(dbx.Params{"age": 30, "status": 1}).One(&user) +} + +func ExampleQuery_Prepare() { + var users1, users2, users3 []struct { + ID, Name string + } + + db, _ := dbx.Open("mysql", "user:pass@/example") + sql := "SELECT id, name FROM users WHERE age>{:age} AND status={:status}" + + q := db.NewQuery(sql).Prepare() + defer q.Close() + + q.Bind(dbx.Params{"age": 30, "status": 1}).All(&users1) + q.Bind(dbx.Params{"age": 20, "status": 1}).All(&users2) + q.Bind(dbx.Params{"age": 10, "status": 1}).All(&users3) +} + +func ExampleDB() { + db, _ := dbx.Open("mysql", "user:pass@/example") + + // queries data through a plain SQL + var users []struct { + ID, Name string + } + db.NewQuery("SELECT id, name FROM users WHERE age=30").All(&users) + + // queries data using query builder + db.Select("id", "name").From("users").Where(dbx.HashExp{"age": 30}).All(&users) + + // executes a plain SQL + db.NewQuery("INSERT INTO users (name) SET ({:name})").Bind(dbx.Params{"name": "James"}).Execute() + + // executes a SQL using query builder + db.Insert("users", dbx.Params{"name": "James"}).Execute() +} + +func ExampleDB_Open() { + db, err := dbx.Open("mysql", "user:pass@/example") + if err != nil { + panic(err) + } + + var users []dbx.NullStringMap + if err := db.NewQuery("SELECT * FROM users LIMIT 10").All(&users); err != nil { + panic(err) + } +} + +func ExampleDB_Begin() { + db, _ := dbx.Open("mysql", "user:pass@/example") + + tx, _ := db.Begin() + + _, err1 := tx.Insert("user", dbx.Params{ + "name": "user1", + }).Execute() + _, err2 := tx.Insert("user", dbx.Params{ + "name": "user2", + }).Execute() + + if err1 == nil && err2 == nil { + tx.Commit() + } else { + tx.Rollback() + } +} diff --git a/expression.go b/expression.go new file mode 100644 index 0000000..0356ad1 --- /dev/null +++ b/expression.go @@ -0,0 +1,421 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "fmt" + "sort" + "strings" +) + +// Expression represents a DB expression that can be embedded in a SQL statement. +type Expression interface { + // Build converts an expression into a SQL fragment. + // If the expression contains binding parameters, they will be added to the given Params. + Build(*DB, Params) string +} + +// HashExp represents a hash expression. +// +// A hash expression is a map whose keys are DB column names which need to be filtered according +// to the corresponding values. For example, HashExp{"level": 2, "dept": 10} will generate +// the SQL: "level"=2 AND "dept"=10. +// +// HashExp also handles nil values and slice values. For example, HashExp{"level": []interface{}{1, 2}, "dept": nil} +// will generate: "level" IN (1, 2) AND "dept" IS NULL. +type HashExp map[string]interface{} + +// NewExp generates an expression with the specified SQL fragment and the optional binding parameters. +func NewExp(e string, params ...Params) Expression { + if len(params) > 0 { + return &Exp{e, params[0]} + } + return &Exp{e, nil} +} + +// Not generates a NOT expression which prefixes "NOT" to the specified expression. +func Not(e Expression) Expression { + return &NotExp{e} +} + +// And generates an AND expression which concatenates the given expressions with "AND". +func And(exps ...Expression) Expression { + return &AndOrExp{exps, "AND"} +} + +// Or generates an OR expression which concatenates the given expressions with "OR". +func Or(exps ...Expression) Expression { + return &AndOrExp{exps, "OR"} +} + +// In generates an IN expression for the specified column and the list of allowed values. +// If values is empty, a SQL "0=1" will be generated which represents a false expression. +func In(col string, values ...interface{}) Expression { + return &InExp{col, values, false} +} + +// NotIn generates an NOT IN expression for the specified column and the list of disallowed values. +// If values is empty, an empty string will be returned indicating a true expression. +func NotIn(col string, values ...interface{}) Expression { + return &InExp{col, values, true} +} + +// DefaultLikeEscape specifies the default special character escaping for LIKE expressions +// The strings at 2i positions are the special characters to be escaped while those at 2i+1 positions +// are the corresponding escaped versions. +var DefaultLikeEscape = []string{"\\", "\\\\", "%", "\\%", "_", "\\_"} + +// Like generates a LIKE expression for the specified column and the possible strings that the column should be like. +// If multiple values are present, the column should be like *all* of them. For example, Like("name", "key", "word") +// will generate a SQL expression: "name" LIKE "%key%" AND "name" LIKE "%word%". +// +// By default, each value will be surrounded by "%" to enable partial matching. If a value contains special characters +// such as "%", "\", "_", they will also be properly escaped. +// +// You may call Escape() and/or Match() to change the default behavior. For example, Like("name", "key").Match(false, true) +// generates "name" LIKE "key%". +func Like(col string, values ...string) *LikeExp { + return &LikeExp{ + left: true, + right: true, + col: col, + values: values, + escape: DefaultLikeEscape, + Like: "LIKE", + } +} + +// NotLike generates a NOT LIKE expression. +// For example, NotLike("name", "key", "word") will generate a SQL expression: +// "name" NOT LIKE "%key%" AND "name" NOT LIKE "%word%". Please see Like() for more details. +func NotLike(col string, values ...string) *LikeExp { + return &LikeExp{ + left: true, + right: true, + col: col, + values: values, + escape: DefaultLikeEscape, + Like: "NOT LIKE", + } +} + +// OrLike generates an OR LIKE expression. +// This is similar to Like() except that the column should be like one of the possible values. +// For example, OrLike("name", "key", "word") will generate a SQL expression: +// "name" LIKE "%key%" OR "name" LIKE "%word%". Please see Like() for more details. +func OrLike(col string, values ...string) *LikeExp { + return &LikeExp{ + or: true, + left: true, + right: true, + col: col, + values: values, + escape: DefaultLikeEscape, + Like: "LIKE", + } +} + +// OrNotLike generates an OR NOT LIKE expression. +// For example, OrNotLike("name", "key", "word") will generate a SQL expression: +// "name" NOT LIKE "%key%" OR "name" NOT LIKE "%word%". Please see Like() for more details. +func OrNotLike(col string, values ...string) *LikeExp { + return &LikeExp{ + or: true, + left: true, + right: true, + col: col, + values: values, + escape: DefaultLikeEscape, + Like: "NOT LIKE", + } +} + +// Exists generates an EXISTS expression by prefixing "EXISTS" to the given expression. +func Exists(exp Expression) Expression { + return &ExistsExp{exp, false} +} + +// NotExists generates an EXISTS expression by prefixing "NOT EXISTS" to the given expression. +func NotExists(exp Expression) Expression { + return &ExistsExp{exp, true} +} + +// Between generates a BETWEEN expression. +// For example, Between("age", 10, 30) generates: "age" BETWEEN 10 AND 30 +func Between(col string, from, to interface{}) Expression { + return &BetweenExp{col, from, to, false} +} + +// NotBetween generates a NOT BETWEEN expression. +// For example, NotBetween("age", 10, 30) generates: "age" NOT BETWEEN 10 AND 30 +func NotBetween(col string, from, to interface{}) Expression { + return &BetweenExp{col, from, to, true} +} + +// Exp represents an expression with a SQL fragment and a list of optional binding parameters. +type Exp struct { + e string + params Params +} + +// Build converts an expression into a SQL fragment. +func (e *Exp) Build(db *DB, params Params) string { + if len(e.params) == 0 { + return e.e + } + for k, v := range e.params { + params[k] = v + } + return e.e +} + +// Build converts an expression into a SQL fragment. +func (e HashExp) Build(db *DB, params Params) string { + if len(e) == 0 { + return "" + } + + // ensure the hash exp generates the same SQL for different runs + names := []string{} + for name := range e { + names = append(names, name) + } + sort.Strings(names) + + var parts []string + for _, name := range names { + value := e[name] + switch value.(type) { + case nil: + name = db.QuoteColumnName(name) + parts = append(parts, name+" IS NULL") + case Expression: + if sql := value.(Expression).Build(db, params); sql != "" { + parts = append(parts, "("+sql+")") + } + case []interface{}: + in := In(name, value.([]interface{})...) + if sql := in.Build(db, params); sql != "" { + parts = append(parts, sql) + } + default: + pn := fmt.Sprintf("p%v", len(params)) + name = db.QuoteColumnName(name) + parts = append(parts, name+"={:"+pn+"}") + params[pn] = value + } + } + if len(parts) == 1 { + return parts[0] + } + return strings.Join(parts, " AND ") +} + +// NotExp represents an expression that should prefix "NOT" to a specified expression. +type NotExp struct { + e Expression +} + +// Build converts an expression into a SQL fragment. +func (e *NotExp) Build(db *DB, params Params) string { + if sql := e.e.Build(db, params); sql != "" { + return "NOT (" + sql + ")" + } + return "" +} + +// AndOrExp represents an expression that concatenates multiple expressions using either "AND" or "OR". +type AndOrExp struct { + exps []Expression + op string +} + +// Build converts an expression into a SQL fragment. +func (e *AndOrExp) Build(db *DB, params Params) string { + if len(e.exps) == 0 { + return "" + } + + var parts []string + for _, a := range e.exps { + if a == nil { + continue + } + if sql := a.Build(db, params); sql != "" { + parts = append(parts, sql) + } + } + if len(parts) == 1 { + return parts[0] + } + return "(" + strings.Join(parts, ") "+e.op+" (") + ")" +} + +// InExp represents an "IN" or "NOT IN" expression. +type InExp struct { + col string + values []interface{} + not bool +} + +// Build converts an expression into a SQL fragment. +func (e *InExp) Build(db *DB, params Params) string { + if len(e.values) == 0 { + if e.not { + return "" + } + return "0=1" + } + + var values []string + for _, value := range e.values { + switch value.(type) { + case nil: + values = append(values, "NULL") + case Expression: + sql := value.(Expression).Build(db, params) + values = append(values, sql) + default: + name := fmt.Sprintf("p%v", len(params)) + params[name] = value + values = append(values, "{:"+name+"}") + } + } + col := db.QuoteColumnName(e.col) + if len(values) == 1 { + if e.not { + return col + "<>" + values[0] + } + return col + "=" + values[0] + } + in := "IN" + if e.not { + in = "NOT IN" + } + return fmt.Sprintf("%v %v (%v)", col, in, strings.Join(values, ", ")) +} + +// LikeExp represents a variant of LIKE expressions. +type LikeExp struct { + or bool + left, right bool + col string + values []string + escape []string + + // Like stores the LIKE operator. It can be "LIKE", "NOT LIKE". + // It may also be customized as something like "ILIKE". + Like string +} + +// Escape specifies how a LIKE expression should be escaped. +// Each string at position 2i represents a special character and the string at position 2i+1 is +// the corresponding escaped version. +func (e *LikeExp) Escape(chars ...string) *LikeExp { + e.escape = chars + return e +} + +// Match specifies whether to do wildcard matching on the left and/or right of given strings. +func (e *LikeExp) Match(left, right bool) *LikeExp { + e.left, e.right = left, right + return e +} + +// Build converts an expression into a SQL fragment. +func (e *LikeExp) Build(db *DB, params Params) string { + if len(e.values) == 0 { + return "" + } + + if len(e.escape)%2 != 0 { + panic("LikeExp.Escape must be a slice of even number of strings") + } + + var parts []string + col := db.QuoteColumnName(e.col) + for _, value := range e.values { + name := fmt.Sprintf("p%v", len(params)) + for i := 0; i < len(e.escape); i += 2 { + value = strings.Replace(value, e.escape[i], e.escape[i+1], -1) + } + if e.left { + value = "%" + value + } + if e.right { + value += "%" + } + params[name] = value + parts = append(parts, fmt.Sprintf("%v %v {:%v}", col, e.Like, name)) + } + + if e.or { + return strings.Join(parts, " OR ") + } + return strings.Join(parts, " AND ") +} + +// ExistsExp represents an EXISTS or NOT EXISTS expression. +type ExistsExp struct { + exp Expression + not bool +} + +// Build converts an expression into a SQL fragment. +func (e *ExistsExp) Build(db *DB, params Params) string { + sql := e.exp.Build(db, params) + if sql == "" { + if e.not { + return "" + } + return "0=1" + } + if e.not { + return "NOT EXISTS (" + sql + ")" + } + return "EXISTS (" + sql + ")" +} + +// BetweenExp represents a BETWEEN or a NOT BETWEEN expression. +type BetweenExp struct { + col string + from, to interface{} + not bool +} + +// Build converts an expression into a SQL fragment. +func (e *BetweenExp) Build(db *DB, params Params) string { + between := "BETWEEN" + if e.not { + between = "NOT BETWEEN" + } + name1 := fmt.Sprintf("p%v", len(params)) + name2 := fmt.Sprintf("p%v", len(params)+1) + params[name1] = e.from + params[name2] = e.to + col := db.QuoteColumnName(e.col) + return fmt.Sprintf("%v %v {:%v} AND {:%v}", col, between, name1, name2) +} + +// Enclose surrounds the provided nonempty expression with parenthesis "()". +func Enclose(exp Expression) Expression { + return &EncloseExp{exp} +} + +// EncloseExp represents a parenthesis enclosed expression. +type EncloseExp struct { + exp Expression +} + +// Build converts an expression into a SQL fragment. +func (e *EncloseExp) Build(db *DB, params Params) string { + str := e.exp.Build(db, params) + + if str == "" { + return "" + } + + return "(" + str + ")" +} diff --git a/expression_test.go b/expression_test.go new file mode 100644 index 0000000..41cf26c --- /dev/null +++ b/expression_test.go @@ -0,0 +1,196 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExp(t *testing.T) { + params := Params{"k2": "v2"} + + e1 := NewExp("s1").(*Exp) + assert.Equal(t, e1.Build(nil, params), "s1", "e1.Build()") + assert.Equal(t, len(params), 1, `len(params)@1`) + + e2 := NewExp("s2", Params{"k1": "v1"}).(*Exp) + assert.Equal(t, e2.Build(nil, params), "s2", "e2.Build()") + assert.Equal(t, len(params), 2, `len(params)@2`) +} + +func TestHashExp(t *testing.T) { + e1 := HashExp{} + assert.Equal(t, e1.Build(nil, nil), "", `e1.Build()`) + + e2 := HashExp{ + "k1": nil, + "k2": NewExp("s1", Params{"ka": "va"}), + "k3": 1.1, + "k4": "abc", + "k5": []interface{}{1, 2}, + } + db := getDB() + params := Params{"k0": "v0"} + expected := "`k1` IS NULL AND (s1) AND `k3`={:p2} AND `k4`={:p3} AND `k5` IN ({:p4}, {:p5})" + + assert.Equal(t, e2.Build(db, params), expected, `e2.Build()`) + assert.Equal(t, len(params), 6, `len(params)`) + assert.Equal(t, params["p5"].(int), 2, `params["p5"]`) +} + +func TestNotExp(t *testing.T) { + e1 := Not(NewExp("s1")) + assert.Equal(t, e1.Build(nil, nil), "NOT (s1)", `e1.Build()`) + + e2 := Not(NewExp("")) + assert.Equal(t, e2.Build(nil, nil), "", `e2.Build()`) +} + +func TestAndOrExp(t *testing.T) { + e1 := And(NewExp("s1", Params{"k1": "v1"}), NewExp(""), NewExp("s2", Params{"k2": "v2"})) + params := Params{} + assert.Equal(t, e1.Build(nil, params), "(s1) AND (s2)", `e1.Build()`) + assert.Equal(t, len(params), 2, `len(params)`) + + e2 := Or(NewExp("s1"), NewExp("s2")) + assert.Equal(t, e2.Build(nil, nil), "(s1) OR (s2)", `e2.Build()`) + + e3 := And() + assert.Equal(t, e3.Build(nil, nil), "", `e3.Build()`) + + e4 := And(NewExp("s1")) + assert.Equal(t, e4.Build(nil, nil), "s1", `e4.Build()`) + + e5 := And(NewExp("s1"), nil) + assert.Equal(t, e5.Build(nil, nil), "s1", `e5.Build()`) +} + +func TestInExp(t *testing.T) { + db := getDB() + + e1 := In("age", 1, 2, 3) + params := Params{} + assert.Equal(t, e1.Build(db, params), "`age` IN ({:p0}, {:p1}, {:p2})", `e1.Build()`) + assert.Equal(t, len(params), 3, `len(params)@1`) + + e2 := In("age", 1) + params = Params{} + assert.Equal(t, e2.Build(db, params), "`age`={:p0}", `e2.Build()`) + assert.Equal(t, len(params), 1, `len(params)@2`) + + e3 := NotIn("age", 1, 2, 3) + params = Params{} + assert.Equal(t, e3.Build(db, params), "`age` NOT IN ({:p0}, {:p1}, {:p2})", `e3.Build()`) + assert.Equal(t, len(params), 3, `len(params)@3`) + + e4 := NotIn("age", 1) + params = Params{} + assert.Equal(t, e4.Build(db, params), "`age`<>{:p0}", `e4.Build()`) + assert.Equal(t, len(params), 1, `len(params)@4`) + + e5 := In("age") + assert.Equal(t, e5.Build(db, nil), "0=1", `e5.Build()`) + + e6 := NotIn("age") + assert.Equal(t, e6.Build(db, nil), "", `e6.Build()`) +} + +func TestLikeExp(t *testing.T) { + db := getDB() + + e1 := Like("name", "a", "b", "c") + params := Params{} + assert.Equal(t, e1.Build(db, params), "`name` LIKE {:p0} AND `name` LIKE {:p1} AND `name` LIKE {:p2}", `e1.Build()`) + assert.Equal(t, len(params), 3, `len(params)@1`) + + e2 := Like("name", "a") + params = Params{} + assert.Equal(t, e2.Build(db, params), "`name` LIKE {:p0}", `e2.Build()`) + assert.Equal(t, len(params), 1, `len(params)@2`) + + e3 := Like("name") + assert.Equal(t, e3.Build(db, nil), "", `e3.Build()`) + + e4 := NotLike("name", "a", "b", "c") + params = Params{} + assert.Equal(t, e4.Build(db, params), "`name` NOT LIKE {:p0} AND `name` NOT LIKE {:p1} AND `name` NOT LIKE {:p2}", `e4.Build()`) + assert.Equal(t, len(params), 3, `len(params)@4`) + + e5 := OrLike("name", "a", "b", "c") + params = Params{} + assert.Equal(t, e5.Build(db, params), "`name` LIKE {:p0} OR `name` LIKE {:p1} OR `name` LIKE {:p2}", `e5.Build()`) + assert.Equal(t, len(params), 3, `len(params)@5`) + + e6 := OrNotLike("name", "a", "b", "c") + params = Params{} + assert.Equal(t, e6.Build(db, params), "`name` NOT LIKE {:p0} OR `name` NOT LIKE {:p1} OR `name` NOT LIKE {:p2}", `e6.Build()`) + assert.Equal(t, len(params), 3, `len(params)@6`) + + e7 := Like("name", "a_\\%") + params = Params{} + e7.Build(db, params) + assert.Equal(t, params["p0"], "%a\\_\\\\\\%%", `params["p0"]@1`) + + e8 := Like("name", "a").Match(false, true) + params = Params{} + e8.Build(db, params) + assert.Equal(t, params["p0"], "a%", `params["p0"]@2`) + + e9 := Like("name", "a").Match(true, false) + params = Params{} + e9.Build(db, params) + assert.Equal(t, params["p0"], "%a", `params["p0"]@3`) + + e10 := Like("name", "a").Match(false, false) + params = Params{} + e10.Build(db, params) + assert.Equal(t, params["p0"], "a", `params["p0"]@4`) + + e11 := Like("name", "%a").Match(false, false).Escape() + params = Params{} + e11.Build(db, params) + assert.Equal(t, params["p0"], "%a", `params["p0"]@5`) +} + +func TestBetweenExp(t *testing.T) { + db := getDB() + + e1 := Between("age", 30, 40) + params := Params{} + assert.Equal(t, e1.Build(db, params), "`age` BETWEEN {:p0} AND {:p1}", `e1.Build()`) + assert.Equal(t, len(params), 2, `len(params)@1`) + + e2 := NotBetween("age", 30, 40) + params = Params{} + assert.Equal(t, e2.Build(db, params), "`age` NOT BETWEEN {:p0} AND {:p1}", `e2.Build()`) + assert.Equal(t, len(params), 2, `len(params)@2`) +} + +func TestExistsExp(t *testing.T) { + e1 := Exists(NewExp("s1")) + assert.Equal(t, e1.Build(nil, nil), "EXISTS (s1)", `e1.Build()`) + + e2 := NotExists(NewExp("s1")) + assert.Equal(t, e2.Build(nil, nil), "NOT EXISTS (s1)", `e2.Build()`) + + e3 := Exists(NewExp("")) + assert.Equal(t, e3.Build(nil, nil), "0=1", `e3.Build()`) + + e4 := NotExists(NewExp("")) + assert.Equal(t, e4.Build(nil, nil), "", `e4.Build()`) +} + +func TestEncloseExp(t *testing.T) { + e1 := Enclose(NewExp("")) + assert.Equal(t, e1.Build(nil, nil), "", `e1.Build()`) + + e2 := Enclose(NewExp("s1")) + assert.Equal(t, e2.Build(nil, nil), "(s1)", `e2.Build()`) + + e3 := Enclose(NewExp("(s1)")) + assert.Equal(t, e3.Build(nil, nil), "((s1))", `e3.Build()`) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..e9cb0e5 --- /dev/null +++ b/go.mod @@ -0,0 +1,9 @@ +module github.com/pocketbase/dbx + +go 1.13 + +require ( + github.com/go-sql-driver/mysql v1.4.1 + github.com/stretchr/testify v1.4.0 + google.golang.org/appengine v1.6.5 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..3fe1447 --- /dev/null +++ b/go.sum @@ -0,0 +1,22 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= +github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM= +google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/model_query.go b/model_query.go new file mode 100644 index 0000000..97bb2d9 --- /dev/null +++ b/model_query.go @@ -0,0 +1,174 @@ +package dbx + +import ( + "context" + "errors" + "fmt" + "reflect" +) + +type ( + // TableModel is the interface that should be implemented by models which have unconventional table names. + TableModel interface { + TableName() string + } + + // ModelQuery represents a query associated with a struct model. + ModelQuery struct { + db *DB + ctx context.Context + builder Builder + model *structValue + exclude []string + lastError error + } +) + +var ( + MissingPKError = errors.New("missing primary key declaration") + CompositePKError = errors.New("composite primary key is not supported") +) + +func NewModelQuery(model interface{}, fieldMapFunc FieldMapFunc, db *DB, builder Builder) *ModelQuery { + q := &ModelQuery{ + db: db, + ctx: db.ctx, + builder: builder, + model: newStructValue(model, fieldMapFunc, db.TableMapper), + } + if q.model == nil { + q.lastError = VarTypeError("must be a pointer to a struct representing the model") + } + return q +} + +// Context returns the context associated with the query. +func (q *ModelQuery) Context() context.Context { + return q.ctx +} + +// WithContext associates a context with the query. +func (q *ModelQuery) WithContext(ctx context.Context) *ModelQuery { + q.ctx = ctx + return q +} + +// Exclude excludes the specified struct fields from being inserted/updated into the DB table. +func (q *ModelQuery) Exclude(attrs ...string) *ModelQuery { + q.exclude = attrs + return q +} + +// Insert inserts a row in the table using the struct model associated with this query. +// +// By default, it inserts *all* public fields into the table, including those nil or empty ones. +// You may pass a list of the fields to this method to indicate that only those fields should be inserted. +// You may also call Exclude to exclude some fields from being inserted. +// +// If a model has an empty primary key, it is considered auto-incremental and the corresponding struct +// field will be filled with the generated primary key value after a successful insertion. +func (q *ModelQuery) Insert(attrs ...string) error { + if q.lastError != nil { + return q.lastError + } + cols := q.model.columns(attrs, q.exclude) + pkName := "" + for name, value := range q.model.pk() { + if isAutoInc(value) { + delete(cols, name) + pkName = name + break + } + } + + if pkName == "" { + _, err := q.builder.Insert(q.model.tableName, Params(cols)).WithContext(q.ctx).Execute() + return err + } + + // handle auto-incremental PK + query := q.builder.Insert(q.model.tableName, Params(cols)).WithContext(q.ctx) + pkValue, err := insertAndReturnPK(q.db, query, pkName) + if err != nil { + return err + } + + pkField := indirect(q.model.dbNameMap[pkName].getField(q.model.value)) + switch pkField.Kind() { + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + pkField.SetUint(uint64(pkValue)) + default: + pkField.SetInt(pkValue) + } + + return nil +} + +func insertAndReturnPK(db *DB, query *Query, pkName string) (int64, error) { + if db.DriverName() != "postgres" && db.DriverName() != "pgx" { + result, err := query.Execute() + if err != nil { + return 0, err + } + return result.LastInsertId() + } + + // specially handle postgres (lib/pq) as it doesn't support LastInsertId + returning := fmt.Sprintf(" RETURNING %s", db.QuoteColumnName(pkName)) + query.sql += returning + query.rawSQL += returning + var pkValue int64 + err := query.Row(&pkValue) + return pkValue, err +} + +func isAutoInc(value interface{}) bool { + v := reflect.ValueOf(value) + switch v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Ptr: + return v.IsNil() || isAutoInc(v.Elem()) + case reflect.Invalid: + return true + } + return false +} + +// Update updates a row in the table using the struct model associated with this query. +// The row being updated has the same primary key as specified by the model. +// +// By default, it updates *all* public fields in the table, including those nil or empty ones. +// You may pass a list of the fields to this method to indicate that only those fields should be updated. +// You may also call Exclude to exclude some fields from being updated. +func (q *ModelQuery) Update(attrs ...string) error { + if q.lastError != nil { + return q.lastError + } + pk := q.model.pk() + if len(pk) == 0 { + return MissingPKError + } + + cols := q.model.columns(attrs, q.exclude) + for name := range pk { + delete(cols, name) + } + _, err := q.builder.Update(q.model.tableName, Params(cols), HashExp(pk)).WithContext(q.ctx).Execute() + return err +} + +// Delete deletes a row in the table using the primary key specified by the struct model associated with this query. +func (q *ModelQuery) Delete() error { + if q.lastError != nil { + return q.lastError + } + pk := q.model.pk() + if len(pk) == 0 { + return MissingPKError + } + _, err := q.builder.Delete(q.model.tableName, HashExp(pk)).WithContext(q.ctx).Execute() + return err +} diff --git a/model_query_test.go b/model_query_test.go new file mode 100644 index 0000000..039c592 --- /dev/null +++ b/model_query_test.go @@ -0,0 +1,238 @@ +package dbx + +import ( + "database/sql" + "testing" + + "github.com/stretchr/testify/assert" +) + +type Item struct { + ID2 int + Name string +} + +func TestModelQuery_Insert(t *testing.T) { + db := getPreparedDB() + defer db.Close() + + name := "test" + email := "test@example.com" + + { + // inserting normally + customer := Customer{ + Name: name, + Email: email, + } + err := db.Model(&customer).Insert() + if assert.Nil(t, err) { + assert.Equal(t, 4, customer.ID) + var c Customer + db.Select().From("customer").Where(HashExp{"ID": 4}).One(&c) + assert.Equal(t, name, c.Name) + assert.Equal(t, email, c.Email) + assert.Equal(t, 0, c.Status) + assert.False(t, c.Address.Valid) + } + } + + { + // inserting with pointer-typed fields + customer := CustomerPtr{ + Name: name, + Email: &email, + } + err := db.Model(&customer).Insert() + if assert.Nil(t, err) && assert.NotNil(t, customer.ID) { + assert.Equal(t, 5, *customer.ID) + var c CustomerPtr + db.Select().From("customer").Where(HashExp{"ID": 4}).One(&c) + assert.Equal(t, name, c.Name) + if assert.NotNil(t, c.Email) { + assert.Equal(t, email, *c.Email) + } + if assert.NotNil(t, c.Status) { + assert.Equal(t, 0, *c.Status) + } + assert.Nil(t, c.Address) + } + } + + { + // inserting with null-typed fields + customer := CustomerNull{ + Name: name, + Email: sql.NullString{email, true}, + } + err := db.Model(&customer).Insert() + if assert.Nil(t, err) { + // potential todo: need to check if the field implements sql.Scanner + // assert.Equal(t, int64(6), customer.ID.Int64) + var c CustomerNull + db.Select().From("customer").Where(HashExp{"ID": 4}).One(&c) + assert.Equal(t, name, c.Name) + assert.Equal(t, email, c.Email.String) + if assert.NotNil(t, c.Status) { + assert.Equal(t, int64(0), c.Status.Int64) + } + assert.False(t, c.Address.Valid) + } + } + + { + // inserting with embedded structures + customer := CustomerEmbedded{ + Id: 100, + Email: &email, + InnerCustomer: InnerCustomer{ + Name: &name, + Status: sql.NullInt64{1, true}, + }, + } + err := db.Model(&customer).Insert() + if assert.Nil(t, err) { + assert.Equal(t, 100, customer.Id) + var c CustomerEmbedded + db.Select().From("customer").Where(HashExp{"ID": 100}).One(&c) + assert.Equal(t, name, *c.Name) + assert.Equal(t, email, *c.Email) + if assert.NotNil(t, c.Status) { + assert.Equal(t, int64(1), c.Status.Int64) + } + assert.False(t, c.Address.Valid) + } + } + + { + // inserting with include/exclude fields + customer := Customer{ + Name: name, + Email: email, + Status: 1, + } + err := db.Model(&customer).Exclude("Name").Insert("Name", "Email") + if assert.Nil(t, err) { + assert.Equal(t, 101, customer.ID) + var c Customer + db.Select().From("customer").Where(HashExp{"ID": 101}).One(&c) + assert.Equal(t, "", c.Name) + assert.Equal(t, email, c.Email) + assert.Equal(t, 0, c.Status) + assert.False(t, c.Address.Valid) + } + } + + var a int + assert.NotNil(t, db.Model(&a).Insert()) +} + +func TestModelQuery_Update(t *testing.T) { + db := getPreparedDB() + defer db.Close() + + id := 2 + name := "test" + email := "test@example.com" + { + // updating normally + customer := Customer{ + ID: id, + Name: name, + Email: email, + } + err := db.Model(&customer).Update() + if assert.Nil(t, err) { + var c Customer + db.Select().From("customer").Where(HashExp{"ID": id}).One(&c) + assert.Equal(t, name, c.Name) + assert.Equal(t, email, c.Email) + assert.Equal(t, 0, c.Status) + } + } + + { + // updating without primary keys + item2 := Item{ + Name: name, + } + err := db.Model(&item2).Update() + assert.Equal(t, MissingPKError, err) + } + + { + // updating all fields + customer := CustomerPtr{ + ID: &id, + Name: name, + Email: &email, + } + err := db.Model(&customer).Update() + if assert.Nil(t, err) { + assert.Equal(t, id, *customer.ID) + var c CustomerPtr + db.Select().From("customer").Where(HashExp{"ID": id}).One(&c) + assert.Equal(t, name, c.Name) + if assert.NotNil(t, c.Email) { + assert.Equal(t, email, *c.Email) + } + assert.Nil(t, c.Status) + } + } + + { + // updating selected fields only + id = 3 + customer := CustomerPtr{ + ID: &id, + Name: name, + Email: &email, + } + err := db.Model(&customer).Update("Name", "Email") + if assert.Nil(t, err) { + assert.Equal(t, id, *customer.ID) + var c CustomerPtr + db.Select().From("customer").Where(HashExp{"ID": id}).One(&c) + assert.Equal(t, name, c.Name) + if assert.NotNil(t, c.Email) { + assert.Equal(t, email, *c.Email) + } + if assert.NotNil(t, c.Status) { + assert.Equal(t, 2, *c.Status) + } + } + } + + { + // updating non-struct + var a int + assert.NotNil(t, db.Model(&a).Update()) + } +} + +func TestModelQuery_Delete(t *testing.T) { + db := getPreparedDB() + defer db.Close() + + customer := Customer{ + ID: 2, + } + err := db.Model(&customer).Delete() + if assert.Nil(t, err) { + var m Customer + err := db.Select().From("customer").Where(HashExp{"ID": 2}).One(&m) + assert.NotNil(t, err) + } + + { + // deleting without primary keys + item2 := Item{ + Name: "", + } + err := db.Model(&item2).Delete() + assert.Equal(t, MissingPKError, err) + } + + var a int + assert.NotNil(t, db.Model(&a).Delete()) +} diff --git a/query.go b/query.go new file mode 100644 index 0000000..34d1076 --- /dev/null +++ b/query.go @@ -0,0 +1,384 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "context" + "database/sql" + "database/sql/driver" + "encoding/hex" + "errors" + "fmt" + "strings" + "time" +) + +// ExecHookFunc executes before op allowing custom handling like auto fail/retry. +type ExecHookFunc func(q *Query, op func() error) error + +// OneHookFunc executes right before the query populate the row result from One() call (aka. op). +type OneHookFunc func(q *Query, a interface{}, op func(b interface{}) error) error + +// AllHookFunc executes right before the query populate the row result from All() call (aka. op). +type AllHookFunc func(q *Query, sliceA interface{}, op func(sliceB interface{}) error) error + +// Params represents a list of parameter values to be bound to a SQL statement. +// The map keys are the parameter names while the map values are the corresponding parameter values. +type Params map[string]interface{} + +// Executor prepares, executes, or queries a SQL statement. +type Executor interface { + // Exec executes a SQL statement + Exec(query string, args ...interface{}) (sql.Result, error) + // ExecContext executes a SQL statement with the given context + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + // Query queries a SQL statement + Query(query string, args ...interface{}) (*sql.Rows, error) + // QueryContext queries a SQL statement with the given context + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + // Prepare creates a prepared statement + Prepare(query string) (*sql.Stmt, error) +} + +// Query represents a SQL statement to be executed. +type Query struct { + executor Executor + + sql, rawSQL string + placeholders []string + params Params + + stmt *sql.Stmt + ctx context.Context + + // hooks + execHook ExecHookFunc + oneHook OneHookFunc + allHook AllHookFunc + + // FieldMapper maps struct field names to DB column names. + FieldMapper FieldMapFunc + // LastError contains the last error (if any) of the query. + // LastError is cleared by Execute(), Row(), Rows(), One(), and All(). + LastError error + // LogFunc is used to log the SQL statement being executed. + LogFunc LogFunc + // PerfFunc is used to log the SQL execution time. It is ignored if nil. + // Deprecated: Please use QueryLogFunc and ExecLogFunc instead. + PerfFunc PerfFunc + // QueryLogFunc is called each time when performing a SQL query that returns data. + QueryLogFunc QueryLogFunc + // ExecLogFunc is called each time when a SQL statement is executed. + ExecLogFunc ExecLogFunc +} + +// NewQuery creates a new Query with the given SQL statement. +func NewQuery(db *DB, executor Executor, sql string) *Query { + rawSQL, placeholders := db.processSQL(sql) + return &Query{ + executor: executor, + sql: sql, + rawSQL: rawSQL, + placeholders: placeholders, + params: Params{}, + ctx: db.ctx, + FieldMapper: db.FieldMapper, + LogFunc: db.LogFunc, + PerfFunc: db.PerfFunc, + QueryLogFunc: db.QueryLogFunc, + ExecLogFunc: db.ExecLogFunc, + } +} + +// SQL returns the original SQL used to create the query. +// The actual SQL (RawSQL) being executed is obtained by replacing the named +// parameter placeholders with anonymous ones. +func (q *Query) SQL() string { + return q.sql +} + +// Context returns the context associated with the query. +func (q *Query) Context() context.Context { + return q.ctx +} + +// WithContext associates a context with the query. +func (q *Query) WithContext(ctx context.Context) *Query { + q.ctx = ctx + return q +} + +// WithExecHook associates the provided exec hook function with the query. +// +// It is called for every Query resolver (Execute(), One(), All(), Row(), Column()), +// allowing you to implement auto fail/retry or any other additional handling. +func (q *Query) WithExecHook(fn ExecHookFunc) *Query { + q.execHook = fn + return q +} + +// WithOneHook associates the provided hook function with the query, +// called on q.One(), allowing you to implement custom struct scan based +// on the One() argument and/or result. +func (q *Query) WithOneHook(fn OneHookFunc) *Query { + q.oneHook = fn + return q +} + +// WithOneHook associates the provided hook function with the query, +// called on q.All(), allowing you to implement custom slice scan based +// on the All() argument and/or result. +func (q *Query) WithAllHook(fn AllHookFunc) *Query { + q.allHook = fn + return q +} + +// logSQL returns the SQL statement with parameters being replaced with the actual values. +// The result is only for logging purpose and should not be used to execute. +func (q *Query) logSQL() string { + s := q.sql + for k, v := range q.params { + if valuer, ok := v.(driver.Valuer); ok && valuer != nil { + v, _ = valuer.Value() + } + var sv string + if str, ok := v.(string); ok { + sv = "'" + strings.Replace(str, "'", "''", -1) + "'" + } else if bs, ok := v.([]byte); ok { + sv = "0x" + hex.EncodeToString(bs) + } else { + sv = fmt.Sprintf("%v", v) + } + s = strings.Replace(s, "{:"+k+"}", sv, -1) + } + return s +} + +// Params returns the parameters to be bound to the SQL statement represented by this query. +func (q *Query) Params() Params { + return q.params +} + +// Prepare creates a prepared statement for later queries or executions. +// Close() should be called after finishing all queries. +func (q *Query) Prepare() *Query { + stmt, err := q.executor.Prepare(q.rawSQL) + if err != nil { + q.LastError = err + return q + } + q.stmt = stmt + return q +} + +// Close closes the underlying prepared statement. +// Close does nothing if the query has not been prepared before. +func (q *Query) Close() error { + if q.stmt == nil { + return nil + } + + err := q.stmt.Close() + q.stmt = nil + return err +} + +// Bind sets the parameters that should be bound to the SQL statement. +// The parameter placeholders in the SQL statement are in the format of "{:ParamName}". +func (q *Query) Bind(params Params) *Query { + if len(q.params) == 0 { + q.params = params + } else { + for k, v := range params { + q.params[k] = v + } + } + return q +} + +// Execute executes the SQL statement without retrieving data. +func (q *Query) Execute() (sql.Result, error) { + var result sql.Result + + execErr := q.execWrap(func() error { + var err error + result, err = q.execute() + return err + }) + + return result, execErr +} + +func (q *Query) execute() (result sql.Result, err error) { + err = q.LastError + q.LastError = nil + if err != nil { + return + } + + var params []interface{} + params, err = replacePlaceholders(q.placeholders, q.params) + if err != nil { + return + } + + start := time.Now() + + if q.ctx == nil { + if q.stmt == nil { + result, err = q.executor.Exec(q.rawSQL, params...) + } else { + result, err = q.stmt.Exec(params...) + } + } else { + if q.stmt == nil { + result, err = q.executor.ExecContext(q.ctx, q.rawSQL, params...) + } else { + result, err = q.stmt.ExecContext(q.ctx, params...) + } + } + + if q.ExecLogFunc != nil { + q.ExecLogFunc(q.ctx, time.Now().Sub(start), q.logSQL(), result, err) + } + if q.LogFunc != nil { + q.LogFunc("[%.2fms] Execute SQL: %v", float64(time.Now().Sub(start).Milliseconds()), q.logSQL()) + } + if q.PerfFunc != nil { + q.PerfFunc(time.Now().Sub(start).Nanoseconds(), q.logSQL(), true) + } + return +} + +// One executes the SQL statement and populates the first row of the result into a struct or NullStringMap. +// Refer to Rows.ScanStruct() and Rows.ScanMap() for more details on how to specify +// the variable to be populated. +// Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned. +func (q *Query) One(a interface{}) error { + return q.execWrap(func() error { + rows, err := q.Rows() + if err != nil { + return err + } + + if q.oneHook != nil { + return q.oneHook(q, a, rows.one) + } + + return rows.one(a) + }) +} + +// All executes the SQL statement and populates all the resulting rows into a slice of struct or NullStringMap. +// The slice must be given as a pointer. Each slice element must be either a struct or a NullStringMap. +// Refer to Rows.ScanStruct() and Rows.ScanMap() for more details on how each slice element can be. +// If the query returns no row, the slice will be an empty slice (not nil). +func (q *Query) All(slice interface{}) error { + return q.execWrap(func() error { + rows, err := q.Rows() + if err != nil { + return err + } + + if q.allHook != nil { + return q.allHook(q, slice, rows.all) + } + + return rows.all(slice) + }) +} + +// Row executes the SQL statement and populates the first row of the result into a list of variables. +// Note that the number of the variables should match to that of the columns in the query result. +// Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned. +func (q *Query) Row(a ...interface{}) error { + return q.execWrap(func() error { + rows, err := q.Rows() + if err != nil { + return err + } + return rows.row(a...) + }) +} + +// Column executes the SQL statement and populates the first column of the result into a slice. +// Note that the parameter must be a pointer to a slice. +func (q *Query) Column(a interface{}) error { + return q.execWrap(func() error { + rows, err := q.Rows() + if err != nil { + return err + } + return rows.column(a) + }) +} + +// Rows executes the SQL statement and returns a Rows object to allow retrieving data row by row. +func (q *Query) Rows() (rows *Rows, err error) { + err = q.LastError + q.LastError = nil + if err != nil { + return + } + + var params []interface{} + params, err = replacePlaceholders(q.placeholders, q.params) + if err != nil { + return + } + + start := time.Now() + + var rr *sql.Rows + if q.ctx == nil { + if q.stmt == nil { + rr, err = q.executor.Query(q.rawSQL, params...) + } else { + rr, err = q.stmt.Query(params...) + } + } else { + if q.stmt == nil { + rr, err = q.executor.QueryContext(q.ctx, q.rawSQL, params...) + } else { + rr, err = q.stmt.QueryContext(q.ctx, params...) + } + } + rows = &Rows{rr, q.FieldMapper} + + if q.QueryLogFunc != nil { + q.QueryLogFunc(q.ctx, time.Now().Sub(start), q.logSQL(), rr, err) + } + if q.LogFunc != nil { + q.LogFunc("[%.2fms] Query SQL: %v", float64(time.Now().Sub(start).Milliseconds()), q.logSQL()) + } + if q.PerfFunc != nil { + q.PerfFunc(time.Now().Sub(start).Nanoseconds(), q.logSQL(), false) + } + return +} + +func (q *Query) execWrap(op func() error) error { + if q.execHook != nil { + return q.execHook(q, op) + } + return op() +} + +// replacePlaceholders converts a list of named parameters into a list of anonymous parameters. +func replacePlaceholders(placeholders []string, params Params) ([]interface{}, error) { + if len(placeholders) == 0 { + return nil, nil + } + + var result []interface{} + for _, name := range placeholders { + if value, ok := params[name]; ok { + result = append(result, value) + } else { + return nil, errors.New("Named parameter not found: " + name) + } + } + return result, nil +} diff --git a/query_builder.go b/query_builder.go new file mode 100644 index 0000000..0310a0d --- /dev/null +++ b/query_builder.go @@ -0,0 +1,244 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "bytes" + "fmt" + "regexp" + "strings" +) + +// QueryBuilder builds different clauses for a SELECT SQL statement. +type QueryBuilder interface { + // BuildSelect generates a SELECT clause from the given selected column names. + BuildSelect(cols []string, distinct bool, option string) string + // BuildFrom generates a FROM clause from the given tables. + BuildFrom(tables []string) string + // BuildGroupBy generates a GROUP BY clause from the given group-by columns. + BuildGroupBy(cols []string) string + // BuildJoin generates a JOIN clause from the given join information. + BuildJoin([]JoinInfo, Params) string + // BuildWhere generates a WHERE clause from the given expression. + BuildWhere(Expression, Params) string + // BuildHaving generates a HAVING clause from the given expression. + BuildHaving(Expression, Params) string + // BuildOrderByAndLimit generates the ORDER BY and LIMIT clauses. + BuildOrderByAndLimit(string, []string, int64, int64) string + // BuildUnion generates a UNION clause from the given union information. + BuildUnion([]UnionInfo, Params) string +} + +// BaseQueryBuilder provides a basic implementation of QueryBuilder. +type BaseQueryBuilder struct { + db *DB +} + +var _ QueryBuilder = &BaseQueryBuilder{} + +// NewBaseQueryBuilder creates a new BaseQueryBuilder instance. +func NewBaseQueryBuilder(db *DB) *BaseQueryBuilder { + return &BaseQueryBuilder{db} +} + +// DB returns the DB instance associated with the query builder. +func (q *BaseQueryBuilder) DB() *DB { + return q.db +} + +// the regexp for columns and tables. +var selectRegex = regexp.MustCompile(`(?i:\s+as\s+|\s+)([\w\-_\.]+)$`) + +// BuildSelect generates a SELECT clause from the given selected column names. +func (q *BaseQueryBuilder) BuildSelect(cols []string, distinct bool, option string) string { + var s bytes.Buffer + s.WriteString("SELECT ") + if distinct { + s.WriteString("DISTINCT ") + } + if option != "" { + s.WriteString(option) + s.WriteString(" ") + } + if len(cols) == 0 { + s.WriteString("*") + return s.String() + } + + for i, col := range cols { + if i > 0 { + s.WriteString(", ") + } + matches := selectRegex.FindStringSubmatch(col) + if len(matches) == 0 { + s.WriteString(q.db.QuoteColumnName(col)) + } else { + col := col[:len(col)-len(matches[0])] + alias := matches[1] + s.WriteString(q.db.QuoteColumnName(col) + " AS " + q.db.QuoteSimpleColumnName(alias)) + } + } + + return s.String() +} + +// BuildFrom generates a FROM clause from the given tables. +func (q *BaseQueryBuilder) BuildFrom(tables []string) string { + if len(tables) == 0 { + return "" + } + s := "" + for _, table := range tables { + table = q.quoteTableNameAndAlias(table) + if s == "" { + s = table + } else { + s += ", " + table + } + } + return "FROM " + s +} + +// BuildJoin generates a JOIN clause from the given join information. +func (q *BaseQueryBuilder) BuildJoin(joins []JoinInfo, params Params) string { + if len(joins) == 0 { + return "" + } + parts := []string{} + for _, join := range joins { + sql := join.Join + " " + q.quoteTableNameAndAlias(join.Table) + on := "" + if join.On != nil { + on = join.On.Build(q.db, params) + } + if on != "" { + sql += " ON " + on + } + parts = append(parts, sql) + } + return strings.Join(parts, " ") +} + +// BuildWhere generates a WHERE clause from the given expression. +func (q *BaseQueryBuilder) BuildWhere(e Expression, params Params) string { + if e != nil { + if c := e.Build(q.db, params); c != "" { + return "WHERE " + c + } + } + return "" +} + +// BuildHaving generates a HAVING clause from the given expression. +func (q *BaseQueryBuilder) BuildHaving(e Expression, params Params) string { + if e != nil { + if c := e.Build(q.db, params); c != "" { + return "HAVING " + c + } + } + return "" +} + +// BuildGroupBy generates a GROUP BY clause from the given group-by columns. +func (q *BaseQueryBuilder) BuildGroupBy(cols []string) string { + if len(cols) == 0 { + return "" + } + s := "" + for i, col := range cols { + if i == 0 { + s = q.db.QuoteColumnName(col) + } else { + s += ", " + q.db.QuoteColumnName(col) + } + } + return "GROUP BY " + s +} + +// BuildOrderByAndLimit generates the ORDER BY and LIMIT clauses. +func (q *BaseQueryBuilder) BuildOrderByAndLimit(sql string, cols []string, limit int64, offset int64) string { + if orderBy := q.BuildOrderBy(cols); orderBy != "" { + sql += " " + orderBy + } + if limit := q.BuildLimit(limit, offset); limit != "" { + return sql + " " + limit + } + return sql +} + +// BuildUnion generates a UNION clause from the given union information. +func (q *BaseQueryBuilder) BuildUnion(unions []UnionInfo, params Params) string { + if len(unions) == 0 { + return "" + } + sql := "" + for i, union := range unions { + if i > 0 { + sql += " " + } + for k, v := range union.Query.params { + params[k] = v + } + u := "UNION" + if union.All { + u = "UNION ALL" + } + sql += fmt.Sprintf("%v (%v)", u, union.Query.sql) + } + return sql +} + +var orderRegex = regexp.MustCompile(`\s+((?i)ASC|DESC)$`) + +// BuildOrderBy generates the ORDER BY clause. +func (q *BaseQueryBuilder) BuildOrderBy(cols []string) string { + if len(cols) == 0 { + return "" + } + s := "" + for i, col := range cols { + if i > 0 { + s += ", " + } + matches := orderRegex.FindStringSubmatch(col) + if len(matches) == 0 { + s += q.db.QuoteColumnName(col) + } else { + col := col[:len(col)-len(matches[0])] + dir := matches[1] + s += q.db.QuoteColumnName(col) + " " + dir + } + } + return "ORDER BY " + s +} + +// BuildLimit generates the LIMIT clause. +func (q *BaseQueryBuilder) BuildLimit(limit int64, offset int64) string { + if limit < 0 && offset > 0 { + // most DBMS requires LIMIT when OFFSET is present + limit = 9223372036854775807 // 2^63 - 1 + } + + sql := "" + if limit >= 0 { + sql = fmt.Sprintf("LIMIT %v", limit) + } + if offset <= 0 { + return sql + } + if sql != "" { + sql += " " + } + return sql + fmt.Sprintf("OFFSET %v", offset) +} + +func (q *BaseQueryBuilder) quoteTableNameAndAlias(table string) string { + matches := selectRegex.FindStringSubmatch(table) + if len(matches) == 0 { + return q.db.QuoteTableName(table) + } + table = table[:len(table)-len(matches[0])] + return q.db.QuoteTableName(table) + " " + q.db.QuoteSimpleTableName(matches[1]) +} diff --git a/query_builder_test.go b/query_builder_test.go new file mode 100644 index 0000000..2fde9a7 --- /dev/null +++ b/query_builder_test.go @@ -0,0 +1,228 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestQB_BuildSelect(t *testing.T) { + tests := []struct { + tag string + cols []string + distinct bool + option string + expected string + }{ + {"empty", []string{}, false, "", "SELECT *"}, + {"empty distinct", []string{}, true, "CALC_ROWS", "SELECT DISTINCT CALC_ROWS *"}, + {"multi-columns", []string{"name", "DOB1"}, false, "", "SELECT `name`, `DOB1`"}, + {"aliased columns", []string{"name As Name", "users.last_name", "u.first1 first"}, false, "", "SELECT `name` AS `Name`, `users`.`last_name`, `u`.`first1` AS `first`"}, + } + + db := getDB() + qb := db.QueryBuilder() + for _, test := range tests { + s := qb.BuildSelect(test.cols, test.distinct, test.option) + assert.Equal(t, test.expected, s, test.tag) + } + assert.Equal(t, qb.(*BaseQueryBuilder).DB(), db) +} + +func TestQB_BuildFrom(t *testing.T) { + tests := []struct { + tag string + tables []string + expected string + }{ + {"empty", []string{}, ""}, + {"single table", []string{"users"}, "FROM `users`"}, + {"multiple tables", []string{"users", "posts"}, "FROM `users`, `posts`"}, + {"table alias", []string{"users u", "posts as p"}, "FROM `users` `u`, `posts` `p`"}, + {"table prefix and alias", []string{"pub.users p.u", "posts AS p1"}, "FROM `pub`.`users` `p.u`, `posts` `p1`"}, + } + + qb := getDB().QueryBuilder() + for _, test := range tests { + s := qb.BuildFrom(test.tables) + assert.Equal(t, test.expected, s, test.tag) + } +} + +func TestQB_BuildGroupBy(t *testing.T) { + tests := []struct { + tag string + cols []string + expected string + }{ + {"empty", []string{}, ""}, + {"single column", []string{"name"}, "GROUP BY `name`"}, + {"multiple columns", []string{"name", "age"}, "GROUP BY `name`, `age`"}, + } + + qb := getDB().QueryBuilder() + for _, test := range tests { + s := qb.BuildGroupBy(test.cols) + assert.Equal(t, test.expected, s, test.tag) + } +} + +func TestQB_BuildWhere(t *testing.T) { + tests := []struct { + exp Expression + expected string + count int + tag string + }{ + {HashExp{"age": 30, "dept": "marketing"}, "WHERE `age`={:p0} AND `dept`={:p1}", 2, "t1"}, + {nil, "", 0, "t2"}, + {NewExp(""), "", 0, "t3"}, + } + + qb := getDB().QueryBuilder() + for _, test := range tests { + params := Params{} + s := qb.BuildWhere(test.exp, params) + assert.Equal(t, test.expected, s, test.tag) + assert.Equal(t, test.count, len(params), test.tag) + } +} + +func TestQB_BuildHaving(t *testing.T) { + tests := []struct { + exp Expression + expected string + count int + tag string + }{ + {HashExp{"age": 30, "dept": "marketing"}, "HAVING `age`={:p0} AND `dept`={:p1}", 2, "t1"}, + {nil, "", 0, "t2"}, + {NewExp(""), "", 0, "t3"}, + } + + qb := getDB().QueryBuilder() + for _, test := range tests { + params := Params{} + s := qb.BuildHaving(test.exp, params) + assert.Equal(t, test.expected, s, test.tag) + assert.Equal(t, test.count, len(params), test.tag) + } +} + +func TestQB_BuildOrderBy(t *testing.T) { + tests := []struct { + tag string + cols []string + expected string + }{ + {"empty", []string{}, ""}, + {"single column", []string{"name"}, "ORDER BY `name`"}, + {"multiple columns", []string{"name ASC", "age DESC", "id desc"}, "ORDER BY `name` ASC, `age` DESC, `id` desc"}, + } + qb := getDB().QueryBuilder().(*BaseQueryBuilder) + for _, test := range tests { + s := qb.BuildOrderBy(test.cols) + assert.Equal(t, test.expected, s, test.tag) + } +} + +func TestQB_BuildLimit(t *testing.T) { + tests := []struct { + tag string + limit, offset int64 + expected string + }{ + {"t1", 10, -1, "LIMIT 10"}, + {"t2", 10, 0, "LIMIT 10"}, + {"t3", 10, 2, "LIMIT 10 OFFSET 2"}, + {"t4", 0, 2, "LIMIT 0 OFFSET 2"}, + {"t5", -1, 2, "LIMIT 9223372036854775807 OFFSET 2"}, + {"t6", -1, 0, ""}, + } + qb := getDB().QueryBuilder().(*BaseQueryBuilder) + for _, test := range tests { + s := qb.BuildLimit(test.limit, test.offset) + assert.Equal(t, test.expected, s, test.tag) + } +} + +func TestQB_BuildOrderByAndLimit(t *testing.T) { + qb := getDB().QueryBuilder() + + sql := qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, 10, 2) + expected := "SELECT * ORDER BY `name` LIMIT 10 OFFSET 2" + assert.Equal(t, sql, expected, "t1") + + sql = qb.BuildOrderByAndLimit("SELECT *", nil, -1, -1) + expected = "SELECT *" + assert.Equal(t, sql, expected, "t2") + + sql = qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, -1, -1) + expected = "SELECT * ORDER BY `name`" + assert.Equal(t, sql, expected, "t3") + + sql = qb.BuildOrderByAndLimit("SELECT *", nil, 10, -1) + expected = "SELECT * LIMIT 10" + assert.Equal(t, sql, expected, "t4") +} + +func TestQB_BuildJoin(t *testing.T) { + qb := getDB().QueryBuilder() + + params := Params{} + ji := JoinInfo{"LEFT JOIN", "users u", NewExp("id=u.id", Params{"id": 1})} + sql := qb.BuildJoin([]JoinInfo{ji}, params) + expected := "LEFT JOIN `users` `u` ON id=u.id" + assert.Equal(t, sql, expected, "BuildJoin@1") + assert.Equal(t, len(params), 1, "len(params)@1") + + params = Params{} + ji = JoinInfo{"INNER JOIN", "users", nil} + sql = qb.BuildJoin([]JoinInfo{ji}, params) + expected = "INNER JOIN `users`" + assert.Equal(t, sql, expected, "BuildJoin@2") + assert.Equal(t, len(params), 0, "len(params)@2") + + sql = qb.BuildJoin([]JoinInfo{}, nil) + expected = "" + assert.Equal(t, sql, expected, "BuildJoin@3") + + ji = JoinInfo{"INNER JOIN", "users", nil} + ji2 := JoinInfo{"LEFT JOIN", "posts", nil} + sql = qb.BuildJoin([]JoinInfo{ji, ji2}, nil) + expected = "INNER JOIN `users` LEFT JOIN `posts`" + assert.Equal(t, sql, expected, "BuildJoin@3") +} + +func TestQB_BuildUnion(t *testing.T) { + db := getDB() + qb := db.QueryBuilder() + + params := Params{} + ui := UnionInfo{false, db.NewQuery("SELECT names").Bind(Params{"id": 1})} + sql := qb.BuildUnion([]UnionInfo{ui}, params) + expected := "UNION (SELECT names)" + assert.Equal(t, sql, expected, "BuildUnion@1") + assert.Equal(t, len(params), 1, "len(params)@1") + + params = Params{} + ui = UnionInfo{true, db.NewQuery("SELECT names")} + sql = qb.BuildUnion([]UnionInfo{ui}, params) + expected = "UNION ALL (SELECT names)" + assert.Equal(t, sql, expected, "BuildUnion@2") + assert.Equal(t, len(params), 0, "len(params)@2") + + sql = qb.BuildUnion([]UnionInfo{}, nil) + expected = "" + assert.Equal(t, sql, expected, "BuildUnion@3") + + ui = UnionInfo{true, db.NewQuery("SELECT names")} + ui2 := UnionInfo{false, db.NewQuery("SELECT ages")} + sql = qb.BuildUnion([]UnionInfo{ui, ui2}, nil) + expected = "UNION ALL (SELECT names) UNION (SELECT ages)" + assert.Equal(t, sql, expected, "BuildUnion@4") +} diff --git a/query_test.go b/query_test.go new file mode 100644 index 0000000..d7a4c2a --- /dev/null +++ b/query_test.go @@ -0,0 +1,600 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + ss "database/sql" + "encoding/json" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type City struct { + ID int + Name string +} + +func TestNewQuery(t *testing.T) { + db := getDB() + sql := "SELECT * FROM users WHERE id={:id}" + q := NewQuery(db, db.sqlDB, sql) + assert.Equal(t, q.SQL(), sql, "q.SQL()") + assert.Equal(t, q.rawSQL, "SELECT * FROM users WHERE id=?", "q.RawSQL()") + + assert.Equal(t, len(q.Params()), 0, "len(q.Params())@1") + q.Bind(Params{"id": 1}) + assert.Equal(t, len(q.Params()), 1, "len(q.Params())@2") +} + +func TestQuery_Execute(t *testing.T) { + db := getPreparedDB() + defer db.Close() + + result, err := db.NewQuery("INSERT INTO item (name) VALUES ('test')").Execute() + if assert.Nil(t, err) { + rows, _ := result.RowsAffected() + assert.Equal(t, rows, int64(1), "Result.RowsAffected()") + lastID, _ := result.LastInsertId() + assert.Equal(t, lastID, int64(6), "Result.LastInsertId()") + } +} + +type Customer struct { + scanned bool + + ID int + Email string + Status int + Name string + Address ss.NullString +} + +func (m Customer) TableName() string { + return "customer" +} + +func (m *Customer) PostScan() error { + m.scanned = true + return nil +} + +type CustomerPtr struct { + ID *int `db:"pk"` + Email *string + Status *int + Name string + Address *string +} + +func (m CustomerPtr) TableName() string { + return "customer" +} + +type CustomerNull struct { + ID ss.NullInt64 `db:"pk,id"` + Email ss.NullString + Status *ss.NullInt64 + Name string + Address ss.NullString +} + +func (m CustomerNull) TableName() string { + return "customer" +} + +type CustomerEmbedded struct { + Id int + Email *string + InnerCustomer +} + +func (m CustomerEmbedded) TableName() string { + return "customer" +} + +type CustomerEmbedded2 struct { + ID int + Email *string + Inner InnerCustomer +} + +type InnerCustomer struct { + Status ss.NullInt64 + Name *string + Address ss.NullString +} + +func TestQuery_Rows(t *testing.T) { + db := getPreparedDB() + defer db.Close() + + var ( + sql string + err error + ) + + // Query.All() + var customers []Customer + sql = `SELECT * FROM customer ORDER BY id` + err = db.NewQuery(sql).All(&customers) + if assert.Nil(t, err) { + assert.Equal(t, len(customers), 3, "len(customers)") + assert.Equal(t, customers[2].ID, 3, "customers[2].ID") + assert.Equal(t, customers[2].Email, `user3@example.com`, "customers[2].Email") + assert.Equal(t, customers[2].Status, 2, "customers[2].Status") + assert.Equal(t, customers[0].scanned, true, "customers[0].scanned") + assert.Equal(t, customers[1].scanned, true, "customers[1].scanned") + assert.Equal(t, customers[2].scanned, true, "customers[2].scanned") + } + + // Query.All() with slice of pointers + var customersPtrSlice []*Customer + sql = `SELECT * FROM customer ORDER BY id` + err = db.NewQuery(sql).All(&customersPtrSlice) + if assert.Nil(t, err) { + assert.Equal(t, len(customersPtrSlice), 3, "len(customersPtrSlice)") + assert.Equal(t, customersPtrSlice[2].ID, 3, "customersPtrSlice[2].ID") + assert.Equal(t, customersPtrSlice[2].Email, `user3@example.com`, "customersPtrSlice[2].Email") + assert.Equal(t, customersPtrSlice[2].Status, 2, "customersPtrSlice[2].Status") + assert.Equal(t, customersPtrSlice[0].scanned, true, "customersPtrSlice[0].scanned") + assert.Equal(t, customersPtrSlice[1].scanned, true, "customersPtrSlice[1].scanned") + assert.Equal(t, customersPtrSlice[2].scanned, true, "customersPtrSlice[2].scanned") + } + + var customers2 []NullStringMap + err = db.NewQuery(sql).All(&customers2) + if assert.Nil(t, err) { + assert.Equal(t, len(customers2), 3, "len(customers2)") + assert.Equal(t, customers2[1]["id"].String, "2", "customers2[1][id]") + assert.Equal(t, customers2[1]["email"].String, `user2@example.com`, "customers2[1][email]") + assert.Equal(t, customers2[1]["status"].String, "1", "customers2[1][status]") + } + err = db.NewQuery(sql).All(customers) + assert.NotNil(t, err) + + var customers3 []string + err = db.NewQuery(sql).All(&customers3) + assert.NotNil(t, err) + + var customers4 string + err = db.NewQuery(sql).All(&customers4) + assert.NotNil(t, err) + + var customers5 []Customer + err = db.NewQuery(`SELECT * FROM customer WHERE id=999`).All(&customers5) + if assert.Nil(t, err) { + assert.NotNil(t, customers5) + assert.Zero(t, len(customers5)) + } + + // One + var customer Customer + sql = `SELECT * FROM customer WHERE id={:id}` + err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customer) + if assert.Nil(t, err) { + assert.Equal(t, customer.ID, 2, "customer.ID") + assert.Equal(t, customer.Email, `user2@example.com`, "customer.Email") + assert.Equal(t, customer.Status, 1, "customer.Status") + } + + var customerPtr2 CustomerPtr + sql = `SELECT id, email, address FROM customer WHERE id=2` + rows2, err := db.sqlDB.Query(sql) + defer rows2.Close() + assert.Nil(t, err) + rows2.Next() + err = rows2.Scan(&customerPtr2.ID, &customerPtr2.Email, &customerPtr2.Address) + if assert.Nil(t, err) { + assert.Equal(t, *customerPtr2.ID, 2, "customer.ID") + assert.Equal(t, *customerPtr2.Email, `user2@example.com`) + assert.Nil(t, customerPtr2.Address) + } + + // struct fields are pointers + var customerPtr CustomerPtr + sql = `SELECT * FROM customer WHERE id={:id}` + err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customerPtr) + if assert.Nil(t, err) { + assert.Equal(t, *customerPtr.ID, 2, "customer.ID") + assert.Equal(t, *customerPtr.Email, `user2@example.com`, "customer.Email") + assert.Equal(t, *customerPtr.Status, 1, "customer.Status") + } + + // struct fields are null types + var customerNull CustomerNull + sql = `SELECT * FROM customer WHERE id={:id}` + err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customerNull) + if assert.Nil(t, err) { + assert.Equal(t, customerNull.ID.Int64, int64(2), "customer.ID") + assert.Equal(t, customerNull.Email.String, `user2@example.com`, "customer.Email") + assert.Equal(t, customerNull.Status.Int64, int64(1), "customer.Status") + } + + // embedded with anonymous struct + var customerEmbedded CustomerEmbedded + sql = `SELECT * FROM customer WHERE id={:id}` + err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customerEmbedded) + if assert.Nil(t, err) { + assert.Equal(t, customerEmbedded.Id, 2, "customer.ID") + assert.Equal(t, *customerEmbedded.Email, `user2@example.com`, "customer.Email") + assert.Equal(t, customerEmbedded.Status.Int64, int64(1), "customer.Status") + } + + // embedded with named struct + var customerEmbedded2 CustomerEmbedded2 + sql = `SELECT id, email, status as "inner.status" FROM customer WHERE id={:id}` + err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customerEmbedded2) + if assert.Nil(t, err) { + assert.Equal(t, customerEmbedded2.ID, 2, "customer.ID") + assert.Equal(t, *customerEmbedded2.Email, `user2@example.com`, "customer.Email") + assert.Equal(t, customerEmbedded2.Inner.Status.Int64, int64(1), "customer.Status") + } + + customer2 := NullStringMap{} + sql = `SELECT * FROM customer WHERE id={:id}` + err = db.NewQuery(sql).Bind(Params{"id": 1}).One(customer2) + if assert.Nil(t, err) { + assert.Equal(t, customer2["id"].String, "1", "customer2[id]") + assert.Equal(t, customer2["email"].String, `user1@example.com`, "customer2[email]") + assert.Equal(t, customer2["status"].String, "1", "customer2[status]") + } + + err = db.NewQuery(sql).Bind(Params{"id": 2}).One(customer) + assert.NotNil(t, err) + + var customer3 NullStringMap + err = db.NewQuery(sql).Bind(Params{"id": 2}).One(customer3) + assert.NotNil(t, err) + + err = db.NewQuery(sql).Bind(Params{"id": 1}).One(&customer3) + if assert.Nil(t, err) { + assert.Equal(t, customer3["id"].String, "1", "customer3[id]") + } + + // Rows + sql = `SELECT * FROM customer ORDER BY id DESC` + rows, err := db.NewQuery(sql).Rows() + if assert.Nil(t, err) { + s := "" + for rows.Next() { + rows.ScanStruct(&customer) + s += customer.Email + "," + } + assert.Equal(t, s, "user3@example.com,user2@example.com,user1@example.com,", "Rows().Next()") + } + + // FieldMapper + var a struct { + MyID string `db:"id"` + name string + } + sql = `SELECT * FROM customer WHERE id=2` + err = db.NewQuery(sql).One(&a) + if assert.Nil(t, err) { + assert.Equal(t, a.MyID, "2", "a.MyID") + // unexported field is not populated + assert.Equal(t, a.name, "", "a.name") + } + + // prepared statement + sql = `SELECT * FROM customer WHERE id={:id}` + q := db.NewQuery(sql).Prepare() + q.Bind(Params{"id": 1}).One(&customer) + assert.Equal(t, customer.ID, 1, "prepared@1") + err = q.Bind(Params{"id": 20}).One(&customer) + assert.Equal(t, err, ss.ErrNoRows, "prepared@2") + q.Bind(Params{"id": 3}).One(&customer) + assert.Equal(t, customer.ID, 3, "prepared@3") + + sql = `SELECT name FROM customer WHERE id={:id}` + var name string + q = db.NewQuery(sql).Prepare() + q.Bind(Params{"id": 1}).Row(&name) + assert.Equal(t, name, "user1", "prepared2@1") + err = q.Bind(Params{"id": 20}).Row(&name) + assert.Equal(t, err, ss.ErrNoRows, "prepared2@2") + q.Bind(Params{"id": 3}).Row(&name) + assert.Equal(t, name, "user3", "prepared2@3") + + // Query.LastError + sql = `SELECT * FROM a` + q = db.NewQuery(sql).Prepare() + customer.ID = 100 + err = q.Bind(Params{"id": 1}).One(&customer) + assert.NotEqual(t, err, nil, "LastError@0") + assert.Equal(t, customer.ID, 100, "LastError@1") + assert.Equal(t, q.LastError, nil, "LastError@2") + + // Query.Column + sql = `SELECT name, id FROM customer ORDER BY id` + var names []string + err = db.NewQuery(sql).Column(&names) + if assert.Nil(t, err) && assert.Equal(t, 3, len(names)) { + assert.Equal(t, "user1", names[0]) + assert.Equal(t, "user2", names[1]) + assert.Equal(t, "user3", names[2]) + } + err = db.NewQuery(sql).Column(names) + assert.NotNil(t, err) +} + +func TestQuery_logSQL(t *testing.T) { + db := getDB() + q := db.NewQuery("SELECT * FROM users WHERE type={:type} AND id={:id} AND bytes={:bytes}").Bind(Params{ + "id": 1, + "type": "a", + "bytes": []byte("test"), + }) + expected := "SELECT * FROM users WHERE type='a' AND id=1 AND bytes=0x74657374" + assert.Equal(t, q.logSQL(), expected, "logSQL()") +} + +func TestReplacePlaceholders(t *testing.T) { + tests := []struct { + ID string + Placeholders []string + Params Params + ExpectedParams string + HasError bool + }{ + {"t1", nil, nil, "null", false}, + {"t2", []string{"id", "name"}, Params{"id": 1, "name": "xyz"}, `[1,"xyz"]`, false}, + {"t3", []string{"id", "name"}, Params{"id": 1}, `null`, true}, + {"t4", []string{"id", "name"}, Params{"id": 1, "name": "xyz", "age": 30}, `[1,"xyz"]`, false}, + } + for _, test := range tests { + params, err := replacePlaceholders(test.Placeholders, test.Params) + result, _ := json.Marshal(params) + assert.Equal(t, string(result), test.ExpectedParams, "params@"+test.ID) + assert.Equal(t, err != nil, test.HasError, "error@"+test.ID) + } +} + +func TestIssue6(t *testing.T) { + db := getPreparedDB() + q := db.Select("*").From("customer").Where(HashExp{"id": 1}) + var customer Customer + assert.Equal(t, q.One(&customer), nil) + assert.Equal(t, 1, customer.ID) +} + +type User struct { + ID int64 + Email string + Created time.Time + Updated *time.Time +} + +func TestIssue13(t *testing.T) { + db := getPreparedDB() + var user User + err := db.Select().From("user").Where(HashExp{"id": 1}).One(&user) + if assert.Nil(t, err) { + assert.NotZero(t, user.Created) + assert.Nil(t, user.Updated) + } + + now := time.Now() + + user2 := User{ + Email: "now@example.com", + Created: now, + } + err = db.Model(&user2).Insert() + if assert.Nil(t, err) { + assert.NotZero(t, user2.ID) + } + + user3 := User{ + Email: "now@example.com", + Created: now, + Updated: &now, + } + err = db.Model(&user3).Insert() + if assert.Nil(t, err) { + assert.NotZero(t, user2.ID) + } +} + +func TestQueryWithExecHook(t *testing.T) { + db := getPreparedDB() + defer db.Close() + + // error return + { + err := db.NewQuery("select * from user"). + WithExecHook(func(q *Query, op func() error) error { + return errors.New("test") + }). + Row() + + assert.Error(t, err) + } + + // Row() + { + calls := 0 + err := db.NewQuery("select * from user"). + WithExecHook(func(q *Query, op func() error) error { + calls++ + return nil + }). + Row() + assert.Nil(t, err) + assert.Equal(t, 1, calls, "Row()") + } + + // One() + { + calls := 0 + err := db.NewQuery("select * from user"). + WithExecHook(func(q *Query, op func() error) error { + calls++ + return nil + }). + One(nil) + assert.Nil(t, err) + assert.Equal(t, 1, calls, "One()") + } + + // All() + { + calls := 0 + err := db.NewQuery("select * from user"). + WithExecHook(func(q *Query, op func() error) error { + calls++ + return nil + }). + All(nil) + assert.Nil(t, err) + assert.Equal(t, 1, calls, "All()") + } + + // Column() + { + calls := 0 + err := db.NewQuery("select * from user"). + WithExecHook(func(q *Query, op func() error) error { + calls++ + return nil + }). + Column(nil) + assert.Nil(t, err) + assert.Equal(t, 1, calls, "Column()") + } + + // Execute() + { + calls := 0 + _, err := db.NewQuery("select * from user"). + WithExecHook(func(q *Query, op func() error) error { + calls++ + return nil + }). + Execute() + assert.Nil(t, err) + assert.Equal(t, 1, calls, "Execute()") + } + + // op call + { + calls := 0 + var id int + err := db.NewQuery("select id from user where id = 2"). + WithExecHook(func(q *Query, op func() error) error { + calls++ + return op() + }). + Row(&id) + assert.Nil(t, err) + assert.Equal(t, 1, calls, "op hook calls") + assert.Equal(t, 2, id, "id mismatch") + } +} + +func TestQueryWithOneHook(t *testing.T) { + db := getPreparedDB() + defer db.Close() + + // error return + { + err := db.NewQuery("select * from user"). + WithOneHook(func(q *Query, a interface{}, op func(b interface{}) error) error { + return errors.New("test") + }). + One(nil) + + assert.Error(t, err) + } + + // hooks call order + { + hookCalls := []string{} + err := db.NewQuery("select * from user"). + WithExecHook(func(q *Query, op func() error) error { + hookCalls = append(hookCalls, "exec") + return op() + }). + WithOneHook(func(q *Query, a interface{}, op func(b interface{}) error) error { + hookCalls = append(hookCalls, "one") + return nil + }). + One(nil) + + assert.Nil(t, err) + assert.Equal(t, hookCalls, []string{"exec", "one"}) + } + + // op call + { + calls := 0 + other := User{} + err := db.NewQuery("select id from user where id = 2"). + WithOneHook(func(q *Query, a interface{}, op func(b interface{}) error) error { + calls++ + return op(&other) + }). + One(nil) + + assert.Nil(t, err) + assert.Equal(t, 1, calls, "hook calls") + assert.Equal(t, int64(2), other.ID, "replaced scan struct") + } +} + +func TestQueryWithAllHook(t *testing.T) { + db := getPreparedDB() + defer db.Close() + + // error return + { + err := db.NewQuery("select * from user"). + WithAllHook(func(q *Query, a interface{}, op func(b interface{}) error) error { + return errors.New("test") + }). + All(nil) + + assert.Error(t, err) + } + + // hooks call order + { + hookCalls := []string{} + err := db.NewQuery("select * from user"). + WithExecHook(func(q *Query, op func() error) error { + hookCalls = append(hookCalls, "exec") + return op() + }). + WithAllHook(func(q *Query, a interface{}, op func(b interface{}) error) error { + hookCalls = append(hookCalls, "all") + return nil + }). + All(nil) + + assert.Nil(t, err) + assert.Equal(t, hookCalls, []string{"exec", "all"}) + } + + // op call + { + calls := 0 + other := []User{} + err := db.NewQuery("select id from user order by id asc"). + WithAllHook(func(q *Query, a interface{}, op func(b interface{}) error) error { + calls++ + return op(&other) + }). + All(nil) + + assert.Nil(t, err) + assert.Equal(t, 1, calls, "hook calls") + assert.Equal(t, 2, len(other), "users length") + assert.Equal(t, int64(1), other[0].ID, "user 1 id check") + assert.Equal(t, int64(2), other[1].ID, "user 2 id check") + } +} diff --git a/rows.go b/rows.go new file mode 100644 index 0000000..2c2773d --- /dev/null +++ b/rows.go @@ -0,0 +1,301 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "database/sql" + "reflect" +) + +// VarTypeError indicates a variable type error when trying to populating a variable with DB result. +type VarTypeError string + +// Error returns the error message. +func (s VarTypeError) Error() string { + return "Invalid variable type: " + string(s) +} + +// NullStringMap is a map of sql.NullString that can be used to hold DB query result. +// The map keys correspond to the DB column names, while the map values are their corresponding column values. +type NullStringMap map[string]sql.NullString + +// Rows enhances sql.Rows by providing additional data query methods. +// Rows can be obtained by calling Query.Rows(). It is mainly used to populate data row by row. +type Rows struct { + *sql.Rows + fieldMapFunc FieldMapFunc +} + +// ScanMap populates the current row of data into a NullStringMap. +// Note that the NullStringMap must not be nil, or it will panic. +// The NullStringMap will be populated using column names as keys and their values as +// the corresponding element values. +func (r *Rows) ScanMap(a NullStringMap) error { + cols, _ := r.Columns() + var refs []interface{} + for i := 0; i < len(cols); i++ { + var t sql.NullString + refs = append(refs, &t) + } + if err := r.Scan(refs...); err != nil { + return err + } + + for i, col := range cols { + a[col] = *refs[i].(*sql.NullString) + } + + return nil +} + +// ScanStruct populates the current row of data into a struct. +// The struct must be given as a pointer. +// +// ScanStruct associates struct fields with DB table columns through a field mapping function. +// It populates a struct field with the data of its associated column. +// Note that only exported struct fields will be populated. +// +// By default, DefaultFieldMapFunc() is used to map struct fields to table columns. +// This function separates each word in a field name with a underscore and turns every letter into lower case. +// For example, "LastName" is mapped to "last_name", "MyID" is mapped to "my_id", and so on. +// To change the default behavior, set DB.FieldMapper with your custom mapping function. +// You may also set Query.FieldMapper to change the behavior for particular queries. +func (r *Rows) ScanStruct(a interface{}) error { + rv := reflect.ValueOf(a) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return VarTypeError("must be a pointer") + } + rv = indirect(rv) + if rv.Kind() != reflect.Struct { + return VarTypeError("must be a pointer to a struct") + } + + si := getStructInfo(rv.Type(), r.fieldMapFunc) + + cols, _ := r.Columns() + refs := make([]interface{}, len(cols)) + + for i, col := range cols { + if fi, ok := si.dbNameMap[col]; ok { + refs[i] = fi.getField(rv).Addr().Interface() + } else { + refs[i] = &sql.NullString{} + } + } + + if err := r.Scan(refs...); err != nil { + return err + } + + // check for PostScanner + if rv.CanAddr() { + addr := rv.Addr() + if addr.CanInterface() { + if ps, ok := addr.Interface().(PostScanner); ok { + if err := ps.PostScan(); err != nil { + return err + } + } + } + } + + return nil +} + +// all populates all rows of query result into a slice of struct or NullStringMap. +// Note that the slice must be given as a pointer. +func (r *Rows) all(slice interface{}) error { + defer r.Close() + + v := reflect.ValueOf(slice) + if v.Kind() != reflect.Ptr || v.IsNil() { + return VarTypeError("must be a pointer") + } + v = indirect(v) + + if v.Kind() != reflect.Slice { + return VarTypeError("must be a slice of struct or NullStringMap") + } + + if v.IsNil() { + // create an empty slice + v.Set(reflect.MakeSlice(v.Type(), 0, 0)) + } + + et := v.Type().Elem() + + if et.Kind() == reflect.Map { + for r.Next() { + ev, ok := reflect.MakeMap(et).Interface().(NullStringMap) + if !ok { + return VarTypeError("must be a slice of struct or NullStringMap") + } + if err := r.ScanMap(ev); err != nil { + return err + } + v.Set(reflect.Append(v, reflect.ValueOf(ev))) + } + return r.Close() + } + + var isSliceOfPointers bool + if et.Kind() == reflect.Ptr { + isSliceOfPointers = true + et = et.Elem() + } + + if et.Kind() != reflect.Struct { + return VarTypeError("must be a slice of struct or NullStringMap") + } + + etPtr := reflect.PtrTo(et) + implementsPostScanner := etPtr.Implements(postScannerType) + + si := getStructInfo(et, r.fieldMapFunc) + + cols, _ := r.Columns() + for r.Next() { + ev := reflect.New(et).Elem() + refs := make([]interface{}, len(cols)) + for i, col := range cols { + if fi, ok := si.dbNameMap[col]; ok { + refs[i] = fi.getField(ev).Addr().Interface() + } else { + refs[i] = &sql.NullString{} + } + } + if err := r.Scan(refs...); err != nil { + return err + } + + if isSliceOfPointers { + ev = ev.Addr() + } + + // check for PostScanner + if implementsPostScanner { + evAddr := ev + if ev.CanAddr() { + evAddr = ev.Addr() + } + if evAddr.CanInterface() { + if ps, ok := evAddr.Interface().(PostScanner); ok { + if err := ps.PostScan(); err != nil { + return err + } + } + } + } + + v.Set(reflect.Append(v, ev)) + } + + return r.Close() +} + +// column populates the given slice with the first column of the query result. +// Note that the slice must be given as a pointer. +func (r *Rows) column(slice interface{}) error { + defer r.Close() + + v := reflect.ValueOf(slice) + if v.Kind() != reflect.Ptr || v.IsNil() { + return VarTypeError("must be a pointer to a slice") + } + v = indirect(v) + + if v.Kind() != reflect.Slice { + return VarTypeError("must be a pointer to a slice") + } + + et := v.Type().Elem() + + cols, _ := r.Columns() + for r.Next() { + ev := reflect.New(et) + refs := make([]interface{}, len(cols)) + for i := range cols { + if i == 0 { + refs[i] = ev.Interface() + } else { + refs[i] = &sql.NullString{} + } + } + if err := r.Scan(refs...); err != nil { + return err + } + v.Set(reflect.Append(v, ev.Elem())) + } + + return r.Close() +} + +// one populates a single row of query result into a struct or a NullStringMap. +// Note that if a struct is given, it should be a pointer. +func (r *Rows) one(a interface{}) error { + defer r.Close() + + if !r.Next() { + if err := r.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + + var err error + + rt := reflect.TypeOf(a) + if rt.Kind() == reflect.Ptr && rt.Elem().Kind() == reflect.Map { + // pointer to map + v := indirect(reflect.ValueOf(a)) + if v.IsNil() { + v.Set(reflect.MakeMap(v.Type())) + } + a = v.Interface() + rt = reflect.TypeOf(a) + } + + if rt.Kind() == reflect.Map { + v, ok := a.(NullStringMap) + if !ok { + return VarTypeError("must be a NullStringMap") + } + if v == nil { + return VarTypeError("NullStringMap is nil") + } + err = r.ScanMap(v) + } else { + err = r.ScanStruct(a) + } + + if err != nil { + return err + } + + return r.Close() +} + +// row populates a single row of query result into a list of variables. +func (r *Rows) row(a ...interface{}) error { + defer r.Close() + + for _, dp := range a { + if _, ok := dp.(*sql.RawBytes); ok { + return VarTypeError("RawBytes isn't allowed on Row()") + } + } + + if !r.Next() { + if err := r.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + if err := r.Scan(a...); err != nil { + return err + } + + return r.Close() +} diff --git a/select.go b/select.go new file mode 100644 index 0000000..485f74d --- /dev/null +++ b/select.go @@ -0,0 +1,445 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "context" + "fmt" + "reflect" +) + +// BuildHookFunc defines a callback function that is executed on Query creation. +type BuildHookFunc func(q *Query) + +// SelectQuery represents a DB-agnostic SELECT query. +// It can be built into a DB-specific query by calling the Build() method. +type SelectQuery struct { + // FieldMapper maps struct field names to DB column names. + FieldMapper FieldMapFunc + // TableMapper maps structs to DB table names. + TableMapper TableMapFunc + + builder Builder + ctx context.Context + buildHook BuildHookFunc + + preFragment string + postFragment string + selects []string + distinct bool + selectOption string + from []string + where Expression + join []JoinInfo + orderBy []string + groupBy []string + having Expression + union []UnionInfo + limit int64 + offset int64 + params Params +} + +// JoinInfo contains the specification for a JOIN clause. +type JoinInfo struct { + Join string + Table string + On Expression +} + +// UnionInfo contains the specification for a UNION clause. +type UnionInfo struct { + All bool + Query *Query +} + +// NewSelectQuery creates a new SelectQuery instance. +func NewSelectQuery(builder Builder, db *DB) *SelectQuery { + return &SelectQuery{ + builder: builder, + selects: []string{}, + from: []string{}, + join: []JoinInfo{}, + orderBy: []string{}, + groupBy: []string{}, + union: []UnionInfo{}, + limit: -1, + params: Params{}, + ctx: db.ctx, + FieldMapper: db.FieldMapper, + TableMapper: db.TableMapper, + } +} + +// WithBuildHook runs the provided hook function with the query created on Build(). +func (q *SelectQuery) WithBuildHook(fn BuildHookFunc) *SelectQuery { + q.buildHook = fn + return q +} + +// Context returns the context associated with the query. +func (q *SelectQuery) Context() context.Context { + return q.ctx +} + +// WithContext associates a context with the query. +func (q *SelectQuery) WithContext(ctx context.Context) *SelectQuery { + q.ctx = ctx + return q +} + +// PreFragment sets SQL fragment that should be prepended before the select query (e.g. WITH clause). +func (s *SelectQuery) PreFragment(fragment string) *SelectQuery { + s.preFragment = fragment + return s +} + +// PostFragment sets SQL fragment that should be appended at the end of the select query. +func (s *SelectQuery) PostFragment(fragment string) *SelectQuery { + s.postFragment = fragment + return s +} + +// Select specifies the columns to be selected. +// Column names will be automatically quoted. +func (s *SelectQuery) Select(cols ...string) *SelectQuery { + s.selects = cols + return s +} + +// AndSelect adds additional columns to be selected. +// Column names will be automatically quoted. +func (s *SelectQuery) AndSelect(cols ...string) *SelectQuery { + s.selects = append(s.selects, cols...) + return s +} + +// Distinct specifies whether to select columns distinctively. +// By default, distinct is false. +func (s *SelectQuery) Distinct(v bool) *SelectQuery { + s.distinct = v + return s +} + +// SelectOption specifies additional option that should be append to "SELECT". +func (s *SelectQuery) SelectOption(option string) *SelectQuery { + s.selectOption = option + return s +} + +// From specifies which tables to select from. +// Table names will be automatically quoted. +func (s *SelectQuery) From(tables ...string) *SelectQuery { + s.from = tables + return s +} + +// Where specifies the WHERE condition. +func (s *SelectQuery) Where(e Expression) *SelectQuery { + s.where = e + return s +} + +// AndWhere concatenates a new WHERE condition with the existing one (if any) using "AND". +func (s *SelectQuery) AndWhere(e Expression) *SelectQuery { + s.where = And(s.where, e) + return s +} + +// OrWhere concatenates a new WHERE condition with the existing one (if any) using "OR". +func (s *SelectQuery) OrWhere(e Expression) *SelectQuery { + s.where = Or(s.where, e) + return s +} + +// Join specifies a JOIN clause. +// The "typ" parameter specifies the JOIN type (e.g. "INNER JOIN", "LEFT JOIN"). +func (s *SelectQuery) Join(typ string, table string, on Expression) *SelectQuery { + s.join = append(s.join, JoinInfo{typ, table, on}) + return s +} + +// InnerJoin specifies an INNER JOIN clause. +// This is a shortcut method for Join. +func (s *SelectQuery) InnerJoin(table string, on Expression) *SelectQuery { + return s.Join("INNER JOIN", table, on) +} + +// LeftJoin specifies a LEFT JOIN clause. +// This is a shortcut method for Join. +func (s *SelectQuery) LeftJoin(table string, on Expression) *SelectQuery { + return s.Join("LEFT JOIN", table, on) +} + +// RightJoin specifies a RIGHT JOIN clause. +// This is a shortcut method for Join. +func (s *SelectQuery) RightJoin(table string, on Expression) *SelectQuery { + return s.Join("RIGHT JOIN", table, on) +} + +// OrderBy specifies the ORDER BY clause. +// Column names will be properly quoted. A column name can contain "ASC" or "DESC" to indicate its ordering direction. +func (s *SelectQuery) OrderBy(cols ...string) *SelectQuery { + s.orderBy = cols + return s +} + +// AndOrderBy appends additional columns to the existing ORDER BY clause. +// Column names will be properly quoted. A column name can contain "ASC" or "DESC" to indicate its ordering direction. +func (s *SelectQuery) AndOrderBy(cols ...string) *SelectQuery { + s.orderBy = append(s.orderBy, cols...) + return s +} + +// GroupBy specifies the GROUP BY clause. +// Column names will be properly quoted. +func (s *SelectQuery) GroupBy(cols ...string) *SelectQuery { + s.groupBy = cols + return s +} + +// AndGroupBy appends additional columns to the existing GROUP BY clause. +// Column names will be properly quoted. +func (s *SelectQuery) AndGroupBy(cols ...string) *SelectQuery { + s.groupBy = append(s.groupBy, cols...) + return s +} + +// Having specifies the HAVING clause. +func (s *SelectQuery) Having(e Expression) *SelectQuery { + s.having = e + return s +} + +// AndHaving concatenates a new HAVING condition with the existing one (if any) using "AND". +func (s *SelectQuery) AndHaving(e Expression) *SelectQuery { + s.having = And(s.having, e) + return s +} + +// OrHaving concatenates a new HAVING condition with the existing one (if any) using "OR". +func (s *SelectQuery) OrHaving(e Expression) *SelectQuery { + s.having = Or(s.having, e) + return s +} + +// Union specifies a UNION clause. +func (s *SelectQuery) Union(q *Query) *SelectQuery { + s.union = append(s.union, UnionInfo{false, q}) + return s +} + +// UnionAll specifies a UNION ALL clause. +func (s *SelectQuery) UnionAll(q *Query) *SelectQuery { + s.union = append(s.union, UnionInfo{true, q}) + return s +} + +// Limit specifies the LIMIT clause. +// A negative limit means no limit. +func (s *SelectQuery) Limit(limit int64) *SelectQuery { + s.limit = limit + return s +} + +// Offset specifies the OFFSET clause. +// A negative offset means no offset. +func (s *SelectQuery) Offset(offset int64) *SelectQuery { + s.offset = offset + return s +} + +// Bind specifies the parameter values to be bound to the query. +func (s *SelectQuery) Bind(params Params) *SelectQuery { + s.params = params + return s +} + +// AndBind appends additional parameters to be bound to the query. +func (s *SelectQuery) AndBind(params Params) *SelectQuery { + if len(s.params) == 0 { + s.params = params + } else { + for k, v := range params { + s.params[k] = v + } + } + return s +} + +// Build builds the SELECT query and returns an executable Query object. +func (s *SelectQuery) Build() *Query { + params := Params{} + for k, v := range s.params { + params[k] = v + } + + qb := s.builder.QueryBuilder() + + clauses := []string{ + s.preFragment, + qb.BuildSelect(s.selects, s.distinct, s.selectOption), + qb.BuildFrom(s.from), + qb.BuildJoin(s.join, params), + qb.BuildWhere(s.where, params), + qb.BuildGroupBy(s.groupBy), + qb.BuildHaving(s.having, params), + } + + sql := "" + for _, clause := range clauses { + if clause != "" { + if sql == "" { + sql = clause + } else { + sql += " " + clause + } + } + } + + sql = qb.BuildOrderByAndLimit(sql, s.orderBy, s.limit, s.offset) + + if s.postFragment != "" { + sql += " " + s.postFragment + } + + if union := qb.BuildUnion(s.union, params); union != "" { + sql = fmt.Sprintf("(%v) %v", sql, union) + } + + query := s.builder.NewQuery(sql).Bind(params).WithContext(s.ctx) + + if s.buildHook != nil { + s.buildHook(query) + } + + return query +} + +// One executes the SELECT query and populates the first row of the result into the specified variable. +// +// If the query does not specify a "from" clause, the method will try to infer the name of the table +// to be selected from by calling getTableName() which will return either the variable type name +// or the TableName() method if the variable implements the TableModel interface. +// +// Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned. +func (s *SelectQuery) One(a interface{}) error { + if len(s.from) == 0 { + if tableName := s.TableMapper(a); tableName != "" { + s.from = []string{tableName} + } + } + + return s.Build().One(a) +} + +// Model selects the row with the specified primary key and populates the model with the row data. +// +// The model variable should be a pointer to a struct. If the query does not specify a "from" clause, +// it will use the model struct to determine which table to select data from. It will also use the model +// to infer the name of the primary key column. Only simple primary key is supported. For composite primary keys, +// please use Where() to specify the filtering condition. +func (s *SelectQuery) Model(pk, model interface{}) error { + t := reflect.TypeOf(model) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return VarTypeError("must be a pointer to a struct") + } + + si := getStructInfo(t, s.FieldMapper) + if len(si.pkNames) == 1 { + return s.AndWhere(HashExp{si.nameMap[si.pkNames[0]].dbName: pk}).One(model) + } + if len(si.pkNames) == 0 { + return MissingPKError + } + + return CompositePKError +} + +// All executes the SELECT query and populates all rows of the result into a slice. +// +// Note that the slice must be passed in as a pointer. +// +// If the query does not specify a "from" clause, the method will try to infer the name of the table +// to be selected from by calling getTableName() which will return either the type name of the slice elements +// or the TableName() method if the slice element implements the TableModel interface. +func (s *SelectQuery) All(slice interface{}) error { + if len(s.from) == 0 { + if tableName := s.TableMapper(slice); tableName != "" { + s.from = []string{tableName} + } + } + + return s.Build().All(slice) +} + +// Rows builds and executes the SELECT query and returns a Rows object for data retrieval purpose. +// This is a shortcut to SelectQuery.Build().Rows() +func (s *SelectQuery) Rows() (*Rows, error) { + return s.Build().Rows() +} + +// Row builds and executes the SELECT query and populates the first row of the result into the specified variables. +// This is a shortcut to SelectQuery.Build().Row() +func (s *SelectQuery) Row(a ...interface{}) error { + return s.Build().Row(a...) +} + +// Column builds and executes the SELECT statement and populates the first column of the result into a slice. +// Note that the parameter must be a pointer to a slice. +// This is a shortcut to SelectQuery.Build().Column() +func (s *SelectQuery) Column(a interface{}) error { + return s.Build().Column(a) +} + +// QueryInfo represents a debug/info struct with exported SelectQuery fields. +type QueryInfo struct { + PreFragment string + PostFragment string + Builder Builder + Selects []string + Distinct bool + SelectOption string + From []string + Where Expression + Join []JoinInfo + OrderBy []string + GroupBy []string + Having Expression + Union []UnionInfo + Limit int64 + Offset int64 + Params Params + Context context.Context + BuildHook BuildHookFunc +} + +// Info exports common SelectQuery fields allowing to inspect the +// current select query options. +func (s *SelectQuery) Info() *QueryInfo { + return &QueryInfo{ + Builder: s.builder, + PreFragment: s.preFragment, + PostFragment: s.postFragment, + Selects: s.selects, + Distinct: s.distinct, + SelectOption: s.selectOption, + From: s.from, + Where: s.where, + Join: s.join, + OrderBy: s.orderBy, + GroupBy: s.groupBy, + Having: s.having, + Union: s.union, + Limit: s.limit, + Offset: s.offset, + Params: s.params, + Context: s.ctx, + BuildHook: s.buildHook, + } +} diff --git a/select_test.go b/select_test.go new file mode 100644 index 0000000..1dc3dfd --- /dev/null +++ b/select_test.go @@ -0,0 +1,179 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "testing" + + "database/sql" + + "github.com/stretchr/testify/assert" +) + +func TestSelectQuery(t *testing.T) { + db := getDB() + + // minimal select query + q := db.Select().From("users").Build() + expected := "SELECT * FROM `users`" + assert.Equal(t, q.SQL(), expected, "t1") + assert.Equal(t, len(q.Params()), 0, "t2") + + // a full select query + q = db.Select("id", "name"). + PreFragment("pre"). + PostFragment("post"). + AndSelect("age"). + Distinct(true). + SelectOption("CALC"). + From("users"). + Where(NewExp("age>30")). + AndWhere(NewExp("status=1")). + OrWhere(NewExp("type=2")). + InnerJoin("profile", NewExp("user.id=profile.id")). + LeftJoin("team", nil). + RightJoin("dept", nil). + OrderBy("age DESC", "type"). + AndOrderBy("id"). + GroupBy("id"). + AndGroupBy("age"). + Having(NewExp("id>10")). + AndHaving(NewExp("id<20")). + OrHaving(NewExp("type=3")). + Limit(10). + Offset(20). + Bind(Params{"id": 1}). + AndBind(Params{"age": 30}). + Build() + + expected = "pre SELECT DISTINCT CALC `id`, `name`, `age` FROM `users` INNER JOIN `profile` ON user.id=profile.id LEFT JOIN `team` RIGHT JOIN `dept` WHERE ((age>30) AND (status=1)) OR (type=2) GROUP BY `id`, `age` HAVING ((id>10) AND (id<20)) OR (type=3) ORDER BY `age` DESC, `type`, `id` LIMIT 10 OFFSET 20 post" + assert.Equal(t, q.SQL(), expected, "t3") + assert.Equal(t, len(q.Params()), 2, "t4") + + q3 := db.Select().AndBind(Params{"id": 1}).Build() + assert.Equal(t, len(q3.Params()), 1) + + // union + q1 := db.Select().From("users").PreFragment("pre_q1").Build() + q2 := db.Select().From("posts").PostFragment("post_q2").Build() + q = db.Select().From("profiles").Union(q1).UnionAll(q2).PreFragment("pre").PostFragment("post").Build() + expected = "(pre SELECT * FROM `profiles` post) UNION (pre_q1 SELECT * FROM `users`) UNION ALL (SELECT * FROM `posts` post_q2)" + assert.Equal(t, q.SQL(), expected, "t5") +} + +func TestSelectQuery_Data(t *testing.T) { + db := getPreparedDB() + defer db.Close() + + q := db.Select("id", "email").From("customer").OrderBy("id") + + var customer Customer + q.One(&customer) + assert.Equal(t, customer.Email, "user1@example.com", "customer.Email") + + var customers []Customer + q.All(&customers) + assert.Equal(t, len(customers), 3, "len(customers)") + + rows, _ := q.Rows() + customer.Email = "" + rows.one(&customer) + assert.Equal(t, customer.Email, "user1@example.com", "customer.Email") + + var id, email string + q.Row(&id, &email) + assert.Equal(t, id, "1", "id") + assert.Equal(t, email, "user1@example.com", "email") + + var emails []string + err := db.Select("email").From("customer").Column(&emails) + if assert.Nil(t, err) { + assert.Equal(t, 3, len(emails)) + } + + var e int + err = db.Select().From("customer").One(&e) + assert.NotNil(t, err) + err = db.Select().From("customer").All(&e) + assert.NotNil(t, err) +} + +func TestSelectQuery_Model(t *testing.T) { + db := getPreparedDB() + defer db.Close() + + { + // One without specifying FROM + var customer CustomerPtr + err := db.Select().OrderBy("id").One(&customer) + if assert.Nil(t, err) { + assert.Equal(t, "user1@example.com", *customer.Email) + } + } + + { + // All without specifying FROM + var customers []CustomerPtr + err := db.Select().OrderBy("id").All(&customers) + if assert.Nil(t, err) { + assert.Equal(t, 3, len(customers)) + } + } + + { + // Model without specifying FROM + var customer CustomerPtr + err := db.Select().Model(2, &customer) + if assert.Nil(t, err) { + assert.Equal(t, "user2@example.com", *customer.Email) + } + } + + { + // Model with WHERE + var customer CustomerPtr + err := db.Select().Where(HashExp{"id": 1}).Model(2, &customer) + assert.Equal(t, sql.ErrNoRows, err) + + err = db.Select().Where(HashExp{"id": 2}).Model(2, &customer) + assert.Nil(t, err) + } + + { + // errors + var i int + err := db.Select().Model(1, &i) + assert.Equal(t, VarTypeError("must be a pointer to a struct"), err) + + var a struct { + Name string + } + + err = db.Select().Model(1, &a) + assert.Equal(t, MissingPKError, err) + var b struct { + ID1 string `db:"pk"` + ID2 string `db:"pk"` + } + err = db.Select().Model(1, &b) + assert.Equal(t, CompositePKError, err) + } +} + +func TestSelectWithBuildHook(t *testing.T) { + db := getPreparedDB() + defer db.Close() + + var buildSQL string + + db.Select("id"). + From("user"). + WithBuildHook(func(q *Query) { + buildSQL = q.SQL() + }). + Build() + + assert.Equal(t, "SELECT `id` FROM `user`", buildSQL) +} diff --git a/struct.go b/struct.go new file mode 100644 index 0000000..71aebdd --- /dev/null +++ b/struct.go @@ -0,0 +1,281 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "database/sql" + "reflect" + "regexp" + "strings" + "sync" +) + +type ( + // FieldMapFunc converts a struct field name into a DB column name. + FieldMapFunc func(string) string + + // TableMapFunc converts a sample struct into a DB table name. + TableMapFunc func(a interface{}) string + + structInfo struct { + nameMap map[string]*fieldInfo // mapping from struct field names to field infos + dbNameMap map[string]*fieldInfo // mapping from db column names to field infos + pkNames []string // struct field names representing PKs + } + + structValue struct { + *structInfo + value reflect.Value // the struct value + tableName string // the db table name for the struct + } + + fieldInfo struct { + name string // field name + dbName string // db column name + path []int // index path to the struct field reflection + } + + structInfoMapKey struct { + t reflect.Type + m reflect.Value + } +) + +var ( + // DbTag is the name of the struct tag used to specify the column name for the associated struct field + DbTag = "db" + + fieldRegex = regexp.MustCompile(`([^A-Z_])([A-Z])`) + scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + postScannerType = reflect.TypeOf((*PostScanner)(nil)).Elem() + structInfoMap = make(map[structInfoMapKey]*structInfo) + muStructInfoMap sync.Mutex +) + +// PostScanner is an optional interface used by ScanStruct. +type PostScanner interface { + // PostScan executes right after the struct has been populated + // with the DB values, allowing you to further normalize or validate + // the loaded data. + PostScan() error +} + +// DefaultFieldMapFunc maps a field name to a DB column name. +// The mapping rule set by this method is that words in a field name will be separated by underscores +// and the name will be turned into lower case. For example, "FirstName" maps to "first_name", and "MyID" becomes "my_id". +// See DB.FieldMapper for more details. +func DefaultFieldMapFunc(f string) string { + return strings.ToLower(fieldRegex.ReplaceAllString(f, "${1}_$2")) +} + +func getStructInfo(a reflect.Type, mapper FieldMapFunc) *structInfo { + muStructInfoMap.Lock() + defer muStructInfoMap.Unlock() + + key := structInfoMapKey{a, reflect.ValueOf(mapper)} + if si, ok := structInfoMap[key]; ok { + return si + } + + si := &structInfo{ + nameMap: map[string]*fieldInfo{}, + dbNameMap: map[string]*fieldInfo{}, + } + si.build(a, make([]int, 0), "", "", mapper) + structInfoMap[key] = si + + return si +} + +func newStructValue(model interface{}, fieldMapFunc FieldMapFunc, tableMapFunc TableMapFunc) *structValue { + value := reflect.ValueOf(model) + if value.Kind() != reflect.Ptr || value.Elem().Kind() != reflect.Struct || value.IsNil() { + return nil + } + + return &structValue{ + structInfo: getStructInfo(reflect.TypeOf(model).Elem(), fieldMapFunc), + value: value.Elem(), + tableName: tableMapFunc(model), + } +} + +// pk returns the primary key values indexed by the corresponding primary key column names. +func (s *structValue) pk() map[string]interface{} { + if len(s.pkNames) == 0 { + return nil + } + return s.columns(s.pkNames, nil) +} + +// columns returns the struct field values indexed by their corresponding DB column names. +func (s *structValue) columns(include, exclude []string) map[string]interface{} { + v := make(map[string]interface{}, len(s.nameMap)) + if len(include) == 0 { + for _, fi := range s.nameMap { + v[fi.dbName] = fi.getValue(s.value) + } + } else { + for _, attr := range include { + if fi, ok := s.nameMap[attr]; ok { + v[fi.dbName] = fi.getValue(s.value) + } + } + } + if len(exclude) > 0 { + for _, name := range exclude { + if fi, ok := s.nameMap[name]; ok { + delete(v, fi.dbName) + } + } + } + return v +} + +// getValue returns the field value for the given struct value. +func (fi *fieldInfo) getValue(a reflect.Value) interface{} { + for _, i := range fi.path { + a = a.Field(i) + if a.Kind() == reflect.Ptr { + if a.IsNil() { + return nil + } + a = a.Elem() + } + } + return a.Interface() +} + +// getField returns the reflection value of the field for the given struct value. +func (fi *fieldInfo) getField(a reflect.Value) reflect.Value { + i := 0 + for ; i < len(fi.path)-1; i++ { + a = indirect(a.Field(fi.path[i])) + } + return a.Field(fi.path[i]) +} + +func (si *structInfo) build(a reflect.Type, path []int, namePrefix, dbNamePrefix string, mapper FieldMapFunc) { + n := a.NumField() + for i := 0; i < n; i++ { + field := a.Field(i) + tag := field.Tag.Get(DbTag) + + // only handle anonymous or exported fields + if !field.Anonymous && field.PkgPath != "" || tag == "-" { + continue + } + + path2 := make([]int, len(path), len(path)+1) + copy(path2, path) + path2 = append(path2, i) + + ft := field.Type + if ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + + name := field.Name + dbName, isPK := parseTag(tag) + if dbName == "" && !field.Anonymous { + if mapper != nil { + dbName = mapper(field.Name) + } else { + dbName = field.Name + } + } + if field.Anonymous { + name = "" + } + + if isNestedStruct(ft) { + // dive into non-scanner struct + si.build(ft, path2, concat(namePrefix, name), concat(dbNamePrefix, dbName), mapper) + } else if dbName != "" { + // non-anonymous scanner or struct field + fi := &fieldInfo{ + name: concat(namePrefix, name), + dbName: concat(dbNamePrefix, dbName), + path: path2, + } + // a field in an anonymous struct may be shadowed + if _, ok := si.nameMap[fi.name]; !ok || len(path2) < len(si.nameMap[fi.name].path) { + si.nameMap[fi.name] = fi + si.dbNameMap[fi.dbName] = fi + if isPK { + si.pkNames = append(si.pkNames, fi.name) + } + } + } + } + if len(si.pkNames) == 0 { + if _, ok := si.nameMap["ID"]; ok { + si.pkNames = append(si.pkNames, "ID") + } else if _, ok := si.nameMap["Id"]; ok { + si.pkNames = append(si.pkNames, "Id") + } + } +} + +func isNestedStruct(t reflect.Type) bool { + if t.PkgPath() == "time" && t.Name() == "Time" { + return false + } + return t.Kind() == reflect.Struct && !reflect.PtrTo(t).Implements(scannerType) +} + +func parseTag(tag string) (string, bool) { + if tag == "pk" { + return "", true + } + if strings.HasPrefix(tag, "pk,") { + return tag[3:], true + } + return tag, false +} + +func concat(s1, s2 string) string { + if s1 == "" { + return s2 + } else if s2 == "" { + return s1 + } else { + return s1 + "." + s2 + } +} + +// indirect dereferences pointers and returns the actual value it points to. +// If a pointer is nil, it will be initialized with a new value. +func indirect(v reflect.Value) reflect.Value { + for v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + return v +} + +// GetTableName implements the default way of determining the table name corresponding to the given model struct +// or slice of structs. To get the actual table name for a model, you should use DB.TableMapFunc() instead. +// Do not call this method in a model's TableName() method because it will cause infinite loop. +func GetTableName(a interface{}) string { + if tm, ok := a.(TableModel); ok { + v := reflect.ValueOf(a) + if v.Kind() == reflect.Ptr && v.IsNil() { + a = reflect.New(v.Type().Elem()).Interface() + return a.(TableModel).TableName() + } + return tm.TableName() + } + t := reflect.TypeOf(a) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() == reflect.Slice { + return GetTableName(reflect.Zero(t.Elem()).Interface()) + } + return DefaultFieldMapFunc(t.Name()) +} diff --git a/struct_test.go b/struct_test.go new file mode 100644 index 0000000..82f1751 --- /dev/null +++ b/struct_test.go @@ -0,0 +1,154 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import ( + "database/sql" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDefaultFieldMapFunc(t *testing.T) { + tests := []struct { + input, output string + }{ + {"Name", "name"}, + {"FirstName", "first_name"}, + {"Name0", "name0"}, + {"ID", "id"}, + {"UserID", "user_id"}, + {"User0ID", "user0_id"}, + {"MyURL", "my_url"}, + {"URLPath", "urlpath"}, + {"MyURLPath", "my_urlpath"}, + {"First_Name", "first_name"}, + {"first_name", "first_name"}, + {"_FirstName", "_first_name"}, + {"_First_Name", "_first_name"}, + } + for _, test := range tests { + r := DefaultFieldMapFunc(test.input) + assert.Equal(t, test.output, r, test.input) + } +} + +func Test_concat(t *testing.T) { + assert.Equal(t, "a.b", concat("a", "b")) + assert.Equal(t, "a", concat("a", "")) + assert.Equal(t, "b", concat("", "b")) +} + +func Test_parseTag(t *testing.T) { + name, pk := parseTag("abc") + assert.Equal(t, "abc", name) + assert.False(t, pk) + + name, pk = parseTag("pk,abc") + assert.Equal(t, "abc", name) + assert.True(t, pk) + + name, pk = parseTag("pk") + assert.Equal(t, "", name) + assert.True(t, pk) +} + +func Test_indirect(t *testing.T) { + var a int + assert.Equal(t, reflect.ValueOf(a).Kind(), indirect(reflect.ValueOf(a)).Kind()) + var b *int + bi := indirect(reflect.ValueOf(&b)) + assert.Equal(t, reflect.ValueOf(a).Kind(), bi.Kind()) + if assert.NotNil(t, b) { + assert.Equal(t, 0, *b) + } +} + +func Test_structValue_columns(t *testing.T) { + customer := Customer{ + ID: 1, + Name: "abc", + Status: 2, + Email: "abc@example.com", + } + sv := newStructValue(&customer, DefaultFieldMapFunc, GetTableName) + cols := sv.columns(nil, nil) + assert.Equal(t, map[string]interface{}{"id": 1, "name": "abc", "status": 2, "email": "abc@example.com", "address": sql.NullString{}}, cols) + + cols = sv.columns([]string{"ID", "name"}, nil) + assert.Equal(t, map[string]interface{}{"id": 1}, cols) + + cols = sv.columns([]string{"ID", "Name"}, []string{"ID"}) + assert.Equal(t, map[string]interface{}{"name": "abc"}, cols) + + cols = sv.columns(nil, []string{"ID", "Address"}) + assert.Equal(t, map[string]interface{}{"name": "abc", "status": 2, "email": "abc@example.com"}, cols) + + sv = newStructValue(&customer, nil, GetTableName) + cols = sv.columns([]string{"ID", "Name"}, []string{"ID"}) + assert.Equal(t, map[string]interface{}{"Name": "abc"}, cols) +} + +func TestIssue37(t *testing.T) { + customer := Customer{ + ID: 1, + Name: "abc", + Status: 2, + Email: "abc@example.com", + } + ev := struct { + Customer + Status string + }{customer, "20"} + sv := newStructValue(&ev, nil, GetTableName) + cols := sv.columns([]string{"ID", "Status"}, nil) + assert.Equal(t, map[string]interface{}{"ID": 1, "Status": "20"}, cols) + + ev2 := struct { + Status string + Customer + }{"20", customer} + sv = newStructValue(&ev2, nil, GetTableName) + cols = sv.columns([]string{"ID", "Status"}, nil) + assert.Equal(t, map[string]interface{}{"ID": 1, "Status": "20"}, cols) +} + +type MyCustomer struct{} + +func TestGetTableName(t *testing.T) { + var c1 Customer + assert.Equal(t, "customer", GetTableName(c1)) + + var c2 *Customer + assert.Equal(t, "customer", GetTableName(c2)) + + var c3 MyCustomer + assert.Equal(t, "my_customer", GetTableName(c3)) + + var c4 []Customer + assert.Equal(t, "customer", GetTableName(c4)) + + var c5 *[]Customer + assert.Equal(t, "customer", GetTableName(c5)) + + var c6 []MyCustomer + assert.Equal(t, "my_customer", GetTableName(c6)) + + var c7 []CustomerPtr + assert.Equal(t, "customer", GetTableName(c7)) + + var c8 **int + assert.Equal(t, "", GetTableName(c8)) +} + +type FA struct { + A1 string + A2 int +} + +type FB struct { + B1 string +} diff --git a/testdata/mysql.sql b/testdata/mysql.sql new file mode 100644 index 0000000..b09a0bd --- /dev/null +++ b/testdata/mysql.sql @@ -0,0 +1,81 @@ +/** + * This is the database schema for testing MySQL support of ozzo-dbx. + * The following database setup is required in order to run the test: + * - host: 127.0.0.1 + * - user: travis + * - pass: + * - database: pocketbase_dbx_test + */ + +DROP TABLE IF EXISTS `order_item` CASCADE; +DROP TABLE IF EXISTS `item` CASCADE; +DROP TABLE IF EXISTS `order` CASCADE; +DROP TABLE IF EXISTS `customer` CASCADE; +DROP TABLE IF EXISTS `user` CASCADE; + +CREATE TABLE `customer` ( + `id` int(11) NOT NULL AUTO_INCREMENT, + `email` varchar(128) NOT NULL, + `name` varchar(128), + `address` text, + `status` int (11) DEFAULT 0, + PRIMARY KEY (`id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + +CREATE TABLE `user` ( + `id` int(11) NOT NULL AUTO_INCREMENT, + `email` varchar(128) NOT NULL, + `created` date, + `updated` date, + PRIMARY KEY (`id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + +CREATE TABLE `item` ( + `id` int(11) NOT NULL AUTO_INCREMENT, + `name` varchar(128) NOT NULL, + PRIMARY KEY (`id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + +CREATE TABLE `order` ( + `id` int(11) NOT NULL AUTO_INCREMENT, + `customer_id` int(11) NOT NULL, + `created_at` int(11) NOT NULL, + `total` decimal(10,0) NOT NULL, + PRIMARY KEY (`id`), + CONSTRAINT `FK_order_customer_id` FOREIGN KEY (`customer_id`) REFERENCES `customer` (`id`) ON DELETE CASCADE +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + +CREATE TABLE `order_item` ( + `order_id` int(11) NOT NULL, + `item_id` int(11) NOT NULL, + `quantity` int(11) NOT NULL, + `subtotal` decimal(10,0) NOT NULL, + PRIMARY KEY (`order_id`,`item_id`), + KEY `FK_order_item_item_id` (`item_id`), + CONSTRAINT `FK_order_item_order_id` FOREIGN KEY (`order_id`) REFERENCES `order` (`id`) ON DELETE CASCADE, + CONSTRAINT `FK_order_item_item_id` FOREIGN KEY (`item_id`) REFERENCES `item` (`id`) ON DELETE CASCADE +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + +INSERT INTO `customer` (email, name, address, status) VALUES ('user1@example.com', 'user1', 'address1', 1); +INSERT INTO `customer` (email, name, address, status) VALUES ('user2@example.com', 'user2', NULL, 1); +INSERT INTO `customer` (email, name, address, status) VALUES ('user3@example.com', 'user3', 'address3', 2); + +INSERT INTO `user` (email, created) VALUES ('user1@example.com', '2015-01-02'); +INSERT INTO `user` (email, created) VALUES ('user2@example.com', now()); + +INSERT INTO `item` (name) VALUES ('The Go Programming Language'); +INSERT INTO `item` (name) VALUES ('Go in Action'); +INSERT INTO `item` (name) VALUES ('Go Programming Blueprints'); +INSERT INTO `item` (name) VALUES ('Building Microservices'); +INSERT INTO `item` (name) VALUES ('Go Web Programming'); + +INSERT INTO `order` (customer_id, created_at, total) VALUES (1, 1325282384, 110.0); +INSERT INTO `order` (customer_id, created_at, total) VALUES (2, 1325334482, 33.0); +INSERT INTO `order` (customer_id, created_at, total) VALUES (2, 1325502201, 40.0); + +INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (1, 1, 1, 30.0); +INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (1, 2, 2, 40.0); +INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (2, 4, 1, 10.0); +INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (2, 5, 1, 15.0); +INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (2, 3, 1, 8.0); +INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (3, 2, 1, 40.0); diff --git a/tx.go b/tx.go new file mode 100644 index 0000000..6eae9bb --- /dev/null +++ b/tx.go @@ -0,0 +1,23 @@ +// Copyright 2016 Qiang Xue. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbx + +import "database/sql" + +// Tx enhances sql.Tx with additional querying methods. +type Tx struct { + Builder + tx *sql.Tx +} + +// Commit commits the transaction. +func (t *Tx) Commit() error { + return t.tx.Commit() +} + +// Rollback aborts the transaction. +func (t *Tx) Rollback() error { + return t.tx.Rollback() +}