diff --git a/encoding/protojson/encode.go b/encoding/protojson/encode.go index 7d6193300..58bdebe31 100644 --- a/encoding/protojson/encode.go +++ b/encoding/protojson/encode.go @@ -7,13 +7,13 @@ package protojson import ( "encoding/base64" "fmt" - "sort" "google.golang.org/protobuf/internal/encoding/json" "google.golang.org/protobuf/internal/encoding/messageset" "google.golang.org/protobuf/internal/errors" "google.golang.org/protobuf/internal/flags" "google.golang.org/protobuf/internal/genid" + "google.golang.org/protobuf/internal/order" "google.golang.org/protobuf/internal/pragma" "google.golang.org/protobuf/proto" pref "google.golang.org/protobuf/reflect/protoreflect" @@ -160,61 +160,71 @@ func (e encoder) marshalMessage(m pref.Message) error { return nil } +// unpopulatedFieldRanger wraps a protoreflect.Message and modifies its Range +// method to additionally iterate over unpopulated fields. +type unpopulatedFieldRanger struct{ pref.Message } + +func (m unpopulatedFieldRanger) Range(f func(pref.FieldDescriptor, pref.Value) bool) { + fds := m.Descriptor().Fields() + for i := 0; i < fds.Len(); i++ { + fd := fds.Get(i) + if m.Has(fd) || fd.ContainingOneof() != nil { + continue // ignore populated fields and fields within a oneofs + } + + v := m.Get(fd) + isProto2Scalar := fd.Syntax() == pref.Proto2 && fd.Default().IsValid() + isSingularMessage := fd.Cardinality() != pref.Repeated && fd.Message() != nil + if isProto2Scalar || isSingularMessage { + v = pref.Value{} // use invalid value to emit null + } + if !f(fd, v) { + return + } + } + m.Message.Range(f) +} + // marshalFields marshals the fields in the given protoreflect.Message. func (e encoder) marshalFields(m pref.Message) error { - messageDesc := m.Descriptor() - if !flags.ProtoLegacy && messageset.IsMessageSet(messageDesc) { + if !flags.ProtoLegacy && messageset.IsMessageSet(m.Descriptor()) { return errors.New("no support for proto1 MessageSets") } - // Marshal out known fields. - fieldDescs := messageDesc.Fields() - for i := 0; i < fieldDescs.Len(); { - fd := fieldDescs.Get(i) - if od := fd.ContainingOneof(); od != nil { - fd = m.WhichOneof(od) - i += od.Fields().Len() - if fd == nil { - continue // unpopulated oneofs are not affected by EmitUnpopulated - } - } else { - i++ - } + var fields order.FieldRanger = m + if e.opts.EmitUnpopulated { + fields = unpopulatedFieldRanger{m} + } - val := m.Get(fd) - if !m.Has(fd) { - if !e.opts.EmitUnpopulated { - continue - } - isProto2Scalar := fd.Syntax() == pref.Proto2 && fd.Default().IsValid() - isSingularMessage := fd.Cardinality() != pref.Repeated && fd.Message() != nil - if isProto2Scalar || isSingularMessage { - // Use invalid value to emit null. - val = pref.Value{} + var err error + order.RangeFields(fields, order.IndexNameFieldOrder, func(fd pref.FieldDescriptor, v pref.Value) bool { + var name string + switch { + case fd.IsExtension(): + if messageset.IsMessageSetExtension(fd) { + name = "[" + string(fd.FullName().Parent()) + "]" + } else { + name = "[" + string(fd.FullName()) + "]" } - } - - name := fd.JSONName() - if e.opts.UseProtoNames { - name = string(fd.Name()) - // Use type name for group field name. + case e.opts.UseProtoNames: if fd.Kind() == pref.GroupKind { name = string(fd.Message().Name()) + } else { + name = string(fd.Name()) } + default: + name = fd.JSONName() } - if err := e.WriteName(name); err != nil { - return err + + if err = e.WriteName(name); err != nil { + return false } - if err := e.marshalValue(val, fd); err != nil { - return err + if err = e.marshalValue(v, fd); err != nil { + return false } - } - - // Marshal out extensions. - if err := e.marshalExtensions(m); err != nil { - return err - } - return nil + return true + }) + return err } // marshalValue marshals the given protoreflect.Value. @@ -305,98 +315,20 @@ func (e encoder) marshalList(list pref.List, fd pref.FieldDescriptor) error { return nil } -type mapEntry struct { - key pref.MapKey - value pref.Value -} - // marshalMap marshals given protoreflect.Map. func (e encoder) marshalMap(mmap pref.Map, fd pref.FieldDescriptor) error { e.StartObject() defer e.EndObject() - // Get a sorted list based on keyType first. - entries := make([]mapEntry, 0, mmap.Len()) - mmap.Range(func(key pref.MapKey, val pref.Value) bool { - entries = append(entries, mapEntry{key: key, value: val}) - return true - }) - sortMap(fd.MapKey().Kind(), entries) - - // Write out sorted list. - for _, entry := range entries { - if err := e.WriteName(entry.key.String()); err != nil { - return err - } - if err := e.marshalSingular(entry.value, fd.MapValue()); err != nil { - return err - } - } - return nil -} - -// sortMap orders list based on value of key field for deterministic ordering. -func sortMap(keyKind pref.Kind, values []mapEntry) { - sort.Slice(values, func(i, j int) bool { - switch keyKind { - case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind, - pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind: - return values[i].key.Int() < values[j].key.Int() - - case pref.Uint32Kind, pref.Fixed32Kind, - pref.Uint64Kind, pref.Fixed64Kind: - return values[i].key.Uint() < values[j].key.Uint() - } - return values[i].key.String() < values[j].key.String() - }) -} - -// marshalExtensions marshals extension fields. -func (e encoder) marshalExtensions(m pref.Message) error { - type entry struct { - key string - value pref.Value - desc pref.FieldDescriptor - } - - // Get a sorted list based on field key first. - var entries []entry - m.Range(func(fd pref.FieldDescriptor, v pref.Value) bool { - if !fd.IsExtension() { - return true + var err error + order.RangeEntries(mmap, order.GenericKeyOrder, func(k pref.MapKey, v pref.Value) bool { + if err = e.WriteName(k.String()); err != nil { + return false } - - // For MessageSet extensions, the name used is the parent message. - name := fd.FullName() - if messageset.IsMessageSetExtension(fd) { - name = name.Parent() + if err = e.marshalSingular(v, fd.MapValue()); err != nil { + return false } - - // Use [name] format for JSON field name. - entries = append(entries, entry{ - key: string(name), - value: v, - desc: fd, - }) return true }) - - // Sort extensions lexicographically. - sort.Slice(entries, func(i, j int) bool { - return entries[i].key < entries[j].key - }) - - // Write out sorted list. - for _, entry := range entries { - // JSON field name is the proto field name enclosed in [], similar to - // textproto. This is consistent with Go v1 lib. C++ lib v3.7.0 does not - // marshal out extension fields. - if err := e.WriteName("[" + entry.key + "]"); err != nil { - return err - } - if err := e.marshalValue(entry.value, entry.desc); err != nil { - return err - } - } - return nil + return err } diff --git a/encoding/protojson/encode_test.go b/encoding/protojson/encode_test.go index b4cf3bb44..0cca93770 100644 --- a/encoding/protojson/encode_test.go +++ b/encoding/protojson/encode_test.go @@ -1060,12 +1060,12 @@ func TestMarshal(t *testing.T) { return m }(), want: `{ - "[pb2.MessageSetExtension]": { - "optString": "a messageset extension" - }, "[pb2.MessageSetExtension.ext_nested]": { "optString": "just a regular extension" }, + "[pb2.MessageSetExtension]": { + "optString": "a messageset extension" + }, "[pb2.MessageSetExtension.not_message_set_extension]": { "optString": "not a messageset extension" } @@ -2123,6 +2123,35 @@ func TestMarshal(t *testing.T) { "optNested": null } ] +}`, + }, { + desc: "EmitUnpopulated: with populated fields", + mo: protojson.MarshalOptions{EmitUnpopulated: true}, + input: &pb2.Scalars{ + OptInt32: proto.Int32(0xff), + OptUint32: proto.Uint32(47), + OptSint32: proto.Int32(-1001), + OptFixed32: proto.Uint32(32), + OptSfixed32: proto.Int32(-32), + OptFloat: proto.Float32(1.02), + OptBytes: []byte("谷歌"), + }, + want: `{ + "optBool": null, + "optInt32": 255, + "optInt64": null, + "optUint32": 47, + "optUint64": null, + "optSint32": -1001, + "optSint64": null, + "optFixed32": 32, + "optFixed64": null, + "optSfixed32": -32, + "optSfixed64": null, + "optFloat": 1.02, + "optDouble": null, + "optBytes": "6LC35q2M", + "optString": null }`, }, { desc: "UseEnumNumbers in singular field", diff --git a/encoding/prototext/encode.go b/encoding/prototext/encode.go index 0877d71c5..3171156de 100644 --- a/encoding/prototext/encode.go +++ b/encoding/prototext/encode.go @@ -6,7 +6,6 @@ package prototext import ( "fmt" - "sort" "strconv" "unicode/utf8" @@ -16,10 +15,11 @@ import ( "google.golang.org/protobuf/internal/errors" "google.golang.org/protobuf/internal/flags" "google.golang.org/protobuf/internal/genid" - "google.golang.org/protobuf/internal/mapsort" + "google.golang.org/protobuf/internal/order" "google.golang.org/protobuf/internal/pragma" "google.golang.org/protobuf/internal/strs" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" pref "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" ) @@ -169,35 +169,30 @@ func (e encoder) marshalMessage(m pref.Message, inclDelims bool) error { // If unable to expand, continue on to marshal Any as a regular message. } - // Marshal known fields. - fieldDescs := messageDesc.Fields() - size := fieldDescs.Len() - for i := 0; i < size; { - fd := fieldDescs.Get(i) - if od := fd.ContainingOneof(); od != nil { - fd = m.WhichOneof(od) - i += od.Fields().Len() + // Marshal fields. + var err error + order.RangeFields(m, order.IndexNameFieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { + var name string + if fd.IsExtension() { + if messageset.IsMessageSetExtension(fd) { + name = "[" + string(fd.FullName().Parent()) + "]" + } else { + name = "[" + string(fd.FullName()) + "]" + } } else { - i++ - } - - if fd == nil || !m.Has(fd) { - continue + if fd.Kind() == pref.GroupKind { + name = string(fd.Message().Name()) + } else { + name = string(fd.Name()) + } } - name := fd.Name() - // Use type name for group field name. - if fd.Kind() == pref.GroupKind { - name = fd.Message().Name() - } - val := m.Get(fd) - if err := e.marshalField(string(name), val, fd); err != nil { - return err + if err = e.marshalField(string(name), v, fd); err != nil { + return false } - } - - // Marshal extensions. - if err := e.marshalExtensions(m); err != nil { + return true + }) + if err != nil { return err } @@ -290,7 +285,7 @@ func (e encoder) marshalList(name string, list pref.List, fd pref.FieldDescripto // marshalMap marshals the given protoreflect.Map as multiple name-value fields. func (e encoder) marshalMap(name string, mmap pref.Map, fd pref.FieldDescriptor) error { var err error - mapsort.Range(mmap, fd.MapKey().Kind(), func(key pref.MapKey, val pref.Value) bool { + order.RangeEntries(mmap, order.GenericKeyOrder, func(key pref.MapKey, val pref.Value) bool { e.WriteName(name) e.StartMessage() defer e.EndMessage() @@ -311,48 +306,6 @@ func (e encoder) marshalMap(name string, mmap pref.Map, fd pref.FieldDescriptor) return err } -// marshalExtensions marshals extension fields. -func (e encoder) marshalExtensions(m pref.Message) error { - type entry struct { - key string - value pref.Value - desc pref.FieldDescriptor - } - - // Get a sorted list based on field key first. - var entries []entry - m.Range(func(fd pref.FieldDescriptor, v pref.Value) bool { - if !fd.IsExtension() { - return true - } - // For MessageSet extensions, the name used is the parent message. - name := fd.FullName() - if messageset.IsMessageSetExtension(fd) { - name = name.Parent() - } - entries = append(entries, entry{ - key: string(name), - value: v, - desc: fd, - }) - return true - }) - // Sort extensions lexicographically. - sort.Slice(entries, func(i, j int) bool { - return entries[i].key < entries[j].key - }) - - // Write out sorted list. - for _, entry := range entries { - // Extension field name is the proto field name enclosed in []. - name := "[" + entry.key + "]" - if err := e.marshalField(name, entry.value, entry.desc); err != nil { - return err - } - } - return nil -} - // marshalUnknown parses the given []byte and marshals fields out. // This function assumes proper encoding in the given []byte. func (e encoder) marshalUnknown(b []byte) { diff --git a/encoding/prototext/encode_test.go b/encoding/prototext/encode_test.go index 4de385cd0..49fba1428 100644 --- a/encoding/prototext/encode_test.go +++ b/encoding/prototext/encode_test.go @@ -1158,12 +1158,12 @@ opt_int32: 42 }) return m }(), - want: `[pb2.MessageSetExtension]: { - opt_string: "a messageset extension" -} -[pb2.MessageSetExtension.ext_nested]: { + want: `[pb2.MessageSetExtension.ext_nested]: { opt_string: "just a regular extension" } +[pb2.MessageSetExtension]: { + opt_string: "a messageset extension" +} [pb2.MessageSetExtension.not_message_set_extension]: { opt_string: "not a messageset extension" } diff --git a/internal/fieldsort/fieldsort.go b/internal/fieldsort/fieldsort.go deleted file mode 100644 index 517c4e2a0..000000000 --- a/internal/fieldsort/fieldsort.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2019 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package fieldsort defines an ordering of fields. -// -// The ordering defined by this package matches the historic behavior of the proto -// package, placing extensions first and oneofs last. -// -// There is no guarantee about stability of the wire encoding, and users should not -// depend on the order defined in this package as it is subject to change without -// notice. -package fieldsort - -import ( - "google.golang.org/protobuf/reflect/protoreflect" -) - -// Less returns true if field a comes before field j in ordered wire marshal output. -func Less(a, b protoreflect.FieldDescriptor) bool { - ea := a.IsExtension() - eb := b.IsExtension() - oa := a.ContainingOneof() - ob := b.ContainingOneof() - switch { - case ea != eb: - return ea - case oa != nil && ob != nil: - if oa == ob { - return a.Number() < b.Number() - } - return oa.Index() < ob.Index() - case oa != nil && !oa.IsSynthetic(): - return false - case ob != nil && !ob.IsSynthetic(): - return true - default: - return a.Number() < b.Number() - } -} diff --git a/internal/impl/codec_message.go b/internal/impl/codec_message.go index 0e176d565..733ad9f11 100644 --- a/internal/impl/codec_message.go +++ b/internal/impl/codec_message.go @@ -11,7 +11,7 @@ import ( "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/internal/encoding/messageset" - "google.golang.org/protobuf/internal/fieldsort" + "google.golang.org/protobuf/internal/order" pref "google.golang.org/protobuf/reflect/protoreflect" piface "google.golang.org/protobuf/runtime/protoiface" ) @@ -136,7 +136,7 @@ func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) { sort.Slice(mi.orderedCoderFields, func(i, j int) bool { fi := fields.ByNumber(mi.orderedCoderFields[i].num) fj := fields.ByNumber(mi.orderedCoderFields[j].num) - return fieldsort.Less(fi, fj) + return order.LegacyFieldOrder(fi, fj) }) } diff --git a/internal/mapsort/mapsort.go b/internal/mapsort/mapsort.go deleted file mode 100644 index a3de1cf32..000000000 --- a/internal/mapsort/mapsort.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2019 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package mapsort provides sorted access to maps. -package mapsort - -import ( - "sort" - - "google.golang.org/protobuf/reflect/protoreflect" -) - -// Range iterates over every map entry in sorted key order, -// calling f for each key and value encountered. -func Range(mapv protoreflect.Map, keyKind protoreflect.Kind, f func(protoreflect.MapKey, protoreflect.Value) bool) { - var keys []protoreflect.MapKey - mapv.Range(func(key protoreflect.MapKey, _ protoreflect.Value) bool { - keys = append(keys, key) - return true - }) - sort.Slice(keys, func(i, j int) bool { - switch keyKind { - case protoreflect.BoolKind: - return !keys[i].Bool() && keys[j].Bool() - case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind, - protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: - return keys[i].Int() < keys[j].Int() - case protoreflect.Uint32Kind, protoreflect.Fixed32Kind, - protoreflect.Uint64Kind, protoreflect.Fixed64Kind: - return keys[i].Uint() < keys[j].Uint() - case protoreflect.StringKind: - return keys[i].String() < keys[j].String() - default: - panic("invalid kind: " + keyKind.String()) - } - }) - for _, key := range keys { - if !f(key, mapv.Get(key)) { - break - } - } -} diff --git a/internal/mapsort/mapsort_test.go b/internal/mapsort/mapsort_test.go deleted file mode 100644 index 6d1794601..000000000 --- a/internal/mapsort/mapsort_test.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2019 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package mapsort_test - -import ( - "strconv" - "testing" - - "google.golang.org/protobuf/internal/mapsort" - pref "google.golang.org/protobuf/reflect/protoreflect" - - testpb "google.golang.org/protobuf/internal/testprotos/test" -) - -func TestRange(t *testing.T) { - m := (&testpb.TestAllTypes{ - MapBoolBool: map[bool]bool{ - false: false, - true: true, - }, - MapInt32Int32: map[int32]int32{ - 0: 0, - 1: 1, - 2: 2, - }, - MapUint64Uint64: map[uint64]uint64{ - 0: 0, - 1: 1, - 2: 2, - }, - MapStringString: map[string]string{ - "0": "0", - "1": "1", - "2": "2", - }, - }).ProtoReflect() - m.Range(func(fd pref.FieldDescriptor, v pref.Value) bool { - mapv := v.Map() - var got []pref.MapKey - mapsort.Range(mapv, fd.MapKey().Kind(), func(key pref.MapKey, _ pref.Value) bool { - got = append(got, key) - return true - }) - for wanti, key := range got { - var goti int - switch x := mapv.Get(key).Interface().(type) { - case bool: - if x { - goti = 1 - } - case int32: - goti = int(x) - case uint64: - goti = int(x) - case string: - goti, _ = strconv.Atoi(x) - default: - t.Fatalf("unhandled map value type %T", x) - } - if wanti != goti { - t.Errorf("out of order range over map field %v: %v", fd.FullName(), got) - break - } - } - return true - }) -} diff --git a/internal/msgfmt/format.go b/internal/msgfmt/format.go index 9547a5301..f01cf606b 100644 --- a/internal/msgfmt/format.go +++ b/internal/msgfmt/format.go @@ -20,7 +20,7 @@ import ( "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/internal/detrand" "google.golang.org/protobuf/internal/genid" - "google.golang.org/protobuf/internal/mapsort" + "google.golang.org/protobuf/internal/order" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" @@ -64,25 +64,8 @@ func appendMessage(b []byte, m protoreflect.Message) []byte { return b2 } - var fds []protoreflect.FieldDescriptor - m.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool { - fds = append(fds, fd) - return true - }) - sort.Slice(fds, func(i, j int) bool { - fdi, fdj := fds[i], fds[j] - switch { - case !fdi.IsExtension() && !fdj.IsExtension(): - return fdi.Index() < fdj.Index() - case fdi.IsExtension() && fdj.IsExtension(): - return fdi.FullName() < fdj.FullName() - default: - return !fdi.IsExtension() && fdj.IsExtension() - } - }) - b = append(b, '{') - for _, fd := range fds { + order.RangeFields(m, order.IndexNameFieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { k := string(fd.Name()) if fd.IsExtension() { k = string("[" + fd.FullName() + "]") @@ -90,9 +73,10 @@ func appendMessage(b []byte, m protoreflect.Message) []byte { b = append(b, k...) b = append(b, ':') - b = appendValue(b, m.Get(fd), fd) + b = appendValue(b, v, fd) b = append(b, delim()...) - } + return true + }) b = appendUnknown(b, m.GetUnknown()) b = bytes.TrimRight(b, delim()) b = append(b, '}') @@ -247,19 +231,14 @@ func appendList(b []byte, v protoreflect.List, fd protoreflect.FieldDescriptor) } func appendMap(b []byte, v protoreflect.Map, fd protoreflect.FieldDescriptor) []byte { - var ks []protoreflect.MapKey - mapsort.Range(v, fd.MapKey().Kind(), func(k protoreflect.MapKey, _ protoreflect.Value) bool { - ks = append(ks, k) - return true - }) - b = append(b, '{') - for _, k := range ks { + order.RangeEntries(v, order.GenericKeyOrder, func(k protoreflect.MapKey, v protoreflect.Value) bool { b = appendValue(b, k.Value(), fd.MapKey()) b = append(b, ':') - b = appendValue(b, v.Get(k), fd.MapValue()) + b = appendValue(b, v, fd.MapValue()) b = append(b, delim()...) - } + return true + }) b = bytes.TrimRight(b, delim()) b = append(b, '}') return b diff --git a/internal/order/order.go b/internal/order/order.go new file mode 100644 index 000000000..2a24953f6 --- /dev/null +++ b/internal/order/order.go @@ -0,0 +1,89 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package order + +import ( + pref "google.golang.org/protobuf/reflect/protoreflect" +) + +// FieldOrder specifies the ordering to visit message fields. +// It is a function that reports whether x is ordered before y. +type FieldOrder func(x, y pref.FieldDescriptor) bool + +var ( + // AnyFieldOrder specifies no specific field ordering. + AnyFieldOrder FieldOrder = nil + + // LegacyFieldOrder sorts fields in the same ordering as emitted by + // wire serialization in the github.com/golang/protobuf implementation. + LegacyFieldOrder FieldOrder = func(x, y pref.FieldDescriptor) bool { + ox, oy := x.ContainingOneof(), y.ContainingOneof() + inOneof := func(od pref.OneofDescriptor) bool { + return od != nil && !od.IsSynthetic() + } + + // Extension fields sort before non-extension fields. + if x.IsExtension() != y.IsExtension() { + return x.IsExtension() && !y.IsExtension() + } + // Fields not within a oneof sort before those within a oneof. + if inOneof(ox) != inOneof(oy) { + return !inOneof(ox) && inOneof(oy) + } + // Fields in disjoint oneof sets are sorted by declaration index. + if ox != nil && oy != nil && ox != oy { + return ox.Index() < oy.Index() + } + // Fields sorted by field number. + return x.Number() < y.Number() + } + + // NumberFieldOrder sorts fields by their field number. + NumberFieldOrder FieldOrder = func(x, y pref.FieldDescriptor) bool { + return x.Number() < y.Number() + } + + // IndexNameFieldOrder sorts non-extension fields before extension fields. + // Non-extensions are sorted according to their declaration index. + // Extensions are sorted according to their full name. + IndexNameFieldOrder FieldOrder = func(x, y pref.FieldDescriptor) bool { + // Non-extension fields sort before extension fields. + if x.IsExtension() != y.IsExtension() { + return !x.IsExtension() && y.IsExtension() + } + // Extensions sorted by fullname. + if x.IsExtension() && y.IsExtension() { + return x.FullName() < y.FullName() + } + // Non-extensions sorted by declaration index. + return x.Index() < y.Index() + } +) + +// KeyOrder specifies the ordering to visit map entries. +// It is a function that reports whether x is ordered before y. +type KeyOrder func(x, y pref.MapKey) bool + +var ( + // AnyKeyOrder specifies no specific key ordering. + AnyKeyOrder KeyOrder = nil + + // GenericKeyOrder sorts false before true, numeric keys in ascending order, + // and strings in lexicographical ordering according to UTF-8 codepoints. + GenericKeyOrder KeyOrder = func(x, y pref.MapKey) bool { + switch x.Interface().(type) { + case bool: + return !x.Bool() && y.Bool() + case int32, int64: + return x.Int() < y.Int() + case uint32, uint64: + return x.Uint() < y.Uint() + case string: + return x.String() < y.String() + default: + panic("invalid map key type") + } + } +) diff --git a/internal/order/order_test.go b/internal/order/order_test.go new file mode 100644 index 000000000..ecf5e1827 --- /dev/null +++ b/internal/order/order_test.go @@ -0,0 +1,175 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package order + +import ( + "math/rand" + "sort" + "testing" + + "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/reflect/protoreflect" + pref "google.golang.org/protobuf/reflect/protoreflect" +) + +type fieldDesc struct { + index int + name protoreflect.FullName + number protoreflect.FieldNumber + extension bool + oneofIndex int // non-zero means within oneof; negative means synthetic + pref.FieldDescriptor +} + +func (d fieldDesc) Index() int { return d.index } +func (d fieldDesc) Name() pref.Name { return d.name.Name() } +func (d fieldDesc) FullName() pref.FullName { return d.name } +func (d fieldDesc) Number() pref.FieldNumber { return d.number } +func (d fieldDesc) IsExtension() bool { return d.extension } +func (d fieldDesc) ContainingOneof() pref.OneofDescriptor { + switch { + case d.oneofIndex < 0: + return oneofDesc{index: -d.oneofIndex, synthetic: true} + case d.oneofIndex > 0: + return oneofDesc{index: +d.oneofIndex, synthetic: false} + default: + return nil + } +} + +type oneofDesc struct { + index int + synthetic bool + pref.OneofDescriptor +} + +func (d oneofDesc) Index() int { return d.index } +func (d oneofDesc) IsSynthetic() bool { return d.synthetic } + +func TestFieldOrder(t *testing.T) { + tests := []struct { + label string + order FieldOrder + fields []fieldDesc + }{{ + label: "LegacyFieldOrder", + order: LegacyFieldOrder, + fields: []fieldDesc{ + // Extension fields sorted first by field number. + {number: 2, extension: true}, + {number: 4, extension: true}, + {number: 100, extension: true}, + {number: 120, extension: true}, + + // Non-extension fields that are not within a oneof + // sorted next by field number. + {number: 1}, + {number: 5, oneofIndex: -9}, // synthetic oneof + {number: 10}, + {number: 11, oneofIndex: -10}, // synthetic oneof + {number: 12}, + + // Non-synthetic oneofs sorted last by index. + {number: 13, oneofIndex: 4}, + {number: 3, oneofIndex: 5}, + {number: 9, oneofIndex: 5}, + {number: 7, oneofIndex: 8}, + }, + }, { + label: "NumberFieldOrder", + order: NumberFieldOrder, + fields: []fieldDesc{ + {number: 1, index: 5, name: "c"}, + {number: 2, index: 2, name: "b"}, + {number: 3, index: 3, name: "d"}, + {number: 5, index: 1, name: "a"}, + {number: 7, index: 7, name: "e"}, + }, + }, { + label: "IndexNameFieldOrder", + order: IndexNameFieldOrder, + fields: []fieldDesc{ + // Non-extension fields sorted first by index. + {index: 0, number: 5, name: "c"}, + {index: 2, number: 2, name: "a"}, + {index: 4, number: 4, name: "b"}, + {index: 7, number: 6, name: "d"}, + + // Extension fields sorted last by full name. + {index: 3, number: 1, name: "d.a", extension: true}, + {index: 5, number: 3, name: "e", extension: true}, + {index: 1, number: 7, name: "g", extension: true}, + }, + }} + + for _, tt := range tests { + t.Run(tt.label, func(t *testing.T) { + want := tt.fields + got := append([]fieldDesc(nil), want...) + for i, j := range rand.Perm(len(got)) { + got[i], got[j] = got[j], got[i] + } + sort.Slice(got, func(i, j int) bool { + return tt.order(got[i], got[j]) + }) + if diff := cmp.Diff(want, got, + cmp.Comparer(func(x, y fieldDesc) bool { return x == y }), + ); diff != "" { + t.Errorf("order mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestKeyOrder(t *testing.T) { + tests := []struct { + label string + order KeyOrder + keys []interface{} + }{{ + label: "GenericKeyOrder", + order: GenericKeyOrder, + keys: []interface{}{false, true}, + }, { + label: "GenericKeyOrder", + order: GenericKeyOrder, + keys: []interface{}{int32(-100), int32(-99), int32(-10), int32(-9), int32(-1), int32(0), int32(+1), int32(+9), int32(+10), int32(+99), int32(+100)}, + }, { + label: "GenericKeyOrder", + order: GenericKeyOrder, + keys: []interface{}{int64(-100), int64(-99), int64(-10), int64(-9), int64(-1), int64(0), int64(+1), int64(+9), int64(+10), int64(+99), int64(+100)}, + }, { + label: "GenericKeyOrder", + order: GenericKeyOrder, + keys: []interface{}{uint32(0), uint32(1), uint32(9), uint32(10), uint32(99), uint32(100)}, + }, { + label: "GenericKeyOrder", + order: GenericKeyOrder, + keys: []interface{}{uint64(0), uint64(1), uint64(9), uint64(10), uint64(99), uint64(100)}, + }, { + label: "GenericKeyOrder", + order: GenericKeyOrder, + keys: []interface{}{"", "a", "aa", "ab", "ba", "bb", "\u0080", "\u0080\u0081", "\u0082\u0080"}, + }} + + for _, tt := range tests { + t.Run(tt.label, func(t *testing.T) { + var got, want []protoreflect.MapKey + for _, v := range tt.keys { + want = append(want, pref.ValueOf(v).MapKey()) + } + got = append(got, want...) + for i, j := range rand.Perm(len(got)) { + got[i], got[j] = got[j], got[i] + } + sort.Slice(got, func(i, j int) bool { + return tt.order(got[i], got[j]) + }) + if diff := cmp.Diff(want, got, cmp.Transformer("", protoreflect.MapKey.Interface)); diff != "" { + t.Errorf("order mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/internal/order/range.go b/internal/order/range.go new file mode 100644 index 000000000..c8090e0c5 --- /dev/null +++ b/internal/order/range.go @@ -0,0 +1,115 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package order provides ordered access to messages and maps. +package order + +import ( + "sort" + "sync" + + pref "google.golang.org/protobuf/reflect/protoreflect" +) + +type messageField struct { + fd pref.FieldDescriptor + v pref.Value +} + +var messageFieldPool = sync.Pool{ + New: func() interface{} { return new([]messageField) }, +} + +type ( + // FieldRnger is an interface for visiting all fields in a message. + // The protoreflect.Message type implements this interface. + FieldRanger interface{ Range(VisitField) } + // VisitField is called everytime a message field is visited. + VisitField = func(pref.FieldDescriptor, pref.Value) bool +) + +// RangeFields iterates over the fields of fs according to the specified order. +func RangeFields(fs FieldRanger, less FieldOrder, fn VisitField) { + if less == nil { + fs.Range(fn) + return + } + + // Obtain a pre-allocated scratch buffer. + p := messageFieldPool.Get().(*[]messageField) + fields := (*p)[:0] + defer func() { + if cap(fields) < 1024 { + *p = fields + messageFieldPool.Put(p) + } + }() + + // Collect all fields in the message and sort them. + fs.Range(func(fd pref.FieldDescriptor, v pref.Value) bool { + fields = append(fields, messageField{fd, v}) + return true + }) + sort.Slice(fields, func(i, j int) bool { + return less(fields[i].fd, fields[j].fd) + }) + + // Visit the fields in the specified ordering. + for _, f := range fields { + if !fn(f.fd, f.v) { + return + } + } +} + +type mapEntry struct { + k pref.MapKey + v pref.Value +} + +var mapEntryPool = sync.Pool{ + New: func() interface{} { return new([]mapEntry) }, +} + +type ( + // EntryRanger is an interface for visiting all fields in a message. + // The protoreflect.Map type implements this interface. + EntryRanger interface{ Range(VisitEntry) } + // VisitEntry is called everytime a map entry is visited. + VisitEntry = func(pref.MapKey, pref.Value) bool +) + +// RangeEntries iterates over the entries of es according to the specified order. +func RangeEntries(es EntryRanger, less KeyOrder, fn VisitEntry) { + if less == nil { + es.Range(fn) + return + } + + // Obtain a pre-allocated scratch buffer. + p := mapEntryPool.Get().(*[]mapEntry) + entries := (*p)[:0] + defer func() { + if cap(entries) < 1024 { + *p = entries + mapEntryPool.Put(p) + } + }() + + // Collect all entries in the map and sort them. + es.Range(func(k pref.MapKey, v pref.Value) bool { + entries = append(entries, mapEntry{k, v}) + return true + }) + sort.Slice(entries, func(i, j int) bool { + return less(entries[i].k, entries[j].k) + }) + + // Visit the entries in the specified ordering. + for _, e := range entries { + if !fn(e.k, e.v) { + return + } + } +} diff --git a/proto/encode.go b/proto/encode.go index 7b47a1180..d18239c23 100644 --- a/proto/encode.go +++ b/proto/encode.go @@ -5,12 +5,9 @@ package proto import ( - "sort" - "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/internal/encoding/messageset" - "google.golang.org/protobuf/internal/fieldsort" - "google.golang.org/protobuf/internal/mapsort" + "google.golang.org/protobuf/internal/order" "google.golang.org/protobuf/internal/pragma" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/runtime/protoiface" @@ -211,14 +208,15 @@ func (o MarshalOptions) marshalMessageSlow(b []byte, m protoreflect.Message) ([] if messageset.IsMessageSet(m.Descriptor()) { return o.marshalMessageSet(b, m) } - // There are many choices for what order we visit fields in. The default one here - // is chosen for reasonable efficiency and simplicity given the protoreflect API. - // It is not deterministic, since Message.Range does not return fields in any - // defined order. - // - // When using deterministic serialization, we sort the known fields. + fieldOrder := order.AnyFieldOrder + if o.Deterministic { + // TODO: This should use a more natural ordering like NumberFieldOrder, + // but doing so breaks golden tests that make invalid assumption about + // output stability of this implementation. + fieldOrder = order.LegacyFieldOrder + } var err error - o.rangeFields(m, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { + order.RangeFields(m, fieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { b, err = o.marshalField(b, fd, v) return err == nil }) @@ -229,27 +227,6 @@ func (o MarshalOptions) marshalMessageSlow(b []byte, m protoreflect.Message) ([] return b, nil } -// rangeFields visits fields in a defined order when deterministic serialization is enabled. -func (o MarshalOptions) rangeFields(m protoreflect.Message, f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) { - if !o.Deterministic { - m.Range(f) - return - } - var fds []protoreflect.FieldDescriptor - m.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool { - fds = append(fds, fd) - return true - }) - sort.Slice(fds, func(a, b int) bool { - return fieldsort.Less(fds[a], fds[b]) - }) - for _, fd := range fds { - if !f(fd, m.Get(fd)) { - break - } - } -} - func (o MarshalOptions) marshalField(b []byte, fd protoreflect.FieldDescriptor, value protoreflect.Value) ([]byte, error) { switch { case fd.IsList(): @@ -292,8 +269,12 @@ func (o MarshalOptions) marshalList(b []byte, fd protoreflect.FieldDescriptor, l func (o MarshalOptions) marshalMap(b []byte, fd protoreflect.FieldDescriptor, mapv protoreflect.Map) ([]byte, error) { keyf := fd.MapKey() valf := fd.MapValue() + keyOrder := order.AnyKeyOrder + if o.Deterministic { + keyOrder = order.GenericKeyOrder + } var err error - o.rangeMap(mapv, keyf.Kind(), func(key protoreflect.MapKey, value protoreflect.Value) bool { + order.RangeEntries(mapv, keyOrder, func(key protoreflect.MapKey, value protoreflect.Value) bool { b = protowire.AppendTag(b, fd.Number(), protowire.BytesType) var pos int b, pos = appendSpeculativeLength(b) @@ -312,14 +293,6 @@ func (o MarshalOptions) marshalMap(b []byte, fd protoreflect.FieldDescriptor, ma return b, err } -func (o MarshalOptions) rangeMap(mapv protoreflect.Map, kind protoreflect.Kind, f func(protoreflect.MapKey, protoreflect.Value) bool) { - if !o.Deterministic { - mapv.Range(f) - return - } - mapsort.Range(mapv, kind, f) -} - // When encoding length-prefixed fields, we speculatively set aside some number of bytes // for the length, encode the data, and then encode the length (shifting the data if necessary // to make room). diff --git a/proto/messageset.go b/proto/messageset.go index 1d692c3a8..312d5d45c 100644 --- a/proto/messageset.go +++ b/proto/messageset.go @@ -9,6 +9,7 @@ import ( "google.golang.org/protobuf/internal/encoding/messageset" "google.golang.org/protobuf/internal/errors" "google.golang.org/protobuf/internal/flags" + "google.golang.org/protobuf/internal/order" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" ) @@ -28,8 +29,12 @@ func (o MarshalOptions) marshalMessageSet(b []byte, m protoreflect.Message) ([]b if !flags.ProtoLegacy { return b, errors.New("no support for message_set_wire_format") } + fieldOrder := order.AnyFieldOrder + if o.Deterministic { + fieldOrder = order.NumberFieldOrder + } var err error - o.rangeFields(m, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { + order.RangeFields(m, fieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { b, err = o.marshalMessageSetField(b, fd, v) return err == nil })