Skip to content

Commit 944a547

Browse files
committed
[mppi] implement control cost from paper
1 parent 08d7cf6 commit 944a547

File tree

1 file changed

+32
-6
lines changed

1 file changed

+32
-6
lines changed

nav2_mppi_controller/src/optimizer.cpp

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -538,16 +538,42 @@ void Optimizer::updateControlSequence()
538538
const bool is_holo = isHolonomic();
539539
auto & s = settings_;
540540

541-
auto vx_T = control_sequence_.vx.transpose();
542-
auto bounded_noises_vx = state_.cvx.rowwise() - vx_T;
541+
// Paper
542+
auto vx = state_.cvx; // K×N
543+
std::cout << "Vx: size " << vx.rows() << "x" << vx.cols() << std::endl;
544+
auto ux = control_sequence_.vx; // Nx1
545+
std::cout << "ux: size " << ux.rows() << "x" << ux.cols() << std::endl;
546+
547+
// u_t^T * v_t (paper) = v_t^T * u_t per sample
548+
// Eigen::VectorXf cross = (Vx * ux).eval(); // (KxN) * (Nx1) -> (Kx1)
549+
// Since all vars are Eigen::Array, need to do element-wise multiplication and then sum rows
550+
Eigen::ArrayXf cross_vx = (vx.rowwise() * ux.transpose()) // K×N, broadcast ux over rows
551+
.rowwise()
552+
.sum(); // Kx1
553+
std::cout << "cross_vx: " << cross_vx.rows() << "x" << cross_vx.cols() << " : " << cross_vx(Eigen::seq(0, 9)).transpose() << "\n";
554+
555+
// original mppi
556+
auto vx_T = control_sequence_.vx.transpose(); // 1xN
557+
auto bounded_noises_vx = state_.cvx.rowwise() - vx_T; // KxN
558+
auto costs_vx = ((bounded_noises_vx.rowwise() * vx_T).rowwise().sum()).eval();
559+
// costs_ += (gamma_vx * (ux_T * vx).rowwise().sum()).eval();
560+
std::cout << "costs_vx: " << costs_vx.rows() << "x" << costs_vx.cols() << " : " << costs_vx(Eigen::seq(0, 9)).transpose() << "\n";
561+
543562
const float gamma_vx = s.gamma / (s.sampling_std.vx * s.sampling_std.vx);
544-
costs_ += (gamma_vx * (bounded_noises_vx.rowwise() * vx_T).rowwise().sum()).eval();
563+
costs_ += gamma_vx * cross_vx;
545564

546565
if (s.sampling_std.wz > 0.0f) {
547-
auto wz_T = control_sequence_.wz.transpose();
548-
auto bounded_noises_wz = state_.cwz.rowwise() - wz_T;
566+
// auto wz_T = control_sequence_.wz.transpose(); // 1xN
567+
// auto bounded_noises_wz = state_.cwz.rowwise() - wz_T; // KxN
549568
const float gamma_wz = s.gamma / (s.sampling_std.wz * s.sampling_std.wz);
550-
costs_ += (gamma_wz * (bounded_noises_wz.rowwise() * wz_T).rowwise().sum()).eval();
569+
// costs_ += (gamma_wz * (bounded_noises_wz.rowwise() * wz_T).rowwise().sum()).eval();
570+
571+
auto wz = state_.cwz;
572+
auto uz = control_sequence_.wz; // Nx1
573+
Eigen::ArrayXf cross_wz = (wz.rowwise() * uz.transpose()) // (K×N), broadcast ux over rows
574+
.rowwise()
575+
.sum();
576+
costs_ += gamma_wz * cross_wz;
551577
}
552578

553579
if (is_holo) {

0 commit comments

Comments
 (0)