Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/types/map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ func TestProtoMapConvertToNative(t *testing.T) {
if mapVal4.Equal(mapVal) != True || mapVal.Equal(mapVal4) != True {
t.Errorf("mapVal4.Equal(mapVal) returned false, wanted true")
}
convMap, err = mapVal.ConvertToNative(reflect.TypeOf(&pb.Map{}))
convMap, err = mapVal.ConvertToNative(reflect.TypeOf(map[string]string{}))
if err != nil {
t.Fatalf("mapVal.ConvertToNative() failed: %v", err)
}
Expand Down
97 changes: 88 additions & 9 deletions common/types/pb/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
dynamicpb "google.golang.org/protobuf/types/dynamicpb"
Expand Down Expand Up @@ -254,30 +255,108 @@ func (fd *FieldDescription) GetFrom(target any) (any, error) {
}
pbRef := v.ProtoReflect()
pbDesc := pbRef.Descriptor()
var fieldVal any
var fieldVal protoreflect.Value
if pbDesc == fd.desc.ContainingMessage() {
// When the target protobuf shares the same message descriptor instance as the field
// descriptor, use the cached field descriptor value.
fieldVal = pbRef.Get(fd.desc).Interface()
fieldVal = pbRef.Get(fd.desc)
} else {
// Otherwise, fallback to a dynamic lookup of the field descriptor from the target
// instance as an attempt to use the cached field descriptor will result in a panic.
fieldVal = pbRef.Get(pbDesc.Fields().ByName(protoreflect.Name(fd.Name()))).Interface()
fieldVal = pbRef.Get(pbDesc.Fields().ByName(protoreflect.Name(fd.Name())))
}
switch fv := fieldVal.(type) {
return fd.getNativeValue(fieldVal)
}

func (fd *FieldDescription) getNativeType(v protoreflect.Value) (reflect.Type, error) {
switch fv := v.Interface().(type) {
case protoreflect.Message:
// Make sure to unwrap well-known protobuf types before returning.
unwrapped, _, err := fd.MaybeUnwrapDynamic(fv)
return reflect.TypeOf(unwrapped), err
case protoreflect.EnumNumber:
enumType, err := protoregistry.GlobalTypes.FindEnumByName(fd.desc.Enum().FullName())
if err != nil {
return nil, err
}
return reflect.TypeOf(enumType.New(0)), nil
case protoreflect.List:
if fv == nil {
return nil, nil
}

element := fv.NewElement()
et, err := fd.getNativeType(element)
if err != nil {
return nil, err
}
return reflect.SliceOf(et), nil
case protoreflect.Map:
vt, err := fd.getNativeType(fv.NewValue())
if err != nil {
return nil, err
}
return reflect.MapOf(fd.KeyType.reflectType, vt), nil
default:
return reflect.TypeOf(fv), nil
}
}

func (fd *FieldDescription) getNativeValue(v protoreflect.Value) (any, error) {
switch fv := v.Interface().(type) {
// Fast-path return for primitive types.
case bool, []byte, float32, float64, int32, int64, string, uint32, uint64, protoreflect.List:
case bool, []byte, float32, float64, int32, int64, string, uint32, uint64:
return fv, nil
case protoreflect.EnumNumber:
return int64(fv), nil
enumType, err := protoregistry.GlobalTypes.FindEnumByName(fd.desc.Enum().FullName())
if err != nil {
return nil, err
}
return enumType.New(fv), nil
case protoreflect.Map:
// Return a wrapper around the protobuf-reflected Map types which carries additional
// information about the key and value definitions of the map.
return &Map{Map: fv, KeyType: fd.KeyType, ValueType: fd.ValueType}, nil
if fv == nil {
return nil, nil
}

mapType, err := fd.getNativeType(v)
if err != nil {
return nil, err
}

m := reflect.MakeMap(mapType)
fv.Range(func(mk protoreflect.MapKey, v protoreflect.Value) bool {
vv, err := fd.getNativeValue(v)
if err != nil {
return false
}
m.SetMapIndex(reflect.ValueOf(mk.Interface()), reflect.ValueOf(vv))
return true
})
return m.Interface(), nil
case protoreflect.Message:
// Make sure to unwrap well-known protobuf types before returning.
unwrapped, _, err := fd.MaybeUnwrapDynamic(fv)
return unwrapped, err
case protoreflect.List:
if fv == nil {
return nil, nil
}

sliceType, err := fd.getNativeType(v)
if err != nil {
return nil, err
}

slice := reflect.MakeSlice(sliceType, fv.Len(), fv.Len())

for i := 0; i < fv.Len(); i++ {
elementVal, err := fd.getNativeValue(fv.Get(i))
if err != nil {
return nil, err
}
slice.Index(i).Set(reflect.ValueOf(elementVal))
}
return slice.Interface(), nil
default:
return fv, nil
}
Expand Down
68 changes: 65 additions & 3 deletions common/types/pb/type_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,31 @@ func TestFieldDescriptionGetFrom(t *testing.T) {
SingleStruct: jsonStruct(t, map[string]any{
"null": nil,
}),
RepeatedInt32: []int32{1, 2, 3},
RepeatedInt64: []int64{1, 2, 3},
RepeatedUint32: []uint32{1, 2, 3},
RepeatedUint64: []uint64{1, 2, 3},
RepeatedSint32: []int32{1, 2, 3},
RepeatedSint64: []int64{1, 2, 3},
RepeatedFixed32: []uint32{1, 2, 3},
RepeatedFixed64: []uint64{1, 2, 3},
RepeatedSfixed32: []int32{1, 2, 3},
RepeatedSfixed64: []int64{1, 2, 3},
RepeatedFloat: []float32{1.0, 2.0, 3.0},
RepeatedDouble: []float64{1.0, 2.0, 3.0},
RepeatedBool: []bool{true, false, true},
RepeatedString: []string{"a", "b", "c"},
RepeatedBytes: [][]byte{{1, 2, 3}, {4, 5, 6}},
RepeatedNestedMessage: []*proto3pb.TestAllTypes_NestedMessage{{Bb: 123}, {Bb: 456}},
RepeatedNestedEnum: []proto3pb.TestAllTypes_NestedEnum{proto3pb.TestAllTypes_BAR, proto3pb.TestAllTypes_BAZ},
RepeatedStringPiece: []string{"a", "b", "c"},
RepeatedCord: []string{"a", "b", "c"},
RepeatedLazyMessage: []*proto3pb.TestAllTypes_NestedMessage{{Bb: 123}, {Bb: 456}},
MapStringString: map[string]string{"a": "1", "b": "2"},
MapInt64NestedType: map[int64]*proto3pb.NestedTestAllTypes{
1: {Payload: &proto3pb.TestAllTypes{SingleUint64: 12}},
},
ImportedEnums: []proto3pb.ImportedGlobalEnum{proto3pb.ImportedGlobalEnum_IMPORT_FOO},
}
msgName := string(msg.ProtoReflect().Descriptor().FullName())
_, err := pbdb.RegisterMessage(msg)
Expand All @@ -182,11 +207,36 @@ func TestFieldDescriptionGetFrom(t *testing.T) {
"single_nested_message": &proto3pb.TestAllTypes_NestedMessage{
Bb: 123,
},
"standalone_enum": int64(1),
"standalone_enum": proto3pb.TestAllTypes_BAR,
"single_value": "hello world",
"single_struct": jsonStruct(t, map[string]any{
"null": nil,
}),
"repeated_int32": []int32{1, 2, 3},
"repeated_int64": []int64{1, 2, 3},
"repeated_uint32": []uint32{1, 2, 3},
"repeated_uint64": []uint64{1, 2, 3},
"repeated_sint32": []int32{1, 2, 3},
"repeated_sint64": []int64{1, 2, 3},
"repeated_fixed32": []uint32{1, 2, 3},
"repeated_fixed64": []uint64{1, 2, 3},
"repeated_sfixed32": []int32{1, 2, 3},
"repeated_sfixed64": []int64{1, 2, 3},
"repeated_float": []float32{1.0, 2.0, 3.0},
"repeated_double": []float64{1.0, 2.0, 3.0},
"repeated_bool": []bool{true, false, true},
"repeated_string": []string{"a", "b", "c"},
"repeated_bytes": [][]byte{{1, 2, 3}, {4, 5, 6}},
"repeated_nested_message": []*proto3pb.TestAllTypes_NestedMessage{{Bb: 123}, {Bb: 456}},
"repeated_nested_enum": []proto3pb.TestAllTypes_NestedEnum{proto3pb.TestAllTypes_BAR, proto3pb.TestAllTypes_BAZ},
"repeated_string_piece": []string{"a", "b", "c"},
"repeated_cord": []string{"a", "b", "c"},
"repeated_lazy_message": []*proto3pb.TestAllTypes_NestedMessage{{Bb: 123}, {Bb: 456}},
"map_string_string": map[string]string{"a": "1", "b": "2"},
"map_int64_nested_type": map[int64]*proto3pb.NestedTestAllTypes{
1: {Payload: &proto3pb.TestAllTypes{SingleUint64: 12}},
},
"imported_enums": []proto3pb.ImportedGlobalEnum{proto3pb.ImportedGlobalEnum_IMPORT_FOO},
}
for field, want := range expected {
f, found := td.FieldByName(field)
Expand All @@ -200,11 +250,23 @@ func TestFieldDescriptionGetFrom(t *testing.T) {
switch g := got.(type) {
case proto.Message:
if !proto.Equal(g, want.(proto.Message)) {
t.Errorf("got field %s value %v, wanted %v", field, g, want)
t.Errorf("got field %s type %T, value %v, wanted type %T, value %v", field, g, g, want, want)
}
case []*proto3pb.TestAllTypes_NestedMessage:
for i, gv := range g {
if !proto.Equal(gv, want.([]*proto3pb.TestAllTypes_NestedMessage)[i]) {
t.Errorf("got field %s[%d] type %T, value %v, wanted type %T, value %v", field, i, gv, gv, want.([]*proto3pb.TestAllTypes_NestedMessage)[i], want.([]*proto3pb.TestAllTypes_NestedMessage)[i])
}
}
case map[int64]*proto3pb.NestedTestAllTypes:
for k, gv := range g {
if !proto.Equal(gv, want.(map[int64]*proto3pb.NestedTestAllTypes)[k]) {
t.Errorf("got field %s[%d] type %T, value %v, wanted type %T, value %v", field, k, gv, gv, want.(map[int64]*proto3pb.NestedTestAllTypes)[k], want.(map[int64]*proto3pb.NestedTestAllTypes)[k])
}
}
default:
if !reflect.DeepEqual(g, want) {
t.Errorf("got field %s value %v, wanted %v", field, g, want)
t.Errorf("got field %s type %T, value %v, wanted type %T, value %v", field, g, g, want, want)
}
}
}
Expand Down