Skip to content

Commit 586bde9

Browse files
committed
Factor out interface type construction into helper func
1 parent ba0d4e3 commit 586bde9

File tree

1 file changed

+39
-33
lines changed

1 file changed

+39
-33
lines changed

cmd/protoc-gen-go-grpc/grpc.go

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,17 @@ func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string
310310
return s
311311
}
312312

313+
func clientStreamInterface(g *protogen.GeneratedFile, method *protogen.Method) string {
314+
typeParam := g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent)
315+
if method.Desc.IsStreamingClient() && method.Desc.IsStreamingServer() {
316+
return g.QualifiedGoIdent(grpcPackage.Ident("BidiStreamClient")) + "[" + typeParam + "]"
317+
} else if method.Desc.IsStreamingClient() {
318+
return g.QualifiedGoIdent(grpcPackage.Ident("ClientStreamClient")) + "[" + typeParam + "]"
319+
} else { // i.e. if method.Desc.IsStreamingServer()
320+
return g.QualifiedGoIdent(grpcPackage.Ident("ServerStreamClient")) + "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "]"
321+
}
322+
}
323+
313324
func genClientMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method, index int) {
314325
service := method.Parent
315326
fmSymbol := helper.formatFullMethodSymbol(service, method)
@@ -329,24 +340,16 @@ func genClientMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.Gene
329340
return
330341
}
331342

332-
streamType := unexport(service.GoName) + method.GoName + "Client"
333-
var streamInterface string
343+
streamImpl := unexport(service.GoName) + method.GoName + "Client"
334344
if *useGenericStreams {
335345
typeParam := g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent)
336-
streamType = g.QualifiedGoIdent(grpcPackage.Ident("StreamClientImpl")) + "[" + typeParam + "]"
337-
if method.Desc.IsStreamingClient() && method.Desc.IsStreamingServer() {
338-
streamInterface = g.QualifiedGoIdent(grpcPackage.Ident("BidiStreamClient")) + "[" + typeParam + "]"
339-
} else if method.Desc.IsStreamingClient() {
340-
streamInterface = g.QualifiedGoIdent(grpcPackage.Ident("ClientStreamClient")) + "[" + typeParam + "]"
341-
} else { // i.e. if method.Desc.IsStreamingServer()
342-
streamInterface = g.QualifiedGoIdent(grpcPackage.Ident("ServerStreamClient")) + "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "]"
343-
}
346+
streamImpl = g.QualifiedGoIdent(grpcPackage.Ident("StreamClientImpl")) + "[" + typeParam + "]"
344347
}
345348

346349
serviceDescVar := service.GoName + "_ServiceDesc"
347350
g.P("stream, err := c.cc.NewStream(ctx, &", serviceDescVar, ".Streams[", index, `], `, fmSymbol, `, cOpts...)`)
348351
g.P("if err != nil { return nil, err }")
349-
g.P("x := &", streamType, "{ClientStream: stream}")
352+
g.P("x := &", streamImpl, "{ClientStream: stream}")
350353
if !method.Desc.IsStreamingClient() {
351354
g.P("if err := x.ClientStream.SendMsg(in); err != nil { return nil, err }")
352355
g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }")
@@ -359,7 +362,7 @@ func genClientMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.Gene
359362
if *useGenericStreams {
360363
// Use a type alias so that the type name in the generated function
361364
// signature can remain identical even while we swap out the implementation.
362-
g.P("type ", service.GoName, "_", method.GoName, "Client = ", streamInterface)
365+
g.P("type ", service.GoName, "_", method.GoName, "Client = ", clientStreamInterface(g, method))
363366
g.P()
364367
return
365368
}
@@ -382,27 +385,27 @@ func genClientMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.Gene
382385
g.P("}")
383386
g.P()
384387

385-
g.P("type ", streamType, " struct {")
388+
g.P("type ", streamImpl, " struct {")
386389
g.P(grpcPackage.Ident("ClientStream"))
387390
g.P("}")
388391
g.P()
389392

390393
if genSend {
391-
g.P("func (x *", streamType, ") Send(m *", method.Input.GoIdent, ") error {")
394+
g.P("func (x *", streamImpl, ") Send(m *", method.Input.GoIdent, ") error {")
392395
g.P("return x.ClientStream.SendMsg(m)")
393396
g.P("}")
394397
g.P()
395398
}
396399
if genRecv {
397-
g.P("func (x *", streamType, ") Recv() (*", method.Output.GoIdent, ", error) {")
400+
g.P("func (x *", streamImpl, ") Recv() (*", method.Output.GoIdent, ", error) {")
398401
g.P("m := new(", method.Output.GoIdent, ")")
399402
g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }")
400403
g.P("return m, nil")
401404
g.P("}")
402405
g.P()
403406
}
404407
if genCloseAndRecv {
405-
g.P("func (x *", streamType, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {")
408+
g.P("func (x *", streamImpl, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {")
406409
g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }")
407410
g.P("m := new(", method.Output.GoIdent, ")")
408411
g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }")
@@ -469,6 +472,17 @@ func genServiceDesc(file *protogen.File, g *protogen.GeneratedFile, serviceDescV
469472
g.P()
470473
}
471474

475+
func serverStreamInterface(g *protogen.GeneratedFile, method *protogen.Method) string {
476+
typeParam := g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent)
477+
if method.Desc.IsStreamingClient() && method.Desc.IsStreamingServer() {
478+
return g.QualifiedGoIdent(grpcPackage.Ident("BidiStreamServer")) + "[" + typeParam + "]"
479+
} else if method.Desc.IsStreamingClient() {
480+
return g.QualifiedGoIdent(grpcPackage.Ident("ClientStreamServer")) + "[" + typeParam + "]"
481+
} else { // i.e. if method.Desc.IsStreamingServer()
482+
return g.QualifiedGoIdent(grpcPackage.Ident("ServerStreamServer")) + "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "]"
483+
}
484+
}
485+
472486
func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method, hnameFuncNameFormatter func(string) string) string {
473487
service := method.Parent
474488
hname := fmt.Sprintf("_%s_%s_Handler", service.GoName, method.GoName)
@@ -492,27 +506,19 @@ func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.Gene
492506
return hname
493507
}
494508

495-
streamType := unexport(service.GoName) + method.GoName + "Server"
496-
var streamInterface string
509+
streamImpl := unexport(service.GoName) + method.GoName + "Server"
497510
if *useGenericStreams {
498511
typeParam := g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent)
499-
streamType = g.QualifiedGoIdent(grpcPackage.Ident("StreamServerImpl")) + "[" + typeParam + "]"
500-
if method.Desc.IsStreamingClient() && method.Desc.IsStreamingServer() {
501-
streamInterface = g.QualifiedGoIdent(grpcPackage.Ident("BidiStreamServer")) + "[" + typeParam + "]"
502-
} else if method.Desc.IsStreamingClient() {
503-
streamInterface = g.QualifiedGoIdent(grpcPackage.Ident("ClientStreamServer")) + "[" + typeParam + "]"
504-
} else { // i.e. if method.Desc.IsStreamingServer()
505-
streamInterface = g.QualifiedGoIdent(grpcPackage.Ident("ServerStreamServer")) + "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "]"
506-
}
512+
streamImpl = g.QualifiedGoIdent(grpcPackage.Ident("StreamServerImpl")) + "[" + typeParam + "]"
507513
}
508514

509515
g.P("func ", hnameFuncNameFormatter(hname), "(srv interface{}, stream ", grpcPackage.Ident("ServerStream"), ") error {")
510516
if !method.Desc.IsStreamingClient() {
511517
g.P("m := new(", method.Input.GoIdent, ")")
512518
g.P("if err := stream.RecvMsg(m); err != nil { return err }")
513-
g.P("return srv.(", service.GoName, "Server).", method.GoName, "(m, &", streamType, "{ServerStream: stream})")
519+
g.P("return srv.(", service.GoName, "Server).", method.GoName, "(m, &", streamImpl, "{ServerStream: stream})")
514520
} else {
515-
g.P("return srv.(", service.GoName, "Server).", method.GoName, "(&", streamType, "{ServerStream: stream})")
521+
g.P("return srv.(", service.GoName, "Server).", method.GoName, "(&", streamImpl, "{ServerStream: stream})")
516522
}
517523
g.P("}")
518524
g.P()
@@ -521,7 +527,7 @@ func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.Gene
521527
if *useGenericStreams {
522528
// Use a type alias so that the type name in the generated function
523529
// signature can remain identical even while we swap out the implementation.
524-
g.P("type ", service.GoName, "_", method.GoName, "Server = ", streamInterface)
530+
g.P("type ", service.GoName, "_", method.GoName, "Server = ", serverStreamInterface(g, method))
525531
g.P()
526532
return hname
527533
}
@@ -544,25 +550,25 @@ func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.Gene
544550
g.P("}")
545551
g.P()
546552

547-
g.P("type ", streamType, " struct {")
553+
g.P("type ", streamImpl, " struct {")
548554
g.P(grpcPackage.Ident("ServerStream"))
549555
g.P("}")
550556
g.P()
551557

552558
if genSend {
553-
g.P("func (x *", streamType, ") Send(m *", method.Output.GoIdent, ") error {")
559+
g.P("func (x *", streamImpl, ") Send(m *", method.Output.GoIdent, ") error {")
554560
g.P("return x.ServerStream.SendMsg(m)")
555561
g.P("}")
556562
g.P()
557563
}
558564
if genSendAndClose {
559-
g.P("func (x *", streamType, ") SendAndClose(m *", method.Output.GoIdent, ") error {")
565+
g.P("func (x *", streamImpl, ") SendAndClose(m *", method.Output.GoIdent, ") error {")
560566
g.P("return x.ServerStream.SendMsg(m)")
561567
g.P("}")
562568
g.P()
563569
}
564570
if genRecv {
565-
g.P("func (x *", streamType, ") Recv() (*", method.Input.GoIdent, ", error) {")
571+
g.P("func (x *", streamImpl, ") Recv() (*", method.Input.GoIdent, ", error) {")
566572
g.P("m := new(", method.Input.GoIdent, ")")
567573
g.P("if err := x.ServerStream.RecvMsg(m); err != nil { return nil, err }")
568574
g.P("return m, nil")

0 commit comments

Comments
 (0)