Skip to content

Commit 8eaeb2e

Browse files
committed
[IBR-2091] Convet action type to numpy array in select_action function
1 parent 294837d commit 8eaeb2e

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

rl_algorithms/gail/agent.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,10 @@ def train(self):
168168
# gail reward (imitation reward)
169169
gail_reward = compute_gail_reward(
170170
self.learner.discriminator(
171-
(numpy2floattensor(state, self.learner.device), action)
171+
(
172+
numpy2floattensor(state, self.learner.device),
173+
numpy2floattensor(action, self.learner.device),
174+
)
172175
)
173176
)
174177

rl_algorithms/ppo/agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def select_action(self, state: np.ndarray) -> torch.Tensor:
135135
with torch.no_grad():
136136
state = numpy2floattensor(state, self.learner.device)
137137
selected_action, dist = self.learner.actor(state)
138+
selected_action = selected_action.detach()
138139
log_prob = dist.log_prob(selected_action)
139140
value = self.learner.critic(state)
140141

@@ -155,7 +156,7 @@ def select_action(self, state: np.ndarray) -> torch.Tensor:
155156
self.values.append(value)
156157
self.log_probs.append(_log_prob)
157158

158-
return selected_action
159+
return selected_action.detach().cpu().numpy()
159160

160161
def step(
161162
self, action: Union[np.ndarray, torch.Tensor]

0 commit comments

Comments
 (0)