Skip to content

Commit b6a2122

Browse files
committed
fix floating point conversion
1 parent f3a84f9 commit b6a2122

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

berkeley/hw3/dqn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def __init__(
167167
action_indices = tf.stack([row_indices, self.act_t_ph], axis=1)
168168
yhat = tf.gather_nd(self.q_t, action_indices)
169169

170-
q_target = q_func(self.obs_tp1_ph, self.num_actions, scope="target_q_func")
170+
q_target = q_func(obs_tp1_float, self.num_actions, scope="target_q_func")
171171
max_target_q_val = tf.reduce_max(q_target, axis=-1)
172172
y = self.rew_t_ph + gamma * max_target_q_val * (1 - self.done_mask_ph)
173173

0 commit comments

Comments
 (0)