// Copyright 2016 Qiang Xue. All rights reserved. // Use of this source code is governed by a MIT-style // license that can be found in the LICENSE file. package dbx import ( "context" "database/sql" "database/sql/driver" "encoding/hex" "errors" "fmt" "strings" "time" ) // ExecHookFunc executes before op allowing custom handling like auto fail/retry. type ExecHookFunc func(q *Query, op func() error) error // OneHookFunc executes right before the query populate the row result from One() call (aka. op). type OneHookFunc func(q *Query, a interface{}, op func(b interface{}) error) error // AllHookFunc executes right before the query populate the row result from All() call (aka. op). type AllHookFunc func(q *Query, sliceA interface{}, op func(sliceB interface{}) error) error // Params represents a list of parameter values to be bound to a SQL statement. // The map keys are the parameter names while the map values are the corresponding parameter values. type Params map[string]interface{} // Executor prepares, executes, or queries a SQL statement. type Executor interface { // Exec executes a SQL statement Exec(query string, args ...interface{}) (sql.Result, error) // ExecContext executes a SQL statement with the given context ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) // Query queries a SQL statement Query(query string, args ...interface{}) (*sql.Rows, error) // QueryContext queries a SQL statement with the given context QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) // Prepare creates a prepared statement Prepare(query string) (*sql.Stmt, error) } // Query represents a SQL statement to be executed. type Query struct { executor Executor sql, rawSQL string placeholders []string params Params stmt *sql.Stmt ctx context.Context // hooks execHook ExecHookFunc oneHook OneHookFunc allHook AllHookFunc // FieldMapper maps struct field names to DB column names. FieldMapper FieldMapFunc // LastError contains the last error (if any) of the query. // LastError is cleared by Execute(), Row(), Rows(), One(), and All(). LastError error // LogFunc is used to log the SQL statement being executed. LogFunc LogFunc // PerfFunc is used to log the SQL execution time. It is ignored if nil. // Deprecated: Please use QueryLogFunc and ExecLogFunc instead. PerfFunc PerfFunc // QueryLogFunc is called each time when performing a SQL query that returns data. QueryLogFunc QueryLogFunc // ExecLogFunc is called each time when a SQL statement is executed. ExecLogFunc ExecLogFunc } // NewQuery creates a new Query with the given SQL statement. func NewQuery(db *DB, executor Executor, sql string) *Query { rawSQL, placeholders := db.processSQL(sql) return &Query{ executor: executor, sql: sql, rawSQL: rawSQL, placeholders: placeholders, params: Params{}, ctx: db.ctx, FieldMapper: db.FieldMapper, LogFunc: db.LogFunc, PerfFunc: db.PerfFunc, QueryLogFunc: db.QueryLogFunc, ExecLogFunc: db.ExecLogFunc, } } // SQL returns the original SQL used to create the query. // The actual SQL (RawSQL) being executed is obtained by replacing the named // parameter placeholders with anonymous ones. func (q *Query) SQL() string { return q.sql } // Context returns the context associated with the query. func (q *Query) Context() context.Context { return q.ctx } // WithContext associates a context with the query. func (q *Query) WithContext(ctx context.Context) *Query { q.ctx = ctx return q } // WithExecHook associates the provided exec hook function with the query. // // It is called for every Query resolver (Execute(), One(), All(), Row(), Column()), // allowing you to implement auto fail/retry or any other additional handling. func (q *Query) WithExecHook(fn ExecHookFunc) *Query { q.execHook = fn return q } // WithOneHook associates the provided hook function with the query, // called on q.One(), allowing you to implement custom struct scan based // on the One() argument and/or result. func (q *Query) WithOneHook(fn OneHookFunc) *Query { q.oneHook = fn return q } // WithOneHook associates the provided hook function with the query, // called on q.All(), allowing you to implement custom slice scan based // on the All() argument and/or result. func (q *Query) WithAllHook(fn AllHookFunc) *Query { q.allHook = fn return q } // logSQL returns the SQL statement with parameters being replaced with the actual values. // The result is only for logging purpose and should not be used to execute. func (q *Query) logSQL() string { s := q.sql for k, v := range q.params { if valuer, ok := v.(driver.Valuer); ok && valuer != nil { v, _ = valuer.Value() } var sv string if str, ok := v.(string); ok { sv = "'" + strings.Replace(str, "'", "''", -1) + "'" } else if bs, ok := v.([]byte); ok { sv = "0x" + hex.EncodeToString(bs) } else { sv = fmt.Sprintf("%v", v) } s = strings.Replace(s, "{:"+k+"}", sv, -1) } return s } // Params returns the parameters to be bound to the SQL statement represented by this query. func (q *Query) Params() Params { return q.params } // Prepare creates a prepared statement for later queries or executions. // Close() should be called after finishing all queries. func (q *Query) Prepare() *Query { stmt, err := q.executor.Prepare(q.rawSQL) if err != nil { q.LastError = err return q } q.stmt = stmt return q } // Close closes the underlying prepared statement. // Close does nothing if the query has not been prepared before. func (q *Query) Close() error { if q.stmt == nil { return nil } err := q.stmt.Close() q.stmt = nil return err } // Bind sets the parameters that should be bound to the SQL statement. // The parameter placeholders in the SQL statement are in the format of "{:ParamName}". func (q *Query) Bind(params Params) *Query { if len(q.params) == 0 { q.params = params } else { for k, v := range params { q.params[k] = v } } return q } // Execute executes the SQL statement without retrieving data. func (q *Query) Execute() (sql.Result, error) { var result sql.Result execErr := q.execWrap(func() error { var err error result, err = q.execute() return err }) return result, execErr } func (q *Query) execute() (result sql.Result, err error) { err = q.LastError q.LastError = nil if err != nil { return } var params []interface{} params, err = replacePlaceholders(q.placeholders, q.params) if err != nil { return } start := time.Now() if q.ctx == nil { if q.stmt == nil { result, err = q.executor.Exec(q.rawSQL, params...) } else { result, err = q.stmt.Exec(params...) } } else { if q.stmt == nil { result, err = q.executor.ExecContext(q.ctx, q.rawSQL, params...) } else { result, err = q.stmt.ExecContext(q.ctx, params...) } } if q.ExecLogFunc != nil { q.ExecLogFunc(q.ctx, time.Now().Sub(start), q.logSQL(), result, err) } if q.LogFunc != nil { q.LogFunc("[%.2fms] Execute SQL: %v", float64(time.Now().Sub(start).Milliseconds()), q.logSQL()) } if q.PerfFunc != nil { q.PerfFunc(time.Now().Sub(start).Nanoseconds(), q.logSQL(), true) } return } // One executes the SQL statement and populates the first row of the result into a struct or NullStringMap. // Refer to Rows.ScanStruct() and Rows.ScanMap() for more details on how to specify // the variable to be populated. // Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned. func (q *Query) One(a interface{}) error { return q.execWrap(func() error { rows, err := q.Rows() if err != nil { return err } if q.oneHook != nil { return q.oneHook(q, a, rows.one) } return rows.one(a) }) } // All executes the SQL statement and populates all the resulting rows into a slice of struct or NullStringMap. // The slice must be given as a pointer. Each slice element must be either a struct or a NullStringMap. // Refer to Rows.ScanStruct() and Rows.ScanMap() for more details on how each slice element can be. // If the query returns no row, the slice will be an empty slice (not nil). func (q *Query) All(slice interface{}) error { return q.execWrap(func() error { rows, err := q.Rows() if err != nil { return err } if q.allHook != nil { return q.allHook(q, slice, rows.all) } return rows.all(slice) }) } // Row executes the SQL statement and populates the first row of the result into a list of variables. // Note that the number of the variables should match to that of the columns in the query result. // Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned. func (q *Query) Row(a ...interface{}) error { return q.execWrap(func() error { rows, err := q.Rows() if err != nil { return err } return rows.row(a...) }) } // Column executes the SQL statement and populates the first column of the result into a slice. // Note that the parameter must be a pointer to a slice. func (q *Query) Column(a interface{}) error { return q.execWrap(func() error { rows, err := q.Rows() if err != nil { return err } return rows.column(a) }) } // Rows executes the SQL statement and returns a Rows object to allow retrieving data row by row. func (q *Query) Rows() (rows *Rows, err error) { err = q.LastError q.LastError = nil if err != nil { return } var params []interface{} params, err = replacePlaceholders(q.placeholders, q.params) if err != nil { return } start := time.Now() var rr *sql.Rows if q.ctx == nil { if q.stmt == nil { rr, err = q.executor.Query(q.rawSQL, params...) } else { rr, err = q.stmt.Query(params...) } } else { if q.stmt == nil { rr, err = q.executor.QueryContext(q.ctx, q.rawSQL, params...) } else { rr, err = q.stmt.QueryContext(q.ctx, params...) } } rows = &Rows{rr, q.FieldMapper} if q.QueryLogFunc != nil { q.QueryLogFunc(q.ctx, time.Now().Sub(start), q.logSQL(), rr, err) } if q.LogFunc != nil { q.LogFunc("[%.2fms] Query SQL: %v", float64(time.Now().Sub(start).Milliseconds()), q.logSQL()) } if q.PerfFunc != nil { q.PerfFunc(time.Now().Sub(start).Nanoseconds(), q.logSQL(), false) } return } func (q *Query) execWrap(op func() error) error { if q.execHook != nil { return q.execHook(q, op) } return op() } // replacePlaceholders converts a list of named parameters into a list of anonymous parameters. func replacePlaceholders(placeholders []string, params Params) ([]interface{}, error) { if len(placeholders) == 0 { return nil, nil } var result []interface{} for _, name := range placeholders { if value, ok := params[name]; ok { result = append(result, value) } else { return nil, errors.New("Named parameter not found: " + name) } } return result, nil }