Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
YoshiyukiMineo committed Nov 16, 2024
1 parent 11d03b2 commit d6880cf
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 56 deletions.
62 changes: 17 additions & 45 deletions v2/distributed_gobreaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@ package gobreaker

import (
"context"
"encoding/json"
"fmt"
"time"
)

// SharedState represents the CircuitBreaker state stored in Distributed Storage
type SharedState struct {
State State `json:"state"`
Generation uint64 `json:"generation"`
Counts Counts `json:"counts"`
Expiry time.Time `json:"expiry"`
}

type SharedStateStore interface {
GetState(ctx context.Context, key string) ([]byte, error)
SetState(ctx context.Context, key string, value interface{}, expiration time.Duration) error
GetState(ctx context.Context) (SharedState, error)
SetState(ctx context.Context, state SharedState) error
}

// DistributedCircuitBreaker extends CircuitBreaker with distributed state storage
Expand All @@ -27,20 +34,12 @@ func NewDistributedCircuitBreaker[T any](storageClient SharedStateStore, setting
}
}

// SharedState represents the CircuitBreaker state stored in Distributed Storage
type SharedState struct {
State State `json:"state"`
Generation uint64 `json:"generation"`
Counts Counts `json:"counts"`
Expiry time.Time `json:"expiry"`
}

