@@ -24,18 +24,23 @@ package transport
24
24
25
25
import (
26
26
"context"
27
+ "crypto/tls"
28
+ "crypto/x509"
27
29
"fmt"
28
30
"io"
29
31
"net"
32
+ "os"
30
33
"strings"
31
34
"testing"
32
35
"time"
33
36
34
37
"golang.org/x/net/http2"
38
+ "google.golang.org/grpc/credentials"
35
39
"google.golang.org/grpc/internal/channelz"
36
40
"google.golang.org/grpc/internal/grpctest"
37
41
"google.golang.org/grpc/internal/syscall"
38
42
"google.golang.org/grpc/keepalive"
43
+ "google.golang.org/grpc/testdata"
39
44
)
40
45
41
46
const defaultTestTimeout = 10 * time .Second
@@ -581,47 +586,82 @@ func (s) TestKeepaliveServerEnforcementWithDormantKeepaliveOnClient(t *testing.T
581
586
// the keepalive timeout, as detailed in proposal A18.
582
587
func (s ) TestTCPUserTimeout (t * testing.T ) {
583
588
tests := []struct {
589
+ tls bool
584
590
time time.Duration
585
591
timeout time.Duration
586
592
clientWantTimeout time.Duration
587
593
serverWantTimeout time.Duration
588
594
}{
589
595
{
596
+ false ,
590
597
10 * time .Second ,
591
598
10 * time .Second ,
592
599
10 * 1000 * time .Millisecond ,
593
600
10 * 1000 * time .Millisecond ,
594
601
},
595
602
{
603
+ false ,
596
604
0 ,
597
605
0 ,
598
606
0 ,
599
607
20 * 1000 * time .Millisecond ,
600
608
},
601
609
{
610
+ false ,
611
+ infinity ,
612
+ infinity ,
613
+ 0 ,
614
+ 0 ,
615
+ },
616
+ {
617
+ true ,
618
+ 10 * time .Second ,
619
+ 10 * time .Second ,
620
+ 10 * 1000 * time .Millisecond ,
621
+ 10 * 1000 * time .Millisecond ,
622
+ },
623
+ {
624
+ true ,
625
+ 0 ,
626
+ 0 ,
627
+ 0 ,
628
+ 20 * 1000 * time .Millisecond ,
629
+ },
630
+ {
631
+ true ,
602
632
infinity ,
603
633
infinity ,
604
634
0 ,
605
635
0 ,
606
636
},
607
637
}
608
638
for _ , tt := range tests {
639
+ sopts := & ServerConfig {
640
+ KeepaliveParams : keepalive.ServerParameters {
641
+ Time : tt .time ,
642
+ Timeout : tt .timeout ,
643
+ },
644
+ }
645
+
646
+ copts := ConnectOptions {
647
+ KeepaliveParams : keepalive.ClientParameters {
648
+ Time : tt .time ,
649
+ Timeout : tt .timeout ,
650
+ },
651
+ }
652
+
653
+ if tt .tls {
654
+ copts .TransportCredentials = makeTLSCreds (t , "x509/client1_cert.pem" , "x509/client1_key.pem" , "x509/server_ca_cert.pem" )
655
+ sopts .Credentials = makeTLSCreds (t , "x509/server1_cert.pem" , "x509/server1_key.pem" , "x509/client_ca_cert.pem" )
656
+
657
+ }
658
+
609
659
server , client , cancel := setUpWithOptions (
610
660
t ,
611
661
0 ,
612
- & ServerConfig {
613
- KeepaliveParams : keepalive.ServerParameters {
614
- Time : tt .time ,
615
- Timeout : tt .timeout ,
616
- },
617
- },
662
+ sopts ,
618
663
normal ,
619
- ConnectOptions {
620
- KeepaliveParams : keepalive.ClientParameters {
621
- Time : tt .time ,
622
- Timeout : tt .timeout ,
623
- },
624
- },
664
+ copts ,
625
665
)
626
666
defer func () {
627
667
client .Close (fmt .Errorf ("closed manually by test" ))
@@ -630,6 +670,7 @@ func (s) TestTCPUserTimeout(t *testing.T) {
630
670
}()
631
671
632
672
var sc * http2Server
673
+ var srawConn net.Conn
633
674
// Wait until the server transport is setup.
634
675
for {
635
676
server .mu .Lock ()
@@ -644,6 +685,7 @@ func (s) TestTCPUserTimeout(t *testing.T) {
644
685
if ! ok {
645
686
t .Fatalf ("Failed to convert %v to *http2Server" , k )
646
687
}
688
+ srawConn = server .conns [k ]
647
689
}
648
690
server .mu .Unlock ()
649
691
break
@@ -657,25 +699,60 @@ func (s) TestTCPUserTimeout(t *testing.T) {
657
699
}
658
700
client .CloseStream (stream , io .EOF )
659
701
660
- cltOpt , err := syscall .GetTCPUserTimeout (client .conn )
661
- if err != nil {
662
- t .Fatalf ("syscall.GetTCPUserTimeout() failed: %v" , err )
702
+ // check client TCP user timeout only when non TLS
703
+ // TODO : find a way to get the underlying conn for client when TLS
704
+ if ! tt .tls {
705
+ cltOpt , err := syscall .GetTCPUserTimeout (client .conn )
706
+ if err != nil {
707
+ t .Fatalf ("syscall.GetTCPUserTimeout() failed: %v" , err )
708
+ }
709
+ if cltOpt < 0 {
710
+ t .Skipf ("skipping test on unsupported environment" )
711
+ }
712
+ if gotTimeout := time .Duration (cltOpt ) * time .Millisecond ; gotTimeout != tt .clientWantTimeout {
713
+ t .Fatalf ("syscall.GetTCPUserTimeout() = %d, want %d" , gotTimeout , tt .clientWantTimeout )
714
+ }
663
715
}
664
- if cltOpt < 0 {
665
- t .Skipf ("skipping test on unsupported environment" )
716
+ scConn := sc .conn
717
+ if tt .tls {
718
+ if _ , ok := sc .conn .(* net.TCPConn ); ok {
719
+ t .Fatalf ("sc.conn is should have wrapped conn with TLS" )
720
+ }
721
+ scConn = srawConn
666
722
}
667
- if gotTimeout := time .Duration (cltOpt ) * time .Millisecond ; gotTimeout != tt .clientWantTimeout {
668
- t .Fatalf ("syscall.GetTCPUserTimeout() = %d, want %d" , gotTimeout , tt .clientWantTimeout )
723
+ // verify the type of scConn (on which TCP user timeout will be got)
724
+ if _ , ok := scConn .(* net.TCPConn ); ! ok {
725
+ t .Fatalf ("server underlying conn is of type %T, want net.TCPConn" , scConn )
669
726
}
670
-
671
- srvOpt , err := syscall .GetTCPUserTimeout (sc .conn )
727
+ srvOpt , err := syscall .GetTCPUserTimeout (scConn )
672
728
if err != nil {
673
729
t .Fatalf ("syscall.GetTCPUserTimeout() failed: %v" , err )
674
730
}
675
731
if gotTimeout := time .Duration (srvOpt ) * time .Millisecond ; gotTimeout != tt .serverWantTimeout {
676
732
t .Fatalf ("syscall.GetTCPUserTimeout() = %d, want %d" , gotTimeout , tt .serverWantTimeout )
677
733
}
734
+
735
+ }
736
+ }
737
+
738
+ func makeTLSCreds (t * testing.T , certPath , keyPath , rootsPath string ) credentials.TransportCredentials {
739
+ cert , err := tls .LoadX509KeyPair (testdata .Path (certPath ), testdata .Path (keyPath ))
740
+ if err != nil {
741
+ t .Fatalf ("tls.LoadX509KeyPair(%q, %q) failed: %v" , certPath , keyPath , err )
742
+ }
743
+ b , err := os .ReadFile (testdata .Path (rootsPath ))
744
+ if err != nil {
745
+ t .Fatalf ("os.ReadFile(%q) failed: %v" , rootsPath , err )
746
+ }
747
+ roots := x509 .NewCertPool ()
748
+ if ! roots .AppendCertsFromPEM (b ) {
749
+ t .Fatal ("failed to append certificates" )
678
750
}
751
+ return credentials .NewTLS (& tls.Config {
752
+ Certificates : []tls.Certificate {cert },
753
+ RootCAs : roots ,
754
+ InsecureSkipVerify : true ,
755
+ })
679
756
}
680
757
681
758
// checkForHealthyStream attempts to create a stream and return error if any.
0 commit comments