Skip to content

Commit 7f68dd3

Browse files
committed
update
1 parent e2a9d46 commit 7f68dd3

File tree

1 file changed

+169
-1
lines changed

1 file changed

+169
-1
lines changed

rl/grid_world.py

+169-1
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
# Note: you may need to update your version of future
66
# sudo pip install -U future
77

8-
98
import numpy as np
109

1110

11+
ACTION_SPACE = ('U', 'D', 'L', 'R')
12+
13+
1214
class Grid: # Environment
1315
def __init__(self, rows, cols, start):
1416
self.rows = rows
@@ -32,6 +34,22 @@ def current_state(self):
3234
def is_terminal(self, s):
3335
return s not in self.actions
3436

37+
def get_next_state(self, s, a):
38+
# this answers: where would I end up if I perform action 'a' in state 's'?
39+
i, j = s[0], s[1]
40+
41+
# if this action moves you somewhere else, then it will be in this dictionary
42+
if a in self.actions[(i, j)]:
43+
if a == 'U':
44+
i -= 1
45+
elif a == 'D':
46+
i += 1
47+
elif a == 'R':
48+
j += 1
49+
elif a == 'L':
50+
j -= 1
51+
return i, j
52+
3553
def move(self, action):
3654
# check if legal move first
3755
if action in self.actions[(self.i, self.j)]:
@@ -116,3 +134,153 @@ def negative_grid(step_cost=-0.1):
116134
})
117135
return g
118136

