diff --git a/PPO_continuous.py b/PPO_continuous.py index e76c09c..de73ea1 100644 --- a/PPO_continuous.py +++ b/PPO_continuous.py @@ -12,12 +12,14 @@ def __init__(self): self.states = [] self.logprobs = [] self.rewards = [] + self.is_terminals = [] def clear_memory(self): del self.actions[:] del self.states[:] del self.logprobs[:] del self.rewards[:] + del self.is_terminals[:] class ActorCritic(nn.Module): def __init__(self, state_dim, action_dim, action_std): @@ -94,9 +96,11 @@ def update(self, memory): # Monte Carlo estimate of rewards: rewards = [] discounted_reward = 0 - for reward in reversed(memory.rewards): + for reward, is_terminal in zip(reversed(memory.rewards), reversed(memory.is_terminals)): discounted_reward = reward + (self.gamma * discounted_reward) rewards.insert(0, discounted_reward) + if is_terminal: + discounted_reward = 0 # Normalizing the rewards: rewards = torch.tensor(rewards).to(device) @@ -178,8 +182,10 @@ def main(): # Running policy_old: action = ppo.select_action(state, memory) state, reward, done, _ = env.step(action) - # Saving reward: + + # Saving reward and is_terminals: memory.rewards.append(reward) + memory.is_terminals.append(done) # update if its time if time_step % update_timestep == 0: