Skip to content

Commit

Permalink
jmoiron#22 - fix LastInsertId() with Postgresql. Also fixes quoting i…
Browse files Browse the repository at this point in the history
…ssue noted in jmoiron#21 and test errors in jmoiron#19
  • Loading branch information
coopernurse committed Mar 1, 2013
1 parent dcf1583 commit 84cf915
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 568 deletions.
66 changes: 51 additions & 15 deletions dialect.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package gorp

import (
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
)

// The Dialect interface encapsulates behaviors that differ across
Expand All @@ -21,11 +21,15 @@ type Dialect interface {
// string to append to primary key column definitions
AutoIncrStr() string

AutoIncrBindValue() string

AutoIncrInsertSuffix(col *ColumnMap) string

// string to append to "create table" statement for vendor specific
// table attributes
CreateTableSuffix() string

LastInsertId(res *sql.Result, table *TableMap, exec SqlExecutor) (int64, error)
InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error)

// bind variable string to use when forming SQL statements
// in many dbs it is "?", but Postgres appears to use $1
Expand All @@ -39,6 +43,14 @@ type Dialect interface {
QuoteField(field string) string
}

func standardInsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
res, err := exec.Exec(insertSql, params...)
if err != nil {
return 0, err
}
return res.LastInsertId()
}

///////////////////////////////////////////////////////
// sqlite3 //
/////////////
Expand Down Expand Up @@ -81,6 +93,14 @@ func (d SqliteDialect) AutoIncrStr() string {
return "autoincrement"
}

func (d SqliteDialect) AutoIncrBindValue() string {
return "null"
}

func (d SqliteDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
return ""
}

// Returns suffix
func (d SqliteDialect) CreateTableSuffix() string {
return d.suffix
Expand All @@ -91,12 +111,12 @@ func (d SqliteDialect) BindVar(i int) string {
return "?"
}

func (d SqliteDialect) LastInsertId(res *sql.Result, table *TableMap, exec SqlExecutor) (int64, error) {
return (*res).LastInsertId()
func (d SqliteDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
return standardInsertAutoIncr(exec, insertSql, params...)
}

func (d SqliteDialect) QuoteField(f string) string {
return "'" + f + "'"
return `"` + f + `"`
}

///////////////////////////////////////////////////////
Expand Down Expand Up @@ -149,6 +169,14 @@ func (d PostgresDialect) AutoIncrStr() string {
return ""
}

func (d PostgresDialect) AutoIncrBindValue() string {
return "default"
}

func (d PostgresDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
return " returning " + col.ColumnName
}

// Returns suffix
func (d PostgresDialect) CreateTableSuffix() string {
return d.suffix
Expand All @@ -159,24 +187,24 @@ func (d PostgresDialect) BindVar(i int) string {
return fmt.Sprintf("$%d", i+1)
}

func (d PostgresDialect) LastInsertId(res *sql.Result, table *TableMap, exec SqlExecutor) (int64, error) {
sql := fmt.Sprintf("select currval('%s_%s_seq')", table.TableName, table.keys[0].ColumnName)
rows, err := exec.query(sql)
func (d PostgresDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
rows, err := exec.query(insertSql, params...)
if err != nil {
return 0, err
}
defer rows.Close()

if rows.Next() {
var dest int64
err = rows.Scan(&dest)
return dest, nil
var id int64
err := rows.Scan(&id)
return id, err
}
return 0, errors.New(fmt.Sprintf("PostgresDialect: %s did not return a row", sql))

return 0, errors.New("No serial value returned for insert: " + insertSql)
}

func (d PostgresDialect) QuoteField(f string) string {
return `"` + f + `"`
return `"` + strings.ToLower(f) + `"`
}

///////////////////////////////////////////////////////
Expand Down Expand Up @@ -229,6 +257,14 @@ func (m MySQLDialect) AutoIncrStr() string {
return "auto_increment"
}

func (m MySQLDialect) AutoIncrBindValue() string {
return "null"
}

func (m MySQLDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
return ""
}

// Returns engine=%s charset=%s based on values stored on struct
func (m MySQLDialect) CreateTableSuffix() string {
return fmt.Sprintf(" engine=%s charset=%s", m.Engine, m.Encoding)
Expand All @@ -239,8 +275,8 @@ func (m MySQLDialect) BindVar(i int) string {
return "?"
}

func (m MySQLDialect) LastInsertId(res *sql.Result, table *TableMap, exec SqlExecutor) (int64, error) {
return (*res).LastInsertId()
func (m MySQLDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
return standardInsertAutoIncr(exec, insertSql, params...)
}

func (d MySQLDialect) QuoteField(f string) string {
Expand Down
39 changes: 25 additions & 14 deletions gorp.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,31 +262,42 @@ func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) {
s.WriteString(fmt.Sprintf("insert into %s (", t.TableName))

x := 0
first := true
for y := range t.columns {
col := t.columns[y]
if col.isAutoIncr {
plan.autoIncrIdx = y
} else if !col.Transient {
if x > 0 {

if !col.Transient {
if !first {
s.WriteString(",")
s2.WriteString(",")
}
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
s2.WriteString(t.dbmap.Dialect.BindVar(x))

f := elem.FieldByName(col.fieldName)

if col == t.version {
f.SetInt(int64(1))
}

plan.argFields = append(plan.argFields, col.fieldName)
x++
if col.isAutoIncr {
s2.WriteString(t.dbmap.Dialect.AutoIncrBindValue())
plan.autoIncrIdx = y
} else {
s2.WriteString(t.dbmap.Dialect.BindVar(x))
plan.argFields = append(plan.argFields, col.fieldName)
x++
}

first = false
}
}
s.WriteString(") values (")
s.WriteString(s2.String())
s.WriteString(");")
s.WriteString(")")
if plan.autoIncrIdx > -1 {
s.WriteString(t.dbmap.Dialect.AutoIncrInsertSuffix(t.columns[plan.autoIncrIdx]))
}
s.WriteString(";")

plan.query = s.String()
t.insertPlan = plan
Expand Down Expand Up @@ -1347,17 +1358,17 @@ func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error {
return err
}

res, err := exec.Exec(bi.query, bi.args...)
if err != nil {
return err
}

if bi.autoIncrIdx > -1 {
id, err := m.Dialect.LastInsertId(&res, table, exec)
id, err := m.Dialect.InsertAutoIncr(exec, bi.query, bi.args...)
if err != nil {
return err
}
elem.Field(bi.autoIncrIdx).SetInt(id)
} else {
_, err := exec.Exec(bi.query, bi.args...)
if err != nil {
return err
}
}

err = runHook("PostInsert", eptr, hookarg)
Expand Down
Loading

0 comments on commit 84cf915

Please sign in to comment.