diff --git a/mjpc/planners/gradient/planner.cc b/mjpc/planners/gradient/planner.cc index 190e24269..ba36ea152 100644 --- a/mjpc/planners/gradient/planner.cc +++ b/mjpc/planners/gradient/planner.cc @@ -144,6 +144,9 @@ void GradientPlanner::Reset(int horizon, expected = 0.0; improvement = 0.0; surprise = 0.0; + + // derivative skip + derivative_skip_ = GetNumberOrDefault(0, model, "derivative_skip"); } // set state @@ -191,6 +194,7 @@ void GradientPlanner::OptimizePolicy(int horizon, ThreadPool& pool) { // update policy double c_best = c_prev; + int skip = derivative_skip_; for (int i = 0; i < settings.max_rollout; i++) { // ----- model derivatives ----- // // start timer @@ -200,7 +204,8 @@ void GradientPlanner::OptimizePolicy(int horizon, ThreadPool& pool) { model_derivative.Compute( model, data_, trajectory[0].states.data(), trajectory[0].actions.data(), trajectory[0].times.data(), dim_state, dim_state_derivative, dim_action, - dim_sensor, horizon, settings.fd_tolerance, settings.fd_mode, pool); + dim_sensor, horizon, settings.fd_tolerance, settings.fd_mode, pool, + skip); // stop timer model_derivative_time += GetDuration(model_derivative_start); @@ -468,6 +473,7 @@ void GradientPlanner::GUI(mjUI& ui) { {mjITEM_SELECT, "Spline", 2, &policy.representation, "Zero\nLinear\nCubic"}, {mjITEM_SLIDERINT, "Spline Pts", 2, &policy.num_spline_points, "0 1"}, + {mjITEM_SLIDERINT, "Deriv. Skip", 2, &derivative_skip_, "0 16"}, {mjITEM_END}}; // set number of trajectory slider limits diff --git a/mjpc/planners/gradient/planner.h b/mjpc/planners/gradient/planner.h index ee5b445f5..0d11187fd 100644 --- a/mjpc/planners/gradient/planner.h +++ b/mjpc/planners/gradient/planner.h @@ -160,6 +160,7 @@ class GradientPlanner : public Planner { private: mutable std::shared_mutex mtx_; + int derivative_skip_ = 0; }; } // namespace mjpc diff --git a/mjpc/planners/ilqg/planner.cc b/mjpc/planners/ilqg/planner.cc index 44a52698a..cfb76e18c 100644 --- a/mjpc/planners/ilqg/planner.cc +++ b/mjpc/planners/ilqg/planner.cc @@ -137,6 +137,9 @@ void iLQGPlanner::Reset(int horizon, const double* initial_repeated_action) { improvement = 0.0; expected = 0.0; surprise = 0.0; + + // derivative skip + derivative_skip_ = GetNumberOrDefault(0, model, "derivative_skip"); } // set state @@ -248,6 +251,7 @@ void iLQGPlanner::GUI(mjUI& ui) { "Zero\nLinear\nCubic"}, {mjITEM_SELECT, "Reg. Type", 2, &settings.regularization_type, "Control\nFeedback\nValue\nNone"}, + {mjITEM_SLIDERINT, "Deriv. Skip", 2, &derivative_skip_, "0 16"}, {mjITEM_CHECKINT, "Terminal Print", 2, &settings.verbose, ""}, {mjITEM_END}}; @@ -393,7 +397,7 @@ void iLQGPlanner::Iteration(int horizon, ThreadPool& pool) { candidate_policy[0].trajectory.actions.data(), candidate_policy[0].trajectory.times.data(), dim_state, dim_state_derivative, dim_action, dim_sensor, horizon, - settings.fd_tolerance, settings.fd_mode, pool); + settings.fd_tolerance, settings.fd_mode, pool, derivative_skip_); // stop timer double model_derivative_time = GetDuration(model_derivative_start); diff --git a/mjpc/planners/ilqg/planner.h b/mjpc/planners/ilqg/planner.h index fc3d1bf1d..d30b1ad56 100644 --- a/mjpc/planners/ilqg/planner.h +++ b/mjpc/planners/ilqg/planner.h @@ -157,6 +157,7 @@ class iLQGPlanner : public Planner { private: int num_trajectory_ = 1; int num_rollouts_gui_ = 1; + int derivative_skip_ = 0; }; } // namespace mjpc diff --git a/mjpc/planners/model_derivatives.cc b/mjpc/planners/model_derivatives.cc index 1ce85e702..03f67529a 100644 --- a/mjpc/planners/model_derivatives.cc +++ b/mjpc/planners/model_derivatives.cc @@ -48,40 +48,119 @@ void ModelDerivatives::Compute(const mjModel* m, const double* h, int dim_state, int dim_state_derivative, int dim_action, int dim_sensor, int T, double tol, int mode, - ThreadPool& pool) { - { - int count_before = pool.GetCount(); - for (int t = 0; t < T; t++) { - pool.Schedule([&m, &data, &A = A, &B = B, &C = C, &D = D, &x, &u, &h, - dim_state, dim_state_derivative, dim_action, dim_sensor, - tol, mode, t, T]() { - mjData* d = data[ThreadPool::WorkerId()].get(); - // set state - SetState(m, d, x + t * dim_state); - d->time = h[t]; - - // set action - mju_copy(d->ctrl, u + t * dim_action, dim_action); + ThreadPool& pool, int skip) { + // reset indices + evaluate_.clear(); + interpolate_.clear(); - // Jacobians - if (t == T - 1) { - // Jacobians - mjd_transitionFD(m, d, tol, mode, nullptr, nullptr, - DataAt(C, t * (dim_sensor * dim_state_derivative)), - nullptr); - } else { - // derivatives - mjd_transitionFD( - m, d, tol, mode, - DataAt(A, t * (dim_state_derivative * dim_state_derivative)), - DataAt(B, t * (dim_state_derivative * dim_action)), - DataAt(C, t * (dim_sensor * dim_state_derivative)), - DataAt(D, t * (dim_sensor * dim_action))); - } - }); + // evaluate indices + int s = skip + 1; + evaluate_.push_back(0); + for (int t = s; t < T - s; t += s) { + evaluate_.push_back(t); + } + evaluate_.push_back(T - 2); + evaluate_.push_back(T - 1); + + // interpolate indices + for (int t = 0, e = 0; t < T; t++) { + if (e == evaluate_.size() || evaluate_[e] > t) { + interpolate_.push_back(t); + } else { + e++; } - pool.WaitCount(count_before + T); } + + // evaluate derivatives + int count_before = pool.GetCount(); + for (int t : evaluate_) { + pool.Schedule([&m, &data, &A = A, &B = B, &C = C, &D = D, &x, &u, &h, + dim_state, dim_state_derivative, dim_action, dim_sensor, tol, + mode, t, T]() { + mjData* d = data[ThreadPool::WorkerId()].get(); + // set state + SetState(m, d, x + t * dim_state); + d->time = h[t]; + + // set action + mju_copy(d->ctrl, u + t * dim_action, dim_action); + + // Jacobians + if (t == T - 1) { + // Jacobians + mjd_transitionFD(m, d, tol, mode, nullptr, nullptr, + DataAt(C, t * (dim_sensor * dim_state_derivative)), + nullptr); + } else { + // derivatives + mjd_transitionFD( + m, d, tol, mode, + DataAt(A, t * (dim_state_derivative * dim_state_derivative)), + DataAt(B, t * (dim_state_derivative * dim_action)), + DataAt(C, t * (dim_sensor * dim_state_derivative)), + DataAt(D, t * (dim_sensor * dim_action))); + } + }); + } + pool.WaitCount(count_before + evaluate_.size()); + pool.ResetCount(); + + // interpolate derivatives + count_before = pool.GetCount(); + for (int t : interpolate_) { + pool.Schedule([&A = A, &B = B, &C = C, &D = D, &evaluate_ = this->evaluate_, + dim_state_derivative, dim_action, dim_sensor, t]() { + // find interval + int bounds[2]; + FindInterval(bounds, evaluate_, t, evaluate_.size()); + int e0 = evaluate_[bounds[0]]; + int e1 = evaluate_[bounds[1]]; + + // normalized input + double tt = double(t - e0) / double(e1 - e0); + if (bounds[0] == bounds[1]) { + tt = 0.0; + } + + // A + int nA = dim_state_derivative * dim_state_derivative; + double* Ai = DataAt(A, t * nA); + const double* AL = DataAt(A, e0 * nA); + const double* AU = DataAt(A, e1 * nA); + + mju_scl(Ai, AL, 1.0 - tt, nA); + mju_addToScl(Ai, AU, tt, nA); + + // B + int nB = dim_state_derivative * dim_action; + double* Bi = DataAt(B, t * nB); + const double* BL = DataAt(B, e0 * nB); + const double* BU = DataAt(B, e1 * nB); + + mju_scl(Bi, BL, 1.0 - tt, nB); + mju_addToScl(Bi, BU, tt, nB); + + // C + int nC = dim_sensor * dim_state_derivative; + double* Ci = DataAt(C, t * nC); + const double* CL = DataAt(C, e0 * nC); + const double* CU = DataAt(C, e1 * nC); + + mju_scl(Ci, CL, 1.0 - tt, nC); + mju_addToScl(Ci, CU, tt, nC); + + // D + int nD = dim_sensor * dim_action; + double* Di = DataAt(D, t * nD); + const double* DL = DataAt(D, e0 * nD); + const double* DU = DataAt(D, e1 * nD); + + mju_scl(Di, DL, 1.0 - tt, nD); + mju_addToScl(Di, DU, tt, nD); + }); + } + + pool.WaitCount(count_before + interpolate_.size()); pool.ResetCount(); } diff --git a/mjpc/planners/model_derivatives.h b/mjpc/planners/model_derivatives.h index 1ea295b83..6763ce8be 100644 --- a/mjpc/planners/model_derivatives.h +++ b/mjpc/planners/model_derivatives.h @@ -45,7 +45,7 @@ class ModelDerivatives { void Compute(const mjModel* m, const std::vector& data, const double* x, const double* u, const double* h, int dim_state, int dim_state_derivative, int dim_action, int dim_sensor, int T, - double tol, int mode, ThreadPool& pool); + double tol, int mode, ThreadPool& pool, int skip = 0); // Jacobians std::vector A; // model Jacobians wrt state @@ -56,6 +56,10 @@ class ModelDerivatives { // (T * dim_sensor * dim_state_derivative) std::vector D; // output Jacobians wrt action // (T * dim_sensor * dim_action) + + // indices + std::vector evaluate_; + std::vector interpolate_; }; } // namespace mjpc diff --git a/mjpc/utilities.cc b/mjpc/utilities.cc index f26ea824f..f07bd5bba 100644 --- a/mjpc/utilities.cc +++ b/mjpc/utilities.cc @@ -299,28 +299,6 @@ void LinearRange(double* t, double t_step, double t0, int N) { } } -// find interval in monotonic sequence containing value -void FindInterval(int* bounds, const std::vector& sequence, - double value, int length) { - // get bounds - auto it = - std::upper_bound(sequence.begin(), sequence.begin() + length, value); - int upper_bound = it - sequence.begin(); - int lower_bound = upper_bound - 1; - - // set bounds - if (lower_bound < 0) { - bounds[0] = 0; - bounds[1] = 0; - } else if (lower_bound > length - 1) { - bounds[0] = length - 1; - bounds[1] = length - 1; - } else { - bounds[0] = mju_max(lower_bound, 0); - bounds[1] = mju_min(upper_bound, length - 1); - } -} - // zero-order interpolation void ZeroInterpolation(double* output, double x, const std::vector& xs, const double* ys, int dim, int length) { diff --git a/mjpc/utilities.h b/mjpc/utilities.h index e25c9e576..0c05c520c 100644 --- a/mjpc/utilities.h +++ b/mjpc/utilities.h @@ -123,8 +123,27 @@ double* KeyActByName(const mjModel* m, const mjData* d, void LinearRange(double* t, double t_step, double t0, int N); // find interval in monotonic sequence containing value -void FindInterval(int* bounds, const std::vector& sequence, - double value, int length); +template +void FindInterval(int* bounds, const std::vector& sequence, double value, + int length) { + // get bounds + auto it = + std::upper_bound(sequence.begin(), sequence.begin() + length, value); + int upper_bound = it - sequence.begin(); + int lower_bound = upper_bound - 1; + + // set bounds + if (lower_bound < 0) { + bounds[0] = 0; + bounds[1] = 0; + } else if (lower_bound > length - 1) { + bounds[0] = length - 1; + bounds[1] = length - 1; + } else { + bounds[0] = mju_max(lower_bound, 0); + bounds[1] = mju_min(upper_bound, length - 1); + } +} // zero-order interpolation void ZeroInterpolation(double* output, double x, const std::vector& xs,