Skip to content

Commit

Permalink
Small fixups in preparation to add one-to-many proxying.
Browse files Browse the repository at this point in the history
  • Loading branch information
smira committed Nov 19, 2019
1 parent 6d76ffc commit e3111ef
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 95 deletions.
2 changes: 1 addition & 1 deletion proxy/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func ExampleRegisterService() {
server := grpc.NewServer(grpc.CustomCodec(proxy.Codec()))
// Register a TestService with 4 of its methods explicitly.
proxy.RegisterService(server, director,
"mwitkow.testproto.TestService",
"smira.testproto.TestService",
"PingEmpty", "Ping", "PingError", "PingList")
}

Expand Down
43 changes: 23 additions & 20 deletions proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ type backendConnection struct {
connError error

clientStream grpc.ClientStream

clientCtx context.Context
clientCancel context.CancelFunc
}

// handler is where the real magic of proxying happens.
Expand All @@ -89,20 +86,21 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error
var establishedConnections int
backendConnections := make([]backendConnection, len(backends))

clientCtx, clientCancel := context.WithCancel(serverStream.Context())
defer clientCancel()

for i := range backends {
backendConnections[i].backend = backends[i]

//We require that the backend's returned context inherits from the serverStream.Context().
// We require that the backend's returned context inherits from the serverStream.Context().
var outgoingCtx context.Context
outgoingCtx, backendConnections[i].backendConn, backendConnections[i].connError = backends[i].GetConnection(serverStream.Context())
outgoingCtx, backendConnections[i].backendConn, backendConnections[i].connError = backends[i].GetConnection(clientCtx)

if backendConnections[i].connError != nil {
continue
}

backendConnections[i].clientCtx, backendConnections[i].clientCancel = context.WithCancel(outgoingCtx)

backendConnections[i].clientStream, backendConnections[i].connError = grpc.NewClientStream(backendConnections[i].clientCtx, clientStreamDescForProxying,
backendConnections[i].clientStream, backendConnections[i].connError = grpc.NewClientStream(outgoingCtx, clientStreamDescForProxying,
backendConnections[i].backendConn, fullMethodName)

if backendConnections[i].connError != nil {
Expand All @@ -121,14 +119,11 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error
return backendConnections[0].connError
}

clientStream := backendConnections[0].clientStream
defer backendConnections[0].clientCancel()

// Explicitly *do not close* s2cErrChan and c2sErrChan, otherwise the select below will not terminate.
// Channels do not have to be closed, it is just a control flow mechanism, see
// https://groups.google.com/forum/#!msg/golang-nuts/pZwdYRGxCIk/qpbHxRRPJdUJ
s2cErrChan := s.forwardServerToClient(serverStream, clientStream)
c2sErrChan := s.forwardClientToServer(clientStream, serverStream)
s2cErrChan := s.forwardServerToClient(serverStream, &backendConnections[0])
c2sErrChan := s.forwardClientToServer(&backendConnections[0], serverStream)
// We don't know which side is going to stop sending first, so we need a select between the two.
for i := 0; i < 2; i++ {
select {
Expand All @@ -137,7 +132,7 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error
// this is the happy case where the sender has encountered io.EOF, and won't be sending anymore./
// the clientStream>serverStream may continue pumping though.
//nolint: errcheck
clientStream.CloseSend()
backendConnections[0].clientStream.CloseSend()
break
} else {
// however, we may have gotten a receive error (stream disconnected, a read error etc) in which case we need
Expand All @@ -149,7 +144,7 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error
// This happens when the clientStream has nothing else to offer (io.EOF), returned a gRPC error. In those two
// cases we may have received Trailers as part of the call. In case of other errors (stream closed) the trailers
// will be nil.
serverStream.SetTrailer(clientStream.Trailer())
serverStream.SetTrailer(backendConnections[0].clientStream.Trailer())
// c2sErr will contain RPC error from client code. If not io.EOF return the RPC error as server stream error.
if c2sErr != io.EOF {
return c2sErr
Expand All @@ -160,20 +155,28 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error
return status.Errorf(codes.Internal, "gRPC proxying should never reach this stage.")
}

func (s *handler) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerStream) chan error {
func (s *handler) forwardClientToServer(src *backendConnection, dst grpc.ServerStream) chan error {
ret := make(chan error, 1)
go func() {
f := &frame{}
for i := 0; ; i++ {
if err := src.RecvMsg(f); err != nil {
if err := src.clientStream.RecvMsg(f); err != nil {
ret <- err // this can be io.EOF which is happy case
break
}

var err error
f.payload, err = src.backend.AppendInfo(f.payload)
if err != nil {
ret <- err
break
}

if i == 0 {
// This is a bit of a hack, but client to server headers are only readable after first client msg is
// received but must be written to server stream before the first msg is flushed.
// This is the only place to do it nicely.
md, err := src.Header()
md, err := src.clientStream.Header()
if err != nil {
ret <- err
break
Expand All @@ -192,7 +195,7 @@ func (s *handler) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerSt
return ret
}

func (s *handler) forwardServerToClient(src grpc.ServerStream, dst grpc.ClientStream) chan error {
func (s *handler) forwardServerToClient(src grpc.ServerStream, dst *backendConnection) chan error {
ret := make(chan error, 1)
go func() {
f := &frame{}
Expand All @@ -201,7 +204,7 @@ func (s *handler) forwardServerToClient(src grpc.ServerStream, dst grpc.ClientSt
ret <- err // this can be io.EOF which is happy case
break
}
if err := dst.SendMsg(f); err != nil {
if err := dst.clientStream.SendMsg(f); err != nil {
ret <- err
break
}
Expand Down
2 changes: 1 addition & 1 deletion proxy/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func (s *ProxyHappySuite) SetupSuite() {
)
// Ping handler is handled as an explicit registration and not as a TransparentHandler.
proxy.RegisterService(s.proxy, director,
"mwitkow.testproto.TestService",
"smira.testproto.TestService",
"Ping")

// Start the serving loops.
Expand Down
Loading

0 comments on commit e3111ef

Please sign in to comment.