Skip to content

Commit 511c171

Browse files
committed
client: set grpc-accept-encoding header with all registered compressors
1 parent d83070e commit 511c171

File tree

6 files changed

+195
-3
lines changed

6 files changed

+195
-3
lines changed

encoding/encoding.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ package encoding
2828
import (
2929
"io"
3030
"strings"
31+
32+
"google.golang.org/grpc/internal/grpcutil"
3133
)
3234

3335
// Identity specifies the optional encoding for uncompressed streams.
@@ -73,6 +75,7 @@ var registeredCompressor = make(map[string]Compressor)
7375
// registered with the same name, the one registered last will take effect.
7476
func RegisterCompressor(c Compressor) {
7577
registeredCompressor[c.Name()] = c
78+
grpcutil.RegisteredCompressorNames = append(grpcutil.RegisteredCompressorNames, c.Name())
7679
}
7780

7881
// GetCompressor returns Compressor for the given compressor name.

internal/envconfig/envconfig.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,15 @@ import (
2525
)
2626

2727
const (
28-
prefix = "GRPC_GO_"
29-
txtErrIgnoreStr = prefix + "IGNORE_TXT_ERRORS"
28+
prefix = "GRPC_GO_"
29+
txtErrIgnoreStr = prefix + "IGNORE_TXT_ERRORS"
30+
disableCompressorAdStr = prefix + "DISABLE_COMPRESSOR_ADVERTISEMENT"
3031
)
3132

3233
var (
3334
// TXTErrIgnore is set if TXT errors should be ignored ("GRPC_GO_IGNORE_TXT_ERRORS" is not "false").
3435
TXTErrIgnore = !strings.EqualFold(os.Getenv(txtErrIgnoreStr), "false")
36+
// DisableCompressorAd is set if registered compressor advertisement should
37+
// be disabled ("GRPC_GO_DISABLE_COMPRESSOR_ADVERTISEMENT" is "true").
38+
DisableCompressorAd = strings.EqualFold(os.Getenv(disableCompressorAdStr), "true")
3539
)

