Skip to content

Commit 69f541c

Browse files
committed
added reward file, minor modifications
1 parent bc2f020 commit 69f541c

File tree

4 files changed

+242
-68
lines changed

4 files changed

+242
-68
lines changed
Binary file not shown.

IRL/GradientIRL/data.txt

Lines changed: 173 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,4 +429,176 @@ hello
429429
-3.43639952e-01 -7.65767498e-01 -3.27241543e-02 -2.50201562e-07
430430
-5.05102068e-04 5.08214586e-02 -5.09864219e-02 -9.25494968e-04
431431
-2.43519848e-13 3.26811300e-09 5.12168478e-06 -4.23215557e-06
432-
-7.47071173e-08]]
432+
-7.47071173e-08]][[ 1.52216984e-01 5.36240261e-01 7.45492183e-01 7.55108373e-01
433+
1.76604628e-01 5.06150549e-02 6.94948232e-01 5.03745090e-01
434+
9.33423216e-01 1.32251893e-01 7.91397828e-04 1.05346677e-01
435+
1.36618240e-01 6.67025513e-01 1.48446523e-02 1.56539037e-07
436+
2.70227936e-04 -3.99613657e-02 4.35692156e-02 3.00328441e-04
437+
1.36453140e-13 -4.41400014e-09 -3.79609480e-06 3.62530414e-06
438+
2.16955436e-08]
439+
[ 1.44473648e-01 5.87504238e-02 -1.15104288e-01 1.27960382e-01
440+
1.29991202e-01 3.90641726e-02 2.20072654e-01 7.31356882e-02
441+
1.52545244e-01 1.04910648e-01 5.43836963e-04 7.04120383e-02
442+
2.07021711e-01 9.87419853e-02 1.78795019e-02 9.36625258e-08
443+
2.34874132e-04 -1.08600929e-02 7.41720621e-03 6.25166527e-04
444+
1.07066709e-13 1.14588715e-09 -1.32558998e-06 6.06851426e-07
445+
5.30115737e-08]
446+
[-2.96690633e-01 -5.94990685e-01 -6.30387895e-01 -8.83068755e-01
447+
-3.06595830e-01 -8.96792275e-02 -9.15020886e-01 -5.76880778e-01
448+
-1.08596846e+00 -2.37162541e-01 -1.33523479e-03 -1.75758715e-01
449+
-3.43639952e-01 -7.65767498e-01 -3.27241543e-02 -2.50201562e-07
450+
-5.05102068e-04 5.08214586e-02 -5.09864219e-02 -9.25494968e-04
451+
-2.43519848e-13 3.26811300e-09 5.12168478e-06 -4.23215557e-06
452+
-7.47071173e-08]][[ 1.52216984e-01 5.36240261e-01 7.45492183e-01 7.55108373e-01
453+
1.76604628e-01 1.23032835e-01 5.74048154e-01 1.87986644e-01
454+
8.62077249e-01 1.94078736e-01 6.48181703e-02 7.01463090e-01
455+
3.07926973e-01 9.42892017e-01 1.49646791e-01 1.89401626e-02
456+
5.26491910e-01 9.28261800e-01 8.30158650e-01 7.83563514e-02
457+
2.76800133e-03 2.07759823e-01 4.88878957e-01 7.15663305e-01
458+
2.77725838e-02 1.82402168e-04 4.49873474e-02 -1.11776412e-01
459+
5.57286461e-01 7.63033952e-03 4.74548559e-06 3.95894658e-03
460+
-1.46580466e-01 2.02055488e-01 1.61080603e-03 4.39106976e-08
461+
8.89222055e-05 -2.10955585e-02 2.20887117e-02 1.48199229e-04
462+
1.36967095e-10 -3.68105587e-08 -6.14583511e-04 5.96004567e-04
463+
3.73632857e-06 1.36453140e-13 -4.41400014e-09 -3.79609480e-06
464+
3.62530414e-06 2.16955436e-08]
465+
[ 1.44473648e-01 5.87504238e-02 -1.15104288e-01 1.27960382e-01
466+
1.29991202e-01 1.06595752e-01 9.15120572e-02 -1.86908132e-02
467+
1.11396236e-01 1.45564125e-01 5.09812021e-02 1.95106213e-01
468+
3.87585225e-02 1.46559718e-01 1.16611721e-01 1.40232132e-02
469+
2.39092364e-01 2.36491097e-01 1.43288820e-01 6.73828167e-02
470+
1.95981281e-03 1.31225614e-01 3.09388413e-01 1.07701738e-01
471+
2.88693913e-02 1.20827847e-04 3.02441874e-02 7.72434093e-02
472+
8.57211606e-02 1.08945557e-02 2.91451695e-06 2.77997627e-03
473+
-2.72797891e-02 3.36565626e-02 3.01109141e-03 2.63580460e-08
474+
8.87571988e-05 -6.06906633e-03 3.76864979e-03 3.16359808e-04
475+
8.97362290e-11 7.56676336e-07 -2.01388846e-04 1.01256571e-04
476+
8.60734069e-06 1.07066709e-13 1.14588715e-09 -1.32558998e-06
477+
6.06851426e-07 5.30115737e-08]
478+
[-2.96690633e-01 -5.94990685e-01 -6.30387895e-01 -8.83068755e-01
479+
-3.06595830e-01 -2.29628587e-01 -6.65560212e-01 -1.69295831e-01
480+
-9.73473484e-01 -3.39642861e-01 -1.15799372e-01 -8.96569303e-01
481+
-3.46685495e-01 -1.08945174e+00 -2.66258512e-01 -3.29633759e-02
482+
-7.65584274e-01 -1.16475290e+00 -9.73447470e-01 -1.45739168e-01
483+
-4.72781413e-03 -3.38985437e-01 -7.98267370e-01 -8.23365043e-01
484+
-5.66419751e-02 -3.03230015e-04 -7.52315347e-02 3.45330028e-02
485+
-6.43007622e-01 -1.85248952e-02 -7.66000254e-06 -6.73892285e-03
486+
1.73860256e-01 -2.35712051e-01 -4.62189744e-03 -7.02687436e-08
487+
-1.77679404e-04 2.71646249e-02 -2.58573615e-02 -4.64559037e-04
488+
-2.26703324e-10 -7.19865777e-07 8.15972357e-04 -6.97261138e-04
489+
-1.23436693e-05 -2.43519848e-13 3.26811300e-09 5.12168478e-06
490+
-4.23215557e-06 -7.47071173e-08]][[ 1.52216984e-01 5.36240261e-01 7.45492183e-01 7.55108373e-01
491+
1.76604628e-01 1.23032835e-01 5.74048154e-01 1.87986644e-01
492+
8.62077249e-01 1.94078736e-01 6.48181703e-02 7.01463090e-01
493+
3.07926973e-01 9.42892017e-01 1.49646791e-01 1.89401626e-02
494+
5.26491910e-01 9.28261800e-01 8.30158650e-01 7.83563514e-02
495+
2.76800133e-03 2.07759823e-01 4.88878957e-01 7.15663305e-01
496+
2.77725838e-02 1.82402168e-04 4.49873474e-02 -1.11776412e-01
497+
5.57286461e-01 7.63033952e-03 4.74548559e-06 3.95894658e-03
498+
-1.46580466e-01 2.02055488e-01 1.61080603e-03 4.39106976e-08
499+
8.89222055e-05 -2.10955585e-02 2.20887117e-02 1.48199229e-04
500+
1.36967095e-10 -3.68105587e-08 -6.14583511e-04 5.96004567e-04
501+
3.73632857e-06 1.36453140e-13 -4.41400014e-09 -3.79609480e-06
502+
3.62530414e-06 2.16955436e-08]
503+
[ 1.44473648e-01 5.87504238e-02 -1.15104288e-01 1.27960382e-01
504+
1.29991202e-01 1.06595752e-01 9.15120572e-02 -1.86908132e-02
505+
1.11396236e-01 1.45564125e-01 5.09812021e-02 1.95106213e-01
506+
3.87585225e-02 1.46559718e-01 1.16611721e-01 1.40232132e-02
507+
2.39092364e-01 2.36491097e-01 1.43288820e-01 6.73828167e-02
508+
1.95981281e-03 1.31225614e-01 3.09388413e-01 1.07701738e-01
509+
2.88693913e-02 1.20827847e-04 3.02441874e-02 7.72434093e-02
510+
8.57211606e-02 1.08945557e-02 2.91451695e-06 2.77997627e-03
511+
-2.72797891e-02 3.36565626e-02 3.01109141e-03 2.63580460e-08
512+
8.87571988e-05 -6.06906633e-03 3.76864979e-03 3.16359808e-04
513+
8.97362290e-11 7.56676336e-07 -2.01388846e-04 1.01256571e-04
514+
8.60734069e-06 1.07066709e-13 1.14588715e-09 -1.32558998e-06
515+
6.06851426e-07 5.30115737e-08]
516+
[-2.96690633e-01 -5.94990685e-01 -6.30387895e-01 -8.83068755e-01
517+
-3.06595830e-01 -2.29628587e-01 -6.65560212e-01 -1.69295831e-01
518+
-9.73473484e-01 -3.39642861e-01 -1.15799372e-01 -8.96569303e-01
519+
-3.46685495e-01 -1.08945174e+00 -2.66258512e-01 -3.29633759e-02
520+
-7.65584274e-01 -1.16475290e+00 -9.73447470e-01 -1.45739168e-01
521+
-4.72781413e-03 -3.38985437e-01 -7.98267370e-01 -8.23365043e-01
522+
-5.66419751e-02 -3.03230015e-04 -7.52315347e-02 3.45330028e-02
523+
-6.43007622e-01 -1.85248952e-02 -7.66000254e-06 -6.73892285e-03
524+
1.73860256e-01 -2.35712051e-01 -4.62189744e-03 -7.02687436e-08
525+
-1.77679404e-04 2.71646249e-02 -2.58573615e-02 -4.64559037e-04
526+
-2.26703324e-10 -7.19865777e-07 8.15972357e-04 -6.97261138e-04
527+
-1.23436693e-05 -2.43519848e-13 3.26811300e-09 5.12168478e-06
528+
-4.23215557e-06 -7.47071173e-08]][[ 1.52216984e-01 5.36240261e-01 7.45492183e-01 7.55108373e-01
529+
1.76604628e-01 1.23032835e-01 5.74048154e-01 1.87986644e-01
530+
8.62077249e-01 1.94078736e-01 6.48181703e-02 7.01463090e-01
531+
3.07926973e-01 9.42892017e-01 1.49646791e-01 1.89401626e-02
532+
5.26491910e-01 9.28261800e-01 8.30158650e-01 7.83563514e-02
533+
2.76800133e-03 2.07759823e-01 4.88878957e-01 7.15663305e-01
534+
2.77725838e-02 1.82402168e-04 4.49873474e-02 -1.11776412e-01
535+
5.57286461e-01 7.63033952e-03 4.74548559e-06 3.95894658e-03
536+
-1.46580466e-01 2.02055488e-01 1.61080603e-03 4.39106976e-08
537+
8.89222055e-05 -2.10955585e-02 2.20887117e-02 1.48199229e-04
538+
1.36967095e-10 -3.68105587e-08 -6.14583511e-04 5.96004567e-04
539+
3.73632857e-06 1.36453140e-13 -4.41400014e-09 -3.79609480e-06
540+
3.62530414e-06 2.16955436e-08]
541+
[ 1.44473648e-01 5.87504238e-02 -1.15104288e-01 1.27960382e-01
542+
1.29991202e-01 1.06595752e-01 9.15120572e-02 -1.86908132e-02
543+
1.11396236e-01 1.45564125e-01 5.09812021e-02 1.95106213e-01
544+
3.87585225e-02 1.46559718e-01 1.16611721e-01 1.40232132e-02
545+
2.39092364e-01 2.36491097e-01 1.43288820e-01 6.73828167e-02
546+
1.95981281e-03 1.31225614e-01 3.09388413e-01 1.07701738e-01
547+
2.88693913e-02 1.20827847e-04 3.02441874e-02 7.72434093e-02
548+
8.57211606e-02 1.08945557e-02 2.91451695e-06 2.77997627e-03
549+
-2.72797891e-02 3.36565626e-02 3.01109141e-03 2.63580460e-08
550+
8.87571988e-05 -6.06906633e-03 3.76864979e-03 3.16359808e-04
551+
8.97362290e-11 7.56676336e-07 -2.01388846e-04 1.01256571e-04
552+
8.60734069e-06 1.07066709e-13 1.14588715e-09 -1.32558998e-06
553+
6.06851426e-07 5.30115737e-08]
554+
[-2.96690633e-01 -5.94990685e-01 -6.30387895e-01 -8.83068755e-01
555+
-3.06595830e-01 -2.29628587e-01 -6.65560212e-01 -1.69295831e-01
556+
-9.73473484e-01 -3.39642861e-01 -1.15799372e-01 -8.96569303e-01
557+
-3.46685495e-01 -1.08945174e+00 -2.66258512e-01 -3.29633759e-02
558+
-7.65584274e-01 -1.16475290e+00 -9.73447470e-01 -1.45739168e-01
559+
-4.72781413e-03 -3.38985437e-01 -7.98267370e-01 -8.23365043e-01
560+
-5.66419751e-02 -3.03230015e-04 -7.52315347e-02 3.45330028e-02
561+
-6.43007622e-01 -1.85248952e-02 -7.66000254e-06 -6.73892285e-03
562+
1.73860256e-01 -2.35712051e-01 -4.62189744e-03 -7.02687436e-08
563+
-1.77679404e-04 2.71646249e-02 -2.58573615e-02 -4.64559037e-04
564+
-2.26703324e-10 -7.19865777e-07 8.15972357e-04 -6.97261138e-04
565+
-1.23436693e-05 -2.43519848e-13 3.26811300e-09 5.12168478e-06
566+
-4.23215557e-06 -7.47071173e-08]][[ 1.52216984e-01 5.36240261e-01 7.45492183e-01 7.55108373e-01
567+
1.76604628e-01 1.23032835e-01 5.74048154e-01 1.87986644e-01
568+
8.62077249e-01 1.94078736e-01 6.48181703e-02 7.01463090e-01
569+
3.07926973e-01 9.42892017e-01 1.49646791e-01 1.89401626e-02
570+
5.26491910e-01 9.28261800e-01 8.30158650e-01 7.83563514e-02
571+
2.76800133e-03 2.07759823e-01 4.88878957e-01 7.15663305e-01
572+
2.77725838e-02 1.82402168e-04 4.49873474e-02 -1.11776412e-01
573+
5.57286461e-01 7.63033952e-03 4.74548559e-06 3.95894658e-03
574+
-1.46580466e-01 2.02055488e-01 1.61080603e-03 4.39106976e-08
575+
8.89222055e-05 -2.10955585e-02 2.20887117e-02 1.48199229e-04
576+
1.36967095e-10 -3.68105587e-08 -6.14583511e-04 5.96004567e-04
577+
3.73632857e-06 1.36453140e-13 -4.41400014e-09 -3.79609480e-06
578+
3.62530414e-06 2.16955436e-08]
579+
[ 1.44473648e-01 5.87504238e-02 -1.15104288e-01 1.27960382e-01
580+
1.29991202e-01 1.06595752e-01 9.15120572e-02 -1.86908132e-02
581+
1.11396236e-01 1.45564125e-01 5.09812021e-02 1.95106213e-01
582+
3.87585225e-02 1.46559718e-01 1.16611721e-01 1.40232132e-02
583+
2.39092364e-01 2.36491097e-01 1.43288820e-01 6.73828167e-02
584+
1.95981281e-03 1.31225614e-01 3.09388413e-01 1.07701738e-01
585+
2.88693913e-02 1.20827847e-04 3.02441874e-02 7.72434093e-02
586+
8.57211606e-02 1.08945557e-02 2.91451695e-06 2.77997627e-03
587+
-2.72797891e-02 3.36565626e-02 3.01109141e-03 2.63580460e-08
588+
8.87571988e-05 -6.06906633e-03 3.76864979e-03 3.16359808e-04
589+
8.97362290e-11 7.56676336e-07 -2.01388846e-04 1.01256571e-04
590+
8.60734069e-06 1.07066709e-13 1.14588715e-09 -1.32558998e-06
591+
6.06851426e-07 5.30115737e-08]
592+
[-2.96690633e-01 -5.94990685e-01 -6.30387895e-01 -8.83068755e-01
593+
-3.06595830e-01 -2.29628587e-01 -6.65560212e-01 -1.69295831e-01
594+
-9.73473484e-01 -3.39642861e-01 -1.15799372e-01 -8.96569303e-01
595+
-3.46685495e-01 -1.08945174e+00 -2.66258512e-01 -3.29633759e-02
596+
-7.65584274e-01 -1.16475290e+00 -9.73447470e-01 -1.45739168e-01
597+
-4.72781413e-03 -3.38985437e-01 -7.98267370e-01 -8.23365043e-01
598+
-5.66419751e-02 -3.03230015e-04 -7.52315347e-02 3.45330028e-02
599+
-6.43007622e-01 -1.85248952e-02 -7.66000254e-06 -6.73892285e-03
600+
1.73860256e-01 -2.35712051e-01 -4.62189744e-03 -7.02687436e-08
601+
-1.77679404e-04 2.71646249e-02 -2.58573615e-02 -4.64559037e-04
602+
-2.26703324e-10 -7.19865777e-07 8.15972357e-04 -6.97261138e-04
603+
-1.23436693e-05 -2.43519848e-13 3.26811300e-09 5.12168478e-06
604+
-4.23215557e-06 -7.47071173e-08]]

