Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for embedded structs when inserting or updating. #13

Merged
merged 1 commit into from
Aug 17, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions dataset_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package goqu
import (
"reflect"
"sort"
"time"
)

//Generates the default INSERT statement. If Prepared has been called with true then the statement will not be interpolated. See examples.
Expand Down Expand Up @@ -94,16 +95,7 @@ func (me *Dataset) getInsertColsAndVals(rows ...interface{}) (columns ColumnList
rowCols []interface{}
rowVals []interface{}
)
for j := 0; j < newRowValue.NumField(); j++ {
f := newRowValue.Field(j)
t := newRowValue.Type().Field(j)
if me.canInsertField(t) {
if columns == nil {
rowCols = append(rowCols, t.Tag.Get("db"))
}
rowVals = append(rowVals, f.Interface())
}
}
rowCols, rowVals = me.getFieldsValues(newRowValue)
if columns == nil {
columns = cols(rowCols...)
}
Expand All @@ -115,6 +107,29 @@ func (me *Dataset) getInsertColsAndVals(rows ...interface{}) (columns ColumnList
return columns, vals, nil
}

func (me *Dataset) getFieldsValues(value reflect.Value) (rowCols []interface{}, rowVals []interface{}) {
if value.IsValid() {
for i := 0; i < value.NumField(); i++ {
v := value.Field(i)

kind := v.Kind()
if (reflect.TypeOf(v.Interface()).Name() == reflect.TypeOf((*time.Time)(nil)).Elem().Name()) || ((kind != reflect.Struct) && (kind != reflect.Ptr)) {
t := value.Type().Field(i)
if me.canInsertField(t) {
rowCols = append(rowCols, t.Tag.Get("db"))
rowVals = append(rowVals, v.Interface())
}
} else {
cols, vals := me.getFieldsValues(reflect.Indirect(reflect.ValueOf(v.Interface())))
rowCols = append(rowCols, cols...)
rowVals = append(rowVals, vals...)
}
}
}

return rowCols, rowVals
}

//Creates an INSERT statement with the columns and values passed in
func (me *Dataset) insertSql(cols ColumnList, values [][]interface{}, prepared bool) (string, []interface{}, error) {
buf := NewSqlBuilder(prepared)
Expand Down
70 changes: 63 additions & 7 deletions dataset_insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package goqu
import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/technotronicoz/testify/assert"

"time"
)

