Skip to content

Commit bb07ce3

Browse files
committed
rlp: add "tail" struct tag
1 parent fdb936e commit bb07ce3

File tree

6 files changed

+163
-26
lines changed

6 files changed

+163
-26
lines changed

rlp/decode.go

+34-17
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,12 @@ type Decoder interface {
6363
// must contain an element for each decoded field. Decode returns an
6464
// error if there are too few or too many elements.
6565
//
66-
// The decoding of struct fields honours one particular struct tag,
67-
// "nil". This tag applies to pointer-typed fields and changes the
66+
// The decoding of struct fields honours two struct tags, "tail" and
67+
// "nil". For an explanation of "tail", see the example.
68+
// The "nil" tag applies to pointer-typed fields and changes the
6869
// decoding rules for the field such that input values of size zero
69-
// decode as a nil pointer. This tag can be useful when decoding recursive
70-
// types.
70+
// decode as a nil pointer. This tag can be useful when decoding
71+
// recursive types.
7172
//
7273
// type StructWithEmptyOK struct {
7374
// Foo *[20]byte `rlp:"nil"`
@@ -190,7 +191,7 @@ func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) {
190191
case kind == reflect.String:
191192
return decodeString, nil
192193
case kind == reflect.Slice || kind == reflect.Array:
193-
return makeListDecoder(typ)
194+
return makeListDecoder(typ, tags)
194195
case kind == reflect.Struct:
195196
return makeStructDecoder(typ)
196197
case kind == reflect.Ptr:
@@ -264,7 +265,7 @@ func decodeBigInt(s *Stream, val reflect.Value) error {
264265
return nil
265266
}
266267

267-
func makeListDecoder(typ reflect.Type) (decoder, error) {
268+
func makeListDecoder(typ reflect.Type, tag tags) (decoder, error) {
268269
etype := typ.Elem()
269270
if etype.Kind() == reflect.Uint8 && !reflect.PtrTo(etype).Implements(decoderInterface) {
270271
if typ.Kind() == reflect.Array {
@@ -277,15 +278,26 @@ func makeListDecoder(typ reflect.Type) (decoder, error) {
277278
if err != nil {
278279
return nil, err
279280
}
280-
281-
isArray := typ.Kind() == reflect.Array
282-
return func(s *Stream, val reflect.Value) error {
283-
if isArray {
281+
var dec decoder
282+
switch {
283+
case typ.Kind() == reflect.Array:
284+
dec = func(s *Stream, val reflect.Value) error {
284285
return decodeListArray(s, val, etypeinfo.decoder)
285-
} else {
286+
}
287+
case tag.tail:
288+
// A slice with "tail" tag can occur as the last field
289+
// of a struct and is upposed to swallow all remaining
290+
// list elements. The struct decoder already called s.List,
291+
// proceed directly to decoding the elements.
292+
dec = func(s *Stream, val reflect.Value) error {
293+
return decodeSliceElems(s, val, etypeinfo.decoder)
294+
}
295+
default:
296+
dec = func(s *Stream, val reflect.Value) error {
286297
return decodeListSlice(s, val, etypeinfo.decoder)
287298
}
288-
}, nil
299+
}
300+
return dec, nil
289301
}
290302

291303
func decodeListSlice(s *Stream, val reflect.Value, elemdec decoder) error {
@@ -297,7 +309,13 @@ func decodeListSlice(s *Stream, val reflect.Value, elemdec decoder) error {
297309
val.Set(reflect.MakeSlice(val.Type(), 0, 0))
298310
return s.ListEnd()
299311
}
312+
if err := decodeSliceElems(s, val, elemdec); err != nil {
313+
return err
314+
}
315+
return s.ListEnd()
316+
}
300317

318+
func decodeSliceElems(s *Stream, val reflect.Value, elemdec decoder) error {
301319
i := 0
302320
for ; ; i++ {
303321
// grow slice if necessary
@@ -323,12 +341,11 @@ func decodeListSlice(s *Stream, val reflect.Value, elemdec decoder) error {
323341
if i < val.Len() {
324342
val.SetLen(i)
325343
}
326-
return s.ListEnd()
344+
return nil
327345
}
328346

329347
func decodeListArray(s *Stream, val reflect.Value, elemdec decoder) error {
330-
_, err := s.List()
331-
if err != nil {
348+
if _, err := s.List(); err != nil {
332349
return wrapStreamError(err, val.Type())
333350
}
334351
vlen := val.Len()
@@ -398,11 +415,11 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
398415
return nil, err
399416
}
400417
dec := func(s *Stream, val reflect.Value) (err error) {
401-
if _, err = s.List(); err != nil {
418+
if _, err := s.List(); err != nil {
402419
return wrapStreamError(err, typ)
403420
}
404421
for _, f := range fields {
405-
err = f.info.decoder(s, val.Field(f.index))
422+
err := f.info.decoder(s, val.Field(f.index))
406423
if err == EOL {
407424
return &decodeError{msg: "too few elements", typ: typ}
408425
} else if err != nil {

rlp/decode_tail_test.go

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package rlp
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
)
7+
8+
type structWithTail struct {
9+
A, B uint
10+
C []uint `rlp:"tail"`
11+
}
12+
13+
func ExampleDecode_structTagTail() {
14+
// In this example, the "tail" struct tag is used to decode lists of
15+
// differing length into a struct.
16+
var val structWithTail
17+
18+
err := Decode(bytes.NewReader([]byte{0xC4, 0x01, 0x02, 0x03, 0x04}), &val)
19+
fmt.Printf("with 4 elements: err=%v val=%v\n", err, val)
20+
21+
err = Decode(bytes.NewReader([]byte{0xC6, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06}), &val)
22+
fmt.Printf("with 6 elements: err=%v val=%v\n", err, val)
23+
24+
// Note that at least two list elements must be present to
25+
// fill fields A and B:
26+
err = Decode(bytes.NewReader([]byte{0xC1, 0x01}), &val)
27+
fmt.Printf("with 1 element: err=%q\n", err)
28+
29+
// Output:
30+
// with 4 elements: err=<nil> val={1 2 [3 4]}
31+
// with 6 elements: err=<nil> val={1 2 [3 4 5 6]}
32+
// with 1 element: err="rlp: too few elements for rlp.structWithTail"
33+
}

rlp/decode_test.go

+52
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,26 @@ type recstruct struct {
312312
Child *recstruct `rlp:"nil"`
313313
}
314314

315+
type invalidTail1 struct {
316+
A uint `rlp:"tail"`
317+
B string
318+
}
319+
320+
type invalidTail2 struct {
321+
A uint
322+
B string `rlp:"tail"`
323+
}
324+
325+
type tailRaw struct {
326+
A uint
327+
Tail []RawValue `rlp:"tail"`
328+
}
329+
330+
type tailUint struct {
331+
A uint
332+
Tail []uint `rlp:"tail"`
333+
}
334+
315335
var (
316336
veryBigInt = big.NewInt(0).Add(
317337
big.NewInt(0).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16),
@@ -437,6 +457,38 @@ var decodeTests = []decodeTest{
437457
ptr: new(recstruct),
438458
error: "rlp: expected input string or byte for uint, decoding into (rlp.recstruct).Child.I",
439459
},
460+
{
461+
input: "C0",
462+
ptr: new(invalidTail1),
463+
error: "rlp: invalid struct tag \"tail\" for rlp.invalidTail1.A (must be on last field)",
464+
},
465+
{
466+
input: "C0",
467+
ptr: new(invalidTail2),
468+
error: "rlp: invalid struct tag \"tail\" for rlp.invalidTail2.B (field type is not slice)",
469+
},
470+
{
471+
input: "C50102C20102",
472+
ptr: new(tailUint),
473+
error: "rlp: expected input string or byte for uint, decoding into (rlp.tailUint).Tail[1]",
474+
},
475+
476+
// struct tag "tail"
477+
{
478+
input: "C3010203",
479+
ptr: new(tailRaw),
480+
value: tailRaw{A: 1, Tail: []RawValue{unhex("02"), unhex("03")}},
481+
},
482+
{
483+
input: "C20102",
484+
ptr: new(tailRaw),
485+
value: tailRaw{A: 1, Tail: []RawValue{unhex("02")}},
486+
},
487+
{
488+
input: "C101",
489+
ptr: new(tailRaw),
490+
value: tailRaw{A: 1, Tail: []RawValue{}},
491+
},
440492

441493
// RawValue
442494
{input: "01", ptr: new(RawValue), value: RawValue(unhex("01"))},

rlp/encode.go

+6-5
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ var (
345345
)
346346

347347
// makeWriter creates a writer function for the given type.
348-
func makeWriter(typ reflect.Type) (writer, error) {
348+
func makeWriter(typ reflect.Type, ts tags) (writer, error) {
349349
kind := typ.Kind()
350350
switch {
351351
case typ == rawValueType:
@@ -371,7 +371,7 @@ func makeWriter(typ reflect.Type) (writer, error) {
371371
case kind == reflect.Array && isByte(typ.Elem()):
372372
return writeByteArray, nil
373373
case kind == reflect.Slice || kind == reflect.Array:
374-
return makeSliceWriter(typ)
374+
return makeSliceWriter(typ, ts)
375375
case kind == reflect.Struct:
376376
return makeStructWriter(typ)
377377
case kind == reflect.Ptr:
@@ -507,20 +507,21 @@ func writeInterface(val reflect.Value, w *encbuf) error {
507507
return ti.writer(eval, w)
508508
}
509509

510-
func makeSliceWriter(typ reflect.Type) (writer, error) {
510+
func makeSliceWriter(typ reflect.Type, ts tags) (writer, error) {
511511
etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{})
512512
if err != nil {
513513
return nil, err
514514
}
515515
writer := func(val reflect.Value, w *encbuf) error {
516-
lh := w.list()
516+
if !ts.tail {
517+
defer w.listEnd(w.list())
518+
}
517519
vlen := val.Len()
518520
for i := 0; i < vlen; i++ {
519521
if err := etypeinfo.writer(val.Index(i), w); err != nil {
520522
return err
521523
}
522524
}
523-
w.listEnd(lh)
524525
return nil
525526
}
526527
return writer, nil

rlp/encode_test.go

+4
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,10 @@ var encTests = []encTest{
214214
{val: simplestruct{A: 3, B: "foo"}, output: "C50383666F6F"},
215215
{val: &recstruct{5, nil}, output: "C205C0"},
216216
{val: &recstruct{5, &recstruct{4, &recstruct{3, nil}}}, output: "C605C404C203C0"},
217+
{val: &tailRaw{A: 1, Tail: []RawValue{unhex("02"), unhex("03")}}, output: "C3010203"},
218+
{val: &tailRaw{A: 1, Tail: []RawValue{unhex("02")}}, output: "C20102"},
219+
{val: &tailRaw{A: 1, Tail: []RawValue{}}, output: "C101"},
220+
{val: &tailRaw{A: 1, Tail: nil}, output: "C101"},
217221

218222
// nil
219223
{val: (*uint)(nil), output: "80"},

rlp/typecache.go

+34-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
package rlp
1818

1919
import (
20+
"fmt"
2021
"reflect"
22+
"strings"
2123
"sync"
2224
)
2325

@@ -33,7 +35,13 @@ type typeinfo struct {
3335

3436
// represents struct tags
3537
type tags struct {
38+
// rlp:"nil" controls whether empty input results in a nil pointer.
3639
nilOK bool
40+
41+
// rlp:"tail" controls whether this field swallows additional list
42+
// elements. It can only be set for the last field, which must be
43+
// of slice type.
44+
tail bool
3745
}
3846

3947
type typekey struct {
@@ -89,7 +97,10 @@ type field struct {
8997
func structFields(typ reflect.Type) (fields []field, err error) {
9098
for i := 0; i < typ.NumField(); i++ {
9199
if f := typ.Field(i); f.PkgPath == "" { // exported
92-
tags := parseStructTag(f.Tag.Get("rlp"))
100+
tags, err := parseStructTag(typ, i)
101+
if err != nil {
102+
return nil, err
103+
}
93104
info, err := cachedTypeInfo1(f.Type, tags)
94105
if err != nil {
95106
return nil, err
@@ -100,16 +111,35 @@ func structFields(typ reflect.Type) (fields []field, err error) {
100111
return fields, nil
101112
}
102113

103-
func parseStructTag(tag string) tags {
104-
return tags{nilOK: tag == "nil"}
114+
func parseStructTag(typ reflect.Type, fi int) (tags, error) {
115+
f := typ.Field(fi)
116+
var ts tags
117+
for _, t := range strings.Split(f.Tag.Get("rlp"), ",") {
118+
switch t = strings.TrimSpace(t); t {
119+
case "":
120+
case "nil":
121+
ts.nilOK = true
122+
case "tail":
123+
ts.tail = true
124+
if fi != typ.NumField()-1 {
125+
return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%s (must be on last field)`, typ, f.Name)
126+
}
127+
if f.Type.Kind() != reflect.Slice {
128+
return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%s (field type is not slice)`, typ, f.Name)
129+
}
130+
default:
131+
return ts, fmt.Errorf("rlp: unknown struct tag %q on %v.%s", t, typ, f.Name)
132+
}
133+
}
134+
return ts, nil
105135
}
106136

107137
func genTypeInfo(typ reflect.Type, tags tags) (info *typeinfo, err error) {
108138
info = new(typeinfo)
109139
if info.decoder, err = makeDecoder(typ, tags); err != nil {
110140
return nil, err
111141
}
112-
if info.writer, err = makeWriter(typ); err != nil {
142+
if info.writer, err = makeWriter(typ, tags); err != nil {
113143
return nil, err
114144
}
115145
return info, nil

0 commit comments

Comments
 (0)