Skip to content

Commit

Permalink
Merge pull request google-deepmind#284 from thowell/ce
Browse files Browse the repository at this point in the history
Cross entropy planner changes
  • Loading branch information
erez-tom authored Feb 13, 2024
2 parents 4384ef0 + ed9e635 commit bd6eda6
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 70 deletions.
99 changes: 33 additions & 66 deletions mjpc/planners/cross_entropy/planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,9 @@ void CrossEntropyPlanner::Allocate() {
trajectory[i].Allocate(kMaxTrajectoryHorizon);
candidate_policy[i].Allocate(model, *task, kMaxTrajectoryHorizon);
}

// elite average trajectory
elite_avg.Initialize(num_state, model->nu, task->num_residual,
task->num_trace, kMaxTrajectoryHorizon);
elite_avg.Allocate(kMaxTrajectoryHorizon);
nominal_trajectory.Initialize(num_state, model->nu, task->num_residual,
task->num_trace, kMaxTrajectoryHorizon);
nominal_trajectory.Allocate(kMaxTrajectoryHorizon);
}

// reset memory to zeros
Expand Down Expand Up @@ -143,7 +141,7 @@ void CrossEntropyPlanner::Reset(int horizon,
trajectory[i].Reset(kMaxTrajectoryHorizon);
candidate_policy[i].Reset(horizon);
}
elite_avg.Reset(kMaxTrajectoryHorizon);
nominal_trajectory.Reset(kMaxTrajectoryHorizon);

for (const auto& d : data_) {
mju_zero(d->ctrl, model->nu);
Expand All @@ -161,11 +159,6 @@ void CrossEntropyPlanner::SetState(const State& state) {

// optimize nominal policy using random sampling
void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
// check horizon
if (horizon != elite_avg.horizon) {
NominalTrajectory(horizon, pool);
}

// if num_trajectory_ has changed, use it in this new iteration.
// num_trajectory_ might change while this function runs. Keep it constant
// for the duration of this function.
Expand Down Expand Up @@ -220,66 +213,29 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
int num_spline_points = resampled_policy.num_spline_points;
int num_parameters = resampled_policy.num_parameters;

// reset parameters scratch to zero
std::fill(parameters_scratch.begin(), parameters_scratch.end(), 0.0);

// reset elite average
elite_avg.Reset(horizon);

// set elite average trajectory times
for (int tt = 0; tt <= horizon; tt++) {
elite_avg.times[tt] = time + tt * model->opt.timestep;
}

// best elite
int idx = trajectory_order[0];
// averaged return over elites
double avg_return = 0.0;

// add parameters
mju_copy(parameters_scratch.data(), candidate_policy[idx].parameters.data(),
num_parameters);
// reset parameters scratch
std::fill(parameters_scratch.begin(), parameters_scratch.end(), 0.0);

// copy first elite trajectory
mju_copy(elite_avg.actions.data(), trajectory[idx].actions.data(),
model->nu * (horizon - 1));
mju_copy(elite_avg.trace.data(), trajectory[idx].trace.data(),
trajectory[idx].dim_trace * horizon);
mju_copy(elite_avg.residual.data(), trajectory[idx].residual.data(),
elite_avg.dim_residual * horizon);
mju_copy(elite_avg.costs.data(), trajectory[idx].costs.data(), horizon);
elite_avg.total_return = trajectory[idx].total_return;

// loop over remaining elites to compute average
for (int i = 1; i < n_elite; i++) {
// loop over elites to compute average
for (int i = 0; i < n_elite; i++) {
// ordered trajectory index
int idx = trajectory_order[i];

// add parameters
mju_addTo(parameters_scratch.data(),
candidate_policy[idx].parameters.data(), num_parameters);

// add elite trajectory
mju_addTo(elite_avg.actions.data(), trajectory[idx].actions.data(),
model->nu * (horizon - 1));
mju_addTo(elite_avg.trace.data(), trajectory[idx].trace.data(),
trajectory[idx].dim_trace * horizon);
mju_addTo(elite_avg.residual.data(), trajectory[idx].residual.data(),
elite_avg.dim_residual * horizon);
mju_addTo(elite_avg.costs.data(), trajectory[idx].costs.data(), horizon);
elite_avg.total_return += trajectory[idx].total_return;
// add total return
avg_return += trajectory[idx].total_return;
}

// normalize
mju_scl(parameters_scratch.data(), parameters_scratch.data(), 1.0 / n_elite,
num_parameters);
mju_scl(elite_avg.actions.data(), elite_avg.actions.data(), 1.0 / n_elite,
model->nu * (horizon - 1));
mju_scl(elite_avg.trace.data(), elite_avg.trace.data(), 1.0 / n_elite,
elite_avg.dim_trace * horizon);
mju_scl(elite_avg.residual.data(), elite_avg.residual.data(), 1.0 / n_elite,
elite_avg.dim_residual * horizon);
mju_scl(elite_avg.costs.data(), elite_avg.costs.data(), 1.0 / n_elite,
horizon);
elite_avg.total_return /= n_elite;
avg_return /= n_elite;

// loop over elites to compute variance
std::fill(variance.begin(), variance.end(), 0.0); // reset variance to zero
Expand All @@ -304,25 +260,28 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
}

// improvement: compare nominal to elite average
improvement = mju_max(
elite_avg.total_return - trajectory[trajectory_order[0]].total_return,
0.0);
improvement =
mju_max(avg_return - trajectory[trajectory_order[0]].total_return, 0.0);

// stop timer
policy_update_compute_time = GetDuration(policy_update_start);
}

// compute trajectory using nominal policy
void CrossEntropyPlanner::NominalTrajectory(int horizon, ThreadPool& pool) {
void CrossEntropyPlanner::NominalTrajectory(int horizon) {
// set policy
auto nominal_policy = [&cp = resampled_policy](
double* action, const double* state, double time) {
cp.Action(action, state, time);
};

// rollout nominal policy
elite_avg.Rollout(nominal_policy, task, model, data_[0].get(), state.data(),
time, mocap.data(), userdata.data(), horizon);
nominal_trajectory.Rollout(nominal_policy, task, model,
data_[ThreadPool::WorkerId()].get(), state.data(),
time, mocap.data(), userdata.data(), horizon);
}
void CrossEntropyPlanner::NominalTrajectory(int horizon, ThreadPool& pool) {
NominalTrajectory(horizon);
}

// set action from policy
Expand Down Expand Up @@ -363,6 +322,8 @@ void CrossEntropyPlanner::ResamplePolicy(int horizon) {

LinearRange(resampled_policy.times.data(), time_shift,
resampled_policy.times[0], num_spline_points);

resampled_policy.representation = policy.representation;
}

// add random noise to nominal policy
Expand Down Expand Up @@ -446,12 +407,18 @@ void CrossEntropyPlanner::Rollouts(int num_trajectory, int horizon,
state.data(), time, mocap.data(), userdata.data(), horizon);
});
}
pool.WaitCount(count_before + num_trajectory);
// nominal
pool.Schedule([&s = *this, horizon]() { s.NominalTrajectory(horizon); });

