diff --git a/engine/access/rpc/connection/connection_test.go b/engine/access/rpc/connection/connection_test.go index a3e1ee3988c..21595fa8a64 100644 --- a/engine/access/rpc/connection/connection_test.go +++ b/engine/access/rpc/connection/connection_test.go @@ -802,7 +802,15 @@ func setupGRPCServer(t *testing.T) *grpc.ClientConn { return conn } -// TestCircuitBreakerExecutionNode tests the circuit breaker state changes for execution nodes. +var successCodes = []codes.Code{ + codes.Canceled, + codes.InvalidArgument, + codes.NotFound, + codes.Unimplemented, + codes.OutOfRange, +} + +// TestCircuitBreakerExecutionNode tests the circuit breaker for execution nodes. func TestCircuitBreakerExecutionNode(t *testing.T) { requestTimeout := 500 * time.Millisecond circuitBreakerRestoreTimeout := 1500 * time.Millisecond @@ -812,11 +820,6 @@ func TestCircuitBreakerExecutionNode(t *testing.T) { en.start(t) defer en.stop(t) - // Set up the handler mock to not respond within the requestTimeout. - req := &execution.PingRequest{} - resp := &execution.PingResponse{} - en.handler.On("Ping", testifymock.Anything, req).After(2*requestTimeout).Return(resp, nil) - // Create the connection factory. connectionFactory := new(ConnectionFactoryImpl) @@ -852,10 +855,11 @@ func TestCircuitBreakerExecutionNode(t *testing.T) { client, _, err := connectionFactory.GetExecutionAPIClient(en.listener.Addr().String()) require.NoError(t, err) - ctx := context.Background() + req := &execution.PingRequest{} + resp := &execution.PingResponse{} // Helper function to make the Ping call to the execution node and measure the duration. - callAndMeasurePingDuration := func() (time.Duration, error) { + callAndMeasurePingDuration := func(ctx context.Context) (time.Duration, error) { start := time.Now() // Make the call to the execution node. @@ -865,30 +869,51 @@ func TestCircuitBreakerExecutionNode(t *testing.T) { return time.Since(start), err } - // Call and measure the duration for the first invocation. - duration, err := callAndMeasurePingDuration() - assert.Equal(t, codes.DeadlineExceeded, status.Code(err)) - assert.LessOrEqual(t, requestTimeout, duration) + t.Run("test different states of the circuit breaker", func(t *testing.T) { + ctx := context.Background() + + // Set up the handler mock to not respond within the requestTimeout. + en.handler.On("Ping", testifymock.Anything, req).After(2*requestTimeout).Return(resp, nil) + + // Call and measure the duration for the first invocation. + duration, err := callAndMeasurePingDuration(ctx) + assert.Equal(t, codes.DeadlineExceeded, status.Code(err)) + assert.LessOrEqual(t, requestTimeout, duration) - // Call and measure the duration for the second invocation (circuit breaker state is now "Open"). - duration, err = callAndMeasurePingDuration() - assert.Equal(t, gobreaker.ErrOpenState, err) - assert.Greater(t, requestTimeout, duration) + // Call and measure the duration for the second invocation (circuit breaker state is now "Open"). + duration, err = callAndMeasurePingDuration(ctx) + assert.Equal(t, gobreaker.ErrOpenState, err) + assert.Greater(t, requestTimeout, duration) + + // Reset the mock Ping for the next invocation to return response without delay + en.handler.On("Ping", testifymock.Anything, req).Unset() + en.handler.On("Ping", testifymock.Anything, req).Return(resp, nil) + + // Wait until the circuit breaker transitions to the "HalfOpen" state. + time.Sleep(circuitBreakerRestoreTimeout + (500 * time.Millisecond)) + + // Call and measure the duration for the third invocation (circuit breaker state is now "HalfOpen"). + duration, err = callAndMeasurePingDuration(ctx) + assert.Greater(t, requestTimeout, duration) + assert.Equal(t, nil, err) + }) - // Reset the mock Ping for the next invocation to return response without delay - en.handler.On("Ping", testifymock.Anything, req).Unset() - en.handler.On("Ping", testifymock.Anything, req).Return(resp, nil) + for _, code := range successCodes { + t.Run(fmt.Sprintf("test error %s treated as a success for circuit breaker ", code.String()), func(t *testing.T) { + ctx := context.Background() - // Wait until the circuit breaker transitions to the "HalfOpen" state. - time.Sleep(circuitBreakerRestoreTimeout + (500 * time.Millisecond)) + en.handler.On("Ping", testifymock.Anything, req).Unset() + en.handler.On("Ping", testifymock.Anything, req).Return(nil, status.Error(code, code.String())) - // Call and measure the duration for the third invocation (circuit breaker state is now "HalfOpen"). - duration, err = callAndMeasurePingDuration() - assert.Greater(t, requestTimeout, duration) - assert.Equal(t, nil, err) + duration, err := callAndMeasurePingDuration(ctx) + require.Error(t, err) + require.Equal(t, code, status.Code(err)) + require.Greater(t, requestTimeout, duration) + }) + } } -// TestCircuitBreakerCollectionNode tests the circuit breaker state changes for collection nodes. +// TestCircuitBreakerCollectionNode tests the circuit breaker for collection nodes. func TestCircuitBreakerCollectionNode(t *testing.T) { requestTimeout := 500 * time.Millisecond circuitBreakerRestoreTimeout := 1500 * time.Millisecond @@ -898,11 +923,6 @@ func TestCircuitBreakerCollectionNode(t *testing.T) { cn.start(t) defer cn.stop(t) - // Set up the handler mock to not respond within the requestTimeout. - req := &access.PingRequest{} - resp := &access.PingResponse{} - cn.handler.On("Ping", testifymock.Anything, req).After(2*requestTimeout).Return(resp, nil) - // Create the connection factory. connectionFactory := new(ConnectionFactoryImpl) @@ -938,10 +958,11 @@ func TestCircuitBreakerCollectionNode(t *testing.T) { client, _, err := connectionFactory.GetAccessAPIClient(cn.listener.Addr().String()) assert.NoError(t, err) - ctx := context.Background() + req := &access.PingRequest{} + resp := &access.PingResponse{} // Helper function to make the Ping call to the collection node and measure the duration. - callAndMeasurePingDuration := func() (time.Duration, error) { + callAndMeasurePingDuration := func(ctx context.Context) (time.Duration, error) { start := time.Now() // Make the call to the collection node. @@ -951,25 +972,46 @@ func TestCircuitBreakerCollectionNode(t *testing.T) { return time.Since(start), err } - // Call and measure the duration for the first invocation. - duration, err := callAndMeasurePingDuration() - assert.Equal(t, codes.DeadlineExceeded, status.Code(err)) - assert.LessOrEqual(t, requestTimeout, duration) + t.Run("test different states of the circuit breaker", func(t *testing.T) { + ctx := context.Background() + + // Set up the handler mock to not respond within the requestTimeout. + cn.handler.On("Ping", testifymock.Anything, req).After(2*requestTimeout).Return(resp, nil) + + // Call and measure the duration for the first invocation. + duration, err := callAndMeasurePingDuration(ctx) + assert.Equal(t, codes.DeadlineExceeded, status.Code(err)) + assert.LessOrEqual(t, requestTimeout, duration) + + // Call and measure the duration for the second invocation (circuit breaker state is now "Open"). + duration, err = callAndMeasurePingDuration(ctx) + assert.Equal(t, gobreaker.ErrOpenState, err) + assert.Greater(t, requestTimeout, duration) - // Call and measure the duration for the second invocation (circuit breaker state is now "Open"). - duration, err = callAndMeasurePingDuration() - assert.Equal(t, gobreaker.ErrOpenState, err) - assert.Greater(t, requestTimeout, duration) + // Reset the mock Ping for the next invocation to return response without delay + cn.handler.On("Ping", testifymock.Anything, req).Unset() + cn.handler.On("Ping", testifymock.Anything, req).Return(resp, nil) - // Reset the mock Ping for the next invocation to return response without delay - cn.handler.On("Ping", testifymock.Anything, req).Unset() - cn.handler.On("Ping", testifymock.Anything, req).Return(resp, nil) + // Wait until the circuit breaker transitions to the "HalfOpen" state. + time.Sleep(circuitBreakerRestoreTimeout + (500 * time.Millisecond)) - // Wait until the circuit breaker transitions to the "HalfOpen" state. - time.Sleep(circuitBreakerRestoreTimeout + (500 * time.Millisecond)) + // Call and measure the duration for the third invocation (circuit breaker state is now "HalfOpen"). + duration, err = callAndMeasurePingDuration(ctx) + assert.Greater(t, requestTimeout, duration) + assert.Equal(t, nil, err) + }) + + for _, code := range successCodes { + t.Run(fmt.Sprintf("test error %s treated as a success for circuit breaker ", code.String()), func(t *testing.T) { + ctx := context.Background() - // Call and measure the duration for the third invocation (circuit breaker state is now "HalfOpen"). - duration, err = callAndMeasurePingDuration() - assert.Greater(t, requestTimeout, duration) - assert.Equal(t, nil, err) + cn.handler.On("Ping", testifymock.Anything, req).Unset() + cn.handler.On("Ping", testifymock.Anything, req).Return(nil, status.Error(code, code.String())) + + duration, err := callAndMeasurePingDuration(ctx) + require.Error(t, err) + require.Equal(t, code, status.Code(err)) + require.Greater(t, requestTimeout, duration) + }) + } } diff --git a/engine/access/rpc/connection/manager.go b/engine/access/rpc/connection/manager.go index c50a9026748..b9cfc792620 100644 --- a/engine/access/rpc/connection/manager.go +++ b/engine/access/rpc/connection/manager.go @@ -368,6 +368,25 @@ func (m *Manager) createCircuitBreakerInterceptor() grpc.UnaryClientInterceptor // MaxRequests defines the max number of concurrent requests while the circuit breaker is in the HalfClosed // state. MaxRequests: m.circuitBreakerConfig.MaxRequests, + // IsSuccessful defines gRPC status codes that should be treated as a successful result for the circuit breaker. + IsSuccessful: func(err error) bool { + if se, ok := status.FromError(err); ok { + if se == nil { + return true + } + + // There are several error cases that may occur during normal operation and should be considered + // as "successful" from the perspective of the circuit breaker. + switch se.Code() { + case codes.OK, codes.Canceled, codes.InvalidArgument, codes.NotFound, codes.Unimplemented, codes.OutOfRange: + return true + default: + return false + } + } + + return false + }, }) circuitBreakerInterceptor := func(