Skip to content

Commit bc16b5f

Browse files
authored
interop: support custom creds flag for stress test client (#6809)
1 parent 02ea031 commit bc16b5f

File tree

1 file changed

+40
-18
lines changed

1 file changed

+40
-18
lines changed

interop/stress/client/main.go

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"fmt"
2626
"math/rand"
2727
"net"
28+
"os"
2829
"strconv"
2930
"strings"
3031
"sync"
@@ -34,27 +35,37 @@ import (
3435
"google.golang.org/grpc"
3536
"google.golang.org/grpc/codes"
3637
"google.golang.org/grpc/credentials"
38+
"google.golang.org/grpc/credentials/google"
3739
"google.golang.org/grpc/credentials/insecure"
3840
"google.golang.org/grpc/grpclog"
3941
"google.golang.org/grpc/interop"
42+
"google.golang.org/grpc/resolver"
4043
"google.golang.org/grpc/status"
4144
"google.golang.org/grpc/testdata"
4245

46+
_ "google.golang.org/grpc/xds/googledirectpath" // Register xDS resolver required for c2p directpath.
47+
4348
testgrpc "google.golang.org/grpc/interop/grpc_testing"
4449
metricspb "google.golang.org/grpc/interop/stress/grpc_testing"
4550
)
4651

52+
const (
53+
googleDefaultCredsName = "google_default_credentials"
54+
computeEngineCredsName = "compute_engine_channel_creds"
55+
)
56+
4757
var (
48-
serverAddresses = flag.String("server_addresses", "localhost:8080", "a list of server addresses")
49-
testCases = flag.String("test_cases", "", "a list of test cases along with the relative weights")
50-
testDurationSecs = flag.Int("test_duration_secs", -1, "test duration in seconds")
51-
numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server")
52-
numStubsPerChannel = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server")
53-
metricsPort = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics")
54-
useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
55-
testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
56-
tlsServerName = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.")
57-
caFile = flag.String("ca_file", "", "The file containing the CA root cert file")
58+
serverAddresses = flag.String("server_addresses", "localhost:8080", "a list of server addresses")
59+
testCases = flag.String("test_cases", "", "a list of test cases along with the relative weights")
60+
testDurationSecs = flag.Int("test_duration_secs", -1, "test duration in seconds")
61+
numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server")
62+
numStubsPerChannel = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server")
63+
metricsPort = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics")
64+
useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
65+
testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
66+
tlsServerName = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.")
67+
caFile = flag.String("ca_file", "", "The file containing the CA root cert file")
68+
customCredentialsType = flag.String("custom_credentials_type", "", "Custom credentials type to use")
5869

5970
totalNumCalls int64
6071
logger = grpclog.Component("stress")
@@ -71,12 +82,13 @@ func parseTestCases(testCaseString string) []testCaseWithWeight {
7182
testCaseStrings := strings.Split(testCaseString, ",")
7283
testCases := make([]testCaseWithWeight, len(testCaseStrings))
7384
for i, str := range testCaseStrings {
74-
testCase := strings.Split(str, ":")
75-
if len(testCase) != 2 {
85+
testCaseNameAndWeight := strings.Split(str, ":")
86+
if len(testCaseNameAndWeight) != 2 {
7687
panic(fmt.Sprintf("invalid test case with weight: %s", str))
7788
}
7889
// Check if test case is supported.
79-
switch testCase[0] {
90+
testCaseName := strings.ToLower(testCaseNameAndWeight[0])
91+
switch testCaseName {
8092
case
8193
"empty_unary",
8294
"large_unary",
@@ -90,10 +102,10 @@ func parseTestCases(testCaseString string) []testCaseWithWeight {
90102
"status_code_and_message",
91103
"custom_metadata":
92104
default:
93-
panic(fmt.Sprintf("unknown test type: %s", testCase[0]))
105+
panic(fmt.Sprintf("unknown test type: %s", testCaseNameAndWeight[0]))
94106
}
95-
testCases[i].name = testCase[0]
96-
w, err := strconv.Atoi(testCase[1])
107+
testCases[i].name = testCaseName
108+
w, err := strconv.Atoi(testCaseNameAndWeight[1])
97109
if err != nil {
98110
panic(fmt.Sprintf("%v", err))
99111
}
@@ -263,6 +275,7 @@ func logParameterInfo(addresses []string, tests []testCaseWithWeight) {
263275
logger.Infof("use_tls: %t", *useTLS)
264276
logger.Infof("use_test_ca: %t", *testCA)
265277
logger.Infof("server_host_override: %s", *tlsServerName)
278+
logger.Infof("custom_credentials_type: %s", *customCredentialsType)
266279

267280
logger.Infoln("addresses:")
268281
for i, addr := range addresses {
@@ -276,7 +289,15 @@ func logParameterInfo(addresses []string, tests []testCaseWithWeight) {
276289

277290
func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.ClientConn, error) {
278291
var opts []grpc.DialOption
279-
if useTLS {
292+
if *customCredentialsType != "" {
293+
if *customCredentialsType == googleDefaultCredsName {
294+
opts = append(opts, grpc.WithCredentialsBundle(google.NewDefaultCredentials()))
295+
} else if *customCredentialsType == computeEngineCredsName {
296+
opts = append(opts, grpc.WithCredentialsBundle(google.NewComputeEngineCredentials()))
297+
} else {
298+
logger.Fatalf("Unknown custom credentials: %v", *customCredentialsType)
299+
}
300+
} else if useTLS {
280301
var sn string
281302
if tlsServerName != "" {
282303
sn = tlsServerName
@@ -303,6 +324,7 @@ func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.C
303324

304325
func main() {
305326
flag.Parse()
327+
resolver.SetDefaultScheme("dns")
306328
addresses := strings.Split(*serverAddresses, ",")
307329
tests := parseTestCases(*testCases)
308330
logParameterInfo(addresses, tests)
@@ -337,6 +359,6 @@ func main() {
337359
close(stop)
338360
}
339361
wg.Wait()
340-
logger.Infof("Total calls made: %v", totalNumCalls)
362+
fmt.Fprintf(os.Stdout, "Total calls made: %v\n", totalNumCalls)
341363
logger.Infof(" ===== ALL DONE ===== ")
342364
}

0 commit comments

Comments
 (0)