Skip to content

Commit 78ddb07

Browse files
committed
created utils for pooling code common to more than one part; created optimwrapper; added (normal RL) PCL implementation
1 parent 6fd9589 commit 78ddb07

12 files changed

+577
-1
lines changed
Binary file not shown.
Binary file not shown.
0 Bytes
Binary file not shown.

IRL/GradientIRL/data.txt

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
hello
2+
3+
[[-1.18900898e-02 4.96903900e-02 1.04406555e-01 1.58730417e-02
4+
-4.23941544e-03 -8.72280802e-05 -1.12692773e-02 3.76846329e-02
5+
-1.44644651e-02 -9.29604415e-03 4.32957681e-06 9.65612129e-04
6+
2.24042062e-04 -2.53010936e-03 -6.93311306e-04 1.00279959e-09
7+
2.02489775e-06 1.71430447e-04 1.07611322e-06 -5.75768644e-05
8+
5.17360376e-17 1.35899597e-12 6.99835356e-09 1.30627524e-07
9+
-8.31919421e-10]
10+
[ 4.31977934e-03 -3.10926976e-02 -3.25875064e-02 -1.90753612e-04
11+
7.91306133e-04 2.26136559e-05 2.76141387e-03 -4.86423932e-02
12+
1.97018338e-02 -3.34931567e-03 -1.54217847e-06 -4.10223872e-05
13+
8.39226736e-05 9.40446507e-03 -1.80195315e-03 -4.43447068e-10
14+
-9.14381650e-07 2.26588982e-04 4.62295143e-04 -1.88973112e-04
15+
-3.02221762e-17 -4.31908670e-13 1.06335015e-08 1.13591239e-07
16+
-1.10257323e-08]
17+
[ 7.57031043e-03 -1.85976924e-02 -7.18190487e-02 -1.56822881e-02
18+
3.44810931e-03 6.46144243e-05 8.50786344e-03 1.09577603e-02
19+
-5.23736864e-03 1.26453598e-02 -2.78739834e-06 -9.24589742e-04
20+
-3.07964736e-04 -6.87435571e-03 2.49526445e-03 -5.59352521e-10
21+
-1.11051610e-06 -3.98019429e-04 -4.63371256e-04 2.46549977e-04
22+
-2.15138613e-17 -9.27087305e-13 -1.76318551e-08 -2.44218763e-07
23+
1.18576517e-08]]
24+
25+
26+
[[-7.44378261e-03 1.03476025e-02 5.55985211e-02 1.79939048e-02
27+
-1.83442126e-02 -4.34513747e-04 -1.72593141e-02 1.23381006e-02
28+
-2.41314009e-03 -1.00333097e-02 2.13844881e-05 5.59763726e-04
29+
2.87668824e-03 -2.69822164e-03 -7.11075425e-04 1.65680790e-09
30+
1.57166751e-06 5.87768688e-05 1.26117189e-06 -4.07846578e-05
31+
1.25194551e-16 1.68596934e-12 1.67656117e-08 4.23171197e-08
32+
-6.24737306e-10]
33+
[ 2.69890893e-03 -6.48616999e-03 -1.74207249e-02 -2.65889039e-04
34+
3.41127553e-03 1.11432297e-04 4.20024488e-03 -1.59521272e-02
35+
3.27676022e-03 -3.62028185e-03 -7.62726781e-06 -2.42766160e-05
36+
1.07544713e-03 1.00014054e-02 -1.84757090e-03 -7.33240323e-10
37+
-7.10326089e-07 7.76888100e-05 7.95034961e-04 -1.33798329e-04
38+
-7.31920602e-17 -5.36533099e-13 2.54746966e-08 3.67819945e-08
39+
-8.25541905e-09]
40+
[ 4.74487368e-03 -3.86143255e-03 -3.81777963e-02 -1.77280157e-02
41+
1.49329371e-02 3.23081451e-04 1.30590693e-02 3.61402658e-03
42+
-8.63620131e-04 1.36535916e-02 -1.37572203e-05 -5.35487110e-04
43+
-3.95213537e-03 -7.30318374e-03 2.55864632e-03 -9.23567578e-10
44+
-8.61341419e-07 -1.36465679e-04 -7.96296133e-04 1.74582987e-04
45+
-5.20024904e-17 -1.14943624e-12 -4.22403083e-08 -7.90991142e-08
46+
8.88015635e-09]][[-9.92087548e-03 2.28633497e-02 1.55308595e-02 1.72920464e-02
47+
-1.29134687e-02 -5.71800192e-04 -8.09621854e-03 1.22918435e-02
48+
-3.46049370e-03 -1.01573588e-02 1.01556829e-05 1.83615773e-03
49+
3.34707702e-03 -5.30968134e-03 -7.49743952e-04 5.52985339e-10
50+
2.43187571e-06 6.70649357e-06 8.26933087e-06 -3.37988580e-05
51+
7.19197484e-18 1.27523366e-12 1.23515609e-08 1.22079467e-08
52+
-7.29325710e-10]
53+
[ 3.60172215e-03 -1.43018657e-02 -4.85617612e-03 -2.69481309e-04
54+
2.40347666e-03 1.47205882e-04 1.97561451e-03 -1.58400106e-02
55+
4.72896839e-03 -3.67641370e-03 -3.61951651e-06 -8.01815159e-05
56+
1.24520901e-03 1.98555401e-02 -1.95361509e-03 -2.44582906e-10
57+
-1.09837125e-06 8.85600246e-06 1.70244365e-03 -1.11273571e-04
58+
-4.20112071e-18 -4.05574962e-13 1.87541036e-08 1.05691766e-08
59+
-9.78063016e-09]
60+
[ 6.31915332e-03 -8.56148397e-03 -1.06746834e-02 -1.70225651e-02
61+
1.05099920e-02 4.24594310e-04 6.12060403e-03 3.54816705e-03
62+
-1.26847469e-03 1.38337725e-02 -6.53616639e-06 -1.75597621e-03
63+
-4.59228603e-03 -1.45458588e-02 2.70335904e-03 -3.08402433e-10
64+
-1.33350446e-06 -1.55624960e-05 -1.71071298e-03 1.45072429e-04
65+
-2.99085413e-18 -8.69658701e-13 -3.11056645e-08 -2.27771232e-08
66+
1.05099559e-08]]