func (me *datasetTest) TestInsertSqlNoReturning() {
Expand Down Expand Up @@ -36,21 +38,75 @@ func (me *datasetTest) TestInsertSqlWithStructs() {
t := me.T()
ds1 := From("items")
type item struct {
Address string `db:"address"`
Name string `db:"name"`
Created time.Time `db:"created"`
}
created, _ := time.Parse("2006-01-02", "2015-01-01")
sql, _, err := ds1.ToInsertSql(item{Name: "Test", Address: "111 Test Addr", Created: created})
assert.NoError(t, err)
assert.Equal(t, sql, `INSERT INTO "items" ("address", "name", "created") VALUES ('111 Test Addr', 'Test', '`+created.Format(time.RFC3339Nano)+`')`)

sql, _, err = ds1.ToInsertSql(
item{Address: "111 Test Addr", Name: "Test1", Created: created},
item{Address: "211 Test Addr", Name: "Test2", Created: created},
item{Address: "311 Test Addr", Name: "Test3", Created: created},
item{Address: "411 Test Addr", Name: "Test4", Created: created},
)
assert.NoError(t, err)
assert.Equal(t, sql, `INSERT INTO "items" ("address", "name", "created") VALUES ('111 Test Addr', 'Test1', '`+created.Format(time.RFC3339Nano)+`'), ('211 Test Addr', 'Test2', '`+created.Format(time.RFC3339Nano)+`'), ('311 Test Addr', 'Test3', '`+created.Format(time.RFC3339Nano)+`'), ('411 Test Addr', 'Test4', '`+created.Format(time.RFC3339Nano)+`')`)
}

func (me *datasetTest) TestInsertSqlWithEmbeddedStruct() {
t := me.T()
ds1 := From("items")
type phone struct {
Primary string `db:"primary_phone"`
Home string `db:"home_phone"`
}
type item struct {
phone
Address string `db:"address"`
Name string `db:"name"`
}
sql, _, err := ds1.ToInsertSql(item{Name: "Test", Address: "111 Test Addr"})
sql, _, err := ds1.ToInsertSql(item{Name: "Test", Address: "111 Test Addr", phone: phone{Home: "123123", Primary: "456456"}})
assert.NoError(t, err)
assert.Equal(t, sql, `INSERT INTO "items" ("address", "name") VALUES ('111 Test Addr', 'Test')`)
assert.Equal(t, sql, `INSERT INTO "items" ("primary_phone", "home_phone", "address", "name") VALUES ('456456', '123123', '111 Test Addr', 'Test')`)

sql, _, err = ds1.ToInsertSql(
item{Address: "111 Test Addr", Name: "Test1"},
item{Address: "211 Test Addr", Name: "Test2"},
item{Address: "311 Test Addr", Name: "Test3"},
item{Address: "411 Test Addr", Name: "Test4"},
item{Address: "111 Test Addr", Name: "Test1", phone: phone{Home: "123123", Primary: "456456"}},
item{Address: "211 Test Addr", Name: "Test2", phone: phone{Home: "123123", Primary: "456456"}},
item{Address: "311 Test Addr", Name: "Test3", phone: phone{Home: "123123", Primary: "456456"}},
item{Address: "411 Test Addr", Name: "Test4", phone: phone{Home: "123123", Primary: "456456"}},
)
assert.NoError(t, err)
assert.Equal(t, sql, `INSERT INTO "items" ("address", "name") VALUES ('111 Test Addr', 'Test1'), ('211 Test Addr', 'Test2'), ('311 Test Addr', 'Test3'), ('411 Test Addr', 'Test4')`)
assert.Equal(t, sql, `INSERT INTO "items" ("primary_phone", "home_phone", "address", "name") VALUES ('456456', '123123', '111 Test Addr', 'Test1'), ('456456', '123123', '211 Test Addr', 'Test2'), ('456456', '123123', '311 Test Addr', 'Test3'), ('456456', '123123', '411 Test Addr', 'Test4')`)
}

func (me *datasetTest) TestInsertSqlWithEmbeddedStructPtr() {
t := me.T()
ds1 := From("items")
type phone struct {
Primary string `db:"primary_phone"`
Home string `db:"home_phone"`
}
type item struct {
*phone
Address string `db:"address"`
Name string `db:"name"`
}
sql, _, err := ds1.ToInsertSql(item{Name: "Test", Address: "111 Test Addr", phone: &phone{Home: "123123", Primary: "456456"}})
assert.NoError(t, err)
assert.Equal(t, sql, `INSERT INTO "items" ("primary_phone", "home_phone", "address", "name") VALUES ('456456', '123123', '111 Test Addr', 'Test')`)

sql, _, err = ds1.ToInsertSql(
item{Address: "111 Test Addr", Name: "Test1", phone: &phone{Home: "123123", Primary: "456456"}},
item{Address: "211 Test Addr", Name: "Test2", phone: &phone{Home: "123123", Primary: "456456"}},
item{Address: "311 Test Addr", Name: "Test3", phone: &phone{Home: "123123", Primary: "456456"}},
item{Address: "411 Test Addr", Name: "Test4", phone: &phone{Home: "123123", Primary: "456456"}},
)
assert.NoError(t, err)
assert.Equal(t, sql, `INSERT INTO "items" ("primary_phone", "home_phone", "address", "name") VALUES ('456456', '123123', '111 Test Addr', 'Test1'), ('456456', '123123', '211 Test Addr', 'Test2'), ('456456', '123123', '311 Test Addr', 'Test3'), ('456456', '123123', '411 Test Addr', 'Test4')`)
}

func (me *datasetTest) TestInsertSqlWithMaps() {
Expand Down
26 changes: 19 additions & 7 deletions dataset_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package goqu
import (
"reflect"
"sort"
"time"
)

func (me *Dataset) canUpdateField(field reflect.StructField) bool {
Expand Down Expand Up @@ -38,13 +39,7 @@ func (me *Dataset) ToUpdateSql(update interface{}) (string, []interface{}, error
updates = append(updates, I(key.String()).Set(updateValue.MapIndex(key).Interface()))
}
case reflect.Struct:
for j := 0; j < updateValue.NumField(); j++ {
f := updateValue.Field(j)
t := updateValue.Type().Field(j)
if me.canUpdateField(t) {
updates = append(updates, I(t.Tag.Get("db")).Set(f.Interface()))
}
}
updates = me.getUpdateExpression(updateValue)
default:
return "", nil, NewGoquError("Unsupported update interface type %+v", updateValue.Type())
}
Expand Down Expand Up @@ -81,3 +76,20 @@ func (me *Dataset) ToUpdateSql(update interface{}) (string, []interface{}, error
sql, args := buf.ToSql()
return sql, args, nil
}

func (me *Dataset) getUpdateExpression(value reflect.Value) (updates []UpdateExpression) {
for i := 0; i < value.NumField(); i++ {
v := value.Field(i)
kind := v.Kind()
if reflect.TypeOf(v.Interface()).Name() == reflect.TypeOf((*time.Time)(nil)).Elem().Name() || (kind != reflect.Struct && kind != reflect.Ptr) {
t := value.Type().Field(i)
if me.canUpdateField(t) {
updates = append(updates, I(t.Tag.Get("db")).Set(v.Interface()))
}
} else {
updates = append(updates, me.getUpdateExpression(reflect.Indirect(reflect.ValueOf(v.Interface())))...)
}
}

return updates
}
53 changes: 53 additions & 0 deletions dataset_update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package goqu
import (
"database/sql/driver"
"fmt"
"time"

"github.com/DATA-DOG/go-sqlmock"
"github.com/technotronicoz/testify/assert"
Expand Down Expand Up @@ -221,6 +222,58 @@ func (me *datasetTest) TestPreparedUpdateSqlWithSkipupdateTag() {
assert.Equal(t, sql, `UPDATE "items" SET "name"=?`)
}

func (me *datasetTest) TestPreparedUpdateSqlWithEmbeddedStruct() {
t := me.T()
ds1 := From("items")
type phone struct {
Primary string `db:"primary_phone"`
Home string `db:"home_phone"`
Created time.Time `db:"phone_created"`
}
type item struct {
phone
Address string `db:"address" goqu:"skipupdate"`
Name string `db:"name"`
Created time.Time `db:"created"`
}
created, _ := time.Parse("2006-01-02", "2015-01-01")

sql, args, err := ds1.Prepared(true).ToUpdateSql(item{Name: "Test", Address: "111 Test Addr", Created: created, phone: phone{
Home: "123123",
Primary: "456456",
Created: created,
}})
assert.NoError(t, err)
assert.Equal(t, args, []interface{}{"456456", "123123", created, "Test", created})
assert.Equal(t, sql, `UPDATE "items" SET "primary_phone"=?,"home_phone"=?,"phone_created"=?,"name"=?,"created"=?`)
}

func (me *datasetTest) TestPreparedUpdateSqlWithEmbeddedStructPtr() {
t := me.T()
ds1 := From("items")
type phone struct {
Primary string `db:"primary_phone"`
Home string `db:"home_phone"`
Created time.Time `db:"phone_created"`
}
type item struct {
*phone
Address string `db:"address" goqu:"skipupdate"`
Name string `db:"name"`
Created time.Time `db:"created"`
}
created, _ := time.Parse("2006-01-02", "2015-01-01")

sql, args, err := ds1.Prepared(true).ToUpdateSql(item{Name: "Test", Address: "111 Test Addr", Created: created, phone: &phone{
Home: "123123",
Primary: "456456",
Created: created,
}})
assert.NoError(t, err)
assert.Equal(t, args, []interface{}{"456456", "123123", created, "Test", created})
assert.Equal(t, sql, `UPDATE "items" SET "primary_phone"=?,"home_phone"=?,"phone_created"=?,"name"=?,"created"=?`)
}

func (me *datasetTest) TestPreparedUpdateSqlWithWhere() {
t := me.T()
ds1 := From("items")
Expand Down
2 changes: 1 addition & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ type GoquError struct {

func (me GoquError) Error() string {
return me.err
}
}