@@ -2,7 +2,6 @@ package fsm
2
2
3
3
import (
4
4
"encoding/json"
5
- "sync"
6
5
)
7
6
8
7
// StateID is a type for state identifier
@@ -15,25 +14,45 @@ type Callback func(f *FSM, args ...any)
15
14
type FSM struct {
16
15
initialStateID StateID
17
16
callbacks map [StateID ]Callback
18
- userStatesMu sync.RWMutex
19
- userStates map [int64 ]StateID
20
- storageMx sync.Mutex
21
- storage map [int64 ]map [any ]any
17
+ userStates UserStateStorage
18
+ storage DataStorage
19
+ }
20
+
21
+ // UserStateStorage is an interface for user state storage
22
+ type UserStateStorage interface {
23
+ Set (userID int64 , stateID StateID ) error
24
+ Exists (userID int64 ) (bool , error )
25
+ Get (userID int64 ) (StateID , error )
26
+ MarshalJSON () ([]byte , error )
27
+ UnmarshalJSON (data []byte ) error
28
+ }
29
+
30
+ // DataStorage is an interface for data storage
31
+ type DataStorage interface {
32
+ Set (userID int64 , key , value any ) error
33
+ Get (userID int64 , key any ) (any , error )
34
+ Delete (userID int64 , key any ) error
35
+ MarshalJSON () ([]byte , error )
36
+ UnmarshalJSON (data []byte ) error
22
37
}
23
38
24
39
// New creates a new FSM
25
- func New (initialStateName StateID , callbacks map [StateID ]Callback ) * FSM {
40
+ func New (initialStateName StateID , callbacks map [StateID ]Callback , opts ... Option ) * FSM {
26
41
s := & FSM {
27
42
initialStateID : initialStateName ,
28
43
callbacks : make (map [StateID ]Callback ),
29
- userStates : make ( map [ int64 ] StateID ),
30
- storage : make ( map [ int64 ] map [ any ] any ),
44
+ userStates : initialUserStateStorage ( ),
45
+ storage : initialDataStorage ( ),
31
46
}
32
47
33
48
for stateID , callback := range callbacks {
34
49
s .callbacks [stateID ] = callback
35
50
}
36
51
52
+ for _ , opt := range opts {
53
+ opt (s )
54
+ }
55
+
37
56
return s
38
57
}
39
58
@@ -50,54 +69,55 @@ func (f *FSM) AddCallbacks(cb map[StateID]Callback) {
50
69
}
51
70
52
71
// Transition transitions the user to a new state
53
- func (f * FSM ) Transition (userID int64 , stateID StateID , args ... any ) {
54
- f .userStatesMu .Lock ()
55
-
56
- userStateID , okUserState := f .userStates [userID ]
57
- if ! okUserState {
58
- userStateID = f .initialStateID
59
- f .userStates [userID ] = userStateID
72
+ func (f * FSM ) Transition (userID int64 , stateID StateID , args ... any ) error {
73
+ err := f .userStates .Set (userID , stateID )
74
+ if err != nil {
75
+ return err
60
76
}
61
- f .userStates [userID ] = stateID
62
-
63
- f .userStatesMu .Unlock ()
64
77
65
78
cb , okCb := f .callbacks [stateID ]
66
79
if okCb {
67
80
cb (f , args ... )
68
81
}
82
+
83
+ return nil
69
84
}
70
85
71
86
// Current returns the current state of the user
72
- func (f * FSM ) Current (userID int64 ) StateID {
73
- f . userStatesMu . RLock ( )
74
- defer f . userStatesMu . RUnlock ()
75
-
76
- userStateID , ok := f . userStates [ userID ]
87
+ func (f * FSM ) Current (userID int64 ) ( StateID , error ) {
88
+ ok , err := f . userStates . Exists ( userID )
89
+ if err != nil {
90
+ return "" , err
91
+ }
77
92
if ! ok {
78
- f .userStates [userID ] = f .initialStateID
79
- return f .initialStateID
93
+ err = f .userStates .Set (userID , f .initialStateID )
94
+ if err != nil {
95
+ return "" , err
96
+ }
97
+
98
+ return f .initialStateID , nil
80
99
}
81
100
82
- return userStateID
101
+ state , err := f .userStates .Get (userID )
102
+ if err != nil {
103
+ return "" , err
104
+ }
105
+
106
+ return state , nil
83
107
}
84
108
85
109
// Reset resets the state of the user to the initial state
86
- func (f * FSM ) Reset (userID int64 ) {
87
- f .userStatesMu .Lock ()
88
- delete (f .userStates , userID )
89
- f .userStatesMu .Unlock ()
110
+ func (f * FSM ) Reset (userID int64 ) error {
111
+ return f .userStates .Set (userID , f .initialStateID )
90
112
}
91
113
92
114
// MarshalJSON marshals the FSM to JSON
93
115
func (f * FSM ) MarshalJSON () ([]byte , error ) {
94
- f .userStatesMu .RLock ()
95
- defer f .userStatesMu .RUnlock ()
96
116
97
117
type response struct {
98
- InitialStateID StateID `json:"initial_state_id"`
99
- UserStates map [ int64 ] StateID `json:"user_states"`
100
- Storage map [ int64 ] map [ any ] any `json:"storage"`
118
+ InitialStateID StateID `json:"initial_state_id"`
119
+ UserStates UserStateStorage `json:"user_states"`
120
+ Storage DataStorage `json:"storage"`
101
121
}
102
122
103
123
return json .Marshal (response {
@@ -109,13 +129,11 @@ func (f *FSM) MarshalJSON() ([]byte, error) {
109
129
110
130
// UnmarshalJSON unmarshals the FSM from JSON
111
131
func (f * FSM ) UnmarshalJSON (data []byte ) error {
112
- f .userStatesMu .Lock ()
113
- defer f .userStatesMu .Unlock ()
114
132
115
133
type response struct {
116
- InitialStateID StateID `json:"initial_state_id"`
117
- UserStates map [ int64 ] StateID `json:"user_states"`
118
- Storage map [ int64 ] map [ any ] any `json:"storage"`
134
+ InitialStateID StateID `json:"initial_state_id"`
135
+ UserStates UserStateStorage `json:"user_states"`
136
+ Storage DataStorage `json:"storage"`
119
137
}
120
138
121
139
var r response
@@ -131,25 +149,21 @@ func (f *FSM) UnmarshalJSON(data []byte) error {
131
149
}
132
150
133
151
// Set sets a value for a key for a user
134
- func (f * FSM ) Set (userID int64 , key , value any ) {
135
- f .storageMx .Lock ()
136
- defer f .storageMx .Unlock ()
137
- s , ok := f .storage [userID ]
138
- if ! ok {
139
- s = make (map [any ]any )
140
- f .storage [userID ] = s
152
+ func (f * FSM ) Set (userID int64 , key , value any ) error {
153
+ err := f .storage .Set (userID , key , value )
154
+ if err != nil {
155
+ return err
141
156
}
142
- s [key ] = value
157
+
158
+ return nil
143
159
}
144
160
145
161
// Get gets a value for a key for a user
146
- func (f * FSM ) Get (userID int64 , key any ) (any , bool ) {
147
- f .storageMx .Lock ()
148
- defer f .storageMx .Unlock ()
149
- s , ok := f .storage [userID ]
150
- if ! ok {
151
- return nil , false
162
+ func (f * FSM ) Get (userID int64 , key any ) (any , bool , error ) {
163
+ v , err := f .storage .Get (userID , key )
164
+ if err != nil {
165
+ return nil , false , err
152
166
}
153
- v , ok := s [ key ]
154
- return v , ok
167
+
168
+ return v , true , nil
155
169
}
0 commit comments