Skip to content

Commit dc2f30f

Browse files
committed
feat: add interfaces
1 parent ac51f41 commit dc2f30f

File tree

5 files changed

+240
-60
lines changed

5 files changed

+240
-60
lines changed

data_storage.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package fsm
2+
3+
import (
4+
"encoding/json"
5+
"sync"
6+
)
7+
8+
// defaultDataStorage is a type for default data storage
9+
type dataStorage struct {
10+
mu sync.Mutex
11+
storage map[int64]map[any]any
12+
}
13+
14+
// initialDataStorage creates in memory storage for user's data
15+
func initialDataStorage() *dataStorage {
16+
return &dataStorage{
17+
storage: make(map[int64]map[any]any),
18+
}
19+
}
20+
21+
// Set sets user's data to data storage
22+
func (d *dataStorage) Set(userID int64, key, value any) error {
23+
d.mu.Lock()
24+
defer d.mu.Unlock()
25+
s, ok := d.storage[userID]
26+
if !ok {
27+
s = make(map[any]any)
28+
d.storage[userID] = s
29+
}
30+
31+
s[key] = value
32+
33+
return nil
34+
}
35+
36+
// Get gets user's data from data storage
37+
func (d *dataStorage) Get(userID int64, key any) (any, error) {
38+
d.mu.Lock()
39+
defer d.mu.Unlock()
40+
41+
_, ok := d.storage[userID]
42+
if !ok {
43+
return nil, nil
44+
}
45+
46+
return d.storage[userID][key], nil
47+
}
48+
49+
// Delete deletes user's data from data storage
50+
func (d *dataStorage) Delete(userID int64, key any) error {
51+
d.mu.Lock()
52+
delete(d.storage, userID)
53+
d.mu.Unlock()
54+
return nil
55+
}
56+
57+
// MarshalJSON implements json.Marshaler
58+
func (d *dataStorage) MarshalJSON() ([]byte, error) {
59+
d.mu.Lock()
60+
defer d.mu.Unlock()
61+
62+
return json.Marshal(d.storage)
63+
}
64+
65+
// UnmarshalJSON implements json.Unmarshaler
66+
func (d *dataStorage) UnmarshalJSON(data []byte) error {
67+
d.mu.Lock()
68+
defer d.mu.Unlock()
69+
70+
var response map[int64]map[any]any
71+
if err := json.Unmarshal(data, &response); err != nil {
72+
return err
73+
}
74+
75+
d.storage = response
76+
77+
return nil
78+
}

