Skip to content

Commit

Permalink
feat:fork ctx
Browse files Browse the repository at this point in the history
  • Loading branch information
wuqinqiang committed Mar 16, 2022
1 parent e8f5e57 commit 14bc583
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 65 deletions.
4 changes: 2 additions & 2 deletions event_entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ func WithForkCtxFunc(fn func(ctx context.Context) context.Context) EventEntityOp
}
}

// Execute executes the event.
func (e *EventEntity) Execute(param *Param) (State, error) {
// execute executes the event.
func (e *EventEntity) execute(param *Param) (State, error) {
if e.hook != nil {
e.hook.Before(param)
}
Expand Down
54 changes: 23 additions & 31 deletions event_entity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,17 @@ func TestNewEventEntityNoOption(t *testing.T) {
handler := func(opt *Param) (State, error) {
return State(1), nil
}
testKey := "remember"
testVal := "ok"

wantEntity := EventEntity{
eventName: EventName(eventName),
eventFunc: handler,
observers: make([]EventObserver, 0),
forkCtxFunc := func(ctx context.Context) context.Context {
return context.WithValue(ctx, testKey, testVal)
}

wantEntity := NewEventEntity(EventName(eventName), handler, WithForkCtxFunc(forkCtxFunc))

t.Run(businessName, func(t *testing.T) {
got := NewEventEntity(EventName(eventName), handler)
got := NewEventEntity(EventName(eventName), handler, WithForkCtxFunc(forkCtxFunc))

if !reflect.DeepEqual(got.eventName, wantEntity.eventName) {
t.Errorf("eventEntity name =%v,want %v", got.eventName, wantEntity.eventName)
Expand All @@ -66,6 +68,12 @@ func TestNewEventEntityNoOption(t *testing.T) {
if !funcEqual(got.eventFunc, wantEntity.eventFunc) {
t.Errorf("eventEntity handler =%v,want %v", got.eventFunc, wantEntity.eventFunc)
}
ctx := got.forkCtxFunc(context.Background())
val := ctx.Value(testKey).(string)
if val != testVal {
t.Errorf("eventEntity ctxText =%v,want %v", val, testVal)
}

})
}

Expand All @@ -76,23 +84,13 @@ func TestNewEventEntityWithObservers(t *testing.T) {
return State(1), nil
}

wantEntity := EventEntity{
eventName: EventName(eventName),
eventFunc: handler,
observers: []EventObserver{ObserverTest{}},
}
wantEntity := NewEventEntity(EventName(eventName), handler, WithObservers(ObserverTest{}))

t.Run(businessName, func(t *testing.T) {
got := NewEventEntity(EventName(eventName), handler, WithObservers(ObserverTest{}))

if !reflect.DeepEqual(got.observers, wantEntity.observers) {
t.Errorf("eventEntity observers =%v,want %v", got.observers, wantEntity.observers)
if len(got.observers) != 1 {
t.Errorf("eventEntity observers len =%v,want %v", got.observers, wantEntity.observers)
}

if !reflect.DeepEqual(got.hook, wantEntity.hook) {
t.Errorf("eventEntity hook =%v,want %v", got.hook, wantEntity.hook)
}

})
}

Expand All @@ -103,12 +101,7 @@ func TestNewEventEntityWithHook(t *testing.T) {
return State(1), nil
}

wantEntity := EventEntity{
eventName: EventName(eventName),
eventFunc: handler,
hook: HookTest{},
observers: make([]EventObserver, 0),
}
wantEntity := NewEventEntity(EventName(eventName), handler, WithHook(HookTest{}))

t.Run(businessName, func(t *testing.T) {
got := NewEventEntity(EventName(eventName), handler, WithHook(HookTest{}))
Expand Down Expand Up @@ -137,20 +130,19 @@ func TestEventEntity_Execute_Success(t *testing.T) {
Ctx: context.TODO(),
Data: CreateOrderPar{OrderId: "wuqq0223"},
}
state, err := entity.Execute(param)
state, err := entity.execute(param)

wantState, wantErr := handler(param)
if err != nil {
t.Errorf("Execute err %v ,want %v", err, wantErr)
t.Errorf("execute err %v ,want %v", err, wantErr)
}

if state != wantState {
t.Errorf("Execute state %v ,want %v", err, wantErr)
t.Errorf("execute state %v ,want %v", err, wantErr)
}
}

func TestEventEntity_Execute_Err(t *testing.T) {

paidErr := fmt.Errorf("paid err")
eventName := "paid_order"
handler := func(opt *Param) (State, error) {
Expand All @@ -165,13 +157,13 @@ func TestEventEntity_Execute_Err(t *testing.T) {
Ctx: context.TODO(),
Data: CreateOrderPar{OrderId: "wuqq0223"},
}
state, err := entity.Execute(param)
state, err := entity.execute(param)

wantState, wantErr := handler(param)
if err == nil || !errors.Is(err, paidErr) {
t.Errorf("Execute err %v ,want %v", err, wantErr)
t.Errorf("execute err %v ,want %v", err, wantErr)
}
if state != wantState {
t.Errorf("Execute state %v ,want %v", err, wantErr)
t.Errorf("execute state %v ,want %v", err, wantErr)
}
}
2 changes: 1 addition & 1 deletion fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (f *FSM) Call(eventName EventName, opts ...ParamOption) (State, error) {
}

// call eventName func
state, err := eventEntity.Execute(param)
state, err := eventEntity.execute(param)
if err != nil {
return f.getState(), err
}
Expand Down
4 changes: 1 addition & 3 deletions fsm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ var (
func Init() {
args := DefaultArgList()
for i := range DefaultArgList() {
RegisterStateMachine(DefaultBusinessName, args[i].state, &args[i].entity)
RegisterStateMachine(DefaultBusinessName, args[i].state, args[i].entity)
}
}

Expand All @@ -30,8 +30,6 @@ func TestNewFSM(t *testing.T) {
}

func TestFSM_Call(t *testing.T) {
//clear
stateMachineMap = make(map[BusinessName]map[State]map[EventName]*EventEntity)
Init()
type (
wantRes struct {
Expand Down
3 changes: 3 additions & 0 deletions param.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ type Param struct {

func WithCtx(ctx context.Context) ParamOption {
return func(opt *Param) {
if ctx == nil {
return
}
opt.Ctx = ctx
}
}
Expand Down
50 changes: 22 additions & 28 deletions register_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package easyfsm

import (
"context"
"sync"
"testing"
)

type arg struct {
state State
eventName EventName
entity EventEntity
entity *EventEntity
}

func DefaultArgList() []arg {
Expand All @@ -19,66 +20,57 @@ func DefaultArgList() []arg {
args = append(args, arg{
state: 0,
eventName: "crateOrder",
entity: EventEntity{
eventName: "crateOrder",
eventFunc: func(opt *Param) (State, error) {
return State(1), nil
},
},
entity: NewEventEntity("crateOrder", func(opt *Param) (State, error) {
return State(1), nil
}),
},
arg{
state: 1,
eventName: "payOrder",
entity: EventEntity{
eventName: "payOrder",
eventFunc: func(opt *Param) (State, error) {
entity: NewEventEntity(
"payOrder",
func(opt *Param) (State, error) {
return State(2), nil
},
},
}),
},
arg{
state: 1,
eventName: "cancelOrder",
entity: EventEntity{
eventName: "cancelOrder",
eventFunc: func(opt *Param) (State, error) {
entity: NewEventEntity(
"cancelOrder",
func(opt *Param) (State, error) {
return State(3), nil
},
},
}),
},
)
return args
}

//
func TestRegisterStateMachine(t *testing.T) {
businessName := BusinessName("business_order")
businessName := BusinessName("TestRegisterStateMachine")
args := DefaultArgList()
// clear
stateMachineMap = make(map[BusinessName]map[State]map[EventName]*EventEntity)
for i := range args {
RegisterStateMachine(businessName, args[i].state, &args[i].entity)
RegisterStateMachine(businessName, args[i].state, args[i].entity)
}
commonTest(args, businessName, t)

}

func TestRegisterStateMachineForConcurrent(t *testing.T) {
businessName := BusinessName("business_order")
businessName := BusinessName("TestRegisterStateMachineForConcurrent")
args := DefaultArgList()
// clear
stateMachineMap = make(map[BusinessName]map[State]map[EventName]*EventEntity)
var (
wg sync.WaitGroup
)
for i := range args {
wg.Add(1)
go func(index int) {
defer wg.Done()
RegisterStateMachine(businessName, args[index].state, &args[index].entity)
RegisterStateMachine(businessName, args[index].state, args[index].entity)
}(i)
}
wg.Wait()

commonTest(args, businessName, t)
}

Expand All @@ -101,8 +93,10 @@ func commonTest(args []arg, businessName BusinessName, t *testing.T) {
t.Errorf("entity shouldn't be nil")
}

state, err := entity.Execute(nil)
wantState, wantErr := args[j].entity.Execute(nil)
param := &Param{Ctx: context.TODO()}

state, err := entity.execute(param)
wantState, wantErr := args[j].entity.execute(param)
if err != nil {
t.Errorf("err %v want:%v", err, wantErr)
}
Expand Down

0 comments on commit 14bc583

Please sign in to comment.