Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cherry-pick #7557 to v1.66.x branch #7564

Merged
merged 1 commit into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,11 @@ type baseCodec interface {
// with encoding.GetCodec and if it is registered wraps it with newCodecV1Bridge
// to turn it into an encoding.CodecV2. Returns nil otherwise.
func getCodec(name string) encoding.CodecV2 {
codecV2 := encoding.GetCodecV2(name)
if codecV2 != nil {
return codecV2
}

codecV1 := encoding.GetCodec(name)
if codecV1 != nil {
if codecV1 := encoding.GetCodec(name); codecV1 != nil {
return newCodecV1Bridge(codecV1)
}

return nil
return encoding.GetCodecV2(name)
}

func newCodecV0Bridge(c Codec) baseCodec {
Expand Down
2 changes: 1 addition & 1 deletion codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
)

func (s) TestGetCodecForProtoIsNotNil(t *testing.T) {
if encoding.GetCodec(proto.Name) == nil {
if encoding.GetCodecV2(proto.Name) == nil {
t.Fatalf("encoding.GetCodec(%q) must not be nil by default", proto.Name)
}
}
5 changes: 3 additions & 2 deletions encoding/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ type Codec interface {
Name() string
}

var registeredCodecs = make(map[string]Codec)
var registeredCodecs = make(map[string]any)

// RegisterCodec registers the provided Codec for use with all gRPC clients and
// servers.
Expand Down Expand Up @@ -126,5 +126,6 @@ func RegisterCodec(codec Codec) {
//
// The content-subtype is expected to be lowercase.
func GetCodec(contentSubtype string) Codec {
return registeredCodecs[contentSubtype]
c, _ := registeredCodecs[contentSubtype].(Codec)
return c
}
37 changes: 19 additions & 18 deletions encoding/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

Expand Down Expand Up @@ -90,18 +91,18 @@ type errProtoCodec struct {
decodingErr error
}

func (c *errProtoCodec) Marshal(v any) ([]byte, error) {
func (c *errProtoCodec) Marshal(v any) (mem.BufferSlice, error) {
if c.encodingErr != nil {
return nil, c.encodingErr
}
return encoding.GetCodec(proto.Name).Marshal(v)
return encoding.GetCodecV2(proto.Name).Marshal(v)
}

func (c *errProtoCodec) Unmarshal(data []byte, v any) error {
func (c *errProtoCodec) Unmarshal(data mem.BufferSlice, v any) error {
if c.decodingErr != nil {
return c.decodingErr
}
return encoding.GetCodec(proto.Name).Unmarshal(data, v)
return encoding.GetCodecV2(proto.Name).Unmarshal(data, v)
}

func (c *errProtoCodec) Name() string {
Expand All @@ -118,7 +119,7 @@ func (s) TestEncodeDoesntPanicOnServer(t *testing.T) {
ec := &errProtoCodec{name: t.Name(), encodingErr: encodingErr}

// Start a server with the above codec.
backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodec(ec))
backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodecV2(ec))
defer backend.Stop()

// Create a channel to the above server.
Expand Down Expand Up @@ -154,7 +155,7 @@ func (s) TestDecodeDoesntPanicOnServer(t *testing.T) {
ec := &errProtoCodec{name: t.Name(), decodingErr: decodingErr}

// Start a server with the above codec.
backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodec(ec))
backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodecV2(ec))
defer backend.Stop()

// Create a channel to the above server. Since we do not specify any codec
Expand Down Expand Up @@ -206,15 +207,15 @@ func (s) TestEncodeDoesntPanicOnClient(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
client := testgrpc.NewTestServiceClient(cc)
_, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(ec))
_, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodecV2(ec))
if err == nil || !strings.Contains(err.Error(), encodingErr.Error()) {
t.Fatalf("RPC failed with error: %v, want: %v", err, encodingErr)
}

// Configure the codec on the client to not return errors anymore and expect
// the RPC to succeed.
ec.encodingErr = nil
if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(ec)); err != nil {
if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodecV2(ec)); err != nil {
t.Fatalf("RPC failed with error: %v", err)
}
}
Expand Down Expand Up @@ -242,15 +243,15 @@ func (s) TestDecodeDoesntPanicOnClient(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
client := testgrpc.NewTestServiceClient(cc)
_, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(ec))
_, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodecV2(ec))
if err == nil || !strings.Contains(err.Error(), decodingErr.Error()) {
t.Fatalf("RPC failed with error: %v, want: %v", err, decodingErr)
}

// Configure the codec on the client to not return errors anymore and expect
// the RPC to succeed.
ec.decodingErr = nil
if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(ec)); err != nil {
if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodecV2(ec)); err != nil {
t.Fatalf("RPC failed with error: %v", err)
}
}
Expand All @@ -265,14 +266,14 @@ type countingProtoCodec struct {
unmarshalCount int32
}