examples/simple/main.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ func (app *Application) handlerCancel(ctx context.Context, b *bot.Bot, update *m
7070
userID := update.Message.From.ID
7171
chatID := update.Message.Chat.ID
7272

73-
if app.f.Current(userID) == stateDefault {
73+
currentState, _ := app.f.Current(userID)
74+
75+
if currentState == stateDefault {
7476
return
7577
}
7678

@@ -86,7 +88,9 @@ func (app *Application) handlerForm(ctx context.Context, b *bot.Bot, update *mod
8688
userID := update.Message.From.ID
8789
chatID := update.Message.Chat.ID
8890

89-
if app.f.Current(userID) != stateDefault {
91+
currentState, _ := app.f.Current(userID)
92+
93+
if currentState != stateDefault {
9094
return
9195
}
9296

@@ -101,7 +105,9 @@ func (app *Application) handlerDefault(ctx context.Context, b *bot.Bot, update *
101105
userID := update.Message.From.ID
102106
chatID := update.Message.Chat.ID
103107

104-
switch app.f.Current(userID) {
108+
currentState, _ := app.f.Current(userID)
109+
110+
switch currentState {
105111
case stateDefault:
106112
b.SendMessage(ctx, &bot.SendMessageParams{
107113
ChatID: chatID,
@@ -150,7 +156,7 @@ func (app *Application) handlerDefault(ctx context.Context, b *bot.Bot, update *
150156
app.f.Transition(userID, stateFinish, chatID, userID)
151157

152158
default:
153-
fmt.Printf("unexpected state %s\n", app.f.Current(userID))
159+
fmt.Printf("unexpected state %s\n", currentState)
154160
}
155161
}
156162

fsm.go

Lines changed: 70 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package fsm
22

33
import (
44
"encoding/json"
5-
"sync"
65
)
76

87
// StateID is a type for state identifier
@@ -15,25 +14,45 @@ type Callback func(f *FSM, args ...any)
1514
type FSM struct {
1615
initialStateID StateID
1716
callbacks map[StateID]Callback
18-
userStatesMu sync.RWMutex
19-
userStates map[int64]StateID
20-
storageMx sync.Mutex
21-
storage map[int64]map[any]any
17+
userStates UserStateStorage
18+
storage DataStorage
19+
}
20+
21+
// UserStateStorage is an interface for user state storage
22+
type UserStateStorage interface {
23+
Set(userID int64, stateID StateID) error
24+
Exists(userID int64) (bool, error)
25+
Get(userID int64) (StateID, error)
26+
MarshalJSON() ([]byte, error)
27+
UnmarshalJSON(data []byte) error
28+
}
29+
30+
// DataStorage is an interface for data storage
31+
type DataStorage interface {
32+
Set(userID int64, key, value any) error
33+
Get(userID int64, key any) (any, error)
34+
Delete(userID int64, key any) error
35+
MarshalJSON() ([]byte, error)
36+
UnmarshalJSON(data []byte) error
2237
}
2338

2439
// New creates a new FSM
25-
func New(initialStateName StateID, callbacks map[StateID]Callback) *FSM {
40+
func New(initialStateName StateID, callbacks map[StateID]Callback, opts ...Option) *FSM {
2641
s := &FSM{
2742
initialStateID: initialStateName,
2843
callbacks: make(map[StateID]Callback),
29-
userStates: make(map[int64]StateID),
30-
storage: make(map[int64]map[any]any),
44+
userStates: initialUserStateStorage(),
45+
storage: initialDataStorage(),
3146
}
3247

3348
for stateID, callback := range callbacks {
3449
s.callbacks[stateID] = callback
3550
}
3651

52+
for _, opt := range opts {
53+
opt(s)
54+
}
55+
3756
return s
3857
}
3958

@@ -50,54 +69,55 @@ func (f *FSM) AddCallbacks(cb map[StateID]Callback) {
5069
}
5170

5271
// Transition transitions the user to a new state
53-
func (f *FSM) Transition(userID int64, stateID StateID, args ...any) {
54-
f.userStatesMu.Lock()
55-
56-
userStateID, okUserState := f.userStates[userID]
57-
if !okUserState {
58-
userStateID = f.initialStateID
59-
f.userStates[userID] = userStateID
72+
func (f *FSM) Transition(userID int64, stateID StateID, args ...any) error {
73+
err := f.userStates.Set(userID, stateID)
74+
if err != nil {
75+
return err
6076
}
61-
f.userStates[userID] = stateID
62-
63-
f.userStatesMu.Unlock()
6477

6578
cb, okCb := f.callbacks[stateID]
6679
if okCb {
6780
cb(f, args...)
6881
}
82+
83+
return nil
6984
}
7085

7186
// Current returns the current state of the user
72-
func (f *FSM) Current(userID int64) StateID {
73-
f.userStatesMu.RLock()
74-
defer f.userStatesMu.RUnlock()
75-
76-
userStateID, ok := f.userStates[userID]
87+
func (f *FSM) Current(userID int64) (StateID, error) {
88+
ok, err := f.userStates.Exists(userID)
89+
if err != nil {
90+
return "", err
91+
}
7792
if !ok {
78-
f.userStates[userID] = f.initialStateID
79-
return f.initialStateID
93+
err = f.userStates.Set(userID, f.initialStateID)
94+
if err != nil {
95+
return "", err
96+
}
97+
98+
return f.initialStateID, nil
8099
}
81100

82-
return userStateID
101+
state, err := f.userStates.Get(userID)
102+
if err != nil {
103+
return "", err
104+
}
105+
106+
return state, nil
83107
}
84108

85109
// Reset resets the state of the user to the initial state
86-
func (f *FSM) Reset(userID int64) {
87-
f.userStatesMu.Lock()
88-
delete(f.userStates, userID)
89-
f.userStatesMu.Unlock()
110+
func (f *FSM) Reset(userID int64) error {
111+
return f.userStates.Set(userID, f.initialStateID)
90112
}
91113

92114
// MarshalJSON marshals the FSM to JSON
93115
func (f *FSM) MarshalJSON() ([]byte, error) {
94-
f.userStatesMu.RLock()
95-
defer f.userStatesMu.RUnlock()
96116

97117
type response struct {
98-
InitialStateID StateID `json:"initial_state_id"`
99-
UserStates map[int64]StateID `json:"user_states"`
100-
Storage map[int64]map[any]any `json:"storage"`
118+
InitialStateID StateID `json:"initial_state_id"`
119+
UserStates UserStateStorage `json:"user_states"`
120+
Storage DataStorage `json:"storage"`
101121
}
102122

103123
return json.Marshal(response{
@@ -109,13 +129,11 @@ func (f *FSM) MarshalJSON() ([]byte, error) {
109129

110130
// UnmarshalJSON unmarshals the FSM from JSON
111131
func (f *FSM) UnmarshalJSON(data []byte) error {
112-
f.userStatesMu.Lock()
113-
defer f.userStatesMu.Unlock()
114132

115133
type response struct {
116-
InitialStateID StateID `json:"initial_state_id"`
117-
UserStates map[int64]StateID `json:"user_states"`
118-
Storage map[int64]map[any]any `json:"storage"`
134+
InitialStateID StateID `json:"initial_state_id"`
135+
UserStates UserStateStorage `json:"user_states"`
136+
Storage DataStorage `json:"storage"`
119137
}
120138

121139
var r response
@@ -131,25 +149,21 @@ func (f *FSM) UnmarshalJSON(data []byte) error {
131149
}
132150

133151
// Set sets a value for a key for a user
134-
func (f *FSM) Set(userID int64, key, value any) {
135-
f.storageMx.Lock()
136-
defer f.storageMx.Unlock()
137-
s, ok := f.storage[userID]
138-
if !ok {
139-
s = make(map[any]any)
140-
f.storage[userID] = s
152+
func (f *FSM) Set(userID int64, key, value any) error {
153+
err := f.storage.Set(userID, key, value)
154+
if err != nil {
155+
return err
141156
}
142-
s[key] = value
157+
158+
return nil
143159
}
144160

145161
// Get gets a value for a key for a user
146-
func (f *FSM) Get(userID int64, key any) (any, bool) {
147-
f.storageMx.Lock()
148-
defer f.storageMx.Unlock()
149-
s, ok := f.storage[userID]
150-
if !ok {
151-
return nil, false
162+
func (f *FSM) Get(userID int64, key any) (any, bool, error) {
163+
v, err := f.storage.Get(userID, key)
164+
if err != nil {
165+
return nil, false, err
152166
}
153-
v, ok := s[key]
154-
return v, ok
167+
168+
return v, true, nil
155169
}

options.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package fsm
2+
3+
// Option is a type for FSM options
4+
type Option func(*FSM)
5+
6+
// WithUserStateStorage sets userStateStorage FSM
7+
func WithUserStateStorage(storage UserStateStorage) Option {
8+
return func(fsm *FSM) {
9+
fsm.userStates = storage
10+
}
11+
}
12+
13+
// WithDataStorage sets a data storage for FSM
14+
func WithDataStorage(storage DataStorage) Option {
15+
return func(fsm *FSM) {
16+
fsm.storage = storage
17+
}
18+
}

0 commit comments

Comments
 (0)