Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: add HybridResolver #62

Merged
merged 8 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,8 @@ require (
golang.org/x/text v0.9.0 // indirect
google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f // indirect
)

retract (
// API changed in an incompatible way
v1.4.8
)
91 changes: 0 additions & 91 deletions proto/all_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1646,97 +1646,6 @@ func TestExtensionMapFieldMarshalDeterministic(t *testing.T) {
}
}

// Many extensions, because small maps might not iterate differently on each iteration.
var exts = []*ExtensionDesc{
E_X201,
E_X202,
E_X203,
E_X204,
E_X205,
E_X206,
E_X207,
E_X208,
E_X209,
E_X210,
E_X211,
E_X212,
E_X213,
E_X214,
E_X215,
E_X216,
E_X217,
E_X218,
E_X219,
E_X220,
E_X221,
E_X222,
E_X223,
E_X224,
E_X225,
E_X226,
E_X227,
E_X228,
E_X229,
E_X230,
E_X231,
E_X232,
E_X233,
E_X234,
E_X235,
E_X236,
E_X237,
E_X238,
E_X239,
E_X240,
E_X241,
E_X242,
E_X243,
E_X244,
E_X245,
E_X246,
E_X247,
E_X248,
E_X249,
E_X250,
}

func TestMessageSetMarshalOrder(t *testing.T) {
m := &MyMessageSet{}
for _, x := range exts {
if err := SetExtension(m, x, &Empty{}); err != nil {
t.Fatalf("SetExtension: %v", err)
}
}

buf, err := Marshal(m)
if err != nil {
t.Fatalf("Marshal: %v", err)
}

// Serialize m several times, and check we get the same bytes each time.
for i := 0; i < 10; i++ {
b1, err := Marshal(m)
if err != nil {
t.Fatalf("Marshal: %v", err)
}
if !bytes.Equal(b1, buf) {
t.Errorf("Bytes differ on re-Marshal #%d", i)
}

m2 := &MyMessageSet{}
if err = Unmarshal(buf, m2); err != nil {
t.Errorf("Unmarshal: %v", err)
}
b2, err := Marshal(m2)
if err != nil {
t.Errorf("re-Marshal: %v", err)
}
if !bytes.Equal(b2, buf) {
t.Errorf("Bytes differ on round-trip #%d", i)
}
}
}

func TestUnmarshalMergesMessages(t *testing.T) {
// If a nested message occurs twice in the input,
// the fields should be merged when decoding.
Expand Down
6 changes: 0 additions & 6 deletions proto/extensions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,12 +458,6 @@ func TestExtensionsRoundTrip(t *testing.T) {
if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
t.Errorf("got %v, expected ErrMissingExtension", e)
}
if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
t.Error("expected bad extension error, got nil")
}
if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
t.Error("expected extension err")
}
if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
t.Error("expected some sort of type mismatch error, got nil")
}
Expand Down
130 changes: 77 additions & 53 deletions proto/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package proto