func (p *countingProtoCodec) Marshal(v any) ([]byte, error) {
func (p *countingProtoCodec) Marshal(v any) (mem.BufferSlice, error) {
atomic.AddInt32(&p.marshalCount, 1)
return encoding.GetCodec(proto.Name).Marshal(v)
return encoding.GetCodecV2(proto.Name).Marshal(v)
}

func (p *countingProtoCodec) Unmarshal(data []byte, v any) error {
func (p *countingProtoCodec) Unmarshal(data mem.BufferSlice, v any) error {
atomic.AddInt32(&p.unmarshalCount, 1)
return encoding.GetCodec(proto.Name).Unmarshal(data, v)
return encoding.GetCodecV2(proto.Name).Unmarshal(data, v)
}

func (p *countingProtoCodec) Name() string {
Expand All @@ -284,7 +285,7 @@ func (p *countingProtoCodec) Name() string {
func (s) TestForceServerCodec(t *testing.T) {
// Create an server with the counting proto codec.
codec := &countingProtoCodec{name: t.Name()}
backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodec(codec))
backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodecV2(codec))
defer backend.Stop()

// Create a channel to the above server.
Expand Down Expand Up @@ -317,7 +318,7 @@ func (s) TestForceServerCodec(t *testing.T) {

// renameProtoCodec wraps the proto codec and allows customizing the Name().
type renameProtoCodec struct {
encoding.Codec
encoding.CodecV2
name string
}

Expand Down Expand Up @@ -356,9 +357,9 @@ func (s) TestForceCodecName(t *testing.T) {

// Force the use of the custom codec on the client with the ForceCodec call
// option. Confirm the name is converted to lowercase before transmitting.
codec := &renameProtoCodec{Codec: encoding.GetCodec(proto.Name), name: t.Name()}
codec := &renameProtoCodec{CodecV2: encoding.GetCodecV2(proto.Name), name: t.Name()}
wantContentTypeCh <- []string{fmt.Sprintf("application/grpc+%s", strings.ToLower(t.Name()))}
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(codec)); err != nil {
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodecV2(codec)); err != nil {
t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
}
}
7 changes: 3 additions & 4 deletions encoding/encoding_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ type CodecV2 interface {
Name() string
}

var registeredV2Codecs = make(map[string]CodecV2)

// RegisterCodecV2 registers the provided CodecV2 for use with all gRPC clients and
// servers.
//
Expand All @@ -70,13 +68,14 @@ func RegisterCodecV2(codec CodecV2) {
panic("cannot register CodecV2 with empty string result for Name()")
}
contentSubtype := strings.ToLower(codec.Name())
registeredV2Codecs[contentSubtype] = codec
registeredCodecs[contentSubtype] = codec
}

// GetCodecV2 gets a registered CodecV2 by content-subtype, or nil if no CodecV2 is
// registered for the content-subtype.
//
// The content-subtype is expected to be lowercase.
func GetCodecV2(contentSubtype string) CodecV2 {
return registeredV2Codecs[contentSubtype]
c, _ := registeredCodecs[contentSubtype].(CodecV2)
return c
}
44 changes: 34 additions & 10 deletions encoding/proto/proto.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright 2018 gRPC authors.
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,6 +24,7 @@
"fmt"

"google.golang.org/grpc/encoding"
"google.golang.org/grpc/mem"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/protoadapt"
)
Expand All @@ -32,28 +33,51 @@
const Name = "proto"

func init() {
encoding.RegisterCodec(codec{})
encoding.RegisterCodecV2(&codecV2{})
}

// codec is a Codec implementation with protobuf. It is the default codec for gRPC.
type codec struct{}
// codec is a CodecV2 implementation with protobuf. It is the default codec for
// gRPC.
type codecV2 struct{}

func (codec) Marshal(v any) ([]byte, error) {
func (c *codecV2) Marshal(v any) (data mem.BufferSlice, err error) {
vv := messageV2Of(v)
if vv == nil {
return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v)
return nil, fmt.Errorf("proto: failed to marshal, message is %T, want proto.Message", v)

Check warning on line 46 in encoding/proto/proto.go

View check run for this annotation

Codecov / codecov/patch

encoding/proto/proto.go#L46

Added line #L46 was not covered by tests
}

return proto.Marshal(vv)
size := proto.Size(vv)
if mem.IsBelowBufferPoolingThreshold(size) {
buf, err := proto.Marshal(vv)
if err != nil {
return nil, err

Check warning on line 53 in encoding/proto/proto.go

View check run for this annotation

Codecov / codecov/patch

encoding/proto/proto.go#L53

Added line #L53 was not covered by tests
}
data = append(data, mem.SliceBuffer(buf))
} else {
pool := mem.DefaultBufferPool()
buf := pool.Get(size)
if _, err := (proto.MarshalOptions{}).MarshalAppend((*buf)[:0], vv); err != nil {
pool.Put(buf)
return nil, err

Check warning on line 61 in encoding/proto/proto.go

View check run for this annotation

Codecov / codecov/patch

encoding/proto/proto.go#L60-L61

Added lines #L60 - L61 were not covered by tests
}
data = append(data, mem.NewBuffer(buf, pool))
}

return data, nil
}

func (codec) Unmarshal(data []byte, v any) error {
func (c *codecV2) Unmarshal(data mem.BufferSlice, v any) (err error) {
vv := messageV2Of(v)
if vv == nil {
return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", v)
}

return proto.Unmarshal(data, vv)
buf := data.MaterializeToBuffer(mem.DefaultBufferPool())
defer buf.Free()
// TODO: Upgrade proto.Unmarshal to support mem.BufferSlice. Right now, it's not
// really possible without a major overhaul of the proto package, but the
// vtprotobuf library may be able to support this.
return proto.Unmarshal(buf.ReadOnlyData(), vv)
}

func messageV2Of(v any) proto.Message {
Expand All @@ -67,6 +91,6 @@
return nil
}

func (codec) Name() string {
func (c *codecV2) Name() string {
return Name
}
6 changes: 3 additions & 3 deletions encoding/proto/proto_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func BenchmarkProtoCodec(b *testing.B) {
protoStructs := setupBenchmarkProtoCodecInputs(s)
name := fmt.Sprintf("MinPayloadSize:%v/SetParallelism(%v)", s, p)
b.Run(name, func(b *testing.B) {
codec := &codec{}
codec := &codecV2{}
b.SetParallelism(p)
b.RunParallel(func(pb *testing.PB) {
benchmarkProtoCodec(codec, protoStructs, pb, b)
Expand All @@ -78,7 +78,7 @@ func BenchmarkProtoCodec(b *testing.B) {
}
}

func benchmarkProtoCodec(codec *codec, protoStructs []proto.Message, pb *testing.PB, b *testing.B) {
func benchmarkProtoCodec(codec *codecV2, protoStructs []proto.Message, pb *testing.PB, b *testing.B) {
counter := 0
for pb.Next() {
counter++
Expand All @@ -87,7 +87,7 @@ func benchmarkProtoCodec(codec *codec, protoStructs []proto.Message, pb *testing
}
}

func fastMarshalAndUnmarshal(codec encoding.Codec, protoStruct proto.Message, b *testing.B) {
func fastMarshalAndUnmarshal(codec encoding.CodecV2, protoStruct proto.Message, b *testing.B) {
marshaledBytes, err := codec.Marshal(protoStruct)
if err != nil {
b.Errorf("codec.Marshal(_) returned an error")
Expand Down
13 changes: 7 additions & 6 deletions encoding/proto/proto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ import (

"google.golang.org/grpc/encoding"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/mem"
pb "google.golang.org/grpc/test/codec_perf"
)

func marshalAndUnmarshal(t *testing.T, codec encoding.Codec, expectedBody []byte) {
func marshalAndUnmarshal(t *testing.T, codec encoding.CodecV2, expectedBody []byte) {
p := &pb.Buffer{}
p.Body = expectedBody

Expand All @@ -55,7 +56,7 @@ func Test(t *testing.T) {
}

func (s) TestBasicProtoCodecMarshalAndUnmarshal(t *testing.T) {
marshalAndUnmarshal(t, codec{}, []byte{1, 2, 3})
marshalAndUnmarshal(t, &codecV2{}, []byte{1, 2, 3})
}

// Try to catch possible race conditions around use of pools
Expand All @@ -75,7 +76,7 @@ func (s) TestConcurrentUsage(t *testing.T) {
}

var wg sync.WaitGroup
codec := codec{}
codec := &codecV2{}

for i := 0; i < numGoRoutines; i++ {
wg.Add(1)
Expand All @@ -93,16 +94,16 @@ func (s) TestConcurrentUsage(t *testing.T) {
// TestStaggeredMarshalAndUnmarshalUsingSamePool tries to catch potential errors in which slices get
// stomped on during reuse of a proto.Buffer.
func (s) TestStaggeredMarshalAndUnmarshalUsingSamePool(t *testing.T) {
codec1 := codec{}
codec2 := codec{}
codec1 := &codecV2{}
codec2 := &codecV2{}

expectedBody1 := []byte{1, 2, 3}
expectedBody2 := []byte{4, 5, 6}

proto1 := pb.Buffer{Body: expectedBody1}
proto2 := pb.Buffer{Body: expectedBody2}

var m1, m2 []byte
var m1, m2 mem.BufferSlice
var err error

if m1, err = codec1.Marshal(&proto1); err != nil {
Expand Down
Loading
Loading