Skip to content

Latest commit

 

History

History
709 lines (637 loc) · 23.1 KB

grpcurl.go

File metadata and controls

709 lines (637 loc) · 23.1 KB
 
1
2
3
4
5
6
7
8
9
10
// Package grpcurl provides the core functionality exposed by the grpcurl command, for
// dynamically connecting to a server, using the reflection service to inspect the server,
// and invoking RPCs. The grpcurl command-line tool constructs a DescriptorSource, based
// on the command-line parameters, and supplies an InvocationEventHandler to supply request
// data (which can come from command-line args or the process's stdin) and to log the
// events (to the process's stdout).
package grpcurl
import (
"bytes"
Feb 22, 2021
Feb 22, 2021
11
"context"
12
13
"crypto/tls"
"crypto/x509"
Dec 13, 2017
Dec 13, 2017
14
"encoding/base64"
15
16
"errors"
"fmt"
Dec 13, 2017
Dec 13, 2017
17
"net"
Sep 26, 2019
Sep 26, 2019
18
19
"os"
"regexp"
20
21
22
"sort"
"strings"
Feb 22, 2021
Feb 22, 2021
23
"github.com/golang/protobuf/proto" //lint:ignore SA1019 we have to import this because it appears in exported API
24
"github.com/jhump/protoreflect/desc"
Oct 18, 2018
Oct 18, 2018
25
"github.com/jhump/protoreflect/desc/protoprint"
26
27
28
"github.com/jhump/protoreflect/dynamic"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
Feb 5, 2022
Feb 5, 2022
29
"google.golang.org/grpc/credentials/insecure"
Feb 12, 2024
Feb 12, 2024
30
xdsCredentials "google.golang.org/grpc/credentials/xds"
May 8, 2024
May 8, 2024
31
_ "google.golang.org/grpc/health" // import grpc/health to enable transparent client side checking
32
"google.golang.org/grpc/metadata"
Feb 22, 2021
Feb 22, 2021
33
34
35
36
37
protov2 "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/structpb"
38
39
40
41
42
43
44
45
46
47
48
49
50
)
// ListServices uses the given descriptor source to return a sorted list of fully-qualified
// service names.
func ListServices(source DescriptorSource) ([]string, error) {
svcs, err := source.ListServices()
if err != nil {
return nil, err
}
sort.Strings(svcs)
return svcs, nil
}
Oct 17, 2018
Oct 17, 2018
51
52
53
54
55
56
57
58
59
60
61
type sourceWithFiles interface {
GetAllFiles() ([]*desc.FileDescriptor, error)
}
var _ sourceWithFiles = (*fileSource)(nil)
// GetAllFiles uses the given descriptor source to return a list of file descriptors.
func GetAllFiles(source DescriptorSource) ([]*desc.FileDescriptor, error) {
var files []*desc.FileDescriptor
srcFiles, ok := source.(sourceWithFiles)
Dec 13, 2018
Dec 13, 2018
62
63
64
// If an error occurs, we still try to load as many files as we can, so that
// caller can decide whether to ignore error or not.
var firstError error
Oct 17, 2018
Oct 17, 2018
65
if ok {
Dec 13, 2018
Dec 13, 2018
66
files, firstError = srcFiles.GetAllFiles()
Oct 17, 2018
Oct 17, 2018
67
68
69
70
71
} else {
// Source does not implement GetAllFiles method, so use ListServices
// and grab files from there.
svcNames, err := source.ListServices()
if err != nil {
Dec 13, 2018
Dec 13, 2018
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
firstError = err
} else {
allFiles := map[string]*desc.FileDescriptor{}
for _, name := range svcNames {
d, err := source.FindSymbol(name)
if err != nil {
if firstError == nil {
firstError = err
}
} else {
addAllFilesToSet(d.GetFile(), allFiles)
}
}
files = make([]*desc.FileDescriptor, len(allFiles))
i := 0
for _, fd := range allFiles {
files[i] = fd
i++
Oct 17, 2018
Oct 17, 2018
90
91
92
93
94
}
}
}
sort.Sort(filesByName(files))
Dec 13, 2018
Dec 13, 2018
95
return files, firstError
Oct 17, 2018
Oct 17, 2018
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
}
type filesByName []*desc.FileDescriptor
func (f filesByName) Len() int {
return len(f)
}
func (f filesByName) Less(i, j int) bool {
return f[i].GetName() < f[j].GetName()
}
func (f filesByName) Swap(i, j int) {
f[i], f[j] = f[j], f[i]
}
func addAllFilesToSet(fd *desc.FileDescriptor, all map[string]*desc.FileDescriptor) {
if _, ok := all[fd.GetName()]; ok {
// already added
return
}
all[fd.GetName()] = fd
for _, dep := range fd.GetDependencies() {
addAllFilesToSet(dep, all)
}
}
123
124
125
126
127
128
129
130
131
132
133
134
// ListMethods uses the given descriptor source to return a sorted list of method names
// for the specified fully-qualified service name.
func ListMethods(source DescriptorSource, serviceName string) ([]string, error) {
dsc, err := source.FindSymbol(serviceName)
if err != nil {
return nil, err
}
if sd, ok := dsc.(*desc.ServiceDescriptor); !ok {
return nil, notFound("Service", serviceName)
} else {
methods := make([]string, 0, len(sd.GetMethods()))
for _, method := range sd.GetMethods() {
Oct 19, 2018
Oct 19, 2018
135
methods = append(methods, method.GetFullyQualifiedName())
136
137
138
139
140
141
}
sort.Strings(methods)
return methods, nil
}
}
Dec 13, 2017
Dec 13, 2017
142
143
144
145
146
147
// MetadataFromHeaders converts a list of header strings (each string in
// "Header-Name: Header-Value" form) into metadata. If a string has a header
// name without a value (e.g. does not contain a colon), the value is assumed
// to be blank. Binary headers (those whose names end in "-bin") should be
// base64-encoded. But if they cannot be base64-decoded, they will be assumed to
// be in raw form and used as is.
148
149
150
151
152
153
154
155
156
func MetadataFromHeaders(headers []string) metadata.MD {
md := make(metadata.MD)
for _, part := range headers {
if part != "" {
pieces := strings.SplitN(part, ":", 2)
if len(pieces) == 1 {
pieces = append(pieces, "") // if no value was specified, just make it "" (maybe the header value doesn't matter)
}
headerName := strings.ToLower(strings.TrimSpace(pieces[0]))
Dec 13, 2017
Dec 13, 2017
157
158
159
160
161
162
163
val := strings.TrimSpace(pieces[1])
if strings.HasSuffix(headerName, "-bin") {
if v, err := decode(val); err == nil {
val = v
}
}
md[headerName] = append(md[headerName], val)
164
165
166
167
168
}
}
return md
}
Sep 26, 2019
Sep 26, 2019
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
var envVarRegex = regexp.MustCompile(`\${\w+}`)
// ExpandHeaders expands environment variables contained in the header string.
// If no corresponding environment variable is found an error is returned.
// TODO: Add escaping for `${`
func ExpandHeaders(headers []string) ([]string, error) {
expandedHeaders := make([]string, len(headers))
for idx, header := range headers {
if header == "" {
continue
}
results := envVarRegex.FindAllString(header, -1)
if len(results) == 0 {
expandedHeaders[idx] = headers[idx]
continue
}
expandedHeader := header
for _, result := range results {
envVarName := result[2 : len(result)-1] // strip leading `${` and trailing `}`
envVarValue, ok := os.LookupEnv(envVarName)
if !ok {
return nil, fmt.Errorf("header %q refers to missing environment variable %q", header, envVarName)
}
expandedHeader = strings.Replace(expandedHeader, result, envVarValue, -1)
}
expandedHeaders[idx] = expandedHeader
}
return expandedHeaders, nil
}
Dec 13, 2017
Dec 13, 2017
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
var base64Codecs = []*base64.Encoding{base64.StdEncoding, base64.URLEncoding, base64.RawStdEncoding, base64.RawURLEncoding}
func decode(val string) (string, error) {
var firstErr error
var b []byte
// we are lenient and can accept any of the flavors of base64 encoding
for _, d := range base64Codecs {
var err error
b, err = d.DecodeString(val)
if err != nil {
if firstErr == nil {
firstErr = err
}
continue
}
return string(b), nil
}
return "", firstErr
}
// MetadataToString returns a string representation of the given metadata, for
// displaying to users.
221
222
223
224
func MetadataToString(md metadata.MD) string {
if len(md) == 0 {
return "(empty)"
}
Oct 17, 2018
Oct 17, 2018
225
226
227
228
229
230
231
keys := make([]string, 0, len(md))
for k := range md {
keys = append(keys, k)
}
sort.Strings(keys)
232
var b bytes.Buffer
Oct 17, 2018
Oct 17, 2018
233
234
235
first := true
for _, k := range keys {
vs := md[k]
236
for _, v := range vs {
Oct 17, 2018
Oct 17, 2018
237
238
239
240
241
if first {
first = false
} else {
b.WriteString("\n")
}
242
243
b.WriteString(k)
b.WriteString(": ")
Dec 13, 2017
Dec 13, 2017
244
245
246
if strings.HasSuffix(k, "-bin") {
v = base64.StdEncoding.EncodeToString([]byte(v))
}
247
248
249
250
251
252
b.WriteString(v)
}
}
return b.String()
}
Oct 18, 2018
Oct 18, 2018
253
254
255
256
257
258
259
var printer = &protoprint.Printer{
Compact: true,
OmitComments: protoprint.CommentsNonDoc,
SortElements: true,
ForceFullyQualifiedNames: true,
}
Dec 13, 2017
Dec 13, 2017
260
// GetDescriptorText returns a string representation of the given descriptor.
Oct 18, 2018
Oct 18, 2018
261
262
263
264
265
266
267
268
269
270
271
272
273
// This returns a snippet of proto source that describes the given element.
func GetDescriptorText(dsc desc.Descriptor, _ DescriptorSource) (string, error) {
// Note: DescriptorSource is not used, but remains an argument for backwards
// compatibility with previous implementation.
txt, err := printer.PrintProtoToString(dsc)
if err != nil {
return "", err
}
// callers don't expect trailing newlines
if txt[len(txt)-1] == '\n' {
txt = txt[:len(txt)-1]
}
return txt, nil
Dec 13, 2017
Dec 13, 2017
276
277
278
// EnsureExtensions uses the given descriptor source to download extensions for
// the given message. It returns a copy of the given message, but as a dynamic
// message that knows about all extensions known to the given descriptor source.
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
func EnsureExtensions(source DescriptorSource, msg proto.Message) proto.Message {
// load any server extensions so we can properly describe custom options
dsc, err := desc.LoadMessageDescriptorForMessage(msg)
if err != nil {
return msg
}
var ext dynamic.ExtensionRegistry
if err = fetchAllExtensions(source, &ext, dsc, map[string]bool{}); err != nil {
return msg
}
// convert message into dynamic message that knows about applicable extensions
// (that way we can show meaningful info for custom options instead of printing as unknown)
msgFactory := dynamic.NewMessageFactoryWithExtensionRegistry(&ext)
dm, err := fullyConvertToDynamic(msgFactory, msg)
if err != nil {
return msg
}
return dm
}
// fetchAllExtensions recursively fetches from the server extensions for the given message type as well as
// for all message types of nested fields. The extensions are added to the given dynamic registry of extensions
// so that all server-known extensions can be correctly parsed by grpcurl.
func fetchAllExtensions(source DescriptorSource, ext *dynamic.ExtensionRegistry, md *desc.MessageDescriptor, alreadyFetched map[string]bool) error {
msgTypeName := md.GetFullyQualifiedName()
if alreadyFetched[msgTypeName] {
return nil
}
alreadyFetched[msgTypeName] = true
if len(md.GetExtensionRanges()) > 0 {
fds, err := source.AllExtensionsForType(msgTypeName)
Mar 24, 2018
Mar 24, 2018
312
313
314
if err != nil {
return fmt.Errorf("failed to query for extensions of type %s: %v", msgTypeName, err)
}
315
for _, fd := range fds {
Mar 24, 2018
Mar 24, 2018
316
if err := ext.AddExtension(fd); err != nil {
Dec 5, 2017
Dec 5, 2017
317
return fmt.Errorf("could not register extension %s of type %s: %v", fd.GetFullyQualifiedName(), msgTypeName, err)
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
}
}
}
// recursively fetch extensions for the types of any message fields
for _, fd := range md.GetFields() {
if fd.GetMessageType() != nil {
err := fetchAllExtensions(source, ext, fd.GetMessageType(), alreadyFetched)
if err != nil {
return err
}
}
}
return nil
}
Nov 9, 2022
Nov 9, 2022
333
// fullyConvertToDynamic attempts to convert the given message to a dynamic message as well
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
// as any nested messages it may contain as field values. If the given message factory has
// extensions registered that were not known when the given message was parsed, this effectively
// allows re-parsing to identify those extensions.
func fullyConvertToDynamic(msgFact *dynamic.MessageFactory, msg proto.Message) (proto.Message, error) {
if _, ok := msg.(*dynamic.Message); ok {
return msg, nil // already a dynamic message
}
md, err := desc.LoadMessageDescriptorForMessage(msg)
if err != nil {
return nil, err
}
newMsg := msgFact.NewMessage(md)
dm, ok := newMsg.(*dynamic.Message)
if !ok {
// if message factory didn't produce a dynamic message, then we should leave msg as is
return msg, nil
}
if err := dm.ConvertFrom(msg); err != nil {
return nil, err
}
// recursively convert all field values, too
for _, fd := range md.GetFields() {
if fd.IsMap() {
if fd.GetMapValueType().GetMessageType() != nil {
m := dm.GetField(fd).(map[interface{}]interface{})
for k, v := range m {
// keys can't be nested messages; so we only need to recurse through map values, not keys
newVal, err := fullyConvertToDynamic(msgFact, v.(proto.Message))
if err != nil {
return nil, err
}
Mar 24, 2018
Mar 24, 2018
367
dm.PutMapField(fd, k, newVal)
368
369
370
371
372
373
374
375
376
377
}
}
} else if fd.IsRepeated() {
if fd.GetMessageType() != nil {
s := dm.GetField(fd).([]interface{})
for i, e := range s {
newVal, err := fullyConvertToDynamic(msgFact, e.(proto.Message))
if err != nil {
return nil, err
}
Mar 24, 2018
Mar 24, 2018
378
dm.SetRepeatedField(fd, i, newVal)
379
380
381
382
383
384
385
386
387
}
}
} else {
if fd.GetMessageType() != nil {
v := dm.GetField(fd)
newVal, err := fullyConvertToDynamic(msgFact, v.(proto.Message))
if err != nil {
return nil, err
}
Mar 24, 2018
Mar 24, 2018
388
dm.SetField(fd, newVal)
389
390
391
392
393
394
}
}
}
return dm, nil
}
Oct 23, 2018
Oct 23, 2018
395
396
397
398
399
400
401
402
403
// MakeTemplate returns a message instance for the given descriptor that is a
// suitable template for creating an instance of that message in JSON. In
// particular, it ensures that any repeated fields (which include map fields)
// are not empty, so they will render with a single element (to show the types
// and optionally nested fields). It also ensures that nested messages are not
// nil by setting them to a message that is also fleshed out as a template
// message.
func MakeTemplate(md *desc.MessageDescriptor) proto.Message {
return makeTemplate(md, nil)
Oct 19, 2018
Oct 19, 2018
404
405
}
Oct 23, 2018
Oct 23, 2018
406
407
408
409
410
func makeTemplate(md *desc.MessageDescriptor, path []*desc.MessageDescriptor) proto.Message {
switch md.GetFullyQualifiedName() {
case "google.protobuf.Any":
// empty type URL is not allowed by JSON representation
// so we must give it a dummy type
Nov 1, 2022
Nov 1, 2022
411
412
413
var anyVal anypb.Any
_ = anypb.MarshalFrom(&anyVal, &emptypb.Empty{}, protov2.MarshalOptions{})
return &anyVal
Oct 23, 2018
Oct 23, 2018
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
case "google.protobuf.Value":
// unset kind is not allowed by JSON representation
// so we must give it something
return &structpb.Value{
Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{
Fields: map[string]*structpb.Value{
"google.protobuf.Value": {Kind: &structpb.Value_StringValue{
StringValue: "supports arbitrary JSON",
}},
},
}},
}
case "google.protobuf.ListValue":
return &structpb.ListValue{
Values: []*structpb.Value{
{
Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{
Fields: map[string]*structpb.Value{
"google.protobuf.ListValue": {Kind: &structpb.Value_StringValue{
StringValue: "is an array of arbitrary JSON values",
}},
},
}},
},
},
}
case "google.protobuf.Struct":
return &structpb.Struct{
Fields: map[string]*structpb.Value{
"google.protobuf.Struct": {Kind: &structpb.Value_StringValue{
StringValue: "supports arbitrary JSON objects",
}},
},
}
Oct 19, 2018
Oct 19, 2018
448
449
}
Oct 23, 2018
Oct 23, 2018
450
451
452
453
454
dm := dynamic.NewMessage(md)
// if the message is a recursive structure, we don't want to blow the stack
for _, seen := range path {
if seen == md {
Oct 19, 2018
Oct 19, 2018
455
// already visited this type; avoid infinite recursion
Oct 23, 2018
Oct 23, 2018
456
return dm
Oct 19, 2018
Oct 19, 2018
457
458
459
460
461
462
463
}
}
path = append(path, dm.GetMessageDescriptor())
// for repeated fields, add a single element with default value
// and for message fields, add a message with all default fields
// that also has non-nil message and non-empty repeated fields
Oct 23, 2018
Oct 23, 2018
464
Oct 19, 2018
Oct 19, 2018
465
466
467
for _, fd := range dm.GetMessageDescriptor().GetFields() {
if fd.IsRepeated() {
switch fd.GetType() {
Feb 22, 2021
Feb 22, 2021
468
469
case descriptorpb.FieldDescriptorProto_TYPE_FIXED32,
descriptorpb.FieldDescriptorProto_TYPE_UINT32:
Oct 19, 2018
Oct 19, 2018
470
471
dm.AddRepeatedField(fd, uint32(0))
Feb 22, 2021
Feb 22, 2021
472
473
474
475
case descriptorpb.FieldDescriptorProto_TYPE_SFIXED32,
descriptorpb.FieldDescriptorProto_TYPE_SINT32,
descriptorpb.FieldDescriptorProto_TYPE_INT32,
descriptorpb.FieldDescriptorProto_TYPE_ENUM:
Oct 19, 2018
Oct 19, 2018
476
477
dm.AddRepeatedField(fd, int32(0))
Feb 22, 2021
Feb 22, 2021
478
479
case descriptorpb.FieldDescriptorProto_TYPE_FIXED64,
descriptorpb.FieldDescriptorProto_TYPE_UINT64:
Oct 19, 2018
Oct 19, 2018
480
481
dm.AddRepeatedField(fd, uint64(0))
Feb 22, 2021
Feb 22, 2021
482
483
484
case descriptorpb.FieldDescriptorProto_TYPE_SFIXED64,
descriptorpb.FieldDescriptorProto_TYPE_SINT64,
descriptorpb.FieldDescriptorProto_TYPE_INT64:
Oct 19, 2018
Oct 19, 2018
485
486
dm.AddRepeatedField(fd, int64(0))
Feb 22, 2021
Feb 22, 2021
487
case descriptorpb.FieldDescriptorProto_TYPE_STRING:
Oct 19, 2018
Oct 19, 2018
488
489
dm.AddRepeatedField(fd, "")
Feb 22, 2021
Feb 22, 2021
490
case descriptorpb.FieldDescriptorProto_TYPE_BYTES:
Oct 19, 2018
Oct 19, 2018
491
492
dm.AddRepeatedField(fd, []byte{})
Feb 22, 2021
Feb 22, 2021
493
case descriptorpb.FieldDescriptorProto_TYPE_BOOL:
Oct 19, 2018
Oct 19, 2018
494
495
dm.AddRepeatedField(fd, false)
Feb 22, 2021
Feb 22, 2021
496
case descriptorpb.FieldDescriptorProto_TYPE_FLOAT:
Oct 19, 2018
Oct 19, 2018
497
498
dm.AddRepeatedField(fd, float32(0))
Feb 22, 2021
Feb 22, 2021
499
case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE:
Oct 19, 2018
Oct 19, 2018
500
501
dm.AddRepeatedField(fd, float64(0))
Feb 22, 2021
Feb 22, 2021
502
503
case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE,
descriptorpb.FieldDescriptorProto_TYPE_GROUP:
Oct 23, 2018
Oct 23, 2018
504
dm.AddRepeatedField(fd, makeTemplate(fd.GetMessageType(), path))
Oct 19, 2018
Oct 19, 2018
505
506
}
} else if fd.GetMessageType() != nil {
Oct 23, 2018
Oct 23, 2018
507
dm.SetField(fd, makeTemplate(fd.GetMessageType(), path))
Oct 19, 2018
Oct 19, 2018
508
509
510
511
512
}
}
return dm
}
Sep 20, 2021
Sep 20, 2021
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
// ClientTransportCredentials is a helper function that constructs a TLS config with
// the given properties (see ClientTLSConfig) and then constructs and returns gRPC
// transport credentials using that config.
//
// Deprecated: Use grpcurl.ClientTLSConfig and credentials.NewTLS instead.
func ClientTransportCredentials(insecureSkipVerify bool, cacertFile, clientCertFile, clientKeyFile string) (credentials.TransportCredentials, error) {
tlsConf, err := ClientTLSConfig(insecureSkipVerify, cacertFile, clientCertFile, clientKeyFile)
if err != nil {
return nil, err
}
return credentials.NewTLS(tlsConf), nil
}
// ClientTLSConfig builds transport-layer config for a gRPC client using the
528
529
530
// given properties. If cacertFile is blank, only standard trusted certs are used to
// verify the server certs. If clientCertFile is blank, the client will not use a client
// certificate. If clientCertFile is not blank then clientKeyFile must not be blank.
Sep 20, 2021
Sep 20, 2021
531
func ClientTLSConfig(insecureSkipVerify bool, cacertFile, clientCertFile, clientKeyFile string) (*tls.Config, error) {
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
var tlsConf tls.Config
if clientCertFile != "" {
// Load the client certificates from disk
certificate, err := tls.LoadX509KeyPair(clientCertFile, clientKeyFile)
if err != nil {
return nil, fmt.Errorf("could not load client key pair: %v", err)
}
tlsConf.Certificates = []tls.Certificate{certificate}
}
if insecureSkipVerify {
tlsConf.InsecureSkipVerify = true
} else if cacertFile != "" {
// Create a certificate pool from the certificate authority
certPool := x509.NewCertPool()
Apr 9, 2024
Apr 9, 2024
548
ca, err := os.ReadFile(cacertFile)
549
550
551
552
553
554
555
556
557
558
559
560
if err != nil {
return nil, fmt.Errorf("could not read ca certificate: %v", err)
}
// Append the certificates from the CA
if ok := certPool.AppendCertsFromPEM(ca); !ok {
return nil, errors.New("failed to append ca certs")
}
tlsConf.RootCAs = certPool
}
Sep 20, 2021
Sep 20, 2021
561
return &tlsConf, nil
Oct 17, 2018
Oct 17, 2018
564
// ServerTransportCredentials builds transport credentials for a gRPC server using the
565
566
567
568
569
570
// given properties. If cacertFile is blank, the server will not request client certs
// unless requireClientCerts is true. When requireClientCerts is false and cacertFile is
// not blank, the server will verify client certs when presented, but will not require
// client certs. The serverCertFile and serverKeyFile must both not be blank.
func ServerTransportCredentials(cacertFile, serverCertFile, serverKeyFile string, requireClientCerts bool) (credentials.TransportCredentials, error) {
var tlsConf tls.Config
Nov 16, 2018
Nov 16, 2018
571
572
573
// TODO(jh): Remove this line once https://github.com/golang/go/issues/28779 is fixed
// in Go tip. Until then, the recently merged TLS 1.3 support breaks the TLS tests.
tlsConf.MaxVersion = tls.VersionTLS12
574
575
576
577
578
579
580
581
582
583
584
// Load the server certificates from disk
certificate, err := tls.LoadX509KeyPair(serverCertFile, serverKeyFile)
if err != nil {
return nil, fmt.Errorf("could not load key pair: %v", err)
}
tlsConf.Certificates = []tls.Certificate{certificate}
if cacertFile != "" {
// Create a certificate pool from the certificate authority
certPool := x509.NewCertPool()
Apr 9, 2024
Apr 9, 2024
585
ca, err := os.ReadFile(cacertFile)
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
if err != nil {
return nil, fmt.Errorf("could not read ca certificate: %v", err)
}
// Append the certificates from the CA
if ok := certPool.AppendCertsFromPEM(ca); !ok {
return nil, errors.New("failed to append ca certs")
}
tlsConf.ClientCAs = certPool
}
if requireClientCerts {
tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
} else if cacertFile != "" {
tlsConf.ClientAuth = tls.VerifyClientCertIfGiven
} else {
tlsConf.ClientAuth = tls.NoClientCert
}
return credentials.NewTLS(&tlsConf), nil
}
Dec 13, 2017
Dec 13, 2017
608
609
610
611
// BlockingDial is a helper method to dial the given address, using optional TLS credentials,
// and blocking until the returned connection is ready. If the given credentials are nil, the
// connection will be insecure (plain-text).
Mar 27, 2018
Mar 27, 2018
612
func BlockingDial(ctx context.Context, network, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
Feb 12, 2024
Feb 12, 2024
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
if creds == nil {
creds = insecure.NewCredentials()
}
var err error
if strings.HasPrefix(address, "xds:///") {
// The xds:/// prefix is used to signal to the gRPC client to use an xDS server to resolve the
// target. The relevant credentials will be automatically pulled from the GRPC_XDS_BOOTSTRAP or
// GRPC_XDS_BOOTSTRAP_CONFIG env vars.
creds, err = xdsCredentials.NewClientCredentials(xdsCredentials.ClientOptions{FallbackCreds: creds})
if err != nil {
return nil, err
}
}
Dec 13, 2017
Dec 13, 2017
628
629
630
631
632
633
634
635
636
637
638
639
640
// grpc.Dial doesn't provide any information on permanent connection errors (like
// TLS handshake failures). So in order to provide good error messages, we need a
// custom dialer that can provide that info. That means we manage the TLS handshake.
result := make(chan interface{}, 1)
writeResult := func(res interface{}) {
// non-blocking write: we only need the first result
select {
case result <- res:
default:
}
}
Jan 27, 2020
Jan 27, 2020
641
642
// custom credentials and dialer will notify on error via the
// writeResult function
Feb 12, 2024
Feb 12, 2024
643
644
645
creds = &errSignalingCreds{
TransportCredentials: creds,
writeResult: writeResult,
Jan 27, 2020
Jan 27, 2020
646
}
Feb 12, 2024
Feb 12, 2024
647
Feb 27, 2019
Feb 27, 2019
648
dialer := func(ctx context.Context, address string) (net.Conn, error) {
Jan 27, 2020
Jan 27, 2020
649
650
// NB: We *could* handle the TLS handshake ourselves, in the custom
// dialer (instead of customizing both the dialer and the credentials).
Feb 5, 2022
Feb 5, 2022
651
652
653
654
655
// But that requires using insecure.NewCredentials() dial transport
// option (so that the gRPC library doesn't *also* try to do a
// handshake). And that would mean that the library would send the
// wrong ":scheme" metaheader to servers: it would send "http" instead
// of "https" because it is unaware that TLS is actually in use.
Feb 27, 2019
Feb 27, 2019
656
conn, err := (&net.Dialer{}).DialContext(ctx, network, address)
Dec 13, 2017
Dec 13, 2017
657
658
659
if err != nil {
writeResult(err)
}
Jan 27, 2020
Jan 27, 2020
660
return conn, err
Dec 13, 2017
Dec 13, 2017
661
662
663
664
665
}
// Even with grpc.FailOnNonTempDialError, this call will usually timeout in
// the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to
// know when we're done. So we run it in a goroutine and then use result
Jan 27, 2020
Jan 27, 2020
666
// channel to either get the connection or fail-fast.
Dec 13, 2017
Dec 13, 2017
667
go func() {
Jul 12, 2021
Jul 12, 2021
668
669
670
671
672
// We put grpc.FailOnNonTempDialError *before* the explicitly provided
// options so that it could be overridden.
opts = append([]grpc.DialOption{grpc.FailOnNonTempDialError(true)}, opts...)
// But we don't want caller to be able to override these two, so we put
// them *after* the explicitly provided options.
Feb 12, 2024
Feb 12, 2024
673
opts = append(opts, grpc.WithBlock(), grpc.WithContextDialer(dialer), grpc.WithTransportCredentials(creds))
Jul 12, 2021
Jul 12, 2021
674
Dec 13, 2017
Dec 13, 2017
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
conn, err := grpc.DialContext(ctx, address, opts...)
var res interface{}
if err != nil {
res = err
} else {
res = conn
}
writeResult(res)
}()
select {
case res := <-result:
if conn, ok := res.(*grpc.ClientConn); ok {
return conn, nil
}
Mar 24, 2018
Mar 24, 2018
690
return nil, res.(error)
Dec 13, 2017
Dec 13, 2017
691
692
693
694
case <-ctx.Done():
return nil, ctx.Err()
}
}
Jan 27, 2020
Jan 27, 2020
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
// errSignalingCreds is a wrapper around a TransportCredentials value, but
// it will use the writeResult function to notify on error.
type errSignalingCreds struct {
credentials.TransportCredentials
writeResult func(res interface{})
}
func (c *errSignalingCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
conn, auth, err := c.TransportCredentials.ClientHandshake(ctx, addr, rawConn)
if err != nil {
c.writeResult(err)
}
return conn, auth, err
}