Skip to content

Commit f50d14f

Browse files
committed
fix q2 bugs
1 parent e4dbe4c commit f50d14f

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

berkeley/hw4/model_based_policy.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,25 @@ def _setup_action_selection(self, state_ph):
136136
"""
137137
### PROBLEM 2
138138
### YOUR CODE HERE
139-
random_action_sequences = tf.random_uniform([self._num_random_action_selection, self._horizon], maxval=self._action_dim)
139+
bs = self._num_random_action_selection
140+
random_action_sequences = tf.random_uniform(shape=[bs, self._horizon, self._action_dim],
141+
minval=self._action_space_low,
142+
maxval=self._action_space_high)
143+
144+
cost = tf.zeros([bs])
145+
146+
# repeat the first state for each in batch
147+
state_ph = tf.tile(state_ph[0:1, :], [bs, 1])
140148

141149
for i in range(self._horizon):
142-
actions = random_action_sequences[:, i]
150+
actions = random_action_sequences[:, i, :]
143151
next_state_pred = self._dynamics_func(state_ph, actions, True)
144-
cost = self._cost_fn(state_ph, actions, next_state_pred)
152+
cost += self._cost_fn(state_ph, actions, next_state_pred)
145153
state_ph = next_state_pred
146154

147155
best_sequence_index = tf.argmin(cost, axis=0)
148-
best_action = random_action_sequences[best_sequence_index, 0]
149156

157+
best_action = random_action_sequences[best_sequence_index, 0, :]
150158
return best_action
151159

152160
def _setup_graph(self):

0 commit comments

Comments
 (0)