Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
thowell committed Feb 9, 2024
1 parent 41505e1 commit 4a06147
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions mjpc/planners/cross_entropy/planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
int num_trajectory = num_trajectory_;

// n_elite_ might change in the GUI - keep constant for in this function
n_elite_ = std::min(n_elite_, num_trajectory);
int n_elite = std::min(n_elite_, num_trajectory);
n_elite_ = std::min(n_elite_, num_trajectory - 1);
int n_elite = std::min(n_elite_, num_trajectory - 1);

// resize number of mjData
ResizeMjData(model, pool.NumThreads());
Expand Down Expand Up @@ -273,9 +273,9 @@ void CrossEntropyPlanner::NominalTrajectory(int horizon, ThreadPool& pool) {
};

// rollout nominal policy
trajectory[0].Rollout(nominal_policy, task, model, data_[0].get(),
state.data(), time, mocap.data(), userdata.data(),
horizon);
trajectory[elite_average_index_].Rollout(
nominal_policy, task, model, data_[0].get(), state.data(), time,
mocap.data(), userdata.data(), horizon);
}

// set action from policy
Expand Down Expand Up @@ -381,7 +381,7 @@ void CrossEntropyPlanner::Rollouts(int num_trajectory, int horizon,
s.resampled_policy.representation;

// sample noise
if (i > 0) s.AddNoiseToPolicy(i, std_min);
if (i != s.elite_average_index_) s.AddNoiseToPolicy(i, std_min);
}

// ----- rollout sample policy ----- //
Expand All @@ -405,7 +405,7 @@ void CrossEntropyPlanner::Rollouts(int num_trajectory, int horizon,

// returns the **nominal** trajectory (this is the purple trace)
const Trajectory* CrossEntropyPlanner::BestTrajectory() {
return &trajectory[0];
return &trajectory[elite_average_index_];
}

// visualize planner-specific traces
Expand Down

0 comments on commit 4a06147

Please sign in to comment.