import (
"bytes"
"compress/gzip"
"errors"
"fmt"
"runtime"
Expand All @@ -24,8 +23,8 @@ import (
//
// In contrast to MergedFileDescriptorsWithValidation,
// MergedFileDescriptors does not validate import paths
func MergedFileDescriptors(globalFiles *protoregistry.Files, appFiles map[string][]byte) (*descriptorpb.FileDescriptorSet, error) {
return mergedFileDescriptors(globalFiles, appFiles, false)
func MergedFileDescriptors(globalFiles *protoregistry.Files, gogoFiles *protoregistry.Files) (*descriptorpb.FileDescriptorSet, error) {
return mergedFileDescriptors(globalFiles, gogoFiles, false)
}

// MergedFileDescriptorsWithValidation returns a single FileDescriptorSet containing all the
Expand All @@ -34,22 +33,22 @@ func MergedFileDescriptors(globalFiles *protoregistry.Files, appFiles map[string
// If there are any incorrect import paths that do not match
// the fully qualified package name, or if there is a common file descriptor
// that differs accross globalFiles and appFiles, an error is returned.
func MergedFileDescriptorsWithValidation(globalFiles *protoregistry.Files, appFiles map[string][]byte) (*descriptorpb.FileDescriptorSet, error) {
return mergedFileDescriptors(globalFiles, appFiles, true)
func MergedFileDescriptorsWithValidation(globalFiles *protoregistry.Files, gogoFiles *protoregistry.Files) (*descriptorpb.FileDescriptorSet, error) {
return mergedFileDescriptors(globalFiles, gogoFiles, true)
}

// MergedGlobalFileDescriptors calls MergedFileDescriptors
// with [protoregistry.GlobalFiles] and all files
// registered through [RegisterFile].
func MergedGlobalFileDescriptors() (*descriptorpb.FileDescriptorSet, error) {
return MergedFileDescriptors(protoregistry.GlobalFiles, protoFiles)
return MergedFileDescriptors(protoregistry.GlobalFiles, gogoProtoRegistry)
}

// MergedGlobalFileDescriptorsWithValidation calls MergedFileDescriptorsWithValidation
// with [protoregistry.GlobalFiles] and all files
// registered through [RegisterFile].
func MergedGlobalFileDescriptorsWithValidation() (*descriptorpb.FileDescriptorSet, error) {
return MergedFileDescriptorsWithValidation(protoregistry.GlobalFiles, protoFiles)
return MergedFileDescriptorsWithValidation(protoregistry.GlobalFiles, gogoProtoRegistry)
}

// MergedRegistry returns a *protoregistry.Files that acts as a single registry
Expand Down Expand Up @@ -177,7 +176,7 @@ LOOP:
type descriptorProcessor struct {
processWG sync.WaitGroup
globalFileCh chan protoreflect.FileDescriptor
appFileCh chan []byte
appFileCh chan protoreflect.FileDescriptor

fdWG sync.WaitGroup
fdCh chan *descriptorpb.FileDescriptorProto
Expand All @@ -186,7 +185,7 @@ type descriptorProcessor struct {

// process reads from p.globalFileCh and p.appFileCh, processing each file descriptor as appropriate,
// and sends the processed file descriptors through p.fdCh for eventual return from mergedFileDescriptors.
// Any errors during processing are sent to ec.ProcessErrCh,
// Any errors during processing are sent to ec.ProcessErrCh,
// which collects the errors also for possible return from mergedFileDescriptors.
//
// If validate is true, extra work is performed to validate import paths
Expand All @@ -213,45 +212,19 @@ func (p *descriptorProcessor) process(globalFiles *protoregistry.Files, ec *desc
}

// Now handle all the app files.

// Reuse a single gzip reader throughout the loop,
// so we don't have to repeatedly allocate new readers.
gzr := new(gzip.Reader)

// Also reuse a single byte buffer for each gzip read.
buf := new(bytes.Buffer)

for compressedBz := range p.appFileCh {
if err := gzr.Reset(bytes.NewReader(compressedBz)); err != nil {
// This should only fail if there is an invalid gzip header in compressedBz.
ec.ProcessErrCh <- fmt.Errorf("failed to reset gzip reader: %w", err)
continue
}

buf.Reset()
if _, err := buf.ReadFrom(gzr); err != nil {
// This should only fail if there was invalidly gzipped content in compressedBz.
ec.ProcessErrCh <- fmt.Errorf("failed to read from gzip reader: %w", err)
continue
}

fd := &descriptorpb.FileDescriptorProto{}
if err := protov2.Unmarshal(buf.Bytes(), fd); err != nil {
// This should only fail if the gzipped data contained invalid bytes for a FileDescriptorProto.
ec.ProcessErrCh <- err
continue
}

for gogoFd := range p.appFileCh {
// If the app FD is not in protoregistry, we need to track it.
gogoFdp := protodesc.ToFileDescriptorProto(gogoFd)
if validate {
// Ensure the import path on the app file is good.
if err := CheckImportPath(fd.GetName(), fd.GetPackage()); err != nil {
if err := CheckImportPath(gogoFdp.GetName(), gogoFdp.GetPackage()); err != nil {
// Track the import error but don't stop processing.
// It is more helpful to present all the import errors,
// rather than just stopping on the first one.
ec.ImportErrCh <- err
// Don't break the loop here, continue to check for a file descriptor diff.
}
}

// If the app FD is not in protoregistry, we need to track it.
protoregFd, err := globalFiles.FindFileByPath(*fd.Name)
protoregFd, err := globalFiles.FindFileByPath(*gogoFdp.Name)
if err != nil {
if !errors.Is(err, protoregistry.NotFound) {
// Non-nil error, and it wasn't a not found error.
Expand All @@ -260,15 +233,16 @@ func (p *descriptorProcessor) process(globalFiles *protoregistry.Files, ec *desc
}
// Otherwise it was a not found error, so add it.
// At this point we can't validate.
p.fdCh <- fd
p.fdCh <- gogoFdp
continue
}

if validate {
fdp := protodesc.ToFileDescriptorProto(protoregFd)
if !protov2.Equal(fdp, fd) {
diff := cmp.Diff(fdp, fd, protocmp.Transform())
ec.DiffCh <- fmt.Sprintf("Mismatch in %s:\n%s", *fd.Name, diff)

if !protov2.Equal(fdp, gogoFdp) {
diff := cmp.Diff(fdp, gogoFdp, protocmp.Transform())
ec.DiffCh <- fmt.Sprintf("Mismatch in %s:\n%s", *gogoFdp.Name, diff)
}
}
}
Expand All @@ -295,7 +269,7 @@ func (p *descriptorProcessor) collectFDs() {
// If validate is true, do extra work to validate that import paths are properly formed
// and that "duplicated" file descriptors across globalFiles and appFiles
// are indeed identical, returning an error if either of those conditions are invalidated.
func mergedFileDescriptors(globalFiles *protoregistry.Files, appFiles map[string][]byte, validate bool) (*descriptorpb.FileDescriptorSet, error) {
func mergedFileDescriptors(globalFiles *protoregistry.Files, gogoFiles *protoregistry.Files, validate bool) (*descriptorpb.FileDescriptorSet, error) {
// GOMAXPROCS is the number of CPU cores available, by default.
// Respect that setting as the number of CPU-bound goroutines,
// and for channel sizes.
Expand All @@ -305,7 +279,7 @@ func mergedFileDescriptors(globalFiles *protoregistry.Files, appFiles map[string

p := &descriptorProcessor{
globalFileCh: make(chan protoreflect.FileDescriptor, nProcs),
appFileCh: make(chan []byte, nProcs),
appFileCh: make(chan protoreflect.FileDescriptor, nProcs),

fdCh: make(chan *descriptorpb.FileDescriptorProto, nProcs),
fds: make([]*descriptorpb.FileDescriptorProto, 0, globalFiles.NumFiles()),
Expand All @@ -330,10 +304,11 @@ func mergedFileDescriptors(globalFiles *protoregistry.Files, appFiles map[string
// Signal that no more global files will be sent.
close(p.globalFileCh)

// Same for appFiles: send everything then signal app files are finished.
for _, bz := range appFiles {
p.appFileCh <- bz
}
// Same for gogoFiles: send everything then signal app files are finished.
gogoFiles.RangeFiles(func(fileDescriptor protoreflect.FileDescriptor) bool {
p.appFileCh <- fileDescriptor
return true
})
close(p.appFileCh)

// Since we are done sending file descriptors and we have closed those channels,
Expand All @@ -360,3 +335,52 @@ func mergedFileDescriptors(globalFiles *protoregistry.Files, appFiles map[string
File: p.fds,
}, nil
}

// HybridResolver is a protodesc.Resolver that uses both protoregistry.GlobalFiles
// and the gogo proto global registry, checking protoregistry.GlobalFiles first and
// then gogo proto global registry.
var HybridResolver Resolver = &hybridResolver{}

// Resolver is a protodesc.Resolver that can range over all the files in the resolver.
type Resolver interface {
protodesc.Resolver

// RangeFiles calls f for each file descriptor in the resolver while f returns true.
RangeFiles(f func(fileDescriptor protoreflect.FileDescriptor) bool)
}

type hybridResolver struct{}

var _ protodesc.Resolver = &hybridResolver{}

func (r *hybridResolver) FindFileByPath(path string) (protoreflect.FileDescriptor, error) {
if fd, err := protoregistry.GlobalFiles.FindFileByPath(path); err == nil {
return fd, nil
}

return gogoProtoRegistry.FindFileByPath(path)
}

func (r *hybridResolver) FindDescriptorByName(name protoreflect.FullName) (protoreflect.Descriptor, error) {
if desc, err := protoregistry.GlobalFiles.FindDescriptorByName(name); err == nil {
return desc, nil
}

return gogoProtoRegistry.FindDescriptorByName(name)
}

func (r *hybridResolver) RangeFiles(f func(fileDescriptor protoreflect.FileDescriptor) bool) {
seen := make(map[protoreflect.FullName]bool, protoregistry.GlobalFiles.NumFiles())

protoregistry.GlobalFiles.RangeFiles(func(fd protoreflect.FileDescriptor) bool {
seen[fd.FullName()] = true
return f(fd)
})

gogoProtoRegistry.RangeFiles(func(fd protoreflect.FileDescriptor) bool {
if seen[fd.FullName()] {
return true
}
return f(fd)
})
}
Loading