Skip to content

Commit

Permalink
replace most of the hook system with a simpler one that uses interfac…
Browse files Browse the repository at this point in the history
…es to pre-calculate hook support and uses a lot less reflect
  • Loading branch information
jmoiron committed May 26, 2013
1 parent 3017cc9 commit 0d53cff
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 48 deletions.
105 changes: 70 additions & 35 deletions gorp.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ type TableMap struct {
deletePlan bindPlan
getPlan bindPlan
dbmap *DbMap
// Cached capabilities for the struct mapped to this table
CanPreInsert bool
CanPostInsert bool
CanPostGet bool
CanPreUpdate bool
CanPostUpdate bool
CanPreDelete bool
CanPostDelete bool
}

// ResetSql removes cached insert/update/select/delete SQL strings
Expand Down Expand Up @@ -711,6 +719,10 @@ func hookedselect(m *DbMap, exec SqlExecutor, i interface{}, query string,
return nil, err
}

// FIXME: should PostGet hooks be run on regular selects? a PostGet
// hook has access to the object and the database, and I'd hate for
// a query to execute SQL on every row of a queryset.

// Determine where the results are: written to i, or returned in list
if t := toSliceType(i); t == nil {
for _, v := range list {
Expand All @@ -731,8 +743,7 @@ func hookedselect(m *DbMap, exec SqlExecutor, i interface{}, query string,
return list, nil
}

func rawselect(m *DbMap, exec SqlExecutor, i interface{}, query string,
args ...interface{}) ([]interface{}, error) {
func rawselect(m *DbMap, exec SqlExecutor, i interface{}, query string, args ...interface{}) ([]interface{}, error) {
appendToSlice := false // Write results to i directly?

// get type for i, verifying it's a struct or a pointer-to-slice
Expand Down Expand Up @@ -955,27 +966,35 @@ func get(m *DbMap, exec SqlExecutor, i interface{},
}
}

err = runHook("PostGet", v, hookArg(exec))
if err != nil {
return nil, err
vi := v.Interface()

if table.CanPostGet {
err = vi.(PostGetter).PostGet(exec)
if err != nil {
return nil, err
}
}

return v.Interface(), nil
return vi, nil
}

func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) {
hookarg := hookArg(exec)
count := int64(0)
var err error
var table *TableMap
var elem reflect.Value
var count int64

for _, ptr := range list {
table, elem, err := m.tableForPointer(ptr, true)
table, elem, err = m.tableForPointer(ptr, true)
if err != nil {
return -1, err
}

eptr := elem.Addr()
err = runHook("PreDelete", eptr, hookarg)
if err != nil {
return -1, err
if table.CanPreDelete {
err = ptr.(PreDeleter).PreDelete(exec)
if err != nil {
return -1, err
}
}

bi := table.bindDelete(elem)
Expand All @@ -984,6 +1003,7 @@ func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) {
if err != nil {
return -1, err
}

rows, err := res.RowsAffected()
if err != nil {
return -1, err
Expand All @@ -996,28 +1016,34 @@ func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) {

count += rows

err = runHook("PostDelete", eptr, hookarg)
if err != nil {
return -1, err
if table.CanPostDelete {
err = ptr.(PostDeleter).PostDelete(exec)
if err != nil {
return -1, err
}
}
}

return count, nil
}

func update(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) {
hookarg := hookArg(exec)
count := int64(0)
var err error
var table *TableMap
var elem reflect.Value
var count int64

for _, ptr := range list {
table, elem, err := m.tableForPointer(ptr, true)
table, elem, err = m.tableForPointer(ptr, true)
if err != nil {
return -1, err
}

eptr := elem.Addr()
err = runHook("PreUpdate", eptr, hookarg)
if err != nil {
return -1, err
if table.CanPreUpdate {
err = ptr.(PreUpdater).PreUpdate(exec)
if err != nil {
return -1, err
}
}

bi := table.bindUpdate(elem)
Expand Down Expand Up @@ -1046,26 +1072,33 @@ func update(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) {

count += rows

err = runHook("PostUpdate", eptr, hookarg)
if err != nil {
return -1, err
if table.CanPostUpdate {
err = ptr.(PostUpdater).PostUpdate(exec)

if err != nil {
return -1, err
}
}
}
return count, nil
}

func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error {
hookarg := hookArg(exec)
var err error
var table *TableMap
var elem reflect.Value

for _, ptr := range list {
table, elem, err := m.tableForPointer(ptr, false)
table, elem, err = m.tableForPointer(ptr, false)
if err != nil {
return err
}

eptr := elem.Addr()
err = runHook("PreInsert", eptr, hookarg)
if err != nil {
return err
if table.CanPreInsert {
err = ptr.(PreInserter).PreInsert(exec)
if err != nil {
return err
}
}

bi := table.bindInsert(elem)
Expand All @@ -1092,9 +1125,11 @@ func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error {
}
}

err = runHook("PostInsert", eptr, hookarg)
if err != nil {
return err
if table.CanPostInsert {
err = ptr.(PostInserter).PostInsert(exec)
if err != nil {
return err
}
}
}
return nil
Expand Down
12 changes: 6 additions & 6 deletions gorp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
_ "github.com/ziutek/mymysql/godrv"
"log"
//"log"
"os"
"reflect"
"testing"
Expand Down Expand Up @@ -120,7 +120,7 @@ func TestCreateTablesIfNotExists(t *testing.T) {
func TestPersistentUser(t *testing.T) {
dbmap := newDbMap()
dbmap.Exec("drop table if exists PersistentUser")
dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
//dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
table := dbmap.AddTable(PersistentUser{}).SetKeys(false, "Key")
table.ColMap("Key").Rename("mykey")
err := dbmap.CreateTablesIfNotExists()
Expand Down Expand Up @@ -300,7 +300,7 @@ func TestNullValues(t *testing.T) {

func TestColumnProps(t *testing.T) {
dbmap := newDbMap()
dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
//dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
t1 := dbmap.AddTable(Invoice{}).SetKeys(true, "Id")
t1.ColMap("Created").Rename("date_created")
t1.ColMap("Updated").SetTransient(true)
Expand Down Expand Up @@ -548,7 +548,7 @@ func TestVersionMultipleRows(t *testing.T) {

func TestWithStringPk(t *testing.T) {
dbmap := newDbMap()
dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
//dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
dbmap.AddTableWithName(WithStringPk{}, "string_pk_test").SetKeys(true, "Id")
_, err := dbmap.Exec("create table string_pk_test (Id varchar(255), Name varchar(255));")
if err != nil {
Expand Down Expand Up @@ -668,7 +668,7 @@ func initDbMapBench() *DbMap {

func initDbMap() *DbMap {
dbmap := newDbMap()
dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
//dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
dbmap.AddTableWithName(Invoice{}, "invoice_test").SetKeys(true, "Id")
dbmap.AddTableWithName(Person{}, "person_test").SetKeys(true, "Id")
dbmap.AddTableWithName(WithIgnoredColumn{}, "ignored_column_test").SetKeys(true, "Id")
Expand All @@ -682,7 +682,7 @@ func initDbMap() *DbMap {

func initDbMapNulls() *DbMap {
dbmap := newDbMap()
dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
//dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
dbmap.AddTable(TableWithNull{}).SetKeys(false, "Id")
err := dbmap.CreateTables()
if err != nil {
Expand Down
52 changes: 52 additions & 0 deletions hooks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package gorp

import (
"reflect"
)

type PreInserter interface {
PreInsert(SqlExecutor) error
}

type PostInserter interface {
PostInsert(SqlExecutor) error
}

type PostGetter interface {
PostGet(SqlExecutor) error
}

type PreUpdater interface {
PreUpdate(SqlExecutor) error
}

type PostUpdater interface {
PostUpdate(SqlExecutor) error
}

type PreDeleter interface {
PreDelete(SqlExecutor) error
}

type PostDeleter interface {
PostDelete(SqlExecutor) error
}

// Determine which hooks are supported by the mapper struct i
func (t *TableMap) setupHooks(i interface{}) {
// These hooks must be implemented on a pointer, so if a value is passed in
// we have to get a pointer for a new value of that type in order for the
// type assertions to pass.
ptr := i
if reflect.ValueOf(i).Kind() == reflect.Struct {
ptr = reflect.New(reflect.ValueOf(i).Type()).Interface()
}

_, t.CanPreInsert = ptr.(PreInserter)
_, t.CanPostInsert = ptr.(PostInserter)
_, t.CanPostGet = ptr.(PostGetter)
_, t.CanPreUpdate = ptr.(PreUpdater)
_, t.CanPostUpdate = ptr.(PostUpdater)
_, t.CanPreDelete = ptr.(PreDeleter)
_, t.CanPostDelete = ptr.(PostDeleter)
}
1 change: 1 addition & 0 deletions mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func (m *DbMap) AddTable(i interface{}, name ...string) *TableMap {
}

tmap := &TableMap{gotype: t, TableName: Name, dbmap: m}
tmap.setupHooks(i)

n := t.NumField()
tmap.columns = make([]*ColumnMap, 0, n)
Expand Down
14 changes: 7 additions & 7 deletions test_all.sh
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
#!/bin/sh

set -e
set -e

export GORP_TEST_DSN=gorptest/gorptest/gorptest
export GORP_TEST_DIALECT=mysql
export GORP_TEST_DSN="gorptest/gorptest/"
export GORP_TEST_DIALECT="mysql"
go test

export GORP_TEST_DSN="user=gorptest password=gorptest dbname=gorptest sslmode=disable"
export GORP_TEST_DIALECT=postgres
export GORP_TEST_DSN="user=$USER dbname=gorptest sslmode=disable"
export GORP_TEST_DIALECT="postgres"
go test

export GORP_TEST_DSN=/tmp/gorptest.bin
export GORP_TEST_DIALECT=sqlite
export GORP_TEST_DSN="/tmp/gorptest.bin"
export GORP_TEST_DIALECT="sqlite"
go test
6 changes: 6 additions & 0 deletions todo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
- replace hook calling process with one that uses interfaces and reflect.CanInterface
- remove list return support
- replace reflect struct filling with structscan
- cache/store as much reflect stuff as possible
- add query builder

0 comments on commit 0d53cff

Please sign in to comment.