@@ -136,17 +136,25 @@ def _setup_action_selection(self, state_ph):
136
136
"""
137
137
### PROBLEM 2
138
138
### YOUR CODE HERE
139
- random_action_sequences = tf .random_uniform ([self ._num_random_action_selection , self ._horizon ], maxval = self ._action_dim )
139
+ bs = self ._num_random_action_selection
140
+ random_action_sequences = tf .random_uniform (shape = [bs , self ._horizon , self ._action_dim ],
141
+ minval = self ._action_space_low ,
142
+ maxval = self ._action_space_high )
143
+
144
+ cost = tf .zeros ([bs ])
145
+
146
+ # repeat the first state for each in batch
147
+ state_ph = tf .tile (state_ph [0 :1 , :], [bs , 1 ])
140
148
141
149
for i in range (self ._horizon ):
142
- actions = random_action_sequences [:, i ]
150
+ actions = random_action_sequences [:, i , : ]
143
151
next_state_pred = self ._dynamics_func (state_ph , actions , True )
144
- cost = self ._cost_fn (state_ph , actions , next_state_pred )
152
+ cost + = self ._cost_fn (state_ph , actions , next_state_pred )
145
153
state_ph = next_state_pred
146
154
147
155
best_sequence_index = tf .argmin (cost , axis = 0 )
148
- best_action = random_action_sequences [best_sequence_index , 0 ]
149
156
157
+ best_action = random_action_sequences [best_sequence_index , 0 , :]
150
158
return best_action
151
159
152
160
def _setup_graph (self ):
0 commit comments