-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGridWorld.py
115 lines (91 loc) · 3.92 KB
/
GridWorld.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 1 19:24:00 2020
@author: joser
"""
import numpy as np
class GridWorldEnv():
def __init__(self, gridsize=4, startState='00', terminalStates=['33'], ditches=['12'],
ditchPenalty=-10, turnPenalty=-1, winReward=100, mode='prod'):
self.mode=mode
self.gridSize=min(gridsize, 9)
self.create_stateSpace()
self.actionSpace = [0, 1, 2, 3]
self.actionDict = {0: 'UP', 1:'DOWN', 2:'LEFT', 3:'RIGHT'}
self.startState = startState
self.terminalStates = terminalStates
self.ditches = ditches
self.winReward = winReward
self.ditchPenalty = ditchPenalty
self.turnPenalty = turnPenalty
self.stateCount = self.get_stateSpace_len()
self.actionCount = self.get_actionSpace_len()
self.stateDict = {k: v for k, v in zip(self.stateSpace, range(self.stateCount))}
self.currentState = self.startState
if self.mode == 'debug':
print("State Space", self.stateSpace)
print("State Dict", self.stateDict)
print("Action Space", self.actionSpace)
print("Action Dict", self.actionDict)
print("Start State", self.startState)
print("Terminal States", self.terminalStates)
print("Ditches", self.ditches)
print("WinReward:{}, TurnPenalty:{}, DitchPenalty:{}".format(self.winReward, self.turnPenalty, self.ditchPenalty))
def create_stateSpace(self):
self.stateSpace = []
for row in range(self.gridSize):
for col in range(self.gridSize):
self.stateSpace.append(str(row)+ str(col))
def set_mode(self, mode):
self.mode = mode
def get_stateSpace(self):
return self.stateSpace
def get_actionSpace(self):
return self.actionSpace
def get_actionDict(self):
return self.actionDict
def get_stateSpace_len(self):
return len(self.stateSpace)
def get_actionSpace_len(self):
return len(self.actionSpace)
def next_state(self, current_state, action):
s_row = int(current_state[0])
s_col = int(current_state[1])
next_row = s_row
next_col = s_col
if action == 0: next_row = max(0, s_row - 1)
if action == 1: next_row = min(self.gridSize-1, s_row+1)
if action == 2: next_col = max(0, s_col - 1)
if action == 3: next_col = min(self.gridSize - 1, s_col + 1)
new_state = str(next_row) + str(next_col)
if new_state in self.stateSpace:
if new_state in self.terminalStates: self.isGameEnd = True
if self.mode == 'debug':
print("CurrentState:{}, Action:{}, NextState:{}".format(current_state, action, new_state))
return new_state
else:
return current_state
def compute_reward(self, state):
reward = 0
reward += self.turnPenalty
if state in self.ditches: reward += self.ditchPenalty
if state in self.terminalStates: reward += self.winReward
return reward
def reset(self):
self.isGameEnd = False
self.totalAccumulatedReward = 0
self.totalTurns = 0
self.currentState = self.startState
return self.currentState
def step(self, action):
if self.isGameEnd:
raise("Game is Over Exception")
if action not in self.actionSpace:
raise("Invalid Action Exception")
self.currentState = self.next_state(self.currentState, action)
obs = self.currentState
reward = self.compute_reward(obs)
done = self.isGameEnd
if self.mode=='debug':
print("Obs:{}, Reward:{}, Done:{}, TotalTurns:{}".format(obs, reward, done, self.totalTurns))
return obs, reward, done, self.totalTurns