Skip to content

Commit

Permalink
marshal: support nested tuples (apache#937)
Browse files Browse the repository at this point in the history
Add support for marshalling/unmarshalling nested tuples via slice, array
and structs.

Also improve tuple handling so they can be used with slices, arrays and
structs for values.
  • Loading branch information
Zariel authored Jul 8, 2017
1 parent 566b74b commit bb83efe
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 5 deletions.
104 changes: 101 additions & 3 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -1595,12 +1595,11 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) {
case unsetColumn:
return nil, unmarshalErrorf("Invalid request: UnsetValue is unsupported for tuples")
case []interface{}:
var buf []byte

if len(v) != len(tuple.Elems) {
return nil, unmarshalErrorf("cannont marshal tuple: wrong number of elements")
}

var buf []byte
for i, elem := range v {
data, err := Marshal(tuple.Elems[i], elem)
if err != nil {
Expand All @@ -1615,7 +1614,51 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) {
return buf, nil
}

return nil, unmarshalErrorf("cannot marshal %T into %s", value, tuple)
rv := reflect.ValueOf(value)
t := rv.Type()
k := t.Kind()

switch k {
case reflect.Struct:
if v := t.NumField(); v != len(tuple.Elems) {
return nil, marshalErrorf("can not marshal tuple into struct %v, not enough fields have %d need %d", t, v, len(tuple.Elems))
}

var buf []byte
for i, elem := range tuple.Elems {
data, err := Marshal(elem, rv.Field(i).Interface())
if err != nil {
return nil, err
}

n := len(data)
buf = appendInt(buf, int32(n))
buf = append(buf, data...)
}

return buf, nil
case reflect.Slice, reflect.Array:
size := rv.Len()
if size != len(tuple.Elems) {
return nil, marshalErrorf("can not marshal tuple into %v of length %d need %d elements", k, size, len(tuple.Elems))
}

var buf []byte
for i, elem := range tuple.Elems {
data, err := Marshal(elem, rv.Index(i).Interface())
if err != nil {
return nil, err
}

n := len(data)
buf = appendInt(buf, int32(n))
buf = append(buf, data...)
}

return buf, nil
}

return nil, marshalErrorf("cannot marshal %T into %s", value, tuple)
}

// currently only support unmarshal into a list of values, this makes it possible
Expand Down Expand Up @@ -1644,6 +1687,61 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
return nil
}

rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Ptr {
return unmarshalErrorf("can not unmarshal into non-pointer %T", value)
}

rv = rv.Elem()
t := rv.Type()
k := t.Kind()

switch k {
case reflect.Struct:
if v := t.NumField(); v != len(tuple.Elems) {
return unmarshalErrorf("can not unmarshal tuple into struct %v, not enough fields have %d need %d", t, v, len(tuple.Elems))
}

for i, elem := range tuple.Elems {
m := readInt(data)
data = data[4:]

v := elem.New()
if err := Unmarshal(elem, data[:m], v); err != nil {
return err
}
rv.Field(i).Set(reflect.ValueOf(v).Elem())

data = data[m:]
}

return nil
case reflect.Slice, reflect.Array:
if k == reflect.Array {
size := rv.Len()
if size != len(tuple.Elems) {
return unmarshalErrorf("can not unmarshal tuple into array of length %d need %d elements", size, len(tuple.Elems))
}
} else {
rv.Set(reflect.MakeSlice(t, len(tuple.Elems), len(tuple.Elems)))
}

for i, elem := range tuple.Elems {
m := readInt(data)
data = data[4:]

v := elem.New()
if err := Unmarshal(elem, data[:m], v); err != nil {
return err
}
rv.Index(i).Set(reflect.ValueOf(v).Elem())

data = data[m:]
}

return nil
}

return unmarshalErrorf("cannot unmarshal %s into %T", info, value)
}

Expand Down
53 changes: 51 additions & 2 deletions tuple_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

package gocql

import "testing"
import (
"reflect"
"testing"
)

func TestTupleSimple(t *testing.T) {
session := createSession(t)
Expand Down Expand Up @@ -55,7 +58,6 @@ func TestTupleMapScan(t *testing.T) {
if session.cfg.ProtoVersion < protoVersion3 {
t.Skip("tuple types are only available of proto>=3")
}
defer session.Close()

err := createTable(session, `CREATE TABLE gocql_test.tuple_map_scan(
id int,
Expand All @@ -76,3 +78,50 @@ func TestTupleMapScan(t *testing.T) {
t.Fatal(err)
}
}

func TestTuple_NestedCollection(t *testing.T) {
session := createSession(t)
defer session.Close()
if session.cfg.ProtoVersion < protoVersion3 {
t.Skip("tuple types are only available of proto>=3")
}

err := createTable(session, `CREATE TABLE gocql_test.nested_tuples(
id int,
val list<frozen<tuple<int, text>>>,
primary key(id))`)
if err != nil {
t.Fatal(err)
}

type typ struct {
A int
B string
}

tests := []struct {
name string
val interface{}
}{
{name: "slice", val: [][]interface{}{{1, "2"}, {3, "4"}}},
{name: "array", val: [][2]interface{}{{1, "2"}, {3, "4"}}},
{name: "struct", val: []typ{{1, "2"}, {3, "4"}}},
}

for i, test := range tests {
t.Run(test.name, func(t *testing.T) {
if err := session.Query(`INSERT INTO nested_tuples (id, val) VALUES (?, ?);`, i, test.val).Exec(); err != nil {
t.Fatal(err)
}

rv := reflect.ValueOf(test.val)
res := reflect.New(rv.Type()).Elem().Addr().Interface()

err = session.Query(`SELECT val FROM nested_tuples WHERE id=?`, i).Scan(res)
if err != nil {
t.Fatal(err)
}
})
}
}

0 comments on commit bb83efe

Please sign in to comment.