384 lines
11 KiB
Go
384 lines
11 KiB
Go
// Copyright 2016 Qiang Xue. All rights reserved.
|
|
// Use of this source code is governed by a MIT-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package dbx
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// ExecHookFunc executes before op allowing custom handling like auto fail/retry.
|
|
type ExecHookFunc func(q *Query, op func() error) error
|
|
|
|
// OneHookFunc executes right before the query populate the row result from One() call (aka. op).
|
|
type OneHookFunc func(q *Query, a interface{}, op func(b interface{}) error) error
|
|
|
|
// AllHookFunc executes right before the query populate the row result from All() call (aka. op).
|
|
type AllHookFunc func(q *Query, sliceA interface{}, op func(sliceB interface{}) error) error
|
|
|
|
// Params represents a list of parameter values to be bound to a SQL statement.
|
|
// The map keys are the parameter names while the map values are the corresponding parameter values.
|
|
type Params map[string]interface{}
|
|
|
|
// Executor prepares, executes, or queries a SQL statement.
|
|
type Executor interface {
|
|
// Exec executes a SQL statement
|
|
Exec(query string, args ...interface{}) (sql.Result, error)
|
|
// ExecContext executes a SQL statement with the given context
|
|
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
|
// Query queries a SQL statement
|
|
Query(query string, args ...interface{}) (*sql.Rows, error)
|
|
// QueryContext queries a SQL statement with the given context
|
|
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
|
// Prepare creates a prepared statement
|
|
Prepare(query string) (*sql.Stmt, error)
|
|
}
|
|
|
|
// Query represents a SQL statement to be executed.
|
|
type Query struct {
|
|
executor Executor
|
|
|
|
sql, rawSQL string
|
|
placeholders []string
|
|
params Params
|
|
|
|
stmt *sql.Stmt
|
|
ctx context.Context
|
|
|
|
// hooks
|
|
execHook ExecHookFunc
|
|
oneHook OneHookFunc
|
|
allHook AllHookFunc
|
|
|
|
// FieldMapper maps struct field names to DB column names.
|
|
FieldMapper FieldMapFunc
|
|
// LastError contains the last error (if any) of the query.
|
|
// LastError is cleared by Execute(), Row(), Rows(), One(), and All().
|
|
LastError error
|
|
// LogFunc is used to log the SQL statement being executed.
|
|
LogFunc LogFunc
|
|
// PerfFunc is used to log the SQL execution time. It is ignored if nil.
|
|
// Deprecated: Please use QueryLogFunc and ExecLogFunc instead.
|
|
PerfFunc PerfFunc
|
|
// QueryLogFunc is called each time when performing a SQL query that returns data.
|
|
QueryLogFunc QueryLogFunc
|
|
// ExecLogFunc is called each time when a SQL statement is executed.
|
|
ExecLogFunc ExecLogFunc
|
|
}
|
|
|
|
// NewQuery creates a new Query with the given SQL statement.
|
|
func NewQuery(db *DB, executor Executor, sql string) *Query {
|
|
rawSQL, placeholders := db.processSQL(sql)
|
|
return &Query{
|
|
executor: executor,
|
|
sql: sql,
|
|
rawSQL: rawSQL,
|
|
placeholders: placeholders,
|
|
params: Params{},
|
|
ctx: db.ctx,
|
|
FieldMapper: db.FieldMapper,
|
|
LogFunc: db.LogFunc,
|
|
PerfFunc: db.PerfFunc,
|
|
QueryLogFunc: db.QueryLogFunc,
|
|
ExecLogFunc: db.ExecLogFunc,
|
|
}
|
|
}
|
|
|
|
// SQL returns the original SQL used to create the query.
|
|
// The actual SQL (RawSQL) being executed is obtained by replacing the named
|
|
// parameter placeholders with anonymous ones.
|
|
func (q *Query) SQL() string {
|
|
return q.sql
|
|
}
|
|
|
|
// Context returns the context associated with the query.
|
|
func (q *Query) Context() context.Context {
|
|
return q.ctx
|
|
}
|
|
|
|
// WithContext associates a context with the query.
|
|
func (q *Query) WithContext(ctx context.Context) *Query {
|
|
q.ctx = ctx
|
|
return q
|
|
}
|
|
|
|
// WithExecHook associates the provided exec hook function with the query.
|
|
//
|
|
// It is called for every Query resolver (Execute(), One(), All(), Row(), Column()),
|
|
// allowing you to implement auto fail/retry or any other additional handling.
|
|
func (q *Query) WithExecHook(fn ExecHookFunc) *Query {
|
|
q.execHook = fn
|
|
return q
|
|
}
|
|
|
|
// WithOneHook associates the provided hook function with the query,
|
|
// called on q.One(), allowing you to implement custom struct scan based
|
|
// on the One() argument and/or result.
|
|
func (q *Query) WithOneHook(fn OneHookFunc) *Query {
|
|
q.oneHook = fn
|
|
return q
|
|
}
|
|
|
|
// WithOneHook associates the provided hook function with the query,
|
|
// called on q.All(), allowing you to implement custom slice scan based
|
|
// on the All() argument and/or result.
|
|
func (q *Query) WithAllHook(fn AllHookFunc) *Query {
|
|
q.allHook = fn
|
|
return q
|
|
}
|
|
|
|
// logSQL returns the SQL statement with parameters being replaced with the actual values.
|
|
// The result is only for logging purpose and should not be used to execute.
|
|
func (q *Query) logSQL() string {
|
|
s := q.sql
|
|
for k, v := range q.params {
|
|
if valuer, ok := v.(driver.Valuer); ok && valuer != nil {
|
|
v, _ = valuer.Value()
|
|
}
|
|
var sv string
|
|
if str, ok := v.(string); ok {
|
|
sv = "'" + strings.Replace(str, "'", "''", -1) + "'"
|
|
} else if bs, ok := v.([]byte); ok {
|
|
sv = "0x" + hex.EncodeToString(bs)
|
|
} else {
|
|
sv = fmt.Sprintf("%v", v)
|
|
}
|
|
s = strings.Replace(s, "{:"+k+"}", sv, -1)
|
|
}
|
|
return s
|
|
}
|
|
|
|
// Params returns the parameters to be bound to the SQL statement represented by this query.
|
|
func (q *Query) Params() Params {
|
|
return q.params
|
|
}
|
|
|
|
// Prepare creates a prepared statement for later queries or executions.
|
|
// Close() should be called after finishing all queries.
|
|
func (q *Query) Prepare() *Query {
|
|
stmt, err := q.executor.Prepare(q.rawSQL)
|
|
if err != nil {
|
|
q.LastError = err
|
|
return q
|
|
}
|
|
q.stmt = stmt
|
|
return q
|
|
}
|
|
|
|
// Close closes the underlying prepared statement.
|
|
// Close does nothing if the query has not been prepared before.
|
|
func (q *Query) Close() error {
|
|
if q.stmt == nil {
|
|
return nil
|
|
}
|
|
|
|
err := q.stmt.Close()
|
|
q.stmt = nil
|
|
return err
|
|
}
|
|
|
|
// Bind sets the parameters that should be bound to the SQL statement.
|
|
// The parameter placeholders in the SQL statement are in the format of "{:ParamName}".
|
|
func (q *Query) Bind(params Params) *Query {
|
|
if len(q.params) == 0 {
|
|
q.params = params
|
|
} else {
|
|
for k, v := range params {
|
|
q.params[k] = v
|
|
}
|
|
}
|
|
return q
|
|
}
|
|
|
|
// Execute executes the SQL statement without retrieving data.
|
|
func (q *Query) Execute() (sql.Result, error) {
|
|
var result sql.Result
|
|
|
|
execErr := q.execWrap(func() error {
|
|
var err error
|
|
result, err = q.execute()
|
|
return err
|
|
})
|
|
|
|
return result, execErr
|
|
}
|
|
|
|
func (q *Query) execute() (result sql.Result, err error) {
|
|
err = q.LastError
|
|
q.LastError = nil
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
var params []interface{}
|
|
params, err = replacePlaceholders(q.placeholders, q.params)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
start := time.Now()
|
|
|
|
if q.ctx == nil {
|
|
if q.stmt == nil {
|
|
result, err = q.executor.Exec(q.rawSQL, params...)
|
|
} else {
|
|
result, err = q.stmt.Exec(params...)
|
|
}
|
|
} else {
|
|
if q.stmt == nil {
|
|
result, err = q.executor.ExecContext(q.ctx, q.rawSQL, params...)
|
|
} else {
|
|
result, err = q.stmt.ExecContext(q.ctx, params...)
|
|
}
|
|
}
|
|
|
|
if q.ExecLogFunc != nil {
|
|
q.ExecLogFunc(q.ctx, time.Now().Sub(start), q.logSQL(), result, err)
|
|
}
|
|
if q.LogFunc != nil {
|
|
q.LogFunc("[%.2fms] Execute SQL: %v", float64(time.Now().Sub(start).Milliseconds()), q.logSQL())
|
|
}
|
|
if q.PerfFunc != nil {
|
|
q.PerfFunc(time.Now().Sub(start).Nanoseconds(), q.logSQL(), true)
|
|
}
|
|
return
|
|
}
|
|
|
|
// One executes the SQL statement and populates the first row of the result into a struct or NullStringMap.
|
|
// Refer to Rows.ScanStruct() and Rows.ScanMap() for more details on how to specify
|
|
// the variable to be populated.
|
|
// Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned.
|
|
func (q *Query) One(a interface{}) error {
|
|
return q.execWrap(func() error {
|
|
rows, err := q.Rows()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if q.oneHook != nil {
|
|
return q.oneHook(q, a, rows.one)
|
|
}
|
|
|
|
return rows.one(a)
|
|
})
|
|
}
|
|
|
|
// All executes the SQL statement and populates all the resulting rows into a slice of struct or NullStringMap.
|
|
// The slice must be given as a pointer. Each slice element must be either a struct or a NullStringMap.
|
|
// Refer to Rows.ScanStruct() and Rows.ScanMap() for more details on how each slice element can be.
|
|
// If the query returns no row, the slice will be an empty slice (not nil).
|
|
func (q *Query) All(slice interface{}) error {
|
|
return q.execWrap(func() error {
|
|
rows, err := q.Rows()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if q.allHook != nil {
|
|
return q.allHook(q, slice, rows.all)
|
|
}
|
|
|
|
return rows.all(slice)
|
|
})
|
|
}
|
|
|
|
// Row executes the SQL statement and populates the first row of the result into a list of variables.
|
|
// Note that the number of the variables should match to that of the columns in the query result.
|
|
// Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned.
|
|
func (q *Query) Row(a ...interface{}) error {
|
|
return q.execWrap(func() error {
|
|
rows, err := q.Rows()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return rows.row(a...)
|
|
})
|
|
}
|
|
|
|
// Column executes the SQL statement and populates the first column of the result into a slice.
|
|
// Note that the parameter must be a pointer to a slice.
|
|
func (q *Query) Column(a interface{}) error {
|
|
return q.execWrap(func() error {
|
|
rows, err := q.Rows()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return rows.column(a)
|
|
})
|
|
}
|
|
|
|
// Rows executes the SQL statement and returns a Rows object to allow retrieving data row by row.
|
|
func (q *Query) Rows() (rows *Rows, err error) {
|
|
err = q.LastError
|
|
q.LastError = nil
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
var params []interface{}
|
|
params, err = replacePlaceholders(q.placeholders, q.params)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
start := time.Now()
|
|
|
|
var rr *sql.Rows
|
|
if q.ctx == nil {
|
|
if q.stmt == nil {
|
|
rr, err = q.executor.Query(q.rawSQL, params...)
|
|
} else {
|
|
rr, err = q.stmt.Query(params...)
|
|
}
|
|
} else {
|
|
if q.stmt == nil {
|
|
rr, err = q.executor.QueryContext(q.ctx, q.rawSQL, params...)
|
|
} else {
|
|
rr, err = q.stmt.QueryContext(q.ctx, params...)
|
|
}
|
|
}
|
|
rows = &Rows{rr, q.FieldMapper}
|
|
|
|
if q.QueryLogFunc != nil {
|
|
q.QueryLogFunc(q.ctx, time.Now().Sub(start), q.logSQL(), rr, err)
|
|
}
|
|
if q.LogFunc != nil {
|
|
q.LogFunc("[%.2fms] Query SQL: %v", float64(time.Now().Sub(start).Milliseconds()), q.logSQL())
|
|
}
|
|
if q.PerfFunc != nil {
|
|
q.PerfFunc(time.Now().Sub(start).Nanoseconds(), q.logSQL(), false)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (q *Query) execWrap(op func() error) error {
|
|
if q.execHook != nil {
|
|
return q.execHook(q, op)
|
|
}
|
|
return op()
|
|
}
|
|
|
|
// replacePlaceholders converts a list of named parameters into a list of anonymous parameters.
|
|
func replacePlaceholders(placeholders []string, params Params) ([]interface{}, error) {
|
|
if len(placeholders) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
var result []interface{}
|
|
for _, name := range placeholders {
|
|
if value, ok := params[name]; ok {
|
|
result = append(result, value)
|
|
} else {
|
|
return nil, errors.New("Named parameter not found: " + name)
|
|
}
|
|
}
|
|
return result, nil
|
|
}
|