func (rcb *DistributedCircuitBreaker[T]) State(ctx context.Context) State {
if rcb.cacheClient == nil {
return rcb.CircuitBreaker.State()
}

state, err := rcb.getStoredState(ctx)
state, err := rcb.cacheClient.GetState(ctx)
if err != nil {
// Fallback to in-memory state if Storage fails
return rcb.CircuitBreaker.State()
Expand All @@ -52,7 +51,7 @@ func (rcb *DistributedCircuitBreaker[T]) State(ctx context.Context) State {
// Update the state in Storage if it has changed
if currentState != state.State {
state.State = currentState
if err := rcb.setStoredState(ctx, state); err != nil {
if err := rcb.cacheClient.SetState(ctx, state); err != nil {
// Log the error, but continue with the current state
fmt.Printf("Failed to update state in storage: %v\n", err)
}
Expand Down Expand Up @@ -87,7 +86,7 @@ func (rcb *DistributedCircuitBreaker[T]) Execute(ctx context.Context, req func()
}

func (rcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uint64, error) {
state, err := rcb.getStoredState(ctx)
state, err := rcb.cacheClient.GetState(ctx)
if err != nil {
return 0, err
}
Expand All @@ -96,7 +95,7 @@ func (rcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uin

if currentState != state.State {
rcb.setState(&state, currentState, now)
err = rcb.setStoredState(ctx, state)
err = rcb.cacheClient.SetState(ctx, state)
if err != nil {
return 0, err
}
Expand All @@ -109,7 +108,7 @@ func (rcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uin
}

state.Counts.onRequest()
err = rcb.setStoredState(ctx, state)
err = rcb.cacheClient.SetState(ctx, state)
if err != nil {
return 0, err
}
Expand All @@ -118,7 +117,7 @@ func (rcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uin
}

func (rcb *DistributedCircuitBreaker[T]) afterRequest(ctx context.Context, before uint64, success bool) {
state, err := rcb.getStoredState(ctx)
state, err := rcb.cacheClient.GetState(ctx)
if err != nil {
return
}
Expand All @@ -134,7 +133,7 @@ func (rcb *DistributedCircuitBreaker[T]) afterRequest(ctx context.Context, befor
rcb.onFailure(&state, currentState, now)
}

rcb.setStoredState(ctx, state)
rcb.cacheClient.SetState(ctx, state)
}

func (rcb *DistributedCircuitBreaker[T]) onSuccess(state *SharedState, currentState State, now time.Time) {
Expand Down Expand Up @@ -213,30 +212,3 @@ func (rcb *DistributedCircuitBreaker[T]) toNewGeneration(state *SharedState, now
state.Expiry = zero
}
}

func (rcb *DistributedCircuitBreaker[T]) getStorageKey() string {
return "cb:" + rcb.name
}

func (rcb *DistributedCircuitBreaker[T]) getStoredState(ctx context.Context) (SharedState, error) {
var state SharedState
data, err := rcb.cacheClient.GetState(ctx, rcb.getStorageKey())
if len(data) == 0 {
// Key doesn't exist, return default state
return SharedState{State: StateClosed}, nil
} else if err != nil {
return state, err
}

err = json.Unmarshal(data, &state)
return state, err
}

func (rcb *DistributedCircuitBreaker[T]) setStoredState(ctx context.Context, state SharedState) error {
data, err := json.Marshal(state)
if err != nil {
return err
}

return rcb.cacheClient.SetState(ctx, rcb.getStorageKey(), data, 0)
}
38 changes: 27 additions & 11 deletions v2/distributed_gobreaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gobreaker

import (
"context"
"encoding/json"
"errors"
"testing"
"time"
Expand All @@ -18,12 +19,27 @@ type storageAdapter struct {
client *redis.Client
}

func (r *storageAdapter) GetState(ctx context.Context, key string) ([]byte, error) {
return r.client.Get(ctx, key).Bytes()
func (r *storageAdapter) GetState(ctx context.Context) (SharedState, error) {
var state SharedState
data, err := r.client.Get(ctx, "gobreaker").Bytes()
if len(data) == 0 {
// Key doesn't exist, return default state
return SharedState{State: StateClosed}, nil
} else if err != nil {
return state, err
}

err = json.Unmarshal(data, &state)
return state, err
}

func (r *storageAdapter) SetState(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
return r.client.Set(ctx, key, value, expiration).Err()
func (r *storageAdapter) SetState(ctx context.Context, state SharedState) error {
data, err := json.Marshal(state)
if err != nil {
return err
}

return r.client.Set(ctx, "gobreaker", data, 0).Err()
}

func setupTestWithMiniredis() (*DistributedCircuitBreaker[any], *miniredis.Miniredis, *redis.Client) {
Expand All @@ -50,14 +66,14 @@ func setupTestWithMiniredis() (*DistributedCircuitBreaker[any], *miniredis.Minir
}

func pseudoSleepStorage(ctx context.Context, rcb *DistributedCircuitBreaker[any], period time.Duration) {
state, _ := rcb.getStoredState(ctx)
state, _ := rcb.cacheClient.GetState(ctx)

state.Expiry = state.Expiry.Add(-period)
// Reset counts if the interval has passed
if time.Now().After(state.Expiry) {
state.Counts = Counts{}
}
rcb.setStoredState(ctx, state)
rcb.cacheClient.SetState(ctx, state)
}

func successRequest(ctx context.Context, rcb *DistributedCircuitBreaker[any]) error {
Expand Down Expand Up @@ -158,11 +174,11 @@ func TestDistributedCircuitBreakerCounts(t *testing.T) {
assert.Nil(t, successRequest(ctx, rcb))
}

state, _ := rcb.getStoredState(ctx)
state, _ := rcb.cacheClient.GetState(ctx)
assert.Equal(t, Counts{5, 5, 0, 5, 0}, state.Counts)

assert.Nil(t, failRequest(ctx, rcb))
state, _ = rcb.getStoredState(ctx)
state, _ = rcb.cacheClient.GetState(ctx)
assert.Equal(t, Counts{6, 5, 1, 0, 1}, state.Counts)
}

Expand Down Expand Up @@ -224,14 +240,14 @@ func TestCustomDistributedCircuitBreaker(t *testing.T) {
assert.NoError(t, failRequest(ctx, customRCB))
}

state, err := customRCB.getStoredState(ctx)
state, err := customRCB.cacheClient.GetState(ctx)
assert.NoError(t, err)
assert.Equal(t, StateClosed, state.State)
assert.Equal(t, Counts{10, 5, 5, 0, 1}, state.Counts)

// Perform one more successful request
assert.NoError(t, successRequest(ctx, customRCB))
state, err = customRCB.getStoredState(ctx)
state, err = customRCB.cacheClient.GetState(ctx)
assert.NoError(t, err)
assert.Equal(t, Counts{11, 6, 5, 1, 0}, state.Counts)

Expand All @@ -246,7 +262,7 @@ func TestCustomDistributedCircuitBreaker(t *testing.T) {
// Check if the circuit breaker is now open
assert.Equal(t, StateOpen, customRCB.State(ctx))

state, err = customRCB.getStoredState(ctx)
state, err = customRCB.cacheClient.GetState(ctx)
assert.NoError(t, err)
assert.Equal(t, Counts{0, 0, 0, 0, 0}, state.Counts)
})
Expand Down

0 comments on commit d6880cf

Please sign in to comment.