diff --git a/.gitignore b/.gitignore index 66fd13c..ee770a6 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,5 @@ # Dependency directories (remove the comment below to include it) # vendor/ + +.idea/ diff --git a/event_entity.go b/event_entity.go index a912025..fbb3e92 100644 --- a/event_entity.go +++ b/event_entity.go @@ -1,5 +1,7 @@ package easyfsm +import "context" + type ( // EventEntity is the core that wraps the basic Event methods. EventEntity struct { @@ -7,6 +9,8 @@ type ( eventName EventName observers []EventObserver eventFunc EventFunc + // issue:https://github.com/wuqinqiang/easyfsm/issues/16 + forkCtxFunc func(ctx context.Context) context.Context } EventEntityOpt func(entity *EventEntity) @@ -34,6 +38,9 @@ func NewEventEntity(event EventName, handler EventFunc, eventName: event, eventFunc: handler, observers: make([]EventObserver, 0), + forkCtxFunc: func(ctx context.Context) context.Context { + return context.Background() + }, } for _, opt := range opts { opt(entity) @@ -61,8 +68,14 @@ func WithHook(hook EventHook) EventEntityOpt { } } -// Execute executes the event. -func (e *EventEntity) Execute(param *Param) (State, error) { +func WithForkCtxFunc(fn func(ctx context.Context) context.Context) EventEntityOpt { + return func(entity *EventEntity) { + entity.forkCtxFunc = fn + } +} + +// execute executes the event. +func (e *EventEntity) execute(param *Param) (State, error) { if e.hook != nil { e.hook.Before(param) } @@ -78,6 +91,7 @@ func (e *EventEntity) Execute(param *Param) (State, error) { // Asynchronous notify observers GoSafe(func() { + param.Ctx = e.forkCtxFunc(param.Ctx) e.notify(param) }) return state, nil diff --git a/event_entity_test.go b/event_entity_test.go index b1dae15..237f34e 100644 --- a/event_entity_test.go +++ b/event_entity_test.go @@ -45,15 +45,17 @@ func TestNewEventEntityNoOption(t *testing.T) { handler := func(opt *Param) (State, error) { return State(1), nil } + testKey := "remember" + testVal := "ok" - wantEntity := EventEntity{ - eventName: EventName(eventName), - eventFunc: handler, - observers: make([]EventObserver, 0), + forkCtxFunc := func(ctx context.Context) context.Context { + return context.WithValue(ctx, testKey, testVal) } + wantEntity := NewEventEntity(EventName(eventName), handler, WithForkCtxFunc(forkCtxFunc)) + t.Run(businessName, func(t *testing.T) { - got := NewEventEntity(EventName(eventName), handler) + got := NewEventEntity(EventName(eventName), handler, WithForkCtxFunc(forkCtxFunc)) if !reflect.DeepEqual(got.eventName, wantEntity.eventName) { t.Errorf("eventEntity name =%v,want %v", got.eventName, wantEntity.eventName) @@ -66,6 +68,12 @@ func TestNewEventEntityNoOption(t *testing.T) { if !funcEqual(got.eventFunc, wantEntity.eventFunc) { t.Errorf("eventEntity handler =%v,want %v", got.eventFunc, wantEntity.eventFunc) } + ctx := got.forkCtxFunc(context.Background()) + val := ctx.Value(testKey).(string) + if val != testVal { + t.Errorf("eventEntity ctxText =%v,want %v", val, testVal) + } + }) } @@ -76,23 +84,13 @@ func TestNewEventEntityWithObservers(t *testing.T) { return State(1), nil } - wantEntity := EventEntity{ - eventName: EventName(eventName), - eventFunc: handler, - observers: []EventObserver{ObserverTest{}}, - } + wantEntity := NewEventEntity(EventName(eventName), handler, WithObservers(ObserverTest{})) t.Run(businessName, func(t *testing.T) { got := NewEventEntity(EventName(eventName), handler, WithObservers(ObserverTest{})) - - if !reflect.DeepEqual(got.observers, wantEntity.observers) { - t.Errorf("eventEntity observers =%v,want %v", got.observers, wantEntity.observers) + if len(got.observers) != 1 { + t.Errorf("eventEntity observers len =%v,want %v", got.observers, wantEntity.observers) } - - if !reflect.DeepEqual(got.hook, wantEntity.hook) { - t.Errorf("eventEntity hook =%v,want %v", got.hook, wantEntity.hook) - } - }) } @@ -103,12 +101,7 @@ func TestNewEventEntityWithHook(t *testing.T) { return State(1), nil } - wantEntity := EventEntity{ - eventName: EventName(eventName), - eventFunc: handler, - hook: HookTest{}, - observers: make([]EventObserver, 0), - } + wantEntity := NewEventEntity(EventName(eventName), handler, WithHook(HookTest{})) t.Run(businessName, func(t *testing.T) { got := NewEventEntity(EventName(eventName), handler, WithHook(HookTest{})) @@ -137,20 +130,19 @@ func TestEventEntity_Execute_Success(t *testing.T) { Ctx: context.TODO(), Data: CreateOrderPar{OrderId: "wuqq0223"}, } - state, err := entity.Execute(param) + state, err := entity.execute(param) wantState, wantErr := handler(param) if err != nil { - t.Errorf("Execute err %v ,want %v", err, wantErr) + t.Errorf("execute err %v ,want %v", err, wantErr) } if state != wantState { - t.Errorf("Execute state %v ,want %v", err, wantErr) + t.Errorf("execute state %v ,want %v", err, wantErr) } } func TestEventEntity_Execute_Err(t *testing.T) { - paidErr := fmt.Errorf("paid err") eventName := "paid_order" handler := func(opt *Param) (State, error) { @@ -165,13 +157,13 @@ func TestEventEntity_Execute_Err(t *testing.T) { Ctx: context.TODO(), Data: CreateOrderPar{OrderId: "wuqq0223"}, } - state, err := entity.Execute(param) + state, err := entity.execute(param) wantState, wantErr := handler(param) if err == nil || !errors.Is(err, paidErr) { - t.Errorf("Execute err %v ,want %v", err, wantErr) + t.Errorf("execute err %v ,want %v", err, wantErr) } if state != wantState { - t.Errorf("Execute state %v ,want %v", err, wantErr) + t.Errorf("execute state %v ,want %v", err, wantErr) } } diff --git a/fsm.go b/fsm.go index c37b54e..cdd68bc 100644 --- a/fsm.go +++ b/fsm.go @@ -45,7 +45,7 @@ func (f *FSM) Call(eventName EventName, opts ...ParamOption) (State, error) { } // call eventName func - state, err := eventEntity.Execute(param) + state, err := eventEntity.execute(param) if err != nil { return f.getState(), err } diff --git a/fsm_test.go b/fsm_test.go index bab41df..e3a5e57 100644 --- a/fsm_test.go +++ b/fsm_test.go @@ -12,7 +12,7 @@ var ( func Init() { args := DefaultArgList() for i := range DefaultArgList() { - RegisterStateMachine(DefaultBusinessName, args[i].state, &args[i].entity) + RegisterStateMachine(DefaultBusinessName, args[i].state, args[i].entity) } } @@ -30,8 +30,6 @@ func TestNewFSM(t *testing.T) { } func TestFSM_Call(t *testing.T) { - //clear - stateMachineMap = make(map[BusinessName]map[State]map[EventName]*EventEntity) Init() type ( wantRes struct { diff --git a/helper.go b/helper.go index 2db1ae1..cc34882 100644 --- a/helper.go +++ b/helper.go @@ -1,6 +1,9 @@ package easyfsm -import "github.com/wuqinqiang/easyfsm/log" +import ( + "context" + "github.com/wuqinqiang/easyfsm/log" +) func GoSafe(fn func()) { go goSafe(fn) @@ -14,3 +17,7 @@ func goSafe(fn func()) { }() fn() } + +type ForkCtxInterface interface { + ForkCtx(ctx context.Context) context.Context +} diff --git a/param.go b/param.go index 04909bc..caa0cea 100644 --- a/param.go +++ b/param.go @@ -16,6 +16,9 @@ type Param struct { func WithCtx(ctx context.Context) ParamOption { return func(opt *Param) { + if ctx == nil { + return + } opt.Ctx = ctx } } diff --git a/register_test.go b/register_test.go index 934766b..92d3ef4 100644 --- a/register_test.go +++ b/register_test.go @@ -1,6 +1,7 @@ package easyfsm import ( + "context" "sync" "testing" ) @@ -8,7 +9,7 @@ import ( type arg struct { state State eventName EventName - entity EventEntity + entity *EventEntity } func DefaultArgList() []arg { @@ -19,54 +20,46 @@ func DefaultArgList() []arg { args = append(args, arg{ state: 0, eventName: "crateOrder", - entity: EventEntity{ - eventName: "crateOrder", - eventFunc: func(opt *Param) (State, error) { - return State(1), nil - }, - }, + entity: NewEventEntity("crateOrder", func(opt *Param) (State, error) { + return State(1), nil + }), }, arg{ state: 1, eventName: "payOrder", - entity: EventEntity{ - eventName: "payOrder", - eventFunc: func(opt *Param) (State, error) { + entity: NewEventEntity( + "payOrder", + func(opt *Param) (State, error) { return State(2), nil - }, - }, + }), }, arg{ state: 1, eventName: "cancelOrder", - entity: EventEntity{ - eventName: "cancelOrder", - eventFunc: func(opt *Param) (State, error) { + entity: NewEventEntity( + "cancelOrder", + func(opt *Param) (State, error) { return State(3), nil - }, - }, + }), }, ) return args } +// func TestRegisterStateMachine(t *testing.T) { - businessName := BusinessName("business_order") + businessName := BusinessName("TestRegisterStateMachine") args := DefaultArgList() - // clear - stateMachineMap = make(map[BusinessName]map[State]map[EventName]*EventEntity) for i := range args { - RegisterStateMachine(businessName, args[i].state, &args[i].entity) + RegisterStateMachine(businessName, args[i].state, args[i].entity) } commonTest(args, businessName, t) } func TestRegisterStateMachineForConcurrent(t *testing.T) { - businessName := BusinessName("business_order") + businessName := BusinessName("TestRegisterStateMachineForConcurrent") args := DefaultArgList() - // clear - stateMachineMap = make(map[BusinessName]map[State]map[EventName]*EventEntity) var ( wg sync.WaitGroup ) @@ -74,11 +67,10 @@ func TestRegisterStateMachineForConcurrent(t *testing.T) { wg.Add(1) go func(index int) { defer wg.Done() - RegisterStateMachine(businessName, args[index].state, &args[index].entity) + RegisterStateMachine(businessName, args[index].state, args[index].entity) }(i) } wg.Wait() - commonTest(args, businessName, t) } @@ -101,8 +93,10 @@ func commonTest(args []arg, businessName BusinessName, t *testing.T) { t.Errorf("entity shouldn't be nil") } - state, err := entity.Execute(nil) - wantState, wantErr := args[j].entity.Execute(nil) + param := &Param{Ctx: context.TODO()} + + state, err := entity.execute(param) + wantState, wantErr := args[j].entity.execute(param) if err != nil { t.Errorf("err %v want:%v", err, wantErr) }