internal/grpcutil/compressor.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
*
3+
* Copyright 2022 gRPC authors.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*
17+
*/
18+
19+
package grpcutil
20+
21+
import (
22+
"strings"
23+
24+
"google.golang.org/grpc/internal/envconfig"
25+
)
26+
27+
// RegisteredCompressorNames holds names of the registered compressors.
28+
var RegisteredCompressorNames []string
29+
30+
// IsCompressorNameRegistered returns true when name is available in registry.
31+
func IsCompressorNameRegistered(name string) bool {
32+
for _, compressor := range RegisteredCompressorNames {
33+
if compressor == name {
34+
return true
35+
}
36+
}
37+
return false
38+
}
39+
40+
// RegisteredCompressors returns a string of registered compressor names
41+
// separated by comma.
42+
func RegisteredCompressors() string {
43+
if envconfig.DisableCompressorAd {
44+
return ""
45+
}
46+
return strings.Join(RegisteredCompressorNames, ",")
47+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
*
3+
* Copyright 2022 gRPC authors.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*
17+
*/
18+
19+
package grpcutil
20+
21+
import (
22+
"testing"
23+
24+
"google.golang.org/grpc/internal/envconfig"
25+
)
26+
27+
func TestRegisteredCompressors(t *testing.T) {
28+
defer func(c []string) { RegisteredCompressorNames = c }(RegisteredCompressorNames)
29+
defer func(v bool) { envconfig.DisableCompressorAd = v }(envconfig.DisableCompressorAd)
30+
RegisteredCompressorNames = []string{"gzip", "snappy"}
31+
tests := []struct {
32+
desc string
33+
disableAd bool
34+
want string
35+
}{
36+
{desc: "compressor_ad_disabled", disableAd: true, want: ""},
37+
{desc: "compressor_ad_enabled", disableAd: false, want: "gzip,snappy"},
38+
}
39+
for _, tt := range tests {
40+
envconfig.DisableCompressorAd = tt.disableAd
41+
compressors := RegisteredCompressors()
42+
if compressors != tt.want {
43+
t.Fatalf("Unexpected compressors got:%s, want:%s", compressors, tt.want)
44+
}
45+
}
46+
}

internal/transport/http2_client.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ type http2Client struct {
109109
streamsQuotaAvailable chan struct{}
110110
waitingStreams uint32
111111
nextID uint32
112+
registeredCompressors string
112113

113114
// Do not access controlBuf with mu held.
114115
mu sync.Mutex // guard the following variables
@@ -299,6 +300,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
299300
ctxDone: ctx.Done(), // Cache Done chan.
300301
cancel: cancel,
301302
userAgent: opts.UserAgent,
303+
registeredCompressors: grpcutil.RegisteredCompressors(),
302304
conn: conn,
303305
remoteAddr: conn.RemoteAddr(),
304306
localAddr: conn.LocalAddr(),
@@ -507,9 +509,22 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
507509
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-previous-rpc-attempts", Value: strconv.Itoa(callHdr.PreviousAttempts)})
508510
}
509511

512+
registeredCompressors := t.registeredCompressors
510513
if callHdr.SendCompress != "" {
511514
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
512-
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-accept-encoding", Value: callHdr.SendCompress})
515+
// Include the outgoing compressor name when compressor is not registered
516+
// via encoding.RegisterCompressor. This is possible when client uses
517+
// WithCompressor dial option.
518+
if !grpcutil.IsCompressorNameRegistered(callHdr.SendCompress) {
519+
if registeredCompressors != "" {
520+
registeredCompressors += ","
521+
}
522+
registeredCompressors += callHdr.SendCompress
523+
}
524+
}
525+
526+
if registeredCompressors != "" {
527+
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-accept-encoding", Value: registeredCompressors})
513528
}
514529
if dl, ok := ctx.Deadline(); ok {
515530
// Send out timeout regardless its value. The server can detect timeout context by itself.

test/end2end_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3249,6 +3249,7 @@ func testMetadataUnaryRPC(t *testing.T, e env) {
32493249
delete(header, "date") // the Date header is also optional
32503250
delete(header, "user-agent")
32513251
delete(header, "content-type")
3252+
delete(header, "grpc-accept-encoding")
32523253
}
32533254
if !reflect.DeepEqual(header, testMetadata) {
32543255
t.Fatalf("Received header metadata %v, want %v", header, testMetadata)
@@ -3288,6 +3289,7 @@ func testMetadataOrderUnaryRPC(t *testing.T, e env) {
32883289
delete(header, "date") // the Date header is also optional
32893290
delete(header, "user-agent")
32903291
delete(header, "content-type")
3292+
delete(header, "grpc-accept-encoding")
32913293
}
32923294

32933295
if !reflect.DeepEqual(header, newMetadata) {
@@ -3400,6 +3402,8 @@ func testSetAndSendHeaderUnaryRPC(t *testing.T, e env) {
34003402
}
34013403
delete(header, "user-agent")
34023404
delete(header, "content-type")
3405+
delete(header, "grpc-accept-encoding")
3406+
34033407
expectedHeader := metadata.Join(testMetadata, testMetadata2)
34043408
if !reflect.DeepEqual(header, expectedHeader) {
34053409
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
@@ -3444,6 +3448,7 @@ func testMultipleSetHeaderUnaryRPC(t *testing.T, e env) {
34443448
}
34453449
delete(header, "user-agent")
34463450
delete(header, "content-type")
3451+
delete(header, "grpc-accept-encoding")
34473452
expectedHeader := metadata.Join(testMetadata, testMetadata2)
34483453
if !reflect.DeepEqual(header, expectedHeader) {
34493454
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
@@ -3487,6 +3492,7 @@ func testMultipleSetHeaderUnaryRPCError(t *testing.T, e env) {
34873492
}
34883493
delete(header, "user-agent")
34893494
delete(header, "content-type")
3495+
delete(header, "grpc-accept-encoding")
34903496
expectedHeader := metadata.Join(testMetadata, testMetadata2)
34913497
if !reflect.DeepEqual(header, expectedHeader) {
34923498
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
@@ -3527,6 +3533,7 @@ func testSetAndSendHeaderStreamingRPC(t *testing.T, e env) {
35273533
}
35283534
delete(header, "user-agent")
35293535
delete(header, "content-type")
3536+
delete(header, "grpc-accept-encoding")
35303537
expectedHeader := metadata.Join(testMetadata, testMetadata2)
35313538
if !reflect.DeepEqual(header, expectedHeader) {
35323539
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
@@ -3590,6 +3597,7 @@ func testMultipleSetHeaderStreamingRPC(t *testing.T, e env) {
35903597
}
35913598
delete(header, "user-agent")
35923599
delete(header, "content-type")
3600+
delete(header, "grpc-accept-encoding")
35933601
expectedHeader := metadata.Join(testMetadata, testMetadata2)
35943602
if !reflect.DeepEqual(header, expectedHeader) {
35953603
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
@@ -3650,6 +3658,7 @@ func testMultipleSetHeaderStreamingRPCError(t *testing.T, e env) {
36503658
}
36513659
delete(header, "user-agent")
36523660
delete(header, "content-type")
3661+
delete(header, "grpc-accept-encoding")
36533662
expectedHeader := metadata.Join(testMetadata, testMetadata2)
36543663
if !reflect.DeepEqual(header, expectedHeader) {
36553664
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
@@ -3981,6 +3990,7 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
39813990
delete(headerMD, "trailer") // ignore if present
39823991
delete(headerMD, "user-agent")
39833992
delete(headerMD, "content-type")
3993+
delete(headerMD, "grpc-accept-encoding")
39843994
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
39853995
t.Errorf("#1 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
39863996
}
@@ -3989,6 +3999,7 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
39893999
delete(headerMD, "trailer") // ignore if present
39904000
delete(headerMD, "user-agent")
39914001
delete(headerMD, "content-type")
4002+
delete(headerMD, "grpc-accept-encoding")
39924003
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
39934004
t.Errorf("#2 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
39944005
}
@@ -5431,6 +5442,72 @@ func (s) TestForceServerCodec(t *testing.T) {
54315442
}
54325443
}
54335444

5445+
// renameCompressor is a grpc.Compressor wrapper that allows customizing the
5446+
// Type() of another compressor.
5447+
type renameCompressor struct {
5448+
grpc.Compressor
5449+
name string
5450+
}
5451+
5452+
func (r *renameCompressor) Type() string { return r.name }
5453+
5454+
// renameDecompressor is a grpc.Decompressor wrapper that allows customizing the
5455+
// Type() of another Decompressor.
5456+
type renameDecompressor struct {
5457+
grpc.Decompressor
5458+
name string
5459+
}
5460+
5461+
func (r *renameDecompressor) Type() string { return r.name }
5462+
5463+
func (s) TestClientForwardsGrpcAcceptEncodingHeader(t *testing.T) {
5464+
wantGrpcAcceptEncodingCh := make(chan []string, 1)
5465+
defer close(wantGrpcAcceptEncodingCh)
5466+
5467+
compressor := renameCompressor{Compressor: grpc.NewGZIPCompressor(), name: "testgzip"}
5468+
decompressor := renameDecompressor{Decompressor: grpc.NewGZIPDecompressor(), name: "testgzip"}
5469+
5470+
ss := &stubserver.StubServer{
5471+
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
5472+
md, ok := metadata.FromIncomingContext(ctx)
5473+
if !ok {
5474+
return nil, status.Errorf(codes.Internal, "no metadata in context")
5475+
}
5476+
if got, want := md["grpc-accept-encoding"], <-wantGrpcAcceptEncodingCh; !reflect.DeepEqual(got, want) {
5477+
return nil, status.Errorf(codes.Internal, "got grpc-accept-encoding=%q; want [%q]", got, want)
5478+
}
5479+
return &testpb.Empty{}, nil
5480+
},
5481+
}
5482+
if err := ss.Start([]grpc.ServerOption{grpc.RPCDecompressor(&decompressor)}); err != nil {
5483+
t.Fatalf("Error starting endpoint server: %v", err)
5484+
}
5485+
defer ss.Stop()
5486+
5487+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
5488+
defer cancel()
5489+
5490+
wantGrpcAcceptEncodingCh <- []string{"gzip"}
5491+
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
5492+
t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
5493+
}
5494+
5495+
wantGrpcAcceptEncodingCh <- []string{"gzip"}
5496+
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.UseCompressor("gzip")); err != nil {
5497+
t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
5498+
}
5499+
5500+
// Use compressor directly which is not registered via
5501+
// encoding.RegisterCompressor.
5502+
if err := ss.StartClient(grpc.WithCompressor(&compressor)); err != nil {
5503+
t.Fatalf("Error starting client: %v", err)
5504+
}
5505+
wantGrpcAcceptEncodingCh <- []string{"gzip,testgzip"}
5506+
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
5507+
t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
5508+
}
5509+
}
5510+
54345511
func (s) TestUnaryProxyDoesNotForwardMetadata(t *testing.T) {
54355512
const mdkey = "somedata"
54365513

0 commit comments

Comments
 (0)