-
Notifications
You must be signed in to change notification settings - Fork 513
/
grpcurl.go
708 lines (636 loc) · 23 KB
1
2
3
4
5
6
7
8
9
10
// Package grpcurl provides the core functionality exposed by the grpcurl command, for
// dynamically connecting to a server, using the reflection service to inspect the server,
// and invoking RPCs. The grpcurl command-line tool constructs a DescriptorSource, based
// on the command-line parameters, and supplies an InvocationEventHandler to supply request
// data (which can come from command-line args or the process's stdin) and to log the
// events (to the process's stdout).
package grpcurl
import (
"bytes"
11
"context"
12
13
"crypto/tls"
"crypto/x509"
14
"encoding/base64"
15
16
"errors"
"fmt"
17
"net"
18
19
"os"
"regexp"
20
21
22
"sort"
"strings"
23
"github.com/golang/protobuf/proto" //lint:ignore SA1019 we have to import this because it appears in exported API
24
"github.com/jhump/protoreflect/desc"
25
"github.com/jhump/protoreflect/desc/protoprint"
26
27
28
"github.com/jhump/protoreflect/dynamic"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
29
"google.golang.org/grpc/credentials/insecure"
30
xdsCredentials "google.golang.org/grpc/credentials/xds"
31
"google.golang.org/grpc/metadata"
32
33
34
35
36
protov2 "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/structpb"
37
38
39
40
41
42
43
44
45
46
47
48
49
)
// ListServices uses the given descriptor source to return a sorted list of fully-qualified
// service names.
func ListServices(source DescriptorSource) ([]string, error) {
svcs, err := source.ListServices()
if err != nil {
return nil, err
}
sort.Strings(svcs)
return svcs, nil
}
50
51
52
53
54
55
56
57
58
59
60
type sourceWithFiles interface {
GetAllFiles() ([]*desc.FileDescriptor, error)
}
var _ sourceWithFiles = (*fileSource)(nil)
// GetAllFiles uses the given descriptor source to return a list of file descriptors.
func GetAllFiles(source DescriptorSource) ([]*desc.FileDescriptor, error) {
var files []*desc.FileDescriptor
srcFiles, ok := source.(sourceWithFiles)
61
62
63
// If an error occurs, we still try to load as many files as we can, so that
// caller can decide whether to ignore error or not.
var firstError error
64
if ok {
65
files, firstError = srcFiles.GetAllFiles()
66
67
68
69
70
} else {
// Source does not implement GetAllFiles method, so use ListServices
// and grab files from there.
svcNames, err := source.ListServices()
if err != nil {
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
firstError = err
} else {
allFiles := map[string]*desc.FileDescriptor{}
for _, name := range svcNames {
d, err := source.FindSymbol(name)
if err != nil {
if firstError == nil {
firstError = err
}
} else {
addAllFilesToSet(d.GetFile(), allFiles)
}
}
files = make([]*desc.FileDescriptor, len(allFiles))
i := 0
for _, fd := range allFiles {
files[i] = fd
i++
89
90
91
92
93
}
}
}
sort.Sort(filesByName(files))
94
return files, firstError
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
}
type filesByName []*desc.FileDescriptor
func (f filesByName) Len() int {
return len(f)
}
func (f filesByName) Less(i, j int) bool {
return f[i].GetName() < f[j].GetName()
}
func (f filesByName) Swap(i, j int) {
f[i], f[j] = f[j], f[i]
}
func addAllFilesToSet(fd *desc.FileDescriptor, all map[string]*desc.FileDescriptor) {
if _, ok := all[fd.GetName()]; ok {
// already added
return
}
all[fd.GetName()] = fd
for _, dep := range fd.GetDependencies() {
addAllFilesToSet(dep, all)
}
}
122
123
124
125
126
127
128
129
130
131
132
133
// ListMethods uses the given descriptor source to return a sorted list of method names
// for the specified fully-qualified service name.
func ListMethods(source DescriptorSource, serviceName string) ([]string, error) {
dsc, err := source.FindSymbol(serviceName)
if err != nil {
return nil, err
}
if sd, ok := dsc.(*desc.ServiceDescriptor); !ok {
return nil, notFound("Service", serviceName)
} else {
methods := make([]string, 0, len(sd.GetMethods()))
for _, method := range sd.GetMethods() {
134
methods = append(methods, method.GetFullyQualifiedName())
135
136
137
138
139
140
}
sort.Strings(methods)
return methods, nil
}
}
141
142
143
144
145
146
// MetadataFromHeaders converts a list of header strings (each string in
// "Header-Name: Header-Value" form) into metadata. If a string has a header
// name without a value (e.g. does not contain a colon), the value is assumed
// to be blank. Binary headers (those whose names end in "-bin") should be
// base64-encoded. But if they cannot be base64-decoded, they will be assumed to
// be in raw form and used as is.
147
148
149
150
151
152
153
154
155
func MetadataFromHeaders(headers []string) metadata.MD {
md := make(metadata.MD)
for _, part := range headers {
if part != "" {
pieces := strings.SplitN(part, ":", 2)
if len(pieces) == 1 {
pieces = append(pieces, "") // if no value was specified, just make it "" (maybe the header value doesn't matter)
}
headerName := strings.ToLower(strings.TrimSpace(pieces[0]))
156
157
158
159
160
161
162
val := strings.TrimSpace(pieces[1])
if strings.HasSuffix(headerName, "-bin") {
if v, err := decode(val); err == nil {
val = v
}
}
md[headerName] = append(md[headerName], val)
163
164
165
166
167
}
}
return md
}
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
var envVarRegex = regexp.MustCompile(`\${\w+}`)
// ExpandHeaders expands environment variables contained in the header string.
// If no corresponding environment variable is found an error is returned.
// TODO: Add escaping for `${`
func ExpandHeaders(headers []string) ([]string, error) {
expandedHeaders := make([]string, len(headers))
for idx, header := range headers {
if header == "" {
continue
}
results := envVarRegex.FindAllString(header, -1)
if len(results) == 0 {
expandedHeaders[idx] = headers[idx]
continue
}
expandedHeader := header
for _, result := range results {
envVarName := result[2 : len(result)-1] // strip leading `${` and trailing `}`
envVarValue, ok := os.LookupEnv(envVarName)
if !ok {
return nil, fmt.Errorf("header %q refers to missing environment variable %q", header, envVarName)
}
expandedHeader = strings.Replace(expandedHeader, result, envVarValue, -1)
}
expandedHeaders[idx] = expandedHeader
}
return expandedHeaders, nil
}
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
var base64Codecs = []*base64.Encoding{base64.StdEncoding, base64.URLEncoding, base64.RawStdEncoding, base64.RawURLEncoding}
func decode(val string) (string, error) {
var firstErr error
var b []byte
// we are lenient and can accept any of the flavors of base64 encoding
for _, d := range base64Codecs {
var err error
b, err = d.DecodeString(val)
if err != nil {
if firstErr == nil {
firstErr = err
}
continue
}
return string(b), nil
}
return "", firstErr
}
// MetadataToString returns a string representation of the given metadata, for
// displaying to users.
220
221
222
223
func MetadataToString(md metadata.MD) string {
if len(md) == 0 {
return "(empty)"
}
224
225
226
227
228
229
230
keys := make([]string, 0, len(md))
for k := range md {
keys = append(keys, k)
}
sort.Strings(keys)
231
var b bytes.Buffer
232
233
234
first := true
for _, k := range keys {
vs := md[k]
235
for _, v := range vs {
236
237
238
239
240
if first {
first = false
} else {
b.WriteString("\n")
}
241
242
b.WriteString(k)
b.WriteString(": ")
243
244
245
if strings.HasSuffix(k, "-bin") {
v = base64.StdEncoding.EncodeToString([]byte(v))
}
246
247
248
249
250
251
b.WriteString(v)
}
}
return b.String()
}
252
253
254
255
256
257
258
var printer = &protoprint.Printer{
Compact: true,
OmitComments: protoprint.CommentsNonDoc,
SortElements: true,
ForceFullyQualifiedNames: true,
}
259
// GetDescriptorText returns a string representation of the given descriptor.
260
261
262
263
264
265
266
267
268
269
270
271
272
// This returns a snippet of proto source that describes the given element.
func GetDescriptorText(dsc desc.Descriptor, _ DescriptorSource) (string, error) {
// Note: DescriptorSource is not used, but remains an argument for backwards
// compatibility with previous implementation.
txt, err := printer.PrintProtoToString(dsc)
if err != nil {
return "", err
}
// callers don't expect trailing newlines
if txt[len(txt)-1] == '\n' {
txt = txt[:len(txt)-1]
}
return txt, nil
273
274
}
275
276
277
// EnsureExtensions uses the given descriptor source to download extensions for
// the given message. It returns a copy of the given message, but as a dynamic
// message that knows about all extensions known to the given descriptor source.
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
func EnsureExtensions(source DescriptorSource, msg proto.Message) proto.Message {
// load any server extensions so we can properly describe custom options
dsc, err := desc.LoadMessageDescriptorForMessage(msg)
if err != nil {
return msg
}
var ext dynamic.ExtensionRegistry
if err = fetchAllExtensions(source, &ext, dsc, map[string]bool{}); err != nil {
return msg
}
// convert message into dynamic message that knows about applicable extensions
// (that way we can show meaningful info for custom options instead of printing as unknown)
msgFactory := dynamic.NewMessageFactoryWithExtensionRegistry(&ext)
dm, err := fullyConvertToDynamic(msgFactory, msg)
if err != nil {
return msg
}
return dm
}
// fetchAllExtensions recursively fetches from the server extensions for the given message type as well as
// for all message types of nested fields. The extensions are added to the given dynamic registry of extensions
// so that all server-known extensions can be correctly parsed by grpcurl.
func fetchAllExtensions(source DescriptorSource, ext *dynamic.ExtensionRegistry, md *desc.MessageDescriptor, alreadyFetched map[string]bool) error {
msgTypeName := md.GetFullyQualifiedName()
if alreadyFetched[msgTypeName] {
return nil
}
alreadyFetched[msgTypeName] = true
if len(md.GetExtensionRanges()) > 0 {
fds, err := source.AllExtensionsForType(msgTypeName)
311
312
313
if err != nil {
return fmt.Errorf("failed to query for extensions of type %s: %v", msgTypeName, err)
}
314
for _, fd := range fds {
315
if err := ext.AddExtension(fd); err != nil {
316
return fmt.Errorf("could not register extension %s of type %s: %v", fd.GetFullyQualifiedName(), msgTypeName, err)
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
}
}
}
// recursively fetch extensions for the types of any message fields
for _, fd := range md.GetFields() {
if fd.GetMessageType() != nil {
err := fetchAllExtensions(source, ext, fd.GetMessageType(), alreadyFetched)
if err != nil {
return err
}
}
}
return nil
}
332
// fullyConvertToDynamic attempts to convert the given message to a dynamic message as well
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
// as any nested messages it may contain as field values. If the given message factory has
// extensions registered that were not known when the given message was parsed, this effectively
// allows re-parsing to identify those extensions.
func fullyConvertToDynamic(msgFact *dynamic.MessageFactory, msg proto.Message) (proto.Message, error) {
if _, ok := msg.(*dynamic.Message); ok {
return msg, nil // already a dynamic message
}
md, err := desc.LoadMessageDescriptorForMessage(msg)
if err != nil {
return nil, err
}
newMsg := msgFact.NewMessage(md)
dm, ok := newMsg.(*dynamic.Message)
if !ok {
// if message factory didn't produce a dynamic message, then we should leave msg as is
return msg, nil
}
if err := dm.ConvertFrom(msg); err != nil {
return nil, err
}
// recursively convert all field values, too
for _, fd := range md.GetFields() {
if fd.IsMap() {
if fd.GetMapValueType().GetMessageType() != nil {
m := dm.GetField(fd).(map[interface{}]interface{})
for k, v := range m {
// keys can't be nested messages; so we only need to recurse through map values, not keys
newVal, err := fullyConvertToDynamic(msgFact, v.(proto.Message))
if err != nil {
return nil, err
}
366
dm.PutMapField(fd, k, newVal)
367
368
369
370
371
372
373
374
375
376
}
}
} else if fd.IsRepeated() {
if fd.GetMessageType() != nil {
s := dm.GetField(fd).([]interface{})
for i, e := range s {
newVal, err := fullyConvertToDynamic(msgFact, e.(proto.Message))
if err != nil {
return nil, err
}
377
dm.SetRepeatedField(fd, i, newVal)
378
379
380
381
382
383
384
385
386
}
}
} else {
if fd.GetMessageType() != nil {
v := dm.GetField(fd)
newVal, err := fullyConvertToDynamic(msgFact, v.(proto.Message))
if err != nil {
return nil, err
}
387
dm.SetField(fd, newVal)
388
389
390
391
392
393
}
}
}
return dm, nil
}
394
395
396
397
398
399
400
401
402
// MakeTemplate returns a message instance for the given descriptor that is a
// suitable template for creating an instance of that message in JSON. In
// particular, it ensures that any repeated fields (which include map fields)
// are not empty, so they will render with a single element (to show the types
// and optionally nested fields). It also ensures that nested messages are not
// nil by setting them to a message that is also fleshed out as a template
// message.
func MakeTemplate(md *desc.MessageDescriptor) proto.Message {
return makeTemplate(md, nil)
403
404
}
405
406
407
408
409
func makeTemplate(md *desc.MessageDescriptor, path []*desc.MessageDescriptor) proto.Message {
switch md.GetFullyQualifiedName() {
case "google.protobuf.Any":
// empty type URL is not allowed by JSON representation
// so we must give it a dummy type
410
411
412
var anyVal anypb.Any
_ = anypb.MarshalFrom(&anyVal, &emptypb.Empty{}, protov2.MarshalOptions{})
return &anyVal
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
case "google.protobuf.Value":
// unset kind is not allowed by JSON representation
// so we must give it something
return &structpb.Value{
Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{
Fields: map[string]*structpb.Value{
"google.protobuf.Value": {Kind: &structpb.Value_StringValue{
StringValue: "supports arbitrary JSON",
}},
},
}},
}
case "google.protobuf.ListValue":
return &structpb.ListValue{
Values: []*structpb.Value{
{
Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{
Fields: map[string]*structpb.Value{
"google.protobuf.ListValue": {Kind: &structpb.Value_StringValue{
StringValue: "is an array of arbitrary JSON values",
}},
},
}},
},
},
}
case "google.protobuf.Struct":
return &structpb.Struct{
Fields: map[string]*structpb.Value{
"google.protobuf.Struct": {Kind: &structpb.Value_StringValue{
StringValue: "supports arbitrary JSON objects",
}},
},
}
447
448
}
449
450
451
452
453
dm := dynamic.NewMessage(md)
// if the message is a recursive structure, we don't want to blow the stack
for _, seen := range path {
if seen == md {
454
// already visited this type; avoid infinite recursion
455
return dm
456
457
458
459
460
461
462
}
}
path = append(path, dm.GetMessageDescriptor())
// for repeated fields, add a single element with default value
// and for message fields, add a message with all default fields
// that also has non-nil message and non-empty repeated fields
463
464
465
466
for _, fd := range dm.GetMessageDescriptor().GetFields() {
if fd.IsRepeated() {
switch fd.GetType() {
467
468
case descriptorpb.FieldDescriptorProto_TYPE_FIXED32,
descriptorpb.FieldDescriptorProto_TYPE_UINT32:
469
470
dm.AddRepeatedField(fd, uint32(0))
471
472
473
474
case descriptorpb.FieldDescriptorProto_TYPE_SFIXED32,
descriptorpb.FieldDescriptorProto_TYPE_SINT32,
descriptorpb.FieldDescriptorProto_TYPE_INT32,
descriptorpb.FieldDescriptorProto_TYPE_ENUM:
475
476
dm.AddRepeatedField(fd, int32(0))
477
478
case descriptorpb.FieldDescriptorProto_TYPE_FIXED64,
descriptorpb.FieldDescriptorProto_TYPE_UINT64:
479
480
dm.AddRepeatedField(fd, uint64(0))
481
482
483
case descriptorpb.FieldDescriptorProto_TYPE_SFIXED64,
descriptorpb.FieldDescriptorProto_TYPE_SINT64,
descriptorpb.FieldDescriptorProto_TYPE_INT64:
484
485
dm.AddRepeatedField(fd, int64(0))
486
case descriptorpb.FieldDescriptorProto_TYPE_STRING:
487
488
dm.AddRepeatedField(fd, "")
489
case descriptorpb.FieldDescriptorProto_TYPE_BYTES:
490
491
dm.AddRepeatedField(fd, []byte{})
492
case descriptorpb.FieldDescriptorProto_TYPE_BOOL:
493
494
dm.AddRepeatedField(fd, false)
495
case descriptorpb.FieldDescriptorProto_TYPE_FLOAT:
496
497
dm.AddRepeatedField(fd, float32(0))
498
case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE:
499
500
dm.AddRepeatedField(fd, float64(0))
501
502
case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE,
descriptorpb.FieldDescriptorProto_TYPE_GROUP:
503
dm.AddRepeatedField(fd, makeTemplate(fd.GetMessageType(), path))
504
505
}
} else if fd.GetMessageType() != nil {
506
dm.SetField(fd, makeTemplate(fd.GetMessageType(), path))
507
508
509
510
511
}
}
return dm
}
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
// ClientTransportCredentials is a helper function that constructs a TLS config with
// the given properties (see ClientTLSConfig) and then constructs and returns gRPC
// transport credentials using that config.
//
// Deprecated: Use grpcurl.ClientTLSConfig and credentials.NewTLS instead.
func ClientTransportCredentials(insecureSkipVerify bool, cacertFile, clientCertFile, clientKeyFile string) (credentials.TransportCredentials, error) {
tlsConf, err := ClientTLSConfig(insecureSkipVerify, cacertFile, clientCertFile, clientKeyFile)
if err != nil {
return nil, err
}
return credentials.NewTLS(tlsConf), nil
}
// ClientTLSConfig builds transport-layer config for a gRPC client using the
527
528
529
// given properties. If cacertFile is blank, only standard trusted certs are used to
// verify the server certs. If clientCertFile is blank, the client will not use a client
// certificate. If clientCertFile is not blank then clientKeyFile must not be blank.
530
func ClientTLSConfig(insecureSkipVerify bool, cacertFile, clientCertFile, clientKeyFile string) (*tls.Config, error) {
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
var tlsConf tls.Config
if clientCertFile != "" {
// Load the client certificates from disk
certificate, err := tls.LoadX509KeyPair(clientCertFile, clientKeyFile)
if err != nil {
return nil, fmt.Errorf("could not load client key pair: %v", err)
}
tlsConf.Certificates = []tls.Certificate{certificate}
}
if insecureSkipVerify {
tlsConf.InsecureSkipVerify = true
} else if cacertFile != "" {
// Create a certificate pool from the certificate authority
certPool := x509.NewCertPool()
547
ca, err := os.ReadFile(cacertFile)
548
549
550
551
552
553
554
555
556
557
558
559
if err != nil {
return nil, fmt.Errorf("could not read ca certificate: %v", err)
}
// Append the certificates from the CA
if ok := certPool.AppendCertsFromPEM(ca); !ok {
return nil, errors.New("failed to append ca certs")
}
tlsConf.RootCAs = certPool
}
560
return &tlsConf, nil
561
562
}
563
// ServerTransportCredentials builds transport credentials for a gRPC server using the
564
565
566
567
568
569
// given properties. If cacertFile is blank, the server will not request client certs
// unless requireClientCerts is true. When requireClientCerts is false and cacertFile is
// not blank, the server will verify client certs when presented, but will not require
// client certs. The serverCertFile and serverKeyFile must both not be blank.
func ServerTransportCredentials(cacertFile, serverCertFile, serverKeyFile string, requireClientCerts bool) (credentials.TransportCredentials, error) {
var tlsConf tls.Config
570
571
572
// TODO(jh): Remove this line once https://github.com/golang/go/issues/28779 is fixed
// in Go tip. Until then, the recently merged TLS 1.3 support breaks the TLS tests.
tlsConf.MaxVersion = tls.VersionTLS12
573
574
575
576
577
578
579
580
581
582
583
// Load the server certificates from disk
certificate, err := tls.LoadX509KeyPair(serverCertFile, serverKeyFile)
if err != nil {
return nil, fmt.Errorf("could not load key pair: %v", err)
}
tlsConf.Certificates = []tls.Certificate{certificate}
if cacertFile != "" {
// Create a certificate pool from the certificate authority
certPool := x509.NewCertPool()
584
ca, err := os.ReadFile(cacertFile)
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
if err != nil {
return nil, fmt.Errorf("could not read ca certificate: %v", err)
}
// Append the certificates from the CA
if ok := certPool.AppendCertsFromPEM(ca); !ok {
return nil, errors.New("failed to append ca certs")
}
tlsConf.ClientCAs = certPool
}
if requireClientCerts {
tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
} else if cacertFile != "" {
tlsConf.ClientAuth = tls.VerifyClientCertIfGiven
} else {
tlsConf.ClientAuth = tls.NoClientCert
}
return credentials.NewTLS(&tlsConf), nil
}
607
608
609
610
// BlockingDial is a helper method to dial the given address, using optional TLS credentials,
// and blocking until the returned connection is ready. If the given credentials are nil, the
// connection will be insecure (plain-text).
611
func BlockingDial(ctx context.Context, network, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
if creds == nil {
creds = insecure.NewCredentials()
}
var err error
if strings.HasPrefix(address, "xds:///") {
// The xds:/// prefix is used to signal to the gRPC client to use an xDS server to resolve the
// target. The relevant credentials will be automatically pulled from the GRPC_XDS_BOOTSTRAP or
// GRPC_XDS_BOOTSTRAP_CONFIG env vars.
creds, err = xdsCredentials.NewClientCredentials(xdsCredentials.ClientOptions{FallbackCreds: creds})
if err != nil {
return nil, err
}
}
627
628
629
630
631
632
633
634
635
636
637
638
639
// grpc.Dial doesn't provide any information on permanent connection errors (like
// TLS handshake failures). So in order to provide good error messages, we need a
// custom dialer that can provide that info. That means we manage the TLS handshake.
result := make(chan interface{}, 1)
writeResult := func(res interface{}) {
// non-blocking write: we only need the first result
select {
case result <- res:
default:
}
}
640
641
// custom credentials and dialer will notify on error via the
// writeResult function
642
643
644
creds = &errSignalingCreds{
TransportCredentials: creds,
writeResult: writeResult,
645
}
646
647
dialer := func(ctx context.Context, address string) (net.Conn, error) {
648
649
// NB: We *could* handle the TLS handshake ourselves, in the custom
// dialer (instead of customizing both the dialer and the credentials).
650
651
652
653
654
// But that requires using insecure.NewCredentials() dial transport
// option (so that the gRPC library doesn't *also* try to do a
// handshake). And that would mean that the library would send the
// wrong ":scheme" metaheader to servers: it would send "http" instead
// of "https" because it is unaware that TLS is actually in use.
655
conn, err := (&net.Dialer{}).DialContext(ctx, network, address)
656
657
658
if err != nil {
writeResult(err)
}
659
return conn, err
660
661
662
663
664
}
// Even with grpc.FailOnNonTempDialError, this call will usually timeout in
// the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to
// know when we're done. So we run it in a goroutine and then use result
665
// channel to either get the connection or fail-fast.
666
go func() {
667
668
669
670
671
// We put grpc.FailOnNonTempDialError *before* the explicitly provided
// options so that it could be overridden.
opts = append([]grpc.DialOption{grpc.FailOnNonTempDialError(true)}, opts...)
// But we don't want caller to be able to override these two, so we put
// them *after* the explicitly provided options.
672
opts = append(opts, grpc.WithBlock(), grpc.WithContextDialer(dialer), grpc.WithTransportCredentials(creds))
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
conn, err := grpc.DialContext(ctx, address, opts...)
var res interface{}
if err != nil {
res = err
} else {
res = conn
}
writeResult(res)
}()
select {
case res := <-result:
if conn, ok := res.(*grpc.ClientConn); ok {
return conn, nil
}
689
return nil, res.(error)
690
691
692
693
case <-ctx.Done():
return nil, ctx.Err()
}
}
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
// errSignalingCreds is a wrapper around a TransportCredentials value, but
// it will use the writeResult function to notify on error.
type errSignalingCreds struct {
credentials.TransportCredentials
writeResult func(res interface{})
}
func (c *errSignalingCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
conn, auth, err := c.TransportCredentials.ClientHandshake(ctx, addr, rawConn)
if err != nil {
c.writeResult(err)
}
return conn, auth, err
}