174 lines
4.9 KiB
Go
174 lines
4.9 KiB
Go
package dbx
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
)
|
|
|
|
type (
|
|
// TableModel is the interface that should be implemented by models which have unconventional table names.
|
|
TableModel interface {
|
|
TableName() string
|
|
}
|
|
|
|
// ModelQuery represents a query associated with a struct model.
|
|
ModelQuery struct {
|
|
db *DB
|
|
ctx context.Context
|
|
builder Builder
|
|
model *structValue
|
|
exclude []string
|
|
lastError error
|
|
}
|
|
)
|
|
|
|
var (
|
|
MissingPKError = errors.New("missing primary key declaration")
|
|
CompositePKError = errors.New("composite primary key is not supported")
|
|
)
|
|
|
|
func NewModelQuery(model interface{}, fieldMapFunc FieldMapFunc, db *DB, builder Builder) *ModelQuery {
|
|
q := &ModelQuery{
|
|
db: db,
|
|
ctx: db.ctx,
|
|
builder: builder,
|
|
model: newStructValue(model, fieldMapFunc, db.TableMapper),
|
|
}
|
|
if q.model == nil {
|
|
q.lastError = VarTypeError("must be a pointer to a struct representing the model")
|
|
}
|
|
return q
|
|
}
|
|
|
|
// Context returns the context associated with the query.
|
|
func (q *ModelQuery) Context() context.Context {
|
|
return q.ctx
|
|
}
|
|
|
|
// WithContext associates a context with the query.
|
|
func (q *ModelQuery) WithContext(ctx context.Context) *ModelQuery {
|
|
q.ctx = ctx
|
|
return q
|
|
}
|
|
|
|
// Exclude excludes the specified struct fields from being inserted/updated into the DB table.
|
|
func (q *ModelQuery) Exclude(attrs ...string) *ModelQuery {
|
|
q.exclude = attrs
|
|
return q
|
|
}
|
|
|
|
// Insert inserts a row in the table using the struct model associated with this query.
|
|
//
|
|
// By default, it inserts *all* public fields into the table, including those nil or empty ones.
|
|
// You may pass a list of the fields to this method to indicate that only those fields should be inserted.
|
|
// You may also call Exclude to exclude some fields from being inserted.
|
|
//
|
|
// If a model has an empty primary key, it is considered auto-incremental and the corresponding struct
|
|
// field will be filled with the generated primary key value after a successful insertion.
|
|
func (q *ModelQuery) Insert(attrs ...string) error {
|
|
if q.lastError != nil {
|
|
return q.lastError
|
|
}
|
|
cols := q.model.columns(attrs, q.exclude)
|
|
pkName := ""
|
|
for name, value := range q.model.pk() {
|
|
if isAutoInc(value) {
|
|
delete(cols, name)
|
|
pkName = name
|
|
break
|
|
}
|
|
}
|
|
|
|
if pkName == "" {
|
|
_, err := q.builder.Insert(q.model.tableName, Params(cols)).WithContext(q.ctx).Execute()
|
|
return err
|
|
}
|
|
|
|
// handle auto-incremental PK
|
|
query := q.builder.Insert(q.model.tableName, Params(cols)).WithContext(q.ctx)
|
|
pkValue, err := insertAndReturnPK(q.db, query, pkName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
pkField := indirect(q.model.dbNameMap[pkName].getField(q.model.value))
|
|
switch pkField.Kind() {
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
pkField.SetUint(uint64(pkValue))
|
|
default:
|
|
pkField.SetInt(pkValue)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func insertAndReturnPK(db *DB, query *Query, pkName string) (int64, error) {
|
|
if db.DriverName() != "postgres" && db.DriverName() != "pgx" {
|
|
result, err := query.Execute()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return result.LastInsertId()
|
|
}
|
|
|
|
// specially handle postgres (lib/pq) as it doesn't support LastInsertId
|
|
returning := fmt.Sprintf(" RETURNING %s", db.QuoteColumnName(pkName))
|
|
query.sql += returning
|
|
query.rawSQL += returning
|
|
var pkValue int64
|
|
err := query.Row(&pkValue)
|
|
return pkValue, err
|
|
}
|
|
|
|
func isAutoInc(value interface{}) bool {
|
|
v := reflect.ValueOf(value)
|
|
switch v.Kind() {
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
return v.Int() == 0
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
|
return v.Uint() == 0
|
|
case reflect.Ptr:
|
|
return v.IsNil() || isAutoInc(v.Elem())
|
|
case reflect.Invalid:
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Update updates a row in the table using the struct model associated with this query.
|
|
// The row being updated has the same primary key as specified by the model.
|
|
//
|
|
// By default, it updates *all* public fields in the table, including those nil or empty ones.
|
|
// You may pass a list of the fields to this method to indicate that only those fields should be updated.
|
|
// You may also call Exclude to exclude some fields from being updated.
|
|
func (q *ModelQuery) Update(attrs ...string) error {
|
|
if q.lastError != nil {
|
|
return q.lastError
|
|
}
|
|
pk := q.model.pk()
|
|
if len(pk) == 0 {
|
|
return MissingPKError
|
|
}
|
|
|
|
cols := q.model.columns(attrs, q.exclude)
|
|
for name := range pk {
|
|
delete(cols, name)
|
|
}
|
|
_, err := q.builder.Update(q.model.tableName, Params(cols), HashExp(pk)).WithContext(q.ctx).Execute()
|
|
return err
|
|
}
|
|
|
|
// Delete deletes a row in the table using the primary key specified by the struct model associated with this query.
|
|
func (q *ModelQuery) Delete() error {
|
|
if q.lastError != nil {
|
|
return q.lastError
|
|
}
|
|
pk := q.model.pk()
|
|
if len(pk) == 0 {
|
|
return MissingPKError
|
|
}
|
|
_, err := q.builder.Delete(q.model.tableName, HashExp(pk)).WithContext(q.ctx).Execute()
|
|
return err
|
|
}
|