Skip to content

Commit 7ef4e24

Browse files
committed
sajhdsdf
1 parent 2f0e017 commit 7ef4e24

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
data/data.txt
2+
Images/

utils/reward.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@ class Reward():
1313
dv -> number of basis functions on the velocity dimension.
1414
"""
1515
def __init__(self, dx, dv, env):
16+
sp = env.observation_space
1617
self.dx = dx
1718
self.dv = dv
18-
self.lx = env.low[0] # length of the position interval
19-
self.lv = env.low[1] # length of the velocity interval
20-
self.zx = env.high[0] # zero of the position interval
21-
self.zv = env.high[1] # zero of the velocity interval
19+
self.lx = sp.high[0] - sp.low[0] # length of the position interval
20+
self.lv = sp.high[1] - sp.low[1] # length of the velocity interval
21+
self.zx = -sp.low[0] # zero of the position interval
22+
self.zv = -sp.low[1] # zero of the velocity interval
2223
# tune sigma according to the discretization
2324
self.sigma_inv = inv(np.array([[.05, 0. ],
2425
[0., .0003]]))
@@ -32,7 +33,7 @@ def value(self, state, action):
3233

3334
def basis(self, state, idx):
3435
j = idx % self.dv
35-
i = (idx - j)/self.dv
36+
i = (idx - j)//self.dv
3637
x, v = state
3738
xi = i / (self.dx-1) * self.lx - self.zx
3839
vj = j / (self.dv-1) * self.lv - self.zv
@@ -42,7 +43,7 @@ def basis(self, state, idx):
4243

4344
def partial_value(self, state, action, idx):
4445
j = idx % self.dv
45-
i = (idx - j)/self.dv
46+
i = (idx - j)//self.dv
4647
return self.params[idx] * self.basis(state, i, j)
4748

4849
def partial_traj(self, traj, idx):
@@ -71,12 +72,12 @@ def export_to_file(self, file_path):
7172
def plot(self):
7273
ax = fig.gca(projection='3d')
7374
for i in range(self.dx):
74-
for j in range(self.dv):
75-
xi = i / (X-1) * 1.8 - 0.6
76-
vj = j / (V-1) * 0.14 - 0.07
77-
r[i, j] = reward.value([xi, vj], 1)
78-
ax.plot_surface(x, v, r.T, cmap=cm.coolwarm,
79-
linewidth=0, antialiased=False)
75+
for j in range(self.dv):
76+
xi = i / (X-1) * 1.8 - 0.6
77+
vj = j / (V-1) * 0.14 - 0.07
78+
r[i, j] = reward.value([xi, vj], 1)
79+
ax.plot_surface(x, v, r.T, cmap=cm.coolwarm,
80+
linewidth=0, antialiased=False)
8081

8182
plt.show()
8283

0 commit comments

Comments
 (0)