From 4a061476ff8b049dbfda54fb18951dc57426b840 Mon Sep 17 00:00:00 2001 From: taylor howell Date: Fri, 9 Feb 2024 12:09:16 -0700 Subject: [PATCH] minor fix --- mjpc/planners/cross_entropy/planner.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mjpc/planners/cross_entropy/planner.cc b/mjpc/planners/cross_entropy/planner.cc index 4a00b4c8d..044119cb2 100644 --- a/mjpc/planners/cross_entropy/planner.cc +++ b/mjpc/planners/cross_entropy/planner.cc @@ -161,8 +161,8 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) { int num_trajectory = num_trajectory_; // n_elite_ might change in the GUI - keep constant for in this function - n_elite_ = std::min(n_elite_, num_trajectory); - int n_elite = std::min(n_elite_, num_trajectory); + n_elite_ = std::min(n_elite_, num_trajectory - 1); + int n_elite = std::min(n_elite_, num_trajectory - 1); // resize number of mjData ResizeMjData(model, pool.NumThreads()); @@ -273,9 +273,9 @@ void CrossEntropyPlanner::NominalTrajectory(int horizon, ThreadPool& pool) { }; // rollout nominal policy - trajectory[0].Rollout(nominal_policy, task, model, data_[0].get(), - state.data(), time, mocap.data(), userdata.data(), - horizon); + trajectory[elite_average_index_].Rollout( + nominal_policy, task, model, data_[0].get(), state.data(), time, + mocap.data(), userdata.data(), horizon); } // set action from policy @@ -381,7 +381,7 @@ void CrossEntropyPlanner::Rollouts(int num_trajectory, int horizon, s.resampled_policy.representation; // sample noise - if (i > 0) s.AddNoiseToPolicy(i, std_min); + if (i != s.elite_average_index_) s.AddNoiseToPolicy(i, std_min); } // ----- rollout sample policy ----- // @@ -405,7 +405,7 @@ void CrossEntropyPlanner::Rollouts(int num_trajectory, int horizon, // returns the **nominal** trajectory (this is the purple trace) const Trajectory* CrossEntropyPlanner::BestTrajectory() { - return &trajectory[0]; + return &trajectory[elite_average_index_]; } // visualize planner-specific traces