Skip to content

Commit c94d70f

Browse files
author
apbose
committed
ddpg
1 parent 585e803 commit c94d70f

File tree

3 files changed

+721
-0
lines changed

3 files changed

+721
-0
lines changed

DDPG/DDPG.py

Lines changed: 374 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,374 @@
1+
import gym
2+
import pybulletgym
3+
import pybulletgym.envs
4+
import numpy as np
5+
import math
6+
import matplotlib.pyplot as plt
7+
import queue
8+
import random
9+
from collections import deque
10+
import time
11+
12+
import torch
13+
import torch.nn as nn
14+
import torch.optim as optim
15+
import torch.nn.functional as F
16+
from torch.autograd import Variable
17+
print(torch.__version__)
18+
19+
env = gym.make("modified_gym_env:ReacherPyBulletEnv-v1", rand_init = False)
20+
env.reset()
21+
22+
import torch
23+
import torch.nn as nn
24+
import torch.optim as optim
25+
import torch.nn.functional as F
26+
from torch.autograd import Variable
27+
28+
29+
class Actor(nn.Module) :
30+
def __init__(self, state_dim, action_dim, hidden_size_one, hidden_size_two):
31+
32+
super(Actor, self).__init__()
33+
self.input_size = state_dim;
34+
self.hidden_size_one = hidden_size_one;
35+
self.hidden_size_two = hidden_size_two;
36+
self.output_size = action_dim
37+
38+
self.l1 = nn.Linear(self.input_size, self.hidden_size_one, bias = False)
39+
self.l2 = nn.Linear(self.hidden_size_one, self.hidden_size_two, bias = False)
40+
self.l3 = nn.Linear(self.hidden_size_two, self.output_size, bias = False)
41+
42+
self.model = torch.nn.Sequential(
43+
self.l1,
44+
nn.ReLU(),
45+
#nn.Tanh(),
46+
self.l2,
47+
nn.ReLU(),
48+
self.l3,
49+
nn.Tanh()
50+
)
51+
self.model.apply(self.weights_init_uniform)
52+
53+
54+
# takes in a module and applies the specified weight initialization
55+
def weights_init_uniform(self, m):
56+
classname = m.__class__.__name__
57+
# apply a uniform distribution to the weights and a bias=0
58+
if classname.find('Linear') != -1:
59+
m.weight.data.uniform_(-0.003, 0.003)
60+
#m.bias.data.fill_(0)
61+
62+
def forward (self, state):
63+
64+
65+
66+
return self.model(state)
67+
68+
class Critic(nn.Module):
69+
def __init__(self, state_dim, action_dim, hidden_size_one, hidden_size_two):
70+
71+
super(Critic, self).__init__()
72+
self.input_size = (state_dim + action_dim);
73+
self.hidden_size_one = hidden_size_one;
74+
self.hidden_size_two = hidden_size_two;
75+
self.output_size = 1
76+
77+
self.l1 = nn.Linear(self.input_size, self.hidden_size_one, bias = False)
78+
self.l2 = nn.Linear(self.hidden_size_one, self.hidden_size_two, bias = False)
79+
self.l3 = nn.Linear(self.hidden_size_two, self.output_size, bias = False)
80+
self.model = torch.nn.Sequential(
81+
self.l1,
82+
nn.ReLU(),
83+
#nn.Tanh(),
84+
self.l2,
85+
nn.ReLU(),
86+
self.l3,
87+
nn.Tanh()
88+
)
89+
self.model.apply(self.weights_init_uniform)
90+
91+
def weights_init_uniform(self, m):
92+
classname = m.__class__.__name__
93+
# apply a uniform distribution to the weights and a bias=0
94+
if classname.find('Linear') != -1:
95+
m.weight.data.uniform_(-0.0003, 0.0003)
96+
#m.bias.data.fill_(0)
97+
98+
99+
def forward (self, state, action):
100+
101+
stateAction = torch.cat([state, action], 1)
102+
return self.model(stateAction)
103+
104+
class replayBuffer:
105+
def __init__(self, buffer_size):
106+
self.buffer_size = buffer_size;
107+
self.buffer = deque(maxlen = buffer_size)
108+
109+
def push (self, state, action, next_state, reward, done):
110+
samples = (state, action, next_state, reward, done)
111+
self.buffer.append(samples)
112+
113+
def sample(self, batch_size):
114+
state_batch = []
115+
action_batch = []
116+
next_state_batch = []
117+
reward_batch = []
118+
done_batch = []
119+
120+
batch_data = random.sample(self.buffer, batch_size)
121+
122+
for samples in batch_data:
123+
state, action, next_state, reward, done = samples
124+
state_batch.append(state)
125+
action_batch.append(action)
126+
reward_batch.append(reward)
127+
next_state_batch.append(next_state)
128+
done_batch.append(done)
129+
return (state_batch, action_batch, next_state_batch, reward_batch, done_batch)
130+
131+
def __len__(self):
132+
return len(self.buffer)
133+
134+
####Parameters taken- d = 0.001, critic_lr=0.0003, actor_lr = 0.0003, batch_size = 500, buffer_size= 10000###
135+
class DDPG():
136+
def __init__(self,
137+
env,
138+
action_dim,
139+
state_dim,
140+
actor,
141+
critic,
142+
actor_target,
143+
critic_target,
144+
noise = 1,
145+
d_param = 0.001,
146+
critic_lr = 0.0003,
147+
actor_lr = 0.0003,
148+
gamma = 0.99, batch_size = 500, buffer_size = 10000):
149+
150+
"""
151+
param: env: An gym environment
152+
param: action_dim: Size of action space
153+
param: state_dim: Size of state space
154+
param: actor: actor model
155+
param: critic: critic model
156+
param: critic_lr: Learning rate of the critic
157+
param: actor_lr: Learning rate of the actor
158+
param: gamma: The discount factor
159+
param: batch_size: The batch size for training
160+
"""
161+
162+
163+
self.env = env
164+
self.action_dim = action_dim
165+
self.state_dim = state_dim
166+
self.critic_lr = critic_lr
167+
self.actor_lr = actor_lr
168+
self.gamma = gamma
169+
self.batch_size = batch_size
170+
171+
self.d = d_param
172+
self.noise = noise
173+
174+
self.actor = actor
175+
self.critic = critic
176+
self.actor_target = actor_target
177+
self.critic_target = critic_target
178+
self.actor_optimizer = optim.Adam(self.actor.parameters())# lr= self.actor_lr)
179+
self.critic_optimizer = optim.Adam(self.critic.parameters())# lr = self.critic_lr)
180+
181+
self.iterations = []
182+
self.return_history = []
183+
self.return_reward = []
184+
185+
self.replay_buffer = replayBuffer(buffer_size)
186+
self.loss = nn.MSELoss()
187+
188+
def updateQpolicy(self, batch_size, iterationNo):
189+
states, actions, state_next,rewards, _ = self.replay_buffer.sample(batch_size)
190+
states = torch.FloatTensor(states)
191+
actions = torch.FloatTensor(actions)
192+
rewards = torch.FloatTensor(rewards).reshape([batch_size,1])
193+
state_next = torch.FloatTensor(state_next)
194+
Q_pres = self.critic.forward(states, actions)
195+
action_next = self.actor_target.forward(state_next).detach()
196+
Q_next = self.critic_target.forward(state_next, action_next.detach()).detach()#while doing loss.backward we dont want target_policy parameters to be updated
197+
Q_nexttarget = rewards + Q_next * self.gamma
198+
#wrt Q parameter maps s and actions to theQ value
199+
criticLoss = self.loss(Q_nexttarget, Q_pres)
200+
#wrt policy parameter, maps states to actions
201+
actorLoss = -1 * self.critic.forward(states, actor.forward(states)).mean()
202+
203+
204+
#update the Q paramters which maps states to actions to the Q value
205+
self.critic_optimizer.zero_grad();
206+
criticLoss.backward();
207+
self.critic_optimizer.step();
208+
209+
#update thw policy parameters which updates the states to actions
210+
self.actor_optimizer.zero_grad();
211+
actorLoss.backward();
212+
self.actor_optimizer.step();
213+
214+
215+
#update the target network weights with the original network weights
216+
for tar_param, src_param in zip(self.actor_target.parameters(), self.actor.parameters()):
217+
tar_param.data.copy_(self.d * src_param.data + (1.0 - self.d) * tar_param.data)
218+
219+
for tar_param, src_param in zip(self.critic_target.parameters(), self.critic.parameters()):
220+
tar_param.data.copy_(self.d * src_param.data + (1.0 - self.d) * tar_param.data)
221+
222+
def selectAction(self, state):
223+
#state = torch.FloatTensor(state)
224+
state = Variable(torch.from_numpy(state).float().unsqueeze(0))
225+
action = self.actor.forward(state)
226+
action = action.detach().numpy()[0]
227+
return action
228+
229+
230+
def train(self, epochs):
231+
total_reward = 0
232+
for iterationNo in range(epochs):
233+
state = env.reset()
234+
batch_reward = 0
235+
236+
steps = 0
237+
'''
238+
while(steps < self.batch_size):
239+
steps += 1
240+
action = self.selectAction(state)
241+
if(self.noise):
242+
#mean = torch.zeros(2);
243+
#variance = torch.diag([0.1, 0.1])
244+
#c = MultivariateNormal(mean, variance)
245+
#noise = c.sample()
246+
noise = np.random.normal(0, 0.1)
247+
action[0]+= noise
248+
action[1]+= noise
249+
new_state, reward, done, _ = env.step(action)
250+
251+
batch_reward += reward
252+
total_reward += reward
253+
self.replay_buffer.push(state, action, new_state, reward, done)
254+
state = new_state
255+
256+
257+
258+
if(done == True):
259+
break;
260+
'''
261+
#fill up the buffer
262+
while(len(self.replay_buffer)< self.batch_size):
263+
action = self.selectAction(state)
264+
if(self.noise):
265+
#mean = torch.zeros(2);
266+
#variance = torch.diag([0.1, 0.1])
267+
#c = MultivariateNormal(mean, variance)
268+
#noise = c.sample()
269+
noise = np.random.normal(0, 0.1)
270+
action[0]+= noise
271+
action[1]+= noise
272+
new_state, reward, done, _ = env.step(action)
273+
if(done == True):
274+
state= env.reset()
275+
276+
batch_reward += reward
277+
total_reward += reward
278+
self.replay_buffer.push(state, action, new_state, reward, done)
279+
state = new_state
280+
281+
#if(len(self.replay_buffer) >= self.batch_size):
282+
#if(iterationNo%self.batch_size == 0 and len(self.replay_buffer)>= self.batch_size):
283+
action = self.selectAction(state)
284+
new_state, reward, done, _ = env.step(action)
285+
if(done == True):
286+
state = env.reset()
287+
batch_reward += reward
288+
total_reward += reward
289+
self.replay_buffer.push(state, action, new_state, reward, done)
290+
state = new_state
291+
292+
self.updateQpolicy(self.batch_size, iterationNo)
293+
if((iterationNo % 1000 == 0 and iterationNo!=0) or iterationNo == 1):
294+
self.iterations.append(iterationNo)
295+
self.return_reward.append(total_reward/iterationNo)
296+
print("iteration No is", iterationNo, "reward is", total_reward/iterationNo)
297+
#self.return_history.append(batch_reward)
298+
299+
if(iterationNo%2000 == 0 and iterationNo!= 0):
300+
fileName = "model"+ str(iterationNo)
301+
torch.save(self.actor.state_dict(), fileName)
302+
303+
304+
''''
305+
def train(self, epochs):
306+
batch_reward = 0
307+
for steps in range(epochs):
308+
309+
for iterationNo in range(epochs):
310+
state = env.reset()
311+
#steps = 0
312+
done = False
313+
314+
#while(steps < self.batch_size):
315+
#steps += 1
316+
action = self.selectAction(state)
317+
if(self.noise):
318+
#mean = torch.zeros(2);
319+
#variance = torch.diag([0.1, 0.1])
320+
#c = MultivariateNormal(mean, variance)
321+
#noise = c.sample()
322+
noise = np.random.normal(0, 0.1)
323+
action[0]+= noise
324+
action[1]+= noise
325+
new_state, reward, done, _ = env.step(action)
326+
327+
batch_reward += reward
328+
self.replay_buffer.push(state, action, new_state, reward, done)
329+
state = new_state
330+
331+
332+
if(done == True):
333+
state = env.reset()
334+
335+
336+
#if(len(self.replay_buffer) >= self.batch_size):
337+
if(iterationNo == self.batch_size and len(self.replay_buffer)>= self.batch_size):
338+
self.updateQpolicy(self.batch_size, iterationNo)
339+
# self.updateQpolicy(self.batch_size, iteration)
340+
if((iterationNo % 100 == 0 and iterationNo != 0) or iterationNo == 1):
341+
self.iterations.append(iterationNo)
342+
self.return_reward.append(batch_reward/iterationNo)
343+
print("iteration No is", iterationNo, "reward is", batch_reward/iterationNo)
344+
345+
346+
'''
347+
348+
num_states = 8
349+
num_actions = 2
350+
351+
actor = Actor(num_states, num_actions, 400, 300)
352+
actor_target = Actor(num_states, num_actions, 400, 300)
353+
354+
critic = Critic(num_states, num_actions, 400, 300)
355+
critic_target = Critic(num_states, num_actions, 400, 300)
356+
357+
for tar_param, src_param in zip(actor_target.parameters(), actor.parameters()):
358+
tar_param.data.copy_(src_param.data)
359+
360+
for tar_param, src_param in zip(critic_target.parameters(), critic.parameters()):
361+
tar_param.data.copy_(src_param.data)
362+
363+
364+
365+
#ddpgLinkArm = DDPG(env, num_actions, num_states, actor, critic, actor_target, critic_target, noise )
366+
ddpgLinkArm = DDPG(env, num_actions, num_states, actor, critic, actor_target, critic_target)
367+
ddpgLinkArm.train(200000)
368+
369+
del ddpgLinkArm.iterations[0]
370+
del ddpgLinkArm.return_reward[0]
371+
plt.plot(ddpgLinkArm.iterations,ddpgLinkArm.return_reward, color='b');
372+
plt.xlabel("iterations")
373+
plt.ylabel("return_history")
374+
plt.show()

0 commit comments

Comments
 (0)