@@ -24,13 +24,15 @@ import (
24
24
"crypto/x509"
25
25
"fmt"
26
26
"os"
27
+ "regexp"
27
28
"strings"
28
29
"testing"
29
30
"time"
30
31
31
32
"google.golang.org/grpc"
32
33
"google.golang.org/grpc/codes"
33
34
"google.golang.org/grpc/credentials"
35
+ "google.golang.org/grpc/internal/envconfig"
34
36
"google.golang.org/grpc/internal/grpctest"
35
37
"google.golang.org/grpc/internal/stubserver"
36
38
"google.golang.org/grpc/status"
@@ -243,6 +245,11 @@ func (s) TestTLS_DisabledALPN(t *testing.T) {
243
245
ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
244
246
defer cancel ()
245
247
248
+ initialVal := envconfig .EnforceALPNEnabled
249
+ defer func () {
250
+ envconfig .EnforceALPNEnabled = initialVal
251
+ }()
252
+
246
253
// Start a non gRPC TLS server.
247
254
config := & tls.Config {
248
255
Certificates : []tls.Certificate {serverCert },
@@ -254,31 +261,76 @@ func (s) TestTLS_DisabledALPN(t *testing.T) {
254
261
}
255
262
defer listner .Close ()
256
263
257
- // Start listening for server requests in a new go routine.
258
- go func () {
259
- conn , err := listner .Accept ()
260
- if err != nil {
261
- t .Errorf ("tls.Accept failed err = %v" , err )
262
- } else {
263
- _ , _ = conn .Write ([]byte ("Hello, World!" ))
264
- _ = conn .Close ()
265
- }
266
- }()
267
-
268
- clientCreds := credentials .NewTLS (& tls.Config {
269
- ServerName : serverName ,
270
- RootCAs : certPool ,
271
- })
272
-
273
- cc , err := grpc .NewClient ("dns:" + listner .Addr ().String (), grpc .WithTransportCredentials (clientCreds ))
274
- if err != nil {
275
- t .Fatalf ("grpc.NewClient error: %v" , err )
264
+ tests := []struct {
265
+ description string
266
+ alpnEnforced bool
267
+ wantErrMatchPattern string
268
+ wantErrNonMatchPattern string
269
+ }{
270
+ {
271
+ description : "enforced" ,
272
+ alpnEnforced : true ,
273
+ wantErrMatchPattern : "transport: .*missing selected ALPN property" ,
274
+ },
275
+ {
276
+ description : "not_enforced" ,
277
+ wantErrNonMatchPattern : "transport:" ,
278
+ },
279
+ {
280
+ description : "default_value" ,
281
+ wantErrNonMatchPattern : "transport:" ,
282
+ alpnEnforced : initialVal ,
283
+ },
276
284
}
277
- defer cc .Close ()
278
- client := testgrpc .NewTestServiceClient (cc )
279
285
280
- const wantStr = "missing selected ALPN property"
281
- if _ , err = client .EmptyCall (ctx , & testpb.Empty {}); status .Code (err ) != codes .Unavailable || ! strings .Contains (status .Convert (err ).Message (), wantStr ) {
282
- t .Fatalf ("EmptyCall err = %v; want code=%v, message contains %q" , err , codes .Unavailable , wantStr )
286
+ for _ , tc := range tests {
287
+ t .Run (tc .description , func (t * testing.T ) {
288
+ envconfig .EnforceALPNEnabled = tc .alpnEnforced
289
+ // Listen to one TCP connection request.
290
+ go func () {
291
+ conn , err := listner .Accept ()
292
+ if err != nil {
293
+ t .Errorf ("tls.Accept failed err = %v" , err )
294
+ } else {
295
+ _ , _ = conn .Write ([]byte ("Hello, World!" ))
296
+ _ = conn .Close ()
297
+ }
298
+ }()
299
+
300
+ clientCreds := credentials .NewTLS (& tls.Config {
301
+ ServerName : serverName ,
302
+ RootCAs : certPool ,
303
+ })
304
+
305
+ cc , err := grpc .NewClient ("dns:" + listner .Addr ().String (), grpc .WithTransportCredentials (clientCreds ))
306
+ if err != nil {
307
+ t .Fatalf ("grpc.NewClient error: %v" , err )
308
+ }
309
+ defer cc .Close ()
310
+ client := testgrpc .NewTestServiceClient (cc )
311
+ _ , rpcErr := client .EmptyCall (ctx , & testpb.Empty {})
312
+
313
+ if gotCode := status .Code (rpcErr ); gotCode != codes .Unavailable {
314
+ t .Errorf ("EmptyCall returned unexpected code: got=%v, want=%v" , gotCode , codes .Unavailable )
315
+ }
316
+
317
+ matchPat , err := regexp .Compile (tc .wantErrMatchPattern )
318
+ if err != nil {
319
+ t .Fatalf ("Error message match pattern %q is invalid due to error: %v" , tc .wantErrMatchPattern , err )
320
+ }
321
+
322
+ if tc .wantErrMatchPattern != "" && ! matchPat .MatchString (status .Convert (rpcErr ).Message ()) {
323
+ t .Errorf ("EmptyCall err = %v; want pattern match %q" , rpcErr , matchPat )
324
+ }
325
+ nonMatchPat , err := regexp .Compile (tc .wantErrNonMatchPattern )
326
+ if err != nil {
327
+ t .Fatalf ("Error message non-match pattern %q is invalid due to error: %v" , tc .wantErrNonMatchPattern , err )
328
+ }
329
+
330
+ if tc .wantErrNonMatchPattern != "" && nonMatchPat .MatchString (status .Convert (rpcErr ).Message ()) {
331
+ t .Errorf ("EmptyCall err = %v; want pattern missing %q" , rpcErr , nonMatchPat )
332
+ }
333
+ })
334
+
283
335
}
284
336
}
0 commit comments