-
Notifications
You must be signed in to change notification settings - Fork 513
/
grpcurl.go
709 lines (637 loc) · 23.1 KB
1
2
3
4
5
6
7
8
9
10
// Package grpcurl provides the core functionality exposed by the grpcurl command, for
// dynamically connecting to a server, using the reflection service to inspect the server,
// and invoking RPCs. The grpcurl command-line tool constructs a DescriptorSource, based
// on the command-line parameters, and supplies an InvocationEventHandler to supply request
// data (which can come from command-line args or the process's stdin) and to log the
// events (to the process's stdout).
package grpcurl
import (
"bytes"
11
"context"
12
13
"crypto/tls"
"crypto/x509"
14
"encoding/base64"
15
16
"errors"
"fmt"
17
"net"
18
19
"os"
"regexp"
20
21
22
"sort"
"strings"
23
"github.com/golang/protobuf/proto" //lint:ignore SA1019 we have to import this because it appears in exported API
24
"github.com/jhump/protoreflect/desc"
25
"github.com/jhump/protoreflect/desc/protoprint"
26
27
28
"github.com/jhump/protoreflect/dynamic"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
29
"google.golang.org/grpc/credentials/insecure"
30
xdsCredentials "google.golang.org/grpc/credentials/xds"
31
_ "google.golang.org/grpc/health" // import grpc/health to enable transparent client side checking
32
"google.golang.org/grpc/metadata"
33
34
35
36
37
protov2 "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/structpb"
38
39
40
41
42
43
44
45
46
47
48
49
50
)
// ListServices uses the given descriptor source to return a sorted list of fully-qualified
// service names.
func ListServices(source DescriptorSource) ([]string, error) {
svcs, err := source.ListServices()
if err != nil {
return nil, err
}
sort.Strings(svcs)
return svcs, nil
}
51
52
53
54
55
56
57
58
59
60
61
type sourceWithFiles interface {
GetAllFiles() ([]*desc.FileDescriptor, error)
}
var _ sourceWithFiles = (*fileSource)(nil)
// GetAllFiles uses the given descriptor source to return a list of file descriptors.
func GetAllFiles(source DescriptorSource) ([]*desc.FileDescriptor, error) {
var files []*desc.FileDescriptor
srcFiles, ok := source.(sourceWithFiles)
62
63
64
// If an error occurs, we still try to load as many files as we can, so that
// caller can decide whether to ignore error or not.
var firstError error
65
if ok {
66
files, firstError = srcFiles.GetAllFiles()
67
68
69
70
71
} else {
// Source does not implement GetAllFiles method, so use ListServices
// and grab files from there.
svcNames, err := source.ListServices()
if err != nil {
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
firstError = err
} else {
allFiles := map[string]*desc.FileDescriptor{}
for _, name := range svcNames {
d, err := source.FindSymbol(name)
if err != nil {
if firstError == nil {
firstError = err
}
} else {
addAllFilesToSet(d.GetFile(), allFiles)
}
}
files = make([]*desc.FileDescriptor, len(allFiles))
i := 0
for _, fd := range allFiles {
files[i] = fd
i++
90
91
92
93
94
}
}
}
sort.Sort(filesByName(files))
95
return files, firstError
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
}
type filesByName []*desc.FileDescriptor
func (f filesByName) Len() int {
return len(f)
}
func (f filesByName) Less(i, j int) bool {
return f[i].GetName() < f[j].GetName()
}
func (f filesByName) Swap(i, j int) {
f[i], f[j] = f[j], f[i]
}
func addAllFilesToSet(fd *desc.FileDescriptor, all map[string]*desc.FileDescriptor) {
if _, ok := all[fd.GetName()]; ok {
// already added
return
}
all[fd.GetName()] = fd
for _, dep := range fd.GetDependencies() {
addAllFilesToSet(dep, all)
}
}
123
124
125
126
127
128
129
130
131
132
133
134
// ListMethods uses the given descriptor source to return a sorted list of method names
// for the specified fully-qualified service name.
func ListMethods(source DescriptorSource, serviceName string) ([]string, error) {
dsc, err := source.FindSymbol(serviceName)
if err != nil {
return nil, err
}
if sd, ok := dsc.(*desc.ServiceDescriptor); !ok {
return nil, notFound("Service", serviceName)
} else {
methods := make([]string, 0, len(sd.GetMethods()))
for _, method := range sd.GetMethods() {
135
methods = append(methods, method.GetFullyQualifiedName())
136
137
138
139
140
141
}
sort.Strings(methods)
return methods, nil
}
}
142
143
144
145
146
147
// MetadataFromHeaders converts a list of header strings (each string in
// "Header-Name: Header-Value" form) into metadata. If a string has a header
// name without a value (e.g. does not contain a colon), the value is assumed
// to be blank. Binary headers (those whose names end in "-bin") should be
// base64-encoded. But if they cannot be base64-decoded, they will be assumed to
// be in raw form and used as is.
148
149
150
151
152
153
154
155
156
func MetadataFromHeaders(headers []string) metadata.MD {
md := make(metadata.MD)
for _, part := range headers {
if part != "" {
pieces := strings.SplitN(part, ":", 2)
if len(pieces) == 1 {
pieces = append(pieces, "") // if no value was specified, just make it "" (maybe the header value doesn't matter)
}
headerName := strings.ToLower(strings.TrimSpace(pieces[0]))
157
158
159
160
161
162
163
val := strings.TrimSpace(pieces[1])
if strings.HasSuffix(headerName, "-bin") {
if v, err := decode(val); err == nil {
val = v
}
}
md[headerName] = append(md[headerName], val)
164
165
166
167
168
}
}
return md
}
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
var envVarRegex = regexp.MustCompile(`\${\w+}`)
// ExpandHeaders expands environment variables contained in the header string.
// If no corresponding environment variable is found an error is returned.
// TODO: Add escaping for `${`
func ExpandHeaders(headers []string) ([]string, error) {
expandedHeaders := make([]string, len(headers))
for idx, header := range headers {
if header == "" {
continue
}
results := envVarRegex.FindAllString(header, -1)
if len(results) == 0 {
expandedHeaders[idx] = headers[idx]
continue
}
expandedHeader := header
for _, result := range results {
envVarName := result[2 : len(result)-1] // strip leading `${` and trailing `}`
envVarValue, ok := os.LookupEnv(envVarName)
if !ok {
return nil, fmt.Errorf("header %q refers to missing environment variable %q", header, envVarName)
}
expandedHeader = strings.Replace(expandedHeader, result, envVarValue, -1)
}
expandedHeaders[idx] = expandedHeader
}
return expandedHeaders, nil
}
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
var base64Codecs = []*base64.Encoding{base64.StdEncoding, base64.URLEncoding, base64.RawStdEncoding, base64.RawURLEncoding}
func decode(val string) (string, error) {
var firstErr error
var b []byte
// we are lenient and can accept any of the flavors of base64 encoding
for _, d := range base64Codecs {
var err error
b, err = d.DecodeString(val)
if err != nil {
if firstErr == nil {
firstErr = err
}
continue
}
return string(b), nil
}
return "", firstErr
}
// MetadataToString returns a string representation of the given metadata, for
// displaying to users.
221
222
223
224
func MetadataToString(md metadata.MD) string {
if len(md) == 0 {
return "(empty)"
}
225
226
227
228
229
230
231
keys := make([]string, 0, len(md))
for k := range md {
keys = append(keys, k)
}
sort.Strings(keys)
232
var b bytes.Buffer
233
234
235
first := true
for _, k := range keys {
vs := md[k]
236
for _, v := range vs {
237
238
239
240
241
if first {
first = false
} else {
b.WriteString("\n")
}
242
243
b.WriteString(k)
b.WriteString(": ")
244
245
246
if strings.HasSuffix(k, "-bin") {
v = base64.StdEncoding.EncodeToString([]byte(v))
}
247
248
249
250
251
252
b.WriteString(v)
}
}
return b.String()
}
253
254
255
256
257
258
259
var printer = &protoprint.Printer{
Compact: true,
OmitComments: protoprint.CommentsNonDoc,
SortElements: true,
ForceFullyQualifiedNames: true,
}
260
// GetDescriptorText returns a string representation of the given descriptor.
261
262
263
264
265
266
267
268
269
270
271
272
273
// This returns a snippet of proto source that describes the given element.
func GetDescriptorText(dsc desc.Descriptor, _ DescriptorSource) (string, error) {
// Note: DescriptorSource is not used, but remains an argument for backwards
// compatibility with previous implementation.
txt, err := printer.PrintProtoToString(dsc)
if err != nil {
return "", err
}
// callers don't expect trailing newlines
if txt[len(txt)-1] == '\n' {
txt = txt[:len(txt)-1]
}
return txt, nil
274
275
}
276
277
278
// EnsureExtensions uses the given descriptor source to download extensions for
// the given message. It returns a copy of the given message, but as a dynamic
// message that knows about all extensions known to the given descriptor source.
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
func EnsureExtensions(source DescriptorSource, msg proto.Message) proto.Message {
// load any server extensions so we can properly describe custom options
dsc, err := desc.LoadMessageDescriptorForMessage(msg)
if err != nil {
return msg
}
var ext dynamic.ExtensionRegistry
if err = fetchAllExtensions(source, &ext, dsc, map[string]bool{}); err != nil {
return msg
}
// convert message into dynamic message that knows about applicable extensions
// (that way we can show meaningful info for custom options instead of printing as unknown)
msgFactory := dynamic.NewMessageFactoryWithExtensionRegistry(&ext)
dm, err := fullyConvertToDynamic(msgFactory, msg)
if err != nil {
return msg
}
return dm
}
// fetchAllExtensions recursively fetches from the server extensions for the given message type as well as
// for all message types of nested fields. The extensions are added to the given dynamic registry of extensions
// so that all server-known extensions can be correctly parsed by grpcurl.
func fetchAllExtensions(source DescriptorSource, ext *dynamic.ExtensionRegistry, md *desc.MessageDescriptor, alreadyFetched map[string]bool) error {
msgTypeName := md.GetFullyQualifiedName()
if alreadyFetched[msgTypeName] {
return nil
}
alreadyFetched[msgTypeName] = true
if len(md.GetExtensionRanges()) > 0 {
fds, err := source.AllExtensionsForType(msgTypeName)
312
313
314
if err != nil {
return fmt.Errorf("failed to query for extensions of type %s: %v", msgTypeName, err)
}
315
for _, fd := range fds {
316
if err := ext.AddExtension(fd); err != nil {
317
return fmt.Errorf("could not register extension %s of type %s: %v", fd.GetFullyQualifiedName(), msgTypeName, err)
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
}
}
}
// recursively fetch extensions for the types of any message fields
for _, fd := range md.GetFields() {
if fd.GetMessageType() != nil {
err := fetchAllExtensions(source, ext, fd.GetMessageType(), alreadyFetched)
if err != nil {
return err
}
}
}
return nil
}
333
// fullyConvertToDynamic attempts to convert the given message to a dynamic message as well
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
// as any nested messages it may contain as field values. If the given message factory has
// extensions registered that were not known when the given message was parsed, this effectively
// allows re-parsing to identify those extensions.
func fullyConvertToDynamic(msgFact *dynamic.MessageFactory, msg proto.Message) (proto.Message, error) {
if _, ok := msg.(*dynamic.Message); ok {
return msg, nil // already a dynamic message
}
md, err := desc.LoadMessageDescriptorForMessage(msg)
if err != nil {
return nil, err
}
newMsg := msgFact.NewMessage(md)
dm, ok := newMsg.(*dynamic.Message)
if !ok {
// if message factory didn't produce a dynamic message, then we should leave msg as is
return msg, nil
}
if err := dm.ConvertFrom(msg); err != nil {
return nil, err
}
// recursively convert all field values, too
for _, fd := range md.GetFields() {
if fd.IsMap() {
if fd.GetMapValueType().GetMessageType() != nil {
m := dm.GetField(fd).(map[interface{}]interface{})
for k, v := range m {
// keys can't be nested messages; so we only need to recurse through map values, not keys
newVal, err := fullyConvertToDynamic(msgFact, v.(proto.Message))
if err != nil {
return nil, err
}
367
dm.PutMapField(fd, k, newVal)
368
369
370
371
372
373
374
375
376
377
}
}
} else if fd.IsRepeated() {
if fd.GetMessageType() != nil {
s := dm.GetField(fd).([]interface{})
for i, e := range s {
newVal, err := fullyConvertToDynamic(msgFact, e.(proto.Message))
if err != nil {
return nil, err
}
378
dm.SetRepeatedField(fd, i, newVal)
379
380
381
382
383
384
385
386
387
}
}
} else {
if fd.GetMessageType() != nil {
v := dm.GetField(fd)
newVal, err := fullyConvertToDynamic(msgFact, v.(proto.Message))
if err != nil {
return nil, err
}
388
dm.SetField(fd, newVal)
389
390
391
392
393
394
}
}
}
return dm, nil
}
395
396
397
398
399
400
401
402
403
// MakeTemplate returns a message instance for the given descriptor that is a
// suitable template for creating an instance of that message in JSON. In
// particular, it ensures that any repeated fields (which include map fields)
// are not empty, so they will render with a single element (to show the types
// and optionally nested fields). It also ensures that nested messages are not
// nil by setting them to a message that is also fleshed out as a template
// message.
func MakeTemplate(md *desc.MessageDescriptor) proto.Message {
return makeTemplate(md, nil)
404
405
}
406
407
408
409
410
func makeTemplate(md *desc.MessageDescriptor, path []*desc.MessageDescriptor) proto.Message {
switch md.GetFullyQualifiedName() {
case "google.protobuf.Any":
// empty type URL is not allowed by JSON representation
// so we must give it a dummy type
411
412
413
var anyVal anypb.Any
_ = anypb.MarshalFrom(&anyVal, &emptypb.Empty{}, protov2.MarshalOptions{})
return &anyVal
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
case "google.protobuf.Value":
// unset kind is not allowed by JSON representation
// so we must give it something
return &structpb.Value{
Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{
Fields: map[string]*structpb.Value{
"google.protobuf.Value": {Kind: &structpb.Value_StringValue{
StringValue: "supports arbitrary JSON",
}},
},
}},
}
case "google.protobuf.ListValue":
return &structpb.ListValue{
Values: []*structpb.Value{
{
Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{
Fields: map[string]*structpb.Value{
"google.protobuf.ListValue": {Kind: &structpb.Value_StringValue{
StringValue: "is an array of arbitrary JSON values",
}},
},
}},
},
},
}
case "google.protobuf.Struct":
return &structpb.Struct{
Fields: map[string]*structpb.Value{
"google.protobuf.Struct": {Kind: &structpb.Value_StringValue{
StringValue: "supports arbitrary JSON objects",
}},
},
}
448
449
}
450
451
452
453
454
dm := dynamic.NewMessage(md)
// if the message is a recursive structure, we don't want to blow the stack
for _, seen := range path {
if seen == md {
455
// already visited this type; avoid infinite recursion
456
return dm
457
458
459
460
461
462
463
}
}
path = append(path, dm.GetMessageDescriptor())
// for repeated fields, add a single element with default value
// and for message fields, add a message with all default fields
// that also has non-nil message and non-empty repeated fields
464
465
466
467
for _, fd := range dm.GetMessageDescriptor().GetFields() {
if fd.IsRepeated() {
switch fd.GetType() {
468
469
case descriptorpb.FieldDescriptorProto_TYPE_FIXED32,
descriptorpb.FieldDescriptorProto_TYPE_UINT32:
470
471
dm.AddRepeatedField(fd, uint32(0))
472
473
474
475
case descriptorpb.FieldDescriptorProto_TYPE_SFIXED32,
descriptorpb.FieldDescriptorProto_TYPE_SINT32,
descriptorpb.FieldDescriptorProto_TYPE_INT32,
descriptorpb.FieldDescriptorProto_TYPE_ENUM:
476
477
dm.AddRepeatedField(fd, int32(0))
478
479
case descriptorpb.FieldDescriptorProto_TYPE_FIXED64,
descriptorpb.FieldDescriptorProto_TYPE_UINT64:
480
481
dm.AddRepeatedField(fd, uint64(0))
482
483
484
case descriptorpb.FieldDescriptorProto_TYPE_SFIXED64,
descriptorpb.FieldDescriptorProto_TYPE_SINT64,
descriptorpb.FieldDescriptorProto_TYPE_INT64:
485
486
dm.AddRepeatedField(fd, int64(0))
487
case descriptorpb.FieldDescriptorProto_TYPE_STRING:
488
489
dm.AddRepeatedField(fd, "")
490
case descriptorpb.FieldDescriptorProto_TYPE_BYTES:
491
492
dm.AddRepeatedField(fd, []byte{})
493
case descriptorpb.FieldDescriptorProto_TYPE_BOOL:
494
495
dm.AddRepeatedField(fd, false)
496
case descriptorpb.FieldDescriptorProto_TYPE_FLOAT:
497
498
dm.AddRepeatedField(fd, float32(0))
499
case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE:
500
501
dm.AddRepeatedField(fd, float64(0))
502
503
case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE,
descriptorpb.FieldDescriptorProto_TYPE_GROUP:
504
dm.AddRepeatedField(fd, makeTemplate(fd.GetMessageType(), path))
505
506
}
} else if fd.GetMessageType() != nil {
507
dm.SetField(fd, makeTemplate(fd.GetMessageType(), path))
508
509
510
511
512
}
}
return dm
}
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
// ClientTransportCredentials is a helper function that constructs a TLS config with
// the given properties (see ClientTLSConfig) and then constructs and returns gRPC
// transport credentials using that config.
//
// Deprecated: Use grpcurl.ClientTLSConfig and credentials.NewTLS instead.
func ClientTransportCredentials(insecureSkipVerify bool, cacertFile, clientCertFile, clientKeyFile string) (credentials.TransportCredentials, error) {
tlsConf, err := ClientTLSConfig(insecureSkipVerify, cacertFile, clientCertFile, clientKeyFile)
if err != nil {
return nil, err
}
return credentials.NewTLS(tlsConf), nil
}
// ClientTLSConfig builds transport-layer config for a gRPC client using the
528
529
530
// given properties. If cacertFile is blank, only standard trusted certs are used to
// verify the server certs. If clientCertFile is blank, the client will not use a client
// certificate. If clientCertFile is not blank then clientKeyFile must not be blank.
531
func ClientTLSConfig(insecureSkipVerify bool, cacertFile, clientCertFile, clientKeyFile string) (*tls.Config, error) {
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
var tlsConf tls.Config
if clientCertFile != "" {
// Load the client certificates from disk
certificate, err := tls.LoadX509KeyPair(clientCertFile, clientKeyFile)
if err != nil {
return nil, fmt.Errorf("could not load client key pair: %v", err)
}
tlsConf.Certificates = []tls.Certificate{certificate}
}
if insecureSkipVerify {
tlsConf.InsecureSkipVerify = true
} else if cacertFile != "" {
// Create a certificate pool from the certificate authority
certPool := x509.NewCertPool()
548
ca, err := os.ReadFile(cacertFile)
549
550
551
552
553
554
555
556
557
558
559
560
if err != nil {
return nil, fmt.Errorf("could not read ca certificate: %v", err)
}
// Append the certificates from the CA
if ok := certPool.AppendCertsFromPEM(ca); !ok {
return nil, errors.New("failed to append ca certs")
}
tlsConf.RootCAs = certPool
}
561
return &tlsConf, nil
562
563
}
564
// ServerTransportCredentials builds transport credentials for a gRPC server using the
565
566
567
568
569
570
// given properties. If cacertFile is blank, the server will not request client certs
// unless requireClientCerts is true. When requireClientCerts is false and cacertFile is
// not blank, the server will verify client certs when presented, but will not require
// client certs. The serverCertFile and serverKeyFile must both not be blank.
func ServerTransportCredentials(cacertFile, serverCertFile, serverKeyFile string, requireClientCerts bool) (credentials.TransportCredentials, error) {
var tlsConf tls.Config
571
572
573
// TODO(jh): Remove this line once https://github.com/golang/go/issues/28779 is fixed
// in Go tip. Until then, the recently merged TLS 1.3 support breaks the TLS tests.
tlsConf.MaxVersion = tls.VersionTLS12
574
575
576
577
578
579
580
581
582
583
584
// Load the server certificates from disk
certificate, err := tls.LoadX509KeyPair(serverCertFile, serverKeyFile)
if err != nil {
return nil, fmt.Errorf("could not load key pair: %v", err)
}
tlsConf.Certificates = []tls.Certificate{certificate}
if cacertFile != "" {
// Create a certificate pool from the certificate authority
certPool := x509.NewCertPool()
585
ca, err := os.ReadFile(cacertFile)
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
if err != nil {
return nil, fmt.Errorf("could not read ca certificate: %v", err)
}
// Append the certificates from the CA
if ok := certPool.AppendCertsFromPEM(ca); !ok {
return nil, errors.New("failed to append ca certs")
}
tlsConf.ClientCAs = certPool
}
if requireClientCerts {
tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
} else if cacertFile != "" {
tlsConf.ClientAuth = tls.VerifyClientCertIfGiven
} else {
tlsConf.ClientAuth = tls.NoClientCert
}
return credentials.NewTLS(&tlsConf), nil
}
608
609
610
611
// BlockingDial is a helper method to dial the given address, using optional TLS credentials,
// and blocking until the returned connection is ready. If the given credentials are nil, the
// connection will be insecure (plain-text).
612
func BlockingDial(ctx context.Context, network, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
if creds == nil {
creds = insecure.NewCredentials()
}
var err error
if strings.HasPrefix(address, "xds:///") {
// The xds:/// prefix is used to signal to the gRPC client to use an xDS server to resolve the
// target. The relevant credentials will be automatically pulled from the GRPC_XDS_BOOTSTRAP or
// GRPC_XDS_BOOTSTRAP_CONFIG env vars.
creds, err = xdsCredentials.NewClientCredentials(xdsCredentials.ClientOptions{FallbackCreds: creds})
if err != nil {
return nil, err
}
}
628
629
630
631
632
633
634
635
636
637
638
639
640
// grpc.Dial doesn't provide any information on permanent connection errors (like
// TLS handshake failures). So in order to provide good error messages, we need a
// custom dialer that can provide that info. That means we manage the TLS handshake.
result := make(chan interface{}, 1)
writeResult := func(res interface{}) {
// non-blocking write: we only need the first result
select {
case result <- res:
default:
}
}
641
642
// custom credentials and dialer will notify on error via the
// writeResult function
643
644
645
creds = &errSignalingCreds{
TransportCredentials: creds,
writeResult: writeResult,
646
}
647
648
dialer := func(ctx context.Context, address string) (net.Conn, error) {
649
650
// NB: We *could* handle the TLS handshake ourselves, in the custom
// dialer (instead of customizing both the dialer and the credentials).
651
652
653
654
655
// But that requires using insecure.NewCredentials() dial transport
// option (so that the gRPC library doesn't *also* try to do a
// handshake). And that would mean that the library would send the
// wrong ":scheme" metaheader to servers: it would send "http" instead
// of "https" because it is unaware that TLS is actually in use.
656
conn, err := (&net.Dialer{}).DialContext(ctx, network, address)
657
658
659
if err != nil {
writeResult(err)
}
660
return conn, err
661
662
663
664
665
}
// Even with grpc.FailOnNonTempDialError, this call will usually timeout in
// the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to
// know when we're done. So we run it in a goroutine and then use result
666
// channel to either get the connection or fail-fast.
667
go func() {
668
669
670
671
672
// We put grpc.FailOnNonTempDialError *before* the explicitly provided
// options so that it could be overridden.
opts = append([]grpc.DialOption{grpc.FailOnNonTempDialError(true)}, opts...)
// But we don't want caller to be able to override these two, so we put
// them *after* the explicitly provided options.
673
opts = append(opts, grpc.WithBlock(), grpc.WithContextDialer(dialer), grpc.WithTransportCredentials(creds))
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
conn, err := grpc.DialContext(ctx, address, opts...)
var res interface{}
if err != nil {
res = err
} else {
res = conn
}
writeResult(res)
}()
select {
case res := <-result:
if conn, ok := res.(*grpc.ClientConn); ok {
return conn, nil
}
690
return nil, res.(error)
691
692
693
694
case <-ctx.Done():
return nil, ctx.Err()
}
}
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
// errSignalingCreds is a wrapper around a TransportCredentials value, but
// it will use the writeResult function to notify on error.
type errSignalingCreds struct {
credentials.TransportCredentials
writeResult func(res interface{})
}
func (c *errSignalingCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
conn, auth, err := c.TransportCredentials.ClientHandshake(ctx, addr, rawConn)
if err != nil {
c.writeResult(err)
}
return conn, auth, err
}