137+
138+
139+
140+
141+
class WindyGrid:
142+
def __init__(self, rows, cols, start):
143+
self.rows = rows
144+
self.cols = cols
145+
self.i = start[0]
146+
self.j = start[1]
147+
148+
def set(self, rewards, actions, probs):
149+
# rewards should be a dict of: (i, j): r (row, col): reward
150+
# actions should be a dict of: (i, j): A (row, col): list of possible actions
151+
self.rewards = rewards
152+
self.actions = actions
153+
self.probs = probs
154+
155+
def set_state(self, s):
156+
self.i = s[0]
157+
self.j = s[1]
158+
159+
def current_state(self):
160+
return (self.i, self.j)
161+
162+
def is_terminal(self, s):
163+
return s not in self.actions
164+
165+
def move(self, action):
166+
s = (self.i, self.j)
167+
a = action
168+
169+
next_state_probs = self.probs[(s, a)]
170+
next_states = list(next_state_probs.keys())
171+
next_probs = list(next_state_probs.values())
172+
s2 = np.random.choice(next_states, p=next_probs)
173+
174+
# update the current state
175+
self.i, self.j = s2
176+
177+
# return a reward (if any)
178+
return self.rewards.get(s2, 0)
179+
180+
def game_over(self):
181+
# returns true if game is over, else false
182+
# true if we are in a state where no actions are possible
183+
return (self.i, self.j) not in self.actions
184+
185+
def all_states(self):
186+
# possibly buggy but simple way to get all states
187+
# either a position that has possible next actions
188+
# or a position that yields a reward
189+
return set(self.actions.keys()) | set(self.rewards.keys())
190+
191+
192+
def windy_grid():
193+
g = WindyGrid(3, 4, (2, 0))
194+
rewards = {(0, 3): 1, (1, 3): -1}
195+
actions = {
196+
(0, 0): ('D', 'R'),
197+
(0, 1): ('L', 'R'),
198+
(0, 2): ('L', 'D', 'R'),
199+
(1, 0): ('U', 'D'),
200+
(1, 2): ('U', 'D', 'R'),
201+
(2, 0): ('U', 'R'),
202+
(2, 1): ('L', 'R'),
203+
(2, 2): ('L', 'R', 'U'),
204+
(2, 3): ('L', 'U'),
205+
}
206+
207+
# p(s' | s, a) represented as:
208+
# KEY: (s, a) --> VALUE: {s': p(s' | s, a)}
209+
probs = {
210+
((2, 0), 'U'): {(1, 0): 1.0},
211+
((2, 0), 'D'): {(2, 0): 1.0},
212+
((2, 0), 'L'): {(2, 0): 1.0},
213+
((2, 0), 'R'): {(2, 1): 1.0},
214+
((1, 0), 'U'): {(0, 0): 1.0},
215+
((1, 0), 'D'): {(2, 0): 1.0},
216+
((1, 0), 'L'): {(1, 0): 1.0},
217+
((1, 0), 'R'): {(1, 0): 1.0},
218+
((0, 0), 'U'): {(0, 0): 1.0},
219+
((0, 0), 'D'): {(1, 0): 1.0},
220+
((0, 0), 'L'): {(0, 0): 1.0},
221+
((0, 0), 'R'): {(0, 1): 1.0},
222+
((0, 1), 'U'): {(0, 1): 1.0},
223+
((0, 1), 'D'): {(0, 1): 1.0},
224+
((0, 1), 'L'): {(0, 0): 1.0},
225+
((0, 1), 'R'): {(0, 2): 1.0},
226+
((0, 2), 'U'): {(0, 2): 1.0},
227+
((0, 2), 'D'): {(1, 2): 1.0},
228+
((0, 2), 'L'): {(0, 1): 1.0},
229+
((0, 2), 'R'): {(0, 3): 1.0},
230+
((2, 1), 'U'): {(2, 1): 1.0},
231+
((2, 1), 'D'): {(2, 1): 1.0},
232+
((2, 1), 'L'): {(2, 0): 1.0},
233+
((2, 1), 'R'): {(2, 2): 1.0},
234+
((2, 2), 'U'): {(1, 2): 1.0},
235+
((2, 2), 'D'): {(2, 2): 1.0},
236+
((2, 2), 'L'): {(2, 1): 1.0},
237+
((2, 2), 'R'): {(2, 3): 1.0},
238+
((2, 3), 'U'): {(1, 3): 1.0},
239+
((2, 3), 'D'): {(2, 3): 1.0},
240+
((2, 3), 'L'): {(2, 2): 1.0},
241+
((2, 3), 'R'): {(2, 3): 1.0},
242+
((1, 2), 'U'): {(0, 2): 0.5, (1, 3): 0.5},
243+
((1, 2), 'D'): {(2, 2): 1.0},
244+
((1, 2), 'L'): {(1, 2): 1.0},
245+
((1, 2), 'R'): {(1, 3): 1.0},
246+
}
247+
g.set(rewards, actions, probs)
248+
return g
249+
250+
251+
252+
253+
def grid_5x5(step_cost=-0.1):
254+
g = Grid(5, 5, (4, 0))
255+
rewards = {(0, 4): 1, (1, 4): -1}
256+
actions = {
257+
(0, 0): ('D', 'R'),
258+
(0, 1): ('L', 'R'),
259+
(0, 2): ('L', 'R'),
260+
(0, 3): ('L', 'D', 'R'),
261+
(1, 0): ('U', 'D', 'R'),
262+
(1, 1): ('U', 'D', 'L'),
263+
(1, 3): ('U', 'D', 'R'),
264+
(2, 0): ('U', 'D', 'R'),
265+
(2, 1): ('U', 'L', 'R'),
266+
(2, 2): ('L', 'R', 'D'),
267+
(2, 3): ('L', 'R', 'U'),
268+
(2, 4): ('L', 'U', 'D'),
269+
(3, 0): ('U', 'D'),
270+
(3, 2): ('U', 'D'),
271+
(3, 4): ('U', 'D'),
272+
(4, 0): ('U', 'R'),
273+
(4, 1): ('L', 'R'),
274+
(4, 2): ('L', 'R', 'U'),
275+
(4, 3): ('L', 'R'),
276+
(4, 4): ('L', 'U'),
277+
}
278+
g.set(rewards, actions)
279+
280+
# non-terminal states
281+
visitable_states = actions.keys()
282+
for s in visitable_states:
283+
g.rewards[s] = step_cost
284+
285+
return g
286+

0 commit comments

Comments
 (0)