Skip to content
This repository has been archived by the owner on Aug 1, 2024. It is now read-only.
/ cosmos-sdk Public archive
forked from cosmos/cosmos-sdk

Commit

Permalink
feat: add rapidproto generator (cosmos#14849)
Browse files Browse the repository at this point in the history
  • Loading branch information
kocubinski authored and tsenart committed Apr 12, 2023
1 parent 41adc54 commit a8e180a
Show file tree
Hide file tree
Showing 7 changed files with 259 additions and 0 deletions.
221 changes: 221 additions & 0 deletions testutil/rapidproto/rapidproto.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
package rapidproto

import (
"fmt"

"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"gotest.tools/v3/assert"
"pgregory.net/rapid"
)

func MessageGenerator[T proto.Message](x T, options GeneratorOptions) *rapid.Generator[T] {
msgType := x.ProtoReflect().Type()
return rapid.Custom(func(t *rapid.T) T {
msg := msgType.New()

options.setFields(t, msg, 0)

return msg.Interface().(T)
})
}

type GeneratorOptions struct {
AnyTypeURLs []string
Resolver protoregistry.MessageTypeResolver
}

const depthLimit = 10

func (opts GeneratorOptions) setFields(t *rapid.T, msg protoreflect.Message, depth int) bool {
// to avoid stack overflow we limit the depth of nested messages
if depth > depthLimit {
return false
}

descriptor := msg.Descriptor()
fullName := descriptor.FullName()
switch fullName {
case timestampFullName:
opts.genTimestamp(t, msg)
return true
case durationFullName:
opts.genDuration(t, msg)
return true
case anyFullName:
return opts.genAny(t, msg, depth)
case fieldMaskFullName:
opts.genFieldMask(t, msg)
return true
default:
fields := descriptor.Fields()
n := fields.Len()
for i := 0; i < n; i++ {
field := fields.Get(i)
if !rapid.Bool().Draw(t, fmt.Sprintf("gen-%s", field.Name())) {
continue
}

opts.setFieldValue(t, msg, field, depth)
}
return true
}
}

const (
timestampFullName = "google.protobuf.Timestamp"
durationFullName = "google.protobuf.Duration"
anyFullName = "google.protobuf.Any"
fieldMaskFullName = "google.protobuf.FieldMask"
)

func (opts GeneratorOptions) setFieldValue(t *rapid.T, msg protoreflect.Message, field protoreflect.FieldDescriptor, depth int) {
name := string(field.Name())
kind := field.Kind()

switch {
case field.IsList():
list := msg.Mutable(field).List()
n := rapid.IntRange(0, 10).Draw(t, fmt.Sprintf("%sN", name))
for i := 0; i < n; i++ {
if kind == protoreflect.MessageKind || kind == protoreflect.GroupKind {
if !opts.setFields(t, list.AppendMutable().Message(), depth+1) {
list.Truncate(i)
}
} else {
list.Append(opts.genScalarFieldValue(t, field, fmt.Sprintf("%s%d", name, i)))
}
}
case field.IsMap():
m := msg.Mutable(field).Map()
n := rapid.IntRange(0, 10).Draw(t, fmt.Sprintf("%sN", name))
for i := 0; i < n; i++ {
keyField := field.MapKey()
valueField := field.MapValue()
valueKind := valueField.Kind()
key := opts.genScalarFieldValue(t, keyField, fmt.Sprintf("%s%d-key", name, i))
if valueKind == protoreflect.MessageKind || valueKind == protoreflect.GroupKind {
if !opts.setFields(t, m.Mutable(key.MapKey()).Message(), depth+1) {
m.Clear(key.MapKey())
}
} else {
value := opts.genScalarFieldValue(t, valueField, fmt.Sprintf("%s%d-key", name, i))
m.Set(key.MapKey(), value)
}
}
default:
if kind == protoreflect.MessageKind || kind == protoreflect.GroupKind {
if !opts.setFields(t, msg.Mutable(field).Message(), depth+1) {
msg.Clear(field)
}
} else {
msg.Set(field, opts.genScalarFieldValue(t, field, name))
}
}
}

func (opts GeneratorOptions) genScalarFieldValue(t *rapid.T, field protoreflect.FieldDescriptor, name string) protoreflect.Value {
switch field.Kind() {
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
return protoreflect.ValueOfInt32(rapid.Int32().Draw(t, name))
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
return protoreflect.ValueOfUint32(rapid.Uint32().Draw(t, name))
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
return protoreflect.ValueOfInt64(rapid.Int64().Draw(t, name))
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
return protoreflect.ValueOfUint64(rapid.Uint64().Draw(t, name))
case protoreflect.BoolKind:
return protoreflect.ValueOfBool(rapid.Bool().Draw(t, name))
case protoreflect.BytesKind:
return protoreflect.ValueOfBytes(rapid.SliceOf(rapid.Byte()).Draw(t, name))
case protoreflect.FloatKind:
return protoreflect.ValueOfFloat32(rapid.Float32().Draw(t, name))
case protoreflect.DoubleKind:
return protoreflect.ValueOfFloat64(rapid.Float64().Draw(t, name))
case protoreflect.EnumKind:
enumValues := field.Enum().Values()
val := rapid.Int32Range(0, int32(enumValues.Len()-1)).Draw(t, name)
return protoreflect.ValueOfEnum(protoreflect.EnumNumber(val))
case protoreflect.StringKind:
return protoreflect.ValueOfString(rapid.String().Draw(t, name))
default:
t.Fatalf("unexpected %v", field)
return protoreflect.Value{}
}
}

const (
secondsName = "seconds"
nanosName = "nanos"
)

func (opts GeneratorOptions) genTimestamp(t *rapid.T, msg protoreflect.Message) {
seconds := rapid.Int64Range(-9999999999, 9999999999).Draw(t, "seconds")
nanos := rapid.Int32Range(0, 999999999).Draw(t, "nanos")
setSecondsNanosFields(t, msg, seconds, nanos)
}

func (opts GeneratorOptions) genDuration(t *rapid.T, msg protoreflect.Message) {
seconds := rapid.Int64Range(0, 315576000000).Draw(t, "seconds")
nanos := rapid.Int32Range(0, 999999999).Draw(t, "nanos")
setSecondsNanosFields(t, msg, seconds, nanos)
}

func setSecondsNanosFields(t *rapid.T, message protoreflect.Message, seconds int64, nanos int32) {
fields := message.Descriptor().Fields()

secondsField := fields.ByName(secondsName)
assert.Assert(t, secondsField != nil)
message.Set(secondsField, protoreflect.ValueOfInt64(seconds))

nanosField := fields.ByName(nanosName)
assert.Assert(t, nanosField != nil)
message.Set(nanosField, protoreflect.ValueOfInt32(nanos))
}

const (
typeURLName = "type_url"
valueName = "value"
)

func (opts GeneratorOptions) genAny(t *rapid.T, msg protoreflect.Message, depth int) bool {
if len(opts.AnyTypeURLs) == 0 {
return false
}

fields := msg.Descriptor().Fields()

typeURL := rapid.SampledFrom(opts.AnyTypeURLs).Draw(t, "type_url")
typ, err := opts.Resolver.FindMessageByURL(typeURL)
assert.NilError(t, err)

typeURLField := fields.ByName(typeURLName)
assert.Assert(t, typeURLField != nil)
msg.Set(typeURLField, protoreflect.ValueOfString(typeURL))

valueMsg := typ.New()
opts.setFields(t, valueMsg, depth+1)
valueBz, err := proto.Marshal(valueMsg.Interface())
assert.NilError(t, err)

valueField := fields.ByName(valueName)
assert.Assert(t, valueField != nil)
msg.Set(valueField, protoreflect.ValueOfBytes(valueBz))

return true
}

const (
pathsName = "paths"
)

func (opts GeneratorOptions) genFieldMask(t *rapid.T, msg protoreflect.Message) {
paths := rapid.SliceOfN(rapid.StringMatching("[a-z]+([.][a-z]+){0,2}"), 1, 5).Draw(t, "paths")
pathsField := msg.Descriptor().Fields().ByName(pathsName)
assert.Assert(t, pathsField != nil)
pathsList := msg.NewField(pathsField).List()
for _, path := range paths {
pathsList.Append(protoreflect.ValueOfString(path))
}
}
33 changes: 33 additions & 0 deletions testutil/rapidproto/rapidproto_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package rapidproto_test

import (
"fmt"
"testing"

"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"gotest.tools/v3/assert"
"gotest.tools/v3/golden"
"pgregory.net/rapid"

"github.com/cosmos/cosmos-proto/testpb"

"github.com/cosmos/cosmos-sdk/testutil/rapidproto"
)

// TestRegression checks that the generator still produces the same output
// for the same random seeds, assuming that this data has been hand expected
// to generally look good.
func TestRegression(t *testing.T) {
gen := rapidproto.MessageGenerator(&testpb.A{}, rapidproto.GeneratorOptions{})
for i := 0; i < 5; i++ {
testRegressionSeed(t, i, gen)
}
}

func testRegressionSeed[X proto.Message](t *testing.T, seed int, generator *rapid.Generator[X]) {
x := generator.Example(seed)
bz, err := protojson.Marshal(x)
assert.NilError(t, err)
golden.Assert(t, string(bz), fmt.Sprintf("seed%d.json", seed))
}
1 change: 1 addition & 0 deletions testutil/rapidproto/testdata/seed0.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"enum":"Two", "someBoolean":true, "INT32":6, "SINT32":-53, "INT64":"-261", "SFIXED32":3, "FIXED32":65302, "FIXED64":"45044", "STRING":"󳲠~Âaႃ#", "MESSAGE":{"x":"ʰ="}, "MAP":{"":{"x":""}, "%󠇯º$&.":{"x":"-"}, "=A":{}, "AA|𞀠":{"x":"a\u0000"}}, "LIST":[{}], "ONEOFSTRING":"", "imported":{}}
1 change: 1 addition & 0 deletions testutil/rapidproto/testdata/seed1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"UINT32":177, "INT64":"-139958413", "SFIXED32":41418, "FIXED32":25381940, "FLOAT":-8.336453e+31, "SFIXED64":"-2503553836720", "DOUBLE":-0.03171187036377887, "STRING":"?˄~ע", "MESSAGE":{"x":"dDž#"}, "MAP":{"Ⱥa<":{"x":"+["}, "֑Ⱥ|@!`":{}}, "ONEOFSTRING":"\u0012\t?A", "imported":{}, "type":"A�=*ى~~‮Ⱥ*ᾈാȺAᶊ?"}
1 change: 1 addition & 0 deletions testutil/rapidproto/testdata/seed2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"INT32":-48, "UINT32":246, "INT64":"-21558176502", "SING64":"5030347", "UINT64":"28", "FIXED32":92, "DOUBLE":2.3547259926790202e-142, "STRING":"", "LIST":[{}, {}, {}, {}, {"x":" ᾚ DzA{˭҄\nA ^$?ᾦ,:<\"?_\u0014;|"}], "ONEOFSTRING":"𝟠Ÿ", "LISTENUM":["Two", "One", "One"]}
1 change: 1 addition & 0 deletions testutil/rapidproto/testdata/seed3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"INT32":22525032, "SINT32":897, "INT64":"-301128487533312", "SFIXED64":"-71", "FIXED64":"14", "DOUBLE":-2.983041182946181, "STRING":"-A^'", "MESSAGE":{"x":"#ऻ;́\r‮⋁"}, "LIST":[{}, {}, {}, {}, {}], "ONEOFSTRING":"", "imported":{}, "type":"\u0000^৴~౽  NjAৈ􁇸⃠𝖜ೄ"}
1 change: 1 addition & 0 deletions testutil/rapidproto/testdata/seed4.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"SINT32":1, "INT64":"-9223372036854775808", "SING64":"1", "FLOAT":-0.00013906474, "SFIXED64":"71414010", "STRING":"ף̂", "MESSAGE":{"x":""}, "LIST":[{}], "ONEOFSTRING":"#¯∑Ⱥ�", "LISTENUM":["One", "One", "Two", "Two", "One", "One", "One", "Two"], "imported":{}, "type":"\u001b<ʰ+`𑱐@\u001b*Dž‮\u0000#₻\u0000"}

0 comments on commit a8e180a

Please sign in to comment.