301 lines
7.1 KiB
Go
301 lines
7.1 KiB
Go
// Copyright 2016 Qiang Xue. All rights reserved.
|
|
// Use of this source code is governed by a MIT-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package dbx
|
|
|
|
import (
|
|
"database/sql"
|
|
"reflect"
|
|
)
|
|
|
|
// VarTypeError indicates a variable type error when trying to populating a variable with DB result.
|
|
type VarTypeError string
|
|
|
|
// Error returns the error message.
|
|
func (s VarTypeError) Error() string {
|
|
return "Invalid variable type: " + string(s)
|
|
}
|
|
|
|
// NullStringMap is a map of sql.NullString that can be used to hold DB query result.
|
|
// The map keys correspond to the DB column names, while the map values are their corresponding column values.
|
|
type NullStringMap map[string]sql.NullString
|
|
|
|
// Rows enhances sql.Rows by providing additional data query methods.
|
|
// Rows can be obtained by calling Query.Rows(). It is mainly used to populate data row by row.
|
|
type Rows struct {
|
|
*sql.Rows
|
|
fieldMapFunc FieldMapFunc
|
|
}
|
|
|
|
// ScanMap populates the current row of data into a NullStringMap.
|
|
// Note that the NullStringMap must not be nil, or it will panic.
|
|
// The NullStringMap will be populated using column names as keys and their values as
|
|
// the corresponding element values.
|
|
func (r *Rows) ScanMap(a NullStringMap) error {
|
|
cols, _ := r.Columns()
|
|
var refs []interface{}
|
|
for i := 0; i < len(cols); i++ {
|
|
var t sql.NullString
|
|
refs = append(refs, &t)
|
|
}
|
|
if err := r.Scan(refs...); err != nil {
|
|
return err
|
|
}
|
|
|
|
for i, col := range cols {
|
|
a[col] = *refs[i].(*sql.NullString)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ScanStruct populates the current row of data into a struct.
|
|
// The struct must be given as a pointer.
|
|
//
|
|
// ScanStruct associates struct fields with DB table columns through a field mapping function.
|
|
// It populates a struct field with the data of its associated column.
|
|
// Note that only exported struct fields will be populated.
|
|
//
|
|
// By default, DefaultFieldMapFunc() is used to map struct fields to table columns.
|
|
// This function separates each word in a field name with a underscore and turns every letter into lower case.
|
|
// For example, "LastName" is mapped to "last_name", "MyID" is mapped to "my_id", and so on.
|
|
// To change the default behavior, set DB.FieldMapper with your custom mapping function.
|
|
// You may also set Query.FieldMapper to change the behavior for particular queries.
|
|
func (r *Rows) ScanStruct(a interface{}) error {
|
|
rv := reflect.ValueOf(a)
|
|
if rv.Kind() != reflect.Ptr || rv.IsNil() {
|
|
return VarTypeError("must be a pointer")
|
|
}
|
|
rv = indirect(rv)
|
|
if rv.Kind() != reflect.Struct {
|
|
return VarTypeError("must be a pointer to a struct")
|
|
}
|
|
|
|
si := getStructInfo(rv.Type(), r.fieldMapFunc)
|
|
|
|
cols, _ := r.Columns()
|
|
refs := make([]interface{}, len(cols))
|
|
|
|
for i, col := range cols {
|
|
if fi, ok := si.dbNameMap[col]; ok {
|
|
refs[i] = fi.getField(rv).Addr().Interface()
|
|
} else {
|
|
refs[i] = &sql.NullString{}
|
|
}
|
|
}
|
|
|
|
if err := r.Scan(refs...); err != nil {
|
|
return err
|
|
}
|
|
|
|
// check for PostScanner
|
|
if rv.CanAddr() {
|
|
addr := rv.Addr()
|
|
if addr.CanInterface() {
|
|
if ps, ok := addr.Interface().(PostScanner); ok {
|
|
if err := ps.PostScan(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// all populates all rows of query result into a slice of struct or NullStringMap.
|
|
// Note that the slice must be given as a pointer.
|
|
func (r *Rows) all(slice interface{}) error {
|
|
defer r.Close()
|
|
|
|
v := reflect.ValueOf(slice)
|
|
if v.Kind() != reflect.Ptr || v.IsNil() {
|
|
return VarTypeError("must be a pointer")
|
|
}
|
|
v = indirect(v)
|
|
|
|
if v.Kind() != reflect.Slice {
|
|
return VarTypeError("must be a slice of struct or NullStringMap")
|
|
}
|
|
|
|
if v.IsNil() {
|
|
// create an empty slice
|
|
v.Set(reflect.MakeSlice(v.Type(), 0, 0))
|
|
}
|
|
|
|
et := v.Type().Elem()
|
|
|
|
if et.Kind() == reflect.Map {
|
|
for r.Next() {
|
|
ev, ok := reflect.MakeMap(et).Interface().(NullStringMap)
|
|
if !ok {
|
|
return VarTypeError("must be a slice of struct or NullStringMap")
|
|
}
|
|
if err := r.ScanMap(ev); err != nil {
|
|
return err
|
|
}
|
|
v.Set(reflect.Append(v, reflect.ValueOf(ev)))
|
|
}
|
|
return r.Close()
|
|
}
|
|
|
|
var isSliceOfPointers bool
|
|
if et.Kind() == reflect.Ptr {
|
|
isSliceOfPointers = true
|
|
et = et.Elem()
|
|
}
|
|
|
|
if et.Kind() != reflect.Struct {
|
|
return VarTypeError("must be a slice of struct or NullStringMap")
|
|
}
|
|
|
|
etPtr := reflect.PtrTo(et)
|
|
implementsPostScanner := etPtr.Implements(postScannerType)
|
|
|
|
si := getStructInfo(et, r.fieldMapFunc)
|
|
|
|
cols, _ := r.Columns()
|
|
for r.Next() {
|
|
ev := reflect.New(et).Elem()
|
|
refs := make([]interface{}, len(cols))
|
|
for i, col := range cols {
|
|
if fi, ok := si.dbNameMap[col]; ok {
|
|
refs[i] = fi.getField(ev).Addr().Interface()
|
|
} else {
|
|
refs[i] = &sql.NullString{}
|
|
}
|
|
}
|
|
if err := r.Scan(refs...); err != nil {
|
|
return err
|
|
}
|
|
|
|
if isSliceOfPointers {
|
|
ev = ev.Addr()
|
|
}
|
|
|
|
// check for PostScanner
|
|
if implementsPostScanner {
|
|
evAddr := ev
|
|
if ev.CanAddr() {
|
|
evAddr = ev.Addr()
|
|
}
|
|
if evAddr.CanInterface() {
|
|
if ps, ok := evAddr.Interface().(PostScanner); ok {
|
|
if err := ps.PostScan(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
v.Set(reflect.Append(v, ev))
|
|
}
|
|
|
|
return r.Close()
|
|
}
|
|
|
|
// column populates the given slice with the first column of the query result.
|
|
// Note that the slice must be given as a pointer.
|
|
func (r *Rows) column(slice interface{}) error {
|
|
defer r.Close()
|
|
|
|
v := reflect.ValueOf(slice)
|
|
if v.Kind() != reflect.Ptr || v.IsNil() {
|
|
return VarTypeError("must be a pointer to a slice")
|
|
}
|
|
v = indirect(v)
|
|
|
|
if v.Kind() != reflect.Slice {
|
|
return VarTypeError("must be a pointer to a slice")
|
|
}
|
|
|
|
et := v.Type().Elem()
|
|
|
|
cols, _ := r.Columns()
|
|
for r.Next() {
|
|
ev := reflect.New(et)
|
|
refs := make([]interface{}, len(cols))
|
|
for i := range cols {
|
|
if i == 0 {
|
|
refs[i] = ev.Interface()
|
|
} else {
|
|
refs[i] = &sql.NullString{}
|
|
}
|
|
}
|
|
if err := r.Scan(refs...); err != nil {
|
|
return err
|
|
}
|
|
v.Set(reflect.Append(v, ev.Elem()))
|
|
}
|
|
|
|
return r.Close()
|
|
}
|
|
|
|
// one populates a single row of query result into a struct or a NullStringMap.
|
|
// Note that if a struct is given, it should be a pointer.
|
|
func (r *Rows) one(a interface{}) error {
|
|
defer r.Close()
|
|
|
|
if !r.Next() {
|
|
if err := r.Err(); err != nil {
|
|
return err
|
|
}
|
|
return sql.ErrNoRows
|
|
}
|
|
|
|
var err error
|
|
|
|
rt := reflect.TypeOf(a)
|
|
if rt.Kind() == reflect.Ptr && rt.Elem().Kind() == reflect.Map {
|
|
// pointer to map
|
|
v := indirect(reflect.ValueOf(a))
|
|
if v.IsNil() {
|
|
v.Set(reflect.MakeMap(v.Type()))
|
|
}
|
|
a = v.Interface()
|
|
rt = reflect.TypeOf(a)
|
|
}
|
|
|
|
if rt.Kind() == reflect.Map {
|
|
v, ok := a.(NullStringMap)
|
|
if !ok {
|
|
return VarTypeError("must be a NullStringMap")
|
|
}
|
|
if v == nil {
|
|
return VarTypeError("NullStringMap is nil")
|
|
}
|
|
err = r.ScanMap(v)
|
|
} else {
|
|
err = r.ScanStruct(a)
|
|
}
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return r.Close()
|
|
}
|
|
|
|
// row populates a single row of query result into a list of variables.
|
|
func (r *Rows) row(a ...interface{}) error {
|
|
defer r.Close()
|
|
|
|
for _, dp := range a {
|
|
if _, ok := dp.(*sql.RawBytes); ok {
|
|
return VarTypeError("RawBytes isn't allowed on Row()")
|
|
}
|
|
}
|
|
|
|
if !r.Next() {
|
|
if err := r.Err(); err != nil {
|
|
return err
|
|
}
|
|
return sql.ErrNoRows
|
|
}
|
|
if err := r.Scan(a...); err != nil {
|
|
return err
|
|
}
|
|
|
|
return r.Close()
|
|
}
|