diff --git a/v2/distributed_gobreaker.go b/v2/distributed_gobreaker.go index fd3699f..e790647 100644 --- a/v2/distributed_gobreaker.go +++ b/v2/distributed_gobreaker.go @@ -1,7 +1,6 @@ package gobreaker import ( - "context" "encoding/json" "errors" "time" @@ -24,8 +23,8 @@ type SharedState struct { // SharedDataStore stores the shared state of DistributedCircuitBreaker. type SharedDataStore interface { - GetData(ctx context.Context, name string) ([]byte, error) - SetData(ctx context.Context, name string, data []byte) error + GetData(name string) ([]byte, error) + SetData(name string, data []byte) error } // DistributedCircuitBreaker extends CircuitBreaker with SharedDataStore. @@ -35,7 +34,7 @@ type DistributedCircuitBreaker[T any] struct { } // NewDistributedCircuitBreaker returns a new DistributedCircuitBreaker. -func NewDistributedCircuitBreaker[T any](ctx context.Context, store SharedDataStore, settings Settings) (*DistributedCircuitBreaker[T], error) { +func NewDistributedCircuitBreaker[T any](store SharedDataStore, settings Settings) (*DistributedCircuitBreaker[T], error) { if store == nil { return nil, ErrNoSharedStore } @@ -45,9 +44,9 @@ func NewDistributedCircuitBreaker[T any](ctx context.Context, store SharedDataSt store: store, } - _, err := dcb.getSharedState(ctx) + _, err := dcb.getSharedState() if err == ErrNoSharedState { - err = dcb.setSharedState(ctx, dcb.extract()) + err = dcb.setSharedState(dcb.extract()) } if err != nil { return nil, err @@ -60,13 +59,13 @@ func (dcb *DistributedCircuitBreaker[T]) sharedStateKey() string { return "gobreaker:" + dcb.name } -func (dcb *DistributedCircuitBreaker[T]) getSharedState(ctx context.Context) (SharedState, error) { +func (dcb *DistributedCircuitBreaker[T]) getSharedState() (SharedState, error) { var state SharedState if dcb.store == nil { return state, ErrNoSharedStore } - data, err := dcb.store.GetData(ctx, dcb.sharedStateKey()) + data, err := dcb.store.GetData(dcb.sharedStateKey()) if len(data) == 0 { return state, ErrNoSharedState } else if err != nil { @@ -77,7 +76,7 @@ func (dcb *DistributedCircuitBreaker[T]) getSharedState(ctx context.Context) (Sh return state, err } -func (dcb *DistributedCircuitBreaker[T]) setSharedState(ctx context.Context, state SharedState) error { +func (dcb *DistributedCircuitBreaker[T]) setSharedState(state SharedState) error { if dcb.store == nil { return ErrNoSharedStore } @@ -87,7 +86,7 @@ func (dcb *DistributedCircuitBreaker[T]) setSharedState(ctx context.Context, sta return err } - return dcb.store.SetData(ctx, dcb.sharedStateKey(), data) + return dcb.store.SetData(dcb.sharedStateKey(), data) } func (dcb *DistributedCircuitBreaker[T]) inject(shared SharedState) { @@ -113,8 +112,8 @@ func (dcb *DistributedCircuitBreaker[T]) extract() SharedState { } // State returns the State of DistributedCircuitBreaker. -func (dcb *DistributedCircuitBreaker[T]) State(ctx context.Context) (State, error) { - shared, err := dcb.getSharedState(ctx) +func (dcb *DistributedCircuitBreaker[T]) State() (State, error) { + shared, err := dcb.getSharedState() if err != nil { return shared.State, err } @@ -123,13 +122,13 @@ func (dcb *DistributedCircuitBreaker[T]) State(ctx context.Context) (State, erro state := dcb.CircuitBreaker.State() shared = dcb.extract() - err = dcb.setSharedState(ctx, shared) + err = dcb.setSharedState(shared) return state, err } // Execute runs the given request if the DistributedCircuitBreaker accepts it. -func (dcb *DistributedCircuitBreaker[T]) Execute(ctx context.Context, req func() (T, error)) (T, error) { - shared, err := dcb.getSharedState(ctx) +func (dcb *DistributedCircuitBreaker[T]) Execute(req func() (T, error)) (T, error) { + shared, err := dcb.getSharedState() if err != nil { var defaultValue T return defaultValue, err @@ -139,7 +138,7 @@ func (dcb *DistributedCircuitBreaker[T]) Execute(ctx context.Context, req func() t, e := dcb.CircuitBreaker.Execute(req) shared = dcb.extract() - err = dcb.setSharedState(ctx, shared) + err = dcb.setSharedState(shared) if err != nil { var defaultValue T return defaultValue, err diff --git a/v2/distributed_gobreaker_test.go b/v2/distributed_gobreaker_test.go index b194fa1..28a9d43 100644 --- a/v2/distributed_gobreaker_test.go +++ b/v2/distributed_gobreaker_test.go @@ -12,20 +12,21 @@ import ( ) type storeAdapter struct { + ctx context.Context client *redis.Client } -func (r *storeAdapter) GetData(ctx context.Context, key string) ([]byte, error) { - return r.client.Get(ctx, key).Bytes() +func (sa *storeAdapter) GetData(key string) ([]byte, error) { + return sa.client.Get(sa.ctx, key).Bytes() } -func (r *storeAdapter) SetData(ctx context.Context, key string, value []byte) error { - return r.client.Set(ctx, key, value, 0).Err() +func (sa *storeAdapter) SetData(key string, value []byte) error { + return sa.client.Set(sa.ctx, key, value, 0).Err() } var redisServer *miniredis.Miniredis -func setUpDCB(ctx context.Context) *DistributedCircuitBreaker[any] { +func setUpDCB() *DistributedCircuitBreaker[any] { var err error redisServer, err := miniredis.Run() if err != nil { @@ -36,9 +37,12 @@ func setUpDCB(ctx context.Context) *DistributedCircuitBreaker[any] { Addr: redisServer.Addr(), }) - store := &storeAdapter{client: client} + store := &storeAdapter{ + ctx: context.Background(), + client: client, + } - dcb, err := NewDistributedCircuitBreaker[any](ctx, store, Settings{ + dcb, err := NewDistributedCircuitBreaker[any](store, Settings{ Name: "TestBreaker", MaxRequests: 3, Interval: time.Second, @@ -67,8 +71,8 @@ func tearDownDCB(dcb *DistributedCircuitBreaker[any]) { } } -func dcbPseudoSleep(ctx context.Context, dcb *DistributedCircuitBreaker[any], period time.Duration) { - state, err := dcb.getSharedState(ctx) +func dcbPseudoSleep(dcb *DistributedCircuitBreaker[any], period time.Duration) { + state, err := dcb.getSharedState() if err != nil { panic(err) } @@ -79,34 +83,33 @@ func dcbPseudoSleep(ctx context.Context, dcb *DistributedCircuitBreaker[any], pe state.Counts = Counts{} } - err = dcb.setSharedState(ctx, state) + err = dcb.setSharedState(state) if err != nil { panic(err) } } -func successRequest(ctx context.Context, dcb *DistributedCircuitBreaker[any]) error { - _, err := dcb.Execute(ctx, func() (interface{}, error) { return nil, nil }) +func successRequest(dcb *DistributedCircuitBreaker[any]) error { + _, err := dcb.Execute(func() (interface{}, error) { return nil, nil }) return err } -func failRequest(ctx context.Context, dcb *DistributedCircuitBreaker[any]) error { - _, err := dcb.Execute(ctx, func() (interface{}, error) { return nil, errors.New("fail") }) +func failRequest(dcb *DistributedCircuitBreaker[any]) error { + _, err := dcb.Execute(func() (interface{}, error) { return nil, errors.New("fail") }) if err != nil && err.Error() == "fail" { return nil } return err } -func assertState(ctx context.Context, t *testing.T, dcb *DistributedCircuitBreaker[any], expected State) { - state, err := dcb.State(ctx) +func assertState(t *testing.T, dcb *DistributedCircuitBreaker[any], expected State) { + state, err := dcb.State() assert.Equal(t, expected, state) assert.NoError(t, err) } func TestDistributedCircuitBreakerInitialization(t *testing.T) { - ctx := context.Background() - dcb := setUpDCB(ctx) + dcb := setUpDCB() defer tearDownDCB(dcb) assert.Equal(t, "TestBreaker", dcb.Name()) @@ -115,58 +118,56 @@ func TestDistributedCircuitBreakerInitialization(t *testing.T) { assert.Equal(t, time.Second*2, dcb.timeout) assert.NotNil(t, dcb.readyToTrip) - assertState(ctx, t, dcb, StateClosed) + assertState(t, dcb, StateClosed) } func TestDistributedCircuitBreakerStateTransitions(t *testing.T) { - ctx := context.Background() - dcb := setUpDCB(ctx) + dcb := setUpDCB() defer tearDownDCB(dcb) // Check if initial state is closed - assertState(ctx, t, dcb, StateClosed) + assertState(t, dcb, StateClosed) // StateClosed to StateOpen for i := 0; i < 6; i++ { - assert.NoError(t, failRequest(ctx, dcb)) + assert.NoError(t, failRequest(dcb)) } - assertState(ctx, t, dcb, StateOpen) + assertState(t, dcb, StateOpen) // Ensure requests fail when the circuit is open - err := failRequest(ctx, dcb) + err := failRequest(dcb) assert.Equal(t, ErrOpenState, err) // Wait for timeout so that the state will move to half-open - dcbPseudoSleep(ctx, dcb, dcb.timeout) - assertState(ctx, t, dcb, StateHalfOpen) + dcbPseudoSleep(dcb, dcb.timeout) + assertState(t, dcb, StateHalfOpen) // StateHalfOpen to StateClosed for i := 0; i < int(dcb.maxRequests); i++ { - assert.NoError(t, successRequest(ctx, dcb)) + assert.NoError(t, successRequest(dcb)) } - assertState(ctx, t, dcb, StateClosed) + assertState(t, dcb, StateClosed) // StateClosed to StateOpen (again) for i := 0; i < 6; i++ { - assert.NoError(t, failRequest(ctx, dcb)) + assert.NoError(t, failRequest(dcb)) } - assertState(ctx, t, dcb, StateOpen) + assertState(t, dcb, StateOpen) } func TestDistributedCircuitBreakerExecution(t *testing.T) { - ctx := context.Background() - dcb := setUpDCB(ctx) + dcb := setUpDCB() defer tearDownDCB(dcb) // Test successful execution - result, err := dcb.Execute(ctx, func() (interface{}, error) { + result, err := dcb.Execute(func() (interface{}, error) { return "success", nil }) assert.NoError(t, err) assert.Equal(t, "success", result) // Test failed execution - _, err = dcb.Execute(ctx, func() (interface{}, error) { + _, err = dcb.Execute(func() (interface{}, error) { return nil, errors.New("test error") }) assert.Error(t, err) @@ -174,20 +175,19 @@ func TestDistributedCircuitBreakerExecution(t *testing.T) { } func TestDistributedCircuitBreakerCounts(t *testing.T) { - ctx := context.Background() - dcb := setUpDCB(ctx) + dcb := setUpDCB() defer tearDownDCB(dcb) for i := 0; i < 5; i++ { - assert.Nil(t, successRequest(ctx, dcb)) + assert.Nil(t, successRequest(dcb)) } - state, err := dcb.getSharedState(ctx) + state, err := dcb.getSharedState() assert.Equal(t, Counts{5, 5, 0, 5, 0}, state.Counts) assert.NoError(t, err) - assert.Nil(t, failRequest(ctx, dcb)) - state, err = dcb.getSharedState(ctx) + assert.Nil(t, failRequest(dcb)) + state, err = dcb.getSharedState() assert.Equal(t, Counts{6, 5, 1, 0, 1}, state.Counts) assert.NoError(t, err) } @@ -195,8 +195,6 @@ func TestDistributedCircuitBreakerCounts(t *testing.T) { var customDCB *DistributedCircuitBreaker[any] func TestCustomDistributedCircuitBreaker(t *testing.T) { - ctx := context.Background() - mr, err := miniredis.Run() if err != nil { panic(err) @@ -207,9 +205,12 @@ func TestCustomDistributedCircuitBreaker(t *testing.T) { Addr: mr.Addr(), }) - store := &storeAdapter{client: client} + store := &storeAdapter{ + ctx: context.Background(), + client: client, + } - customDCB, err = NewDistributedCircuitBreaker[any](ctx, store, Settings{ + customDCB, err = NewDistributedCircuitBreaker[any](store, Settings{ Name: "CustomBreaker", MaxRequests: 3, Interval: time.Second * 30, @@ -224,53 +225,53 @@ func TestCustomDistributedCircuitBreaker(t *testing.T) { t.Run("Initialization", func(t *testing.T) { assert.Equal(t, "CustomBreaker", customDCB.Name()) - assertState(ctx, t, customDCB, StateClosed) + assertState(t, customDCB, StateClosed) }) t.Run("Counts and State Transitions", func(t *testing.T) { // Perform 5 successful and 5 failed requests for i := 0; i < 5; i++ { - assert.NoError(t, successRequest(ctx, customDCB)) - assert.NoError(t, failRequest(ctx, customDCB)) + assert.NoError(t, successRequest(customDCB)) + assert.NoError(t, failRequest(customDCB)) } - state, err := customDCB.getSharedState(ctx) + state, err := customDCB.getSharedState() assert.NoError(t, err) assert.Equal(t, StateClosed, state.State) assert.Equal(t, Counts{10, 5, 5, 0, 1}, state.Counts) // Perform one more successful request - assert.NoError(t, successRequest(ctx, customDCB)) - state, err = customDCB.getSharedState(ctx) + assert.NoError(t, successRequest(customDCB)) + state, err = customDCB.getSharedState() assert.NoError(t, err) assert.Equal(t, Counts{11, 6, 5, 1, 0}, state.Counts) // Simulate time passing to reset counts - dcbPseudoSleep(ctx, customDCB, time.Second*30) + dcbPseudoSleep(customDCB, time.Second*30) // Perform requests to trigger StateOpen - assert.NoError(t, successRequest(ctx, customDCB)) - assert.NoError(t, failRequest(ctx, customDCB)) - assert.NoError(t, failRequest(ctx, customDCB)) + assert.NoError(t, successRequest(customDCB)) + assert.NoError(t, failRequest(customDCB)) + assert.NoError(t, failRequest(customDCB)) // Check if the circuit breaker is now open - assertState(ctx, t, customDCB, StateOpen) + assertState(t, customDCB, StateOpen) - state, err = customDCB.getSharedState(ctx) + state, err = customDCB.getSharedState() assert.NoError(t, err) assert.Equal(t, Counts{0, 0, 0, 0, 0}, state.Counts) }) t.Run("Timeout and Half-Open State", func(t *testing.T) { // Simulate timeout to transition to half-open state - dcbPseudoSleep(ctx, customDCB, time.Second*90) - assertState(ctx, t, customDCB, StateHalfOpen) + dcbPseudoSleep(customDCB, time.Second*90) + assertState(t, customDCB, StateHalfOpen) // Successful requests in half-open state should close the circuit for i := 0; i < 3; i++ { - assert.NoError(t, successRequest(ctx, customDCB)) + assert.NoError(t, successRequest(customDCB)) } - assertState(ctx, t, customDCB, StateClosed) + assertState(t, customDCB, StateClosed) }) } @@ -290,8 +291,6 @@ func TestCustomDistributedCircuitBreakerStateTransitions(t *testing.T) { }, } - ctx := context.Background() - mr, err := miniredis.Run() if err != nil { t.Fatalf("Failed to start miniredis: %v", err) @@ -302,44 +301,47 @@ func TestCustomDistributedCircuitBreakerStateTransitions(t *testing.T) { Addr: mr.Addr(), }) - store := &storeAdapter{client: client} + store := &storeAdapter{ + ctx: context.Background(), + client: client, + } - dcb, err := NewDistributedCircuitBreaker[any](ctx, store, customSt) + dcb, err := NewDistributedCircuitBreaker[any](store, customSt) assert.NoError(t, err) // Test case t.Run("Circuit Breaker State Transitions", func(t *testing.T) { // Initial state should be Closed - assertState(ctx, t, dcb, StateClosed) + assertState(t, dcb, StateClosed) // Cause two consecutive failures to trip the circuit for i := 0; i < 2; i++ { - err := failRequest(ctx, dcb) + err := failRequest(dcb) assert.NoError(t, err, "Fail request should not return an error") } // Circuit should now be Open - assertState(ctx, t, dcb, StateOpen) + assertState(t, dcb, StateOpen) assert.Equal(t, StateChange{"cb", StateClosed, StateOpen}, stateChange) // Requests should fail immediately when circuit is Open - err := successRequest(ctx, dcb) + err := successRequest(dcb) assert.Error(t, err) assert.Equal(t, ErrOpenState, err) // Simulate timeout to transition to Half-Open - dcbPseudoSleep(ctx, dcb, 6*time.Second) - assertState(ctx, t, dcb, StateHalfOpen) + dcbPseudoSleep(dcb, 6*time.Second) + assertState(t, dcb, StateHalfOpen) assert.Equal(t, StateChange{"cb", StateOpen, StateHalfOpen}, stateChange) // Successful requests in Half-Open state should close the circuit for i := 0; i < int(dcb.maxRequests); i++ { - err := successRequest(ctx, dcb) + err := successRequest(dcb) assert.NoError(t, err) } // Circuit should now be Closed - assertState(ctx, t, dcb, StateClosed) + assertState(t, dcb, StateClosed) assert.Equal(t, StateChange{"cb", StateHalfOpen, StateClosed}, stateChange) }) }