diff --git a/marshal.go b/marshal.go index 5354ad091..254da307a 100644 --- a/marshal.go +++ b/marshal.go @@ -1521,6 +1521,11 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { udt := info.(UDTTypeInfo) for _, e := range udt.Elements { + if len(data) < 4 { + // UDT def does not match the column value + return nil + } + size := readInt(data[:4]) data = data[4:] @@ -1532,7 +1537,7 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { } if !f.IsValid() || !f.CanAddr() { - return unmarshalErrorf("cannot unmarshal %s into %T", info, value) + return unmarshalErrorf("cannot unmarshal %s into %T: field %v is not valid", info, value, e.Name) } fk := f.Addr().Interface() diff --git a/udt_test.go b/udt_test.go index 441d96f1f..f4d281ec3 100644 --- a/udt_test.go +++ b/udt_test.go @@ -424,3 +424,60 @@ func TestUDT_EmptyCollections(t *testing.T) { t.Fatal(err) } } + +func TestUDT_UpdateField(t *testing.T) { + if *flagProto < protoVersion3 { + t.Skip("UDT are only available on protocol >= 3") + } + + session := createSession(t) + defer session.Close() + + err := createTable(session, `CREATE TYPE gocql_test.update_field_udt( + name text, + owner text);`) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, `CREATE TABLE gocql_test.update_field( + id uuid, + udt_col frozen, + + primary key(id) + );`) + if err != nil { + t.Fatal(err) + } + + type col struct { + Name string `cql:"name"` + Owner string `cql:"owner"` + Data string `cql:"data"` + } + + writeCol := &col{ + Name: "test-name", + Owner: "test-owner", + } + + id := TimeUUID() + err = session.Query("INSERT INTO update_field(id, udt_col) VALUES(?, ?)", id, writeCol).Exec() + if err != nil { + t.Fatal(err) + } + + if err := createTable(session, `ALTER TYPE gocql_test.update_field_udt ADD data text;`); err != nil { + t.Fatal(err) + } + + readCol := &col{} + err = session.Query("SELECT udt_col FROM update_field WHERE id = ?", id).Scan(readCol) + if err != nil { + t.Fatal(err) + } + + if *readCol != *writeCol { + t.Errorf("expected %+v: got %+v", *writeCol, *readCol) + } +}