// wait
pool.WaitCount(count_before + num_trajectory + 1);
pool.ResetCount();
}

// returns the nominal trajectory (this is the purple trace)
const Trajectory* CrossEntropyPlanner::BestTrajectory() { return &elite_avg; }
// returns the **nominal** trajectory (this is the purple trace)
const Trajectory* CrossEntropyPlanner::BestTrajectory() {
return &nominal_trajectory;
}

// visualize planner-specific traces
void CrossEntropyPlanner::Traces(mjvScene* scn) {
Expand Down
6 changes: 2 additions & 4 deletions mjpc/planners/cross_entropy/planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class CrossEntropyPlanner : public Planner {

// compute trajectory using nominal policy
void NominalTrajectory(int horizon, ThreadPool& pool) override;
void NominalTrajectory(int horizon);

// set action from policy
void ActionFromPolicy(double* action, const double* state, double time,
Expand Down Expand Up @@ -111,7 +112,7 @@ class CrossEntropyPlanner : public Planner {

// trajectories
Trajectory trajectory[kMaxTrajectory];
Trajectory elite_avg;
Trajectory nominal_trajectory;

// order of indices of rolled out trajectories, ordered by total return
std::vector<int> trajectory_order;
Expand All @@ -129,9 +130,6 @@ class CrossEntropyPlanner : public Planner {
// improvement
double improvement;

// flags
int processed_noise_status;

// timing
std::atomic<double> noise_compute_time;
double rollouts_compute_time;
Expand Down

0 comments on commit bd6eda6

Please sign in to comment.