Skip to content

Commit e526eba

Browse files
committed
Add support for Maps to the reflect_codec
Support to serialize/deserialize maps with the reflect_codec. This is needed in the subnet EMV to use it to exchange information about the upgrade configs. This PR adds support for maps and tests it. The internal binary format is quite simple ``` [length: 4 bytes][key1][value1][key2][value2] ``` The length is the total number of elements and keys, therefore it should always be a pair number.
1 parent 0ec52a9 commit e526eba

File tree

2 files changed

+161
-0
lines changed

2 files changed

+161
-0
lines changed

codec/reflectcodec/type_codec.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,23 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) {
204204
}
205205
return size, constSize, nil
206206

207+
case reflect.Map:
208+
keys := value.MapKeys()
209+
size := wrappers.IntLen
210+
for _, key := range keys {
211+
innerSize, _, err := c.size(key)
212+
if err != nil {
213+
return 0, false, err
214+
}
215+
size += innerSize
216+
innerSize, _, err = c.size(value.MapIndex(key))
217+
if err != nil {
218+
return 0, false, err
219+
}
220+
size += innerSize
221+
}
222+
return size, false, nil
223+
207224
default:
208225
return 0, false, fmt.Errorf("can't evaluate marshal length of unknown kind %s", valueKind)
209226
}
@@ -332,6 +349,30 @@ func (c *genericCodec) marshal(value reflect.Value, p *wrappers.Packer, maxSlice
332349
return err
333350
}
334351
}
352+
return nil
353+
case reflect.Map:
354+
keys := value.MapKeys()
355+
numElts := len(keys) * 2
356+
if uint32(numElts) > maxSliceLen {
357+
return fmt.Errorf("%w; slice length, %d, exceeds maximum length, %d",
358+
codec.ErrMaxSliceLenExceeded,
359+
numElts,
360+
maxSliceLen,
361+
)
362+
}
363+
p.PackInt(uint32(numElts)) // pack # elements
364+
365+
for _, key := range keys {
366+
// serialize key
367+
if err := c.marshal(key, p, c.maxSliceLen); err != nil {
368+
return err
369+
}
370+
// serialize value
371+
if err := c.marshal(value.MapIndex(key), p, c.maxSliceLen); err != nil {
372+
return err
373+
}
374+
}
375+
335376
return nil
336377
default:
337378
return fmt.Errorf("%w: %s", codec.ErrUnsupportedType, valueKind)
@@ -520,6 +561,36 @@ func (c *genericCodec) unmarshal(p *wrappers.Packer, value reflect.Value, maxSli
520561
// Assign to the top-level struct's member
521562
value.Set(v)
522563
return nil
564+
case reflect.Map:
565+
numElts32 := p.UnpackInt()
566+
if numElts32 > c.maxSliceLen || numElts32%2 != 0 {
567+
return fmt.Errorf("%w; array length, %d, exceeds maximum length, %d",
568+
codec.ErrMaxSliceLenExceeded,
569+
numElts32,
570+
c.maxSliceLen,
571+
)
572+
}
573+
574+
numElts := int(numElts32 / 2)
575+
value.Set(reflect.MakeMapWithSize(value.Type(), numElts))
576+
keyType := value.Type().Key()
577+
valueType := value.Type().Elem()
578+
579+
for i := 0; i < numElts; i++ {
580+
keyValue := reflect.New(keyType).Elem()
581+
valueValue := reflect.New(valueType).Elem()
582+
583+
if err := c.unmarshal(p, keyValue, c.maxSliceLen); err != nil {
584+
return fmt.Errorf("couldn't unmarshal map key: %w", err)
585+
}
586+
if err := c.unmarshal(p, valueValue, c.maxSliceLen); err != nil {
587+
return fmt.Errorf("couldn't unmarshal map element: %w", err)
588+
}
589+
value.SetMapIndex(keyValue, valueValue)
590+
}
591+
592+
return nil
593+
523594
default:
524595
return fmt.Errorf("can't unmarshal unknown type %s", value.Kind().String())
525596
}

codec/test_codec.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ var Tests = []func(c GeneralCodec, t testing.TB){
4141
TestRestrictedSlice,
4242
TestExtraSpace,
4343
TestSliceLengthOverflow,
44+
TestMap,
45+
TestMap2,
46+
TestMap3,
4447
}
4548

4649
var MultipleTagsTests = []func(c GeneralCodec, t testing.TB){
@@ -887,3 +890,90 @@ func TestMultipleTags(codec GeneralCodec, t testing.TB) {
887890
require.Empty(output.NoTags)
888891
}
889892
}
893+
894+
func TestMap(codec GeneralCodec, t testing.TB) {
895+
require := require.New(t)
896+
897+
data := make(map[string]int32)
898+
data["test"] = 12
899+
data["bar"] = 33
900+
901+
manager := NewDefaultManager()
902+
require.NoError(manager.RegisterCodec(0, codec))
903+
904+
bytes, err := manager.Marshal(0, data)
905+
require.NoError(err)
906+
907+
bytesLen, err := manager.Size(0, data)
908+
require.NoError(err)
909+
require.Equal(len(bytes), bytesLen)
910+
911+
var output map[string]int32
912+
_, err = manager.Unmarshal(bytes, &output)
913+
require.NoError(err)
914+
915+
require.Equal(data, output)
916+
}
917+
918+
func TestMap2(codec GeneralCodec, t testing.TB) {
919+
require := require.New(t)
920+
921+
type Foo struct {
922+
A int32 `serialize:"true"`
923+
B string `serialize:"true"`
924+
}
925+
926+
data := make(map[int32]Foo)
927+
data[12] = Foo{A: 1, B: "test"}
928+
data[13] = Foo{A: 2, B: "more test"}
929+
930+
manager := NewDefaultManager()
931+
require.NoError(manager.RegisterCodec(0, codec))
932+
933+
bytes, err := manager.Marshal(0, data)
934+
require.NoError(err)
935+
936+
bytesLen, err := manager.Size(0, data)
937+
require.NoError(err)
938+
require.Equal(len(bytes), bytesLen)
939+
940+
var output map[int32]Foo
941+
_, err = manager.Unmarshal(bytes, &output)
942+
require.NoError(err)
943+
944+
require.Equal(data, output)
945+
}
946+
947+
func TestMap3(codec GeneralCodec, t testing.TB) {
948+
require := require.New(t)
949+
950+
type Foo struct {
951+
A int32 `serialize:"true"`
952+
B string `serialize:"true"`
953+
E map[int32]string `serialize:"true"`
954+
}
955+
956+
data := Foo{
957+
A: 1,
958+
B: "test",
959+
E: make(map[int32]string, 2),
960+
}
961+
data.E[12] = "test"
962+
data.E[13] = "test"
963+
964+
manager := NewDefaultManager()
965+
require.NoError(manager.RegisterCodec(0, codec))
966+
967+
bytes, err := manager.Marshal(0, data)
968+
require.NoError(err)
969+
970+
bytesLen, err := manager.Size(0, data)
971+
require.NoError(err)
972+
require.Equal(len(bytes), bytesLen)
973+
974+
var output Foo
975+
_, err = manager.Unmarshal(bytes, &output)
976+
require.NoError(err)
977+
978+
require.Equal(data, output)
979+
}

0 commit comments

Comments
 (0)