IRL/GradientIRL/gradientIRL.py

Lines changed: 5 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -7,59 +7,6 @@
77

88
sys.path.append('../..')
99

10-
class Reward():
11-
"""
12-
Reward is the class defining a reward function for the IRL problem.
13-
Reward is a linear combination of (Gaussian) radial basis functions.
14-
15-
dx -> number of basis functions on the position dimension;
16-
dv -> number of basis functions on the velocity dimension.
17-
"""
18-
def __init__(self, dx, dv):
19-
self.dx = dx
20-
self.dv = dv
21-
self.lx = 1.8 # length of the position interval
22-
self.lv = 0.14 # length of the velocity interval
23-
self.zx = 0.6 # zero of the position interval
24-
self.zv = 0.07 # zero of the velocity interval
25-
# tune sigma according to the discretization
26-
self.sigma_inv = inv(np.array([[.05, 0. ],
27-
[0., .0003]]))
28-
self.params = np.zeros(dx * dv)
29-
30-
def value(self, state, action):
31-
r = 0.
32-
for i in range(self.dx):
33-
for j in range(self.dv):
34-
r += self.params[i, j] * self.basis(state, i, j)
35-
36-
def basis(self, state, idx):
37-
j = idx % self.dv
38-
i = (idx - j)/self.dv
39-
x, v = state
40-
xi = i / (self.dx-1) * self.lx - self.zx
41-
vj = j / (self.dv-1) * self.lv - self.zv
42-
s = np.array([x, v])
43-
si = np.array([xi, vj])
44-
return np.exp(-np.dot((s - si), np.dot(self.sigma_inv, (s - si))))
45-
46-
def partial_value(self, state, action, idx):
47-
j = idx % self.dv
48-
i = (idx - j)/self.dv
49-
return self.params[idx] * self.basis(state, i, j)
50-
51-
def partial_traj(self, traj, idx):
52-
r = 0.
53-
for state, action in traj:
54-
r += self.partial_value(state, action, idx)
55-
return r
56-
57-
def basis_traj(self, traj, idx):
58-
r = 0.
59-
for state, _ in traj:
60-
r += self.basis(state, idx)
61-
return r
62-
6310
class GIRL():
6411
"""
6512
A class for estimating the parameters of the reward given some trajectory data.
@@ -113,6 +60,11 @@ def objective(self, alpha):
11360
M = np.dot(self.jacobian.T, self.jacobian)
11461
return np.dot(alpha, np.dot(M, alpha))
11562

63+
def loss(self, trajs):
64+
M = np.dot(self.jacobian.T, self.jacobian)
65+
alpha = self.reward.params
66+
return np.dot(alpha, np.dot(M, alpha))
67+
11668
def solve(self):
11769
# Define constraints
11870
h = lambda x: norm(x, 1) - 1 # sum of all the alphas must be 1

IRL/GradientIRL/main.py

Lines changed: 64 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,17 @@
88
import gym
99
import matplotlib.pyplot as plt
1010
import readtrajectory as read
11-
import estimatepolicy as estim
11+
#import estimatepolicy as estim
1212
import utils.gibbspolicy as gp
13+
import utils.reward as rew
1314
import gradientIRL as irl
1415
from mpl_toolkits.mplot3d import Axes3D
16+
from matplotlib import cm
1517

1618
env = gym.make('MountainCar-v0')
1719
T = 1000
1820
data_path = '../../data/data_long.txt'
21+
write_path = 'reward_params.txt'
1922

2023
# Read the data
2124

@@ -27,29 +30,31 @@
2730

2831
print('fitting policy to data...')
2932

30-
trace = policy.fit(data, 200)
33+
#trace = policy.fit(data, 200)
3134

3235
#print(trace[-1])
3336
#print(policy.get_theta())
34-
policy.episode(render=True)
35-
for i in range(10):
36-
policy.episode()
3737

3838
#plt.plot([t[0] for t in trace])
3939
#plt.plot([t[2] for t in trace])
4040
#plt.show()
4141

4242

43-
#policy.set_theta(np.array([-18, -1, 18]))
43+
policy.set_theta(np.array([-18, -1, 18]))
44+
#policy.episode(render=True)
45+
4446
#policy.episode(render=True)
47+
for i in range(10):
48+
policy.episode()
49+
4550
env.close()
4651

4752
print('solving the IRL problem:')
4853

49-
dx = 5
54+
dx = 10
5055
dv = 5
5156

52-
reward = irl.Reward(dx, dv)
57+
reward = rew.Reward(dx, dv)
5358

5459
L = dx*dv
5560

@@ -77,15 +82,60 @@ def plot(p):
7782
'''
7883

7984
girl = irl.GIRL(reward, data, policy)
80-
girl.compute_jacobian()
81-
print(girl.jacobian)
82-
girl.print_jacobian()
83-
alphas = girl.solve()
85+
#girl.compute_jacobian()
86+
#print(girl.jacobian)
87+
#alphas = girl.solve()
88+
89+
#plt.plot(alphas)
90+
#plt.show()
91+
92+
#plot(alphas)
93+
94+
#reward.set_params(alphas)
95+
#reward.export_to_file(write_path)
96+
97+
reward.import_from_file(write_path)
98+
99+
X = 50
100+
V = 50
101+
102+
103+
104+
x = np.arange(-0.6, 1.2, 0.1)
105+
v = np.arange(-0.07, 0.07, 0.005)
106+
X = len(x)
107+
V = len(v)
108+
print(X)
109+
print(V)
110+
x, v = np.meshgrid(x, v)
111+
112+
r = np.zeros([X, V])
113+
114+
fig = plt.figure()
115+
ax = fig.gca(projection='3d')
116+
for i in range(X):
117+
for j in range(V):
118+
xi = i / (X-1) * 1.8 - 0.6
119+
vj = j / (V-1) * 0.14 - 0.07
120+
r[i, j] = reward.value([xi, vj], 1)
121+
print(x.shape)
122+
print(v.shape)
123+
print(r.shape)
124+
ax.plot_surface(x, v, r.T, cmap=cm.coolwarm,
125+
linewidth=0, antialiased=False)
84126

85-
plt.plot(alphas)
86127
plt.show()
87128

88-
plot(alphas)
129+
'''
130+
fig = plt.figure()
131+
ax = fig.add_subplot(111, projection='3d')
132+
for i in range(X):
133+
for j in range(V):
134+
xi = i / (X-1) * 1.8 - 0.6
135+
vj = j / (V-1) * 0.14 - 0.07
136+
ax.scatter(i, j, reward.value([xi, vj], 1), c='r')
137+
plt.show()
138+
'''
89139

90140

91141

0 commit comments

Comments
 (0)