Skip to content

Commit

Permalink
marshal: fix udt nil collection marshalling
Browse files Browse the repository at this point in the history
Fix reflect usage to correct marshal/unmarshal nil collections which are
embedded inside a udt column.
  • Loading branch information
Zariel committed Mar 20, 2016
1 parent f8d117a commit 6e6042c
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 38 deletions.
9 changes: 9 additions & 0 deletions frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -1646,6 +1646,15 @@ func (f *framer) writeByte(b byte) {
f.wbuf = append(f.wbuf, b)
}

func appendBytes(p []byte, d []byte) []byte {
if d == nil {
return appendInt(p, -1)
}
p = appendInt(p, int32(len(d)))
p = append(p, d...)
return p
}

func appendShort(p []byte, n uint16) []byte {
return append(p,
byte(n>>8),
Expand Down
57 changes: 20 additions & 37 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,10 @@ func marshalList(info TypeInfo, value interface{}) ([]byte, error) {
rv := reflect.ValueOf(value)
t := rv.Type()
k := t.Kind()
if k == reflect.Slice && rv.IsNil() {
return nil, nil
}

switch k {
case reflect.Slice, reflect.Array:
buf := &bytes.Buffer{}
Expand Down Expand Up @@ -994,6 +998,9 @@ func unmarshalList(info TypeInfo, data []byte, value interface{}) error {
if k == reflect.Array {
return unmarshalErrorf("unmarshal list: can not store nil in array value")
}
if rv.IsNil() {
return nil
}
rv.Set(reflect.Zero(t))
return nil
}
Expand Down Expand Up @@ -1032,6 +1039,10 @@ func marshalMap(info TypeInfo, value interface{}) ([]byte, error) {
}

rv := reflect.ValueOf(value)
if rv.IsNil() {
return nil, nil
}

t := rv.Type()
if t.Kind() != reflect.Map {
return nil, marshalErrorf("can not marshal %T into %s", value, info)
Expand Down Expand Up @@ -1344,12 +1355,7 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
return nil, err
}

if data == nil && typeCanBeNull(e.Type) {
buf = appendInt(buf, -1)
} else {
buf = appendInt(buf, int32(len(data)))
buf = append(buf, data...)
}
buf = appendBytes(buf, data)
}

return buf, nil
Expand All @@ -1366,12 +1372,7 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
return nil, err
}

if data == nil && typeCanBeNull(e.Type) {
buf = appendInt(buf, -1)
} else {
buf = appendInt(buf, int32(len(data)))
buf = append(buf, data...)
}
buf = appendBytes(buf, data)
}

return buf, nil
Expand Down Expand Up @@ -1406,37 +1407,19 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
f = k.FieldByName(e.Name)
}

if !f.IsValid() {
if _, ok := e.Type.(CollectionType); ok {
f = reflect.Zero(goType(e.Type))
} else {
buf = appendInt(buf, -1)
continue
}
} else if f.Kind() == reflect.Ptr {
if f.IsNil() {
buf = appendInt(buf, -1)
continue
} else {
f = f.Elem()
var data []byte
if f.IsValid() && f.CanInterface() {
var err error
data, err = Marshal(e.Type, f.Interface())
if err != nil {
return nil, err
}
}

data, err := Marshal(e.Type, f.Interface())
if err != nil {
return nil, err
}

if data == nil && typeCanBeNull(e.Type) {
buf = appendInt(buf, -1)
} else {
buf = appendInt(buf, int32(len(data)))
buf = append(buf, data...)
}
buf = appendBytes(buf, data)
}

return buf, nil

}

func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error {
Expand Down
24 changes: 23 additions & 1 deletion udt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,33 @@ func TestUDT_EmptyCollections(t *testing.T) {
t.Fatal(err)
}

type udt struct {
A []string `cql:"a"`
B map[string]string `cql:"b"`
C []string `cql:"c"`
}

id := TimeUUID()
err = session.Query("INSERT INTO nil_collections(id, udt_col) VALUES(?, ?)", id, &struct{}{}).Exec()
err = session.Query("INSERT INTO nil_collections(id, udt_col) VALUES(?, ?)", id, &udt{}).Exec()
if err != nil {
t.Fatal(err)
}

var val udt
err = session.Query("SELECT udt_col FROM nil_collections WHERE id=?", id).Scan(&val)
if err != nil {
t.Fatal(err)
}

if val.A != nil {
t.Errorf("expected to get nil got %#+v", val.A)
}
if val.B != nil {
t.Errorf("expected to get nil got %#+v", val.B)
}
if val.C != nil {
t.Errorf("expected to get nil got %#+v", val.C)
}
}

func TestUDT_UpdateField(t *testing.T) {
Expand Down

0 comments on commit 6e6042c

Please sign in to comment.