5
5
# Note: you may need to update your version of future
6
6
# sudo pip install -U future
7
7
8
-
9
8
import numpy as np
10
9
11
10
11
+ ACTION_SPACE = ('U' , 'D' , 'L' , 'R' )
12
+
13
+
12
14
class Grid : # Environment
13
15
def __init__ (self , rows , cols , start ):
14
16
self .rows = rows
@@ -32,6 +34,22 @@ def current_state(self):
32
34
def is_terminal (self , s ):
33
35
return s not in self .actions
34
36
37
+ def get_next_state (self , s , a ):
38
+ # this answers: where would I end up if I perform action 'a' in state 's'?
39
+ i , j = s [0 ], s [1 ]
40
+
41
+ # if this action moves you somewhere else, then it will be in this dictionary
42
+ if a in self .actions [(i , j )]:
43
+ if a == 'U' :
44
+ i -= 1
45
+ elif a == 'D' :
46
+ i += 1
47
+ elif a == 'R' :
48
+ j += 1
49
+ elif a == 'L' :
50
+ j -= 1
51
+ return i , j
52
+
35
53
def move (self , action ):
36
54
# check if legal move first
37
55
if action in self .actions [(self .i , self .j )]:
@@ -116,3 +134,153 @@ def negative_grid(step_cost=-0.1):
116
134
})
117
135
return g
118
136
137
+
138
+
139
+
140
+
141
+ class WindyGrid :
142
+ def __init__ (self , rows , cols , start ):
143
+ self .rows = rows
144
+ self .cols = cols
145
+ self .i = start [0 ]
146
+ self .j = start [1 ]
147
+
148
+ def set (self , rewards , actions , probs ):
149
+ # rewards should be a dict of: (i, j): r (row, col): reward
150
+ # actions should be a dict of: (i, j): A (row, col): list of possible actions
151
+ self .rewards = rewards
152
+ self .actions = actions
153
+ self .probs = probs
154
+
155
+ def set_state (self , s ):
156
+ self .i = s [0 ]
157
+ self .j = s [1 ]
158
+
159
+ def current_state (self ):
160
+ return (self .i , self .j )
161
+
162
+ def is_terminal (self , s ):
163
+ return s not in self .actions
164
+
165
+ def move (self , action ):
166
+ s = (self .i , self .j )
167
+ a = action
168
+
169
+ next_state_probs = self .probs [(s , a )]
170
+ next_states = list (next_state_probs .keys ())
171
+ next_probs = list (next_state_probs .values ())
172
+ s2 = np .random .choice (next_states , p = next_probs )
173
+
174
+ # update the current state
175
+ self .i , self .j = s2
176
+
177
+ # return a reward (if any)
178
+ return self .rewards .get (s2 , 0 )
179
+
180
+ def game_over (self ):
181
+ # returns true if game is over, else false
182
+ # true if we are in a state where no actions are possible
183
+ return (self .i , self .j ) not in self .actions
184
+
185
+ def all_states (self ):
186
+ # possibly buggy but simple way to get all states
187
+ # either a position that has possible next actions
188
+ # or a position that yields a reward
189
+ return set (self .actions .keys ()) | set (self .rewards .keys ())
190
+
191
+
192
+ def windy_grid ():
193
+ g = WindyGrid (3 , 4 , (2 , 0 ))
194
+ rewards = {(0 , 3 ): 1 , (1 , 3 ): - 1 }
195
+ actions = {
196
+ (0 , 0 ): ('D' , 'R' ),
197
+ (0 , 1 ): ('L' , 'R' ),
198
+ (0 , 2 ): ('L' , 'D' , 'R' ),
199
+ (1 , 0 ): ('U' , 'D' ),
200
+ (1 , 2 ): ('U' , 'D' , 'R' ),
201
+ (2 , 0 ): ('U' , 'R' ),
202
+ (2 , 1 ): ('L' , 'R' ),
203
+ (2 , 2 ): ('L' , 'R' , 'U' ),
204
+ (2 , 3 ): ('L' , 'U' ),
205
+ }
206
+
207
+ # p(s' | s, a) represented as:
208
+ # KEY: (s, a) --> VALUE: {s': p(s' | s, a)}
209
+ probs = {
210
+ ((2 , 0 ), 'U' ): {(1 , 0 ): 1.0 },
211
+ ((2 , 0 ), 'D' ): {(2 , 0 ): 1.0 },
212
+ ((2 , 0 ), 'L' ): {(2 , 0 ): 1.0 },
213
+ ((2 , 0 ), 'R' ): {(2 , 1 ): 1.0 },
214
+ ((1 , 0 ), 'U' ): {(0 , 0 ): 1.0 },
215
+ ((1 , 0 ), 'D' ): {(2 , 0 ): 1.0 },
216
+ ((1 , 0 ), 'L' ): {(1 , 0 ): 1.0 },
217
+ ((1 , 0 ), 'R' ): {(1 , 0 ): 1.0 },
218
+ ((0 , 0 ), 'U' ): {(0 , 0 ): 1.0 },
219
+ ((0 , 0 ), 'D' ): {(1 , 0 ): 1.0 },
220
+ ((0 , 0 ), 'L' ): {(0 , 0 ): 1.0 },
221
+ ((0 , 0 ), 'R' ): {(0 , 1 ): 1.0 },
222
+ ((0 , 1 ), 'U' ): {(0 , 1 ): 1.0 },
223
+ ((0 , 1 ), 'D' ): {(0 , 1 ): 1.0 },
224
+ ((0 , 1 ), 'L' ): {(0 , 0 ): 1.0 },
225
+ ((0 , 1 ), 'R' ): {(0 , 2 ): 1.0 },
226
+ ((0 , 2 ), 'U' ): {(0 , 2 ): 1.0 },
227
+ ((0 , 2 ), 'D' ): {(1 , 2 ): 1.0 },
228
+ ((0 , 2 ), 'L' ): {(0 , 1 ): 1.0 },
229
+ ((0 , 2 ), 'R' ): {(0 , 3 ): 1.0 },
230
+ ((2 , 1 ), 'U' ): {(2 , 1 ): 1.0 },
231
+ ((2 , 1 ), 'D' ): {(2 , 1 ): 1.0 },
232
+ ((2 , 1 ), 'L' ): {(2 , 0 ): 1.0 },
233
+ ((2 , 1 ), 'R' ): {(2 , 2 ): 1.0 },
234
+ ((2 , 2 ), 'U' ): {(1 , 2 ): 1.0 },
235
+ ((2 , 2 ), 'D' ): {(2 , 2 ): 1.0 },
236
+ ((2 , 2 ), 'L' ): {(2 , 1 ): 1.0 },
237
+ ((2 , 2 ), 'R' ): {(2 , 3 ): 1.0 },
238
+ ((2 , 3 ), 'U' ): {(1 , 3 ): 1.0 },
239
+ ((2 , 3 ), 'D' ): {(2 , 3 ): 1.0 },
240
+ ((2 , 3 ), 'L' ): {(2 , 2 ): 1.0 },
241
+ ((2 , 3 ), 'R' ): {(2 , 3 ): 1.0 },
242
+ ((1 , 2 ), 'U' ): {(0 , 2 ): 0.5 , (1 , 3 ): 0.5 },
243
+ ((1 , 2 ), 'D' ): {(2 , 2 ): 1.0 },
244
+ ((1 , 2 ), 'L' ): {(1 , 2 ): 1.0 },
245
+ ((1 , 2 ), 'R' ): {(1 , 3 ): 1.0 },
246
+ }
247
+ g .set (rewards , actions , probs )
248
+ return g
249
+
250
+
251
+
252
+
253
+ def grid_5x5 (step_cost = - 0.1 ):
254
+ g = Grid (5 , 5 , (4 , 0 ))
255
+ rewards = {(0 , 4 ): 1 , (1 , 4 ): - 1 }
256
+ actions = {
257
+ (0 , 0 ): ('D' , 'R' ),
258
+ (0 , 1 ): ('L' , 'R' ),
259
+ (0 , 2 ): ('L' , 'R' ),
260
+ (0 , 3 ): ('L' , 'D' , 'R' ),
261
+ (1 , 0 ): ('U' , 'D' , 'R' ),
262
+ (1 , 1 ): ('U' , 'D' , 'L' ),
263
+ (1 , 3 ): ('U' , 'D' , 'R' ),
264
+ (2 , 0 ): ('U' , 'D' , 'R' ),
265
+ (2 , 1 ): ('U' , 'L' , 'R' ),
266
+ (2 , 2 ): ('L' , 'R' , 'D' ),
267
+ (2 , 3 ): ('L' , 'R' , 'U' ),
268
+ (2 , 4 ): ('L' , 'U' , 'D' ),
269
+ (3 , 0 ): ('U' , 'D' ),
270
+ (3 , 2 ): ('U' , 'D' ),
271
+ (3 , 4 ): ('U' , 'D' ),
272
+ (4 , 0 ): ('U' , 'R' ),
273
+ (4 , 1 ): ('L' , 'R' ),
274
+ (4 , 2 ): ('L' , 'R' , 'U' ),
275
+ (4 , 3 ): ('L' , 'R' ),
276
+ (4 , 4 ): ('L' , 'U' ),
277
+ }
278
+ g .set (rewards , actions )
279
+
280
+ # non-terminal states
281
+ visitable_states = actions .keys ()
282
+ for s in visitable_states :
283
+ g .rewards [s ] = step_cost
284
+
285
+ return g
286
+
0 commit comments