IRL/GradientIRL/gradientIRL.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
import scipy.optimize as opt
3-
from numpy.linalg import inv
3+
from numpy.linalg import inv, norm
44
from tqdm import tqdm
55

66
class Reward():
@@ -93,6 +93,10 @@ def compute_jacobian(self):
9393
for l in tqdm(range(len(self.reward.params))):
9494
self.jacobian[:, l] = self.compute_gradient(l)
9595

96+
def print_jacobian(self):
97+
with open('data.txt', 'a') as f:
98+
f.write(str(self.jacobian))
99+
96100
def objective(self, alpha):
97101
M = np.dot(self.jacobian.T, self.jacobian)
98102
return np.dot(alpha, np.dot(M, alpha))
@@ -108,6 +112,7 @@ def solve(self):
108112
result = opt.minimize(self.objective, alpha0, constraints=eq_cons)
109113
if not result.success:
110114
print(result.message)
115+
print(result)
111116
alpha = result.x
112117
return alpha
113118

IRL/GradientIRL/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
girl = irl.GIRL(reward, data, policy)
6262
girl.compute_jacobian()
6363
print(girl.jacobian)
64+
girl.print_jacobian()
6465
alphas = girl.solve()
6566

6667
plt.plot(alphas)
0 Bytes
Binary file not shown.

RL/pcl/PCL.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Implementation of Path Consistency Learning
2+
# Author : Laetitia Teodorecu
3+
#
4+
# Reference : Nachum et. al. 2017
5+
6+
import sys
7+
import numpy as np
8+
9+
sys.path.append('../..')
10+
11+
import utils.gibbspolicy as gp
12+
import parametricvalue as pv
13+
14+
class PCL():
15+
16+
def __init__(self, env, T, eta_pi, eta_psi, discount, d, tau, B=0, alpha=0):
17+
self.env = env
18+
self.T = T
19+
self.eta_pi = eta_pi
20+
self.eta_psi = eta_psi
21+
self.discount = discount
22+
self.d = d
23+
self.tau = tau
24+
self.B = N
25+
self.alpha = alpha
26+
self.pi_theta = gp.GibbsPolicy(self.env, self.T, .5, gamma=self.discount)
27+
self.V_psi = pv.ParametricValueFunction()
28+
29+
def done(self, state):
30+
pos = state[0]
31+
if pos >= 0.5:
32+
return True
33+
else:
34+
return False
35+
36+
def rollout_traj(self, traj, idx):
37+
T = len(traj)
38+
if (T - idx) > d:
39+
return traj[idx:idx+d]
40+
else:
41+
return traj[idx:]
42+
43+
def R(self, s):
44+
r = 0
45+
for _, _, rew, _ in s:
46+
r += rew
47+
return r
48+
49+
def C(self, s):
50+
c = - self.V_psi(s[0]) + self.discount**(len(s)-1)*self.V_psi(s[-1])
51+
+ self.R(s) - self.tau*self.G(s)
52+
53+
def G(self, s):
54+
g = 0
55+
for i, [state, action, _, _] in enumerate(s):
56+
g += gamma**i * self.pi_theta.logproba(action, state)
57+
58+
def grad_G(self, s)
59+
grad_g = 0
60+
for i, [state, action, _, _] in enumerate(s):
61+
grad_g += gamma**i * self.pi_theta.gradlog(action, state)
62+
63+
def grad_V(self, state):
64+
return self.V_psi.grad(state)
65+
66+
def gradients(self, traj):
67+
delta_theta = 0
68+
delta_psi = 0
69+
for idx in len(traj):
70+
s = self.rollout_traj(traj, idx)
71+
C = self.C(s)
72+
delta_theta += C * grad_G(s) # define pi_theta
73+
delta_psi += C * (grad_V(s[0]) - self.discount**d * grad_V(s[-1]))
74+
return delta_theta, delta_psi
75+
76+
def episode(self, render=False):
77+
traj = []
78+
state = self.env.reset()
79+
for t in range(self.T):
80+
if not self.done(state):
81+
if render:
82+
self.env.render()
83+
action = self.sample(state)
84+
next_state, reward, _, _ = self.env.step(action)
85+
traj.append([state, action, reward, next_state])
86+
state = next_state
87+
else:
88+
break
89+
return traj
90+
91+
def learn(self, N):
92+
for n in tqdm(range(N)):
93+
theta = self.pi_theta.zero()
94+
psi = self.V_psi.zero()
95+
traj = self.episode()
96+
delta_theta, delta_phi = self.gradients(traj)
97+
theta += self.eta_pi*delta_theta
98+
psi += self.eta_psi*delta_psi
99+
# TODO : replay buffer
100+
101+
102+
103+
104+
105+
106+
107+
108+
109+
110+
111+
112+
113+
114+
115+
116+
117+
118+
119+
120+
121+
122+
123+
124+
125+
126+
127+
128+
129+
130+
131+
132+
133+
134+
135+
136+
137+
138+
139+
140+
141+
142+
143+
144+
145+
146+
147+
148+
149+
150+
Binary file not shown.

0 commit comments

Comments
 (0)