Skip to content

Commit

Permalink
Merge pull request #18 from wuqinqiang/feat/fork-ctx-220313
Browse files Browse the repository at this point in the history
Feat/fork ctx 220313
  • Loading branch information
wuqinqiang authored Mar 16, 2022
2 parents a313d9e + 14bc583 commit fcd10d7
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 66 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@

# Dependency directories (remove the comment below to include it)
# vendor/

.idea/
18 changes: 16 additions & 2 deletions event_entity.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package easyfsm

import "context"

type (
// EventEntity is the core that wraps the basic Event methods.
EventEntity struct {
hook EventHook
eventName EventName
observers []EventObserver
eventFunc EventFunc
// issue:https://github.com/wuqinqiang/easyfsm/issues/16
forkCtxFunc func(ctx context.Context) context.Context
}

EventEntityOpt func(entity *EventEntity)
Expand Down Expand Up @@ -34,6 +38,9 @@ func NewEventEntity(event EventName, handler EventFunc,
eventName: event,
eventFunc: handler,
observers: make([]EventObserver, 0),
forkCtxFunc: func(ctx context.Context) context.Context {
return context.Background()
},
}
for _, opt := range opts {
opt(entity)
Expand Down Expand Up @@ -61,8 +68,14 @@ func WithHook(hook EventHook) EventEntityOpt {
}
}

// Execute executes the event.
func (e *EventEntity) Execute(param *Param) (State, error) {
func WithForkCtxFunc(fn func(ctx context.Context) context.Context) EventEntityOpt {
return func(entity *EventEntity) {
entity.forkCtxFunc = fn
}
}

// execute executes the event.
func (e *EventEntity) execute(param *Param) (State, error) {
if e.hook != nil {
e.hook.Before(param)
}
Expand All @@ -78,6 +91,7 @@ func (e *EventEntity) Execute(param *Param) (State, error) {

// Asynchronous notify observers
GoSafe(func() {
param.Ctx = e.forkCtxFunc(param.Ctx)
e.notify(param)
})
return state, nil
Expand Down
54 changes: 23 additions & 31 deletions event_entity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,17 @@ func TestNewEventEntityNoOption(t *testing.T) {
handler := func(opt *Param) (State, error) {
return State(1), nil
}
testKey := "remember"
testVal := "ok"

wantEntity := EventEntity{
eventName: EventName(eventName),
eventFunc: handler,
observers: make([]EventObserver, 0),
forkCtxFunc := func(ctx context.Context) context.Context {
return context.WithValue(ctx, testKey, testVal)
}

wantEntity := NewEventEntity(EventName(eventName), handler, WithForkCtxFunc(forkCtxFunc))

t.Run(businessName, func(t *testing.T) {
got := NewEventEntity(EventName(eventName), handler)
got := NewEventEntity(EventName(eventName), handler, WithForkCtxFunc(forkCtxFunc))

if !reflect.DeepEqual(got.eventName, wantEntity.eventName) {
t.Errorf("eventEntity name =%v,want %v", got.eventName, wantEntity.eventName)
Expand All @@ -66,6 +68,12 @@ func TestNewEventEntityNoOption(t *testing.T) {
if !funcEqual(got.eventFunc, wantEntity.eventFunc) {
t.Errorf("eventEntity handler =%v,want %v", got.eventFunc, wantEntity.eventFunc)
}
ctx := got.forkCtxFunc(context.Background())
val := ctx.Value(testKey).(string)
if val != testVal {
t.Errorf("eventEntity ctxText =%v,want %v", val, testVal)
}

})
}

Expand All @@ -76,23 +84,13 @@ func TestNewEventEntityWithObservers(t *testing.T) {
return State(1), nil
}

wantEntity := EventEntity{
eventName: EventName(eventName),
eventFunc: handler,
observers: []EventObserver{ObserverTest{}},
}
wantEntity := NewEventEntity(EventName(eventName), handler, WithObservers(ObserverTest{}))

t.Run(businessName, func(t *testing.T) {
got := NewEventEntity(EventName(eventName), handler, WithObservers(ObserverTest{}))

if !reflect.DeepEqual(got.observers, wantEntity.observers) {
t.Errorf("eventEntity observers =%v,want %v", got.observers, wantEntity.observers)
if len(got.observers) != 1 {
t.Errorf("eventEntity observers len =%v,want %v", got.observers, wantEntity.observers)
}

if !reflect.DeepEqual(got.hook, wantEntity.hook) {
t.Errorf("eventEntity hook =%v,want %v", got.hook, wantEntity.hook)
}

})
}

Expand All @@ -103,12 +101,7 @@ func TestNewEventEntityWithHook(t *testing.T) {
return State(1), nil
}

wantEntity := EventEntity{
eventName: EventName(eventName),
eventFunc: handler,
hook: HookTest{},
observers: make([]EventObserver, 0),
}
wantEntity := NewEventEntity(EventName(eventName), handler, WithHook(HookTest{}))

t.Run(businessName, func(t *testing.T) {
got := NewEventEntity(EventName(eventName), handler, WithHook(HookTest{}))
Expand Down Expand Up @@ -137,20 +130,19 @@ func TestEventEntity_Execute_Success(t *testing.T) {
Ctx: context.TODO(),
Data: CreateOrderPar{OrderId: "wuqq0223"},
}
state, err := entity.Execute(param)
state, err := entity.execute(param)

wantState, wantErr := handler(param)
if err != nil {
t.Errorf("Execute err %v ,want %v", err, wantErr)
t.Errorf("execute err %v ,want %v", err, wantErr)
}

if state != wantState {
t.Errorf("Execute state %v ,want %v", err, wantErr)
t.Errorf("execute state %v ,want %v", err, wantErr)
}
}

func TestEventEntity_Execute_Err(t *testing.T) {

paidErr := fmt.Errorf("paid err")
eventName := "paid_order"
handler := func(opt *Param) (State, error) {
Expand All @@ -165,13 +157,13 @@ func TestEventEntity_Execute_Err(t *testing.T) {
Ctx: context.TODO(),
Data: CreateOrderPar{OrderId: "wuqq0223"},
}
state, err := entity.Execute(param)
state, err := entity.execute(param)

wantState, wantErr := handler(param)
if err == nil || !errors.Is(err, paidErr) {
t.Errorf("Execute err %v ,want %v", err, wantErr)
t.Errorf("execute err %v ,want %v", err, wantErr)
}
if state != wantState {
t.Errorf("Execute state %v ,want %v", err, wantErr)
t.Errorf("execute state %v ,want %v", err, wantErr)
}
}
2 changes: 1 addition & 1 deletion fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (f *FSM) Call(eventName EventName, opts ...ParamOption) (State, error) {
}

// call eventName func
state, err := eventEntity.Execute(param)
state, err := eventEntity.execute(param)
if err != nil {
return f.getState(), err
}
Expand Down
4 changes: 1 addition & 3 deletions fsm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ var (
func Init() {
args := DefaultArgList()
for i := range DefaultArgList() {
RegisterStateMachine(DefaultBusinessName, args[i].state, &args[i].entity)
RegisterStateMachine(DefaultBusinessName, args[i].state, args[i].entity)
}
}

Expand All @@ -30,8 +30,6 @@ func TestNewFSM(t *testing.T) {
}

func TestFSM_Call(t *testing.T) {
//clear
stateMachineMap = make(map[BusinessName]map[State]map[EventName]*EventEntity)
Init()
type (
wantRes struct {
Expand Down
9 changes: 8 additions & 1 deletion helper.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package easyfsm

import "github.com/wuqinqiang/easyfsm/log"
import (
"context"
"github.com/wuqinqiang/easyfsm/log"
)

func GoSafe(fn func()) {
go goSafe(fn)
Expand All @@ -14,3 +17,7 @@ func goSafe(fn func()) {
}()
fn()
}

type ForkCtxInterface interface {
ForkCtx(ctx context.Context) context.Context
}
3 changes: 3 additions & 0 deletions param.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ type Param struct {

func WithCtx(ctx context.Context) ParamOption {
return func(opt *Param) {
if ctx == nil {
return
}
opt.Ctx = ctx
}
}
Expand Down
50 changes: 22 additions & 28 deletions register_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package easyfsm

import (
"context"
"sync"
"testing"
)

type arg struct {
state State
eventName EventName
entity EventEntity
entity *EventEntity
}

func DefaultArgList() []arg {
Expand All @@ -19,66 +20,57 @@ func DefaultArgList() []arg {
args = append(args, arg{
state: 0,
eventName: "crateOrder",
entity: EventEntity{
eventName: "crateOrder",
eventFunc: func(opt *Param) (State, error) {
return State(1), nil
},
},
entity: NewEventEntity("crateOrder", func(opt *Param) (State, error) {
return State(1), nil
}),
},
arg{
state: 1,
eventName: "payOrder",
entity: EventEntity{
eventName: "payOrder",
eventFunc: func(opt *Param) (State, error) {
entity: NewEventEntity(
"payOrder",
func(opt *Param) (State, error) {
return State(2), nil
},
},
}),
},
arg{
state: 1,
eventName: "cancelOrder",
entity: EventEntity{
eventName: "cancelOrder",
eventFunc: func(opt *Param) (State, error) {
entity: NewEventEntity(
"cancelOrder",
func(opt *Param) (State, error) {
return State(3), nil
},
},
}),
},
)
return args
}

//
func TestRegisterStateMachine(t *testing.T) {
businessName := BusinessName("business_order")
businessName := BusinessName("TestRegisterStateMachine")
args := DefaultArgList()
// clear
stateMachineMap = make(map[BusinessName]map[State]map[EventName]*EventEntity)
for i := range args {
RegisterStateMachine(businessName, args[i].state, &args[i].entity)
RegisterStateMachine(businessName, args[i].state, args[i].entity)
}
commonTest(args, businessName, t)

}

func TestRegisterStateMachineForConcurrent(t *testing.T) {
businessName := BusinessName("business_order")
businessName := BusinessName("TestRegisterStateMachineForConcurrent")
args := DefaultArgList()
// clear
stateMachineMap = make(map[BusinessName]map[State]map[EventName]*EventEntity)
var (
wg sync.WaitGroup
)
for i := range args {
wg.Add(1)
go func(index int) {
defer wg.Done()
RegisterStateMachine(businessName, args[index].state, &args[index].entity)
RegisterStateMachine(businessName, args[index].state, args[index].entity)
}(i)
}
wg.Wait()

commonTest(args, businessName, t)
}

Expand All @@ -101,8 +93,10 @@ func commonTest(args []arg, businessName BusinessName, t *testing.T) {
t.Errorf("entity shouldn't be nil")
}

state, err := entity.Execute(nil)
wantState, wantErr := args[j].entity.Execute(nil)
param := &Param{Ctx: context.TODO()}

state, err := entity.execute(param)
wantState, wantErr := args[j].entity.execute(param)
if err != nil {
t.Errorf("err %v want:%v", err, wantErr)
}
Expand Down

0 comments on commit fcd10d7

Please sign in to comment.