File tree Expand file tree Collapse file tree 2 files changed +6
-2
lines changed Expand file tree Collapse file tree 2 files changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -168,7 +168,10 @@ def train(self):
168
168
# gail reward (imitation reward)
169
169
gail_reward = compute_gail_reward (
170
170
self .learner .discriminator (
171
- (numpy2floattensor (state , self .learner .device ), action )
171
+ (
172
+ numpy2floattensor (state , self .learner .device ),
173
+ numpy2floattensor (action , self .learner .device ),
174
+ )
172
175
)
173
176
)
174
177
Original file line number Diff line number Diff line change @@ -135,6 +135,7 @@ def select_action(self, state: np.ndarray) -> torch.Tensor:
135
135
with torch .no_grad ():
136
136
state = numpy2floattensor (state , self .learner .device )
137
137
selected_action , dist = self .learner .actor (state )
138
+ selected_action = selected_action .detach ()
138
139
log_prob = dist .log_prob (selected_action )
139
140
value = self .learner .critic (state )
140
141
@@ -155,7 +156,7 @@ def select_action(self, state: np.ndarray) -> torch.Tensor:
155
156
self .values .append (value )
156
157
self .log_probs .append (_log_prob )
157
158
158
- return selected_action
159
+ return selected_action . detach (). cpu (). numpy ()
159
160
160
161
def step (
161
162
self , action : Union [np .ndarray , torch .Tensor ]
You can’t perform that action at this time.
0 commit comments