Skip to content

Commit

Permalink
reflect/protoregistry: centralize MessageSet extension resolution logic
Browse files Browse the repository at this point in the history
Centralize the MessageSet extension resolution logic in the registry.
This avoids needless replication of this exact logic in multiple places
(for JSON and text) and elsewhere.

Change-Id: I70bfea899e295e8c589f418965bf0dd099f93628
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/240077
Reviewed-by: Herbie Ong <herbie@google.com>
  • Loading branch information
dsnet committed Jul 1, 2020
1 parent 1726b83 commit b783214
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 46 deletions.
11 changes: 1 addition & 10 deletions encoding/protojson/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func (d decoder) unmarshalFields(m pref.Message, skipTypeURL bool) error {
if strings.HasPrefix(name, "[") && strings.HasSuffix(name, "]") {
// Only extension names are in [name] format.
extName := pref.FullName(name[1 : len(name)-1])
extType, err := d.findExtension(extName)
extType, err := d.opts.Resolver.FindExtensionByName(extName)
if err != nil && err != protoregistry.NotFound {
return d.newError(tok.Pos(), "unable to resolve %s: %v", tok.RawString(), err)
}
Expand Down Expand Up @@ -257,15 +257,6 @@ func (d decoder) unmarshalFields(m pref.Message, skipTypeURL bool) error {
}
}

// findExtension returns protoreflect.ExtensionType from the resolver if found.
func (d decoder) findExtension(xtName pref.FullName) (pref.ExtensionType, error) {
xt, err := d.opts.Resolver.FindExtensionByName(xtName)
if err == nil {
return xt, nil
}
return messageset.FindMessageSetExtension(d.opts.Resolver, xtName)
}

func isKnownValue(fd pref.FieldDescriptor) bool {
md := fd.Message()
return md != nil && md.FullName() == genid.Value_message_fullname
Expand Down
2 changes: 1 addition & 1 deletion encoding/protojson/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1403,7 +1403,7 @@ func TestUnmarshal(t *testing.T) {
"optString": "not a messageset extension"
}
}`,
wantErr: `unknown field "[pb2.FakeMessageSetExtension]"`,
wantErr: `unable to resolve "[pb2.FakeMessageSetExtension]": found wrong type`,
skip: !flags.ProtoLegacy,
}, {
desc: "not real MessageSet 3",
Expand Down
11 changes: 1 addition & 10 deletions encoding/prototext/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (d decoder) unmarshalMessage(m pref.Message, checkDelims bool) error {

case text.TypeName:
// Handle extensions only. This code path is not for Any.
xt, xtErr = d.findExtension(pref.FullName(tok.TypeName()))
xt, xtErr = d.opts.Resolver.FindExtensionByName(pref.FullName(tok.TypeName()))

case text.FieldNumber:
isFieldNumberName = true
Expand Down Expand Up @@ -269,15 +269,6 @@ func (d decoder) unmarshalMessage(m pref.Message, checkDelims bool) error {
return nil
}

// findExtension returns protoreflect.ExtensionType from the Resolver if found.
func (d decoder) findExtension(xtName pref.FullName) (pref.ExtensionType, error) {
xt, err := d.opts.Resolver.FindExtensionByName(xtName)
if err == nil {
return xt, nil
}
return messageset.FindMessageSetExtension(d.opts.Resolver, xtName)
}

// unmarshalSingular unmarshals a non-repeated field value specified by the
// given FieldDescriptor.
func (d decoder) unmarshalSingular(fd pref.FieldDescriptor, m pref.Message) error {
Expand Down
2 changes: 1 addition & 1 deletion encoding/prototext/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1508,7 +1508,7 @@ opt_int32: 42
opt_string: "not a messageset extension"
}
`,
wantErr: "unknown field: [pb2.FakeMessageSetExtension]",
wantErr: `unable to resolve [[pb2.FakeMessageSetExtension]]: found wrong type`,
skip: !flags.ProtoLegacy,
}, {
desc: "not real MessageSet 3",
Expand Down
31 changes: 7 additions & 24 deletions internal/encoding/messageset/messageset.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/errors"
pref "google.golang.org/protobuf/reflect/protoreflect"
preg "google.golang.org/protobuf/reflect/protoregistry"
)

// The MessageSet wire format is equivalent to a message defiend as follows,
Expand Down Expand Up @@ -48,33 +47,17 @@ func IsMessageSet(md pref.MessageDescriptor) bool {
return ok && xmd.IsMessageSet()
}

// IsMessageSetExtension reports this field extends a MessageSet.
// IsMessageSetExtension reports this field properly extends a MessageSet.
func IsMessageSetExtension(fd pref.FieldDescriptor) bool {
if fd.Name() != ExtensionName {
switch {
case fd.Name() != ExtensionName:
return false
}
if fd.FullName().Parent() != fd.Message().FullName() {
case !IsMessageSet(fd.ContainingMessage()):
return false
case fd.FullName().Parent() != fd.Message().FullName():
return false
}
return IsMessageSet(fd.ContainingMessage())
}

// FindMessageSetExtension locates a MessageSet extension field by name.
// In text and JSON formats, the extension name used is the message itself.
// The extension field name is derived by appending ExtensionName.
func FindMessageSetExtension(r preg.ExtensionTypeResolver, s pref.FullName) (pref.ExtensionType, error) {
name := s.Append(ExtensionName)
xt, err := r.FindExtensionByName(name)
if err != nil {
if err == preg.NotFound {
return nil, err
}
return nil, errors.Wrap(err, "%q", name)
}
if !IsMessageSetExtension(xt.TypeDescriptor()) {
return nil, preg.NotFound
}
return xt, nil
return true
}

// SizeField returns the size of a MessageSet item field containing an extension
Expand Down
22 changes: 22 additions & 0 deletions reflect/protoregistry/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import (
"strings"
"sync"

"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/reflect/protoreflect"
)

Expand Down Expand Up @@ -613,6 +615,26 @@ func (r *Types) FindExtensionByName(field protoreflect.FullName) (protoreflect.E
if xt, _ := v.(protoreflect.ExtensionType); xt != nil {
return xt, nil
}

// MessageSet extensions are special in that the name of the extension
// is the name of the message type used to extend the MessageSet.
// This naming scheme is used by text and JSON serialization.
//
// This feature is protected by the ProtoLegacy flag since MessageSets
// are a proto1 feature that is long deprecated.
if flags.ProtoLegacy {
if _, ok := v.(protoreflect.MessageType); ok {
field := field.Append(messageset.ExtensionName)
if v := r.typesByName[field]; v != nil {
if xt, _ := v.(protoreflect.ExtensionType); xt != nil {
if messageset.IsMessageSetExtension(xt.TypeDescriptor()) {
return xt, nil
}
}
}
}
}

return nil, errors.New("found wrong type: got %v, want extension", typeName(v))
}
return nil, NotFound
Expand Down

0 comments on commit b783214

Please sign in to comment.