@@ -538,16 +538,42 @@ void Optimizer::updateControlSequence()
538538 const bool is_holo = isHolonomic ();
539539 auto & s = settings_;
540540
541- auto vx_T = control_sequence_.vx .transpose ();
542- auto bounded_noises_vx = state_.cvx .rowwise () - vx_T;
541+ // Paper
542+ auto vx = state_.cvx ; // K×N
543+ std::cout << " Vx: size " << vx.rows () << " x" << vx.cols () << std::endl;
544+ auto ux = control_sequence_.vx ; // Nx1
545+ std::cout << " ux: size " << ux.rows () << " x" << ux.cols () << std::endl;
546+
547+ // u_t^T * v_t (paper) = v_t^T * u_t per sample
548+ // Eigen::VectorXf cross = (Vx * ux).eval(); // (KxN) * (Nx1) -> (Kx1)
549+ // Since all vars are Eigen::Array, need to do element-wise multiplication and then sum rows
550+ Eigen::ArrayXf cross_vx = (vx.rowwise () * ux.transpose ()) // K×N, broadcast ux over rows
551+ .rowwise ()
552+ .sum (); // Kx1
553+ std::cout << " cross_vx: " << cross_vx.rows () << " x" << cross_vx.cols () << " : " << cross_vx (Eigen::seq (0 , 9 )).transpose () << " \n " ;
554+
555+ // original mppi
556+ auto vx_T = control_sequence_.vx .transpose (); // 1xN
557+ auto bounded_noises_vx = state_.cvx .rowwise () - vx_T; // KxN
558+ auto costs_vx = ((bounded_noises_vx.rowwise () * vx_T).rowwise ().sum ()).eval ();
559+ // costs_ += (gamma_vx * (ux_T * vx).rowwise().sum()).eval();
560+ std::cout << " costs_vx: " << costs_vx.rows () << " x" << costs_vx.cols () << " : " << costs_vx (Eigen::seq (0 , 9 )).transpose () << " \n " ;
561+
543562 const float gamma_vx = s.gamma / (s.sampling_std .vx * s.sampling_std .vx );
544- costs_ += ( gamma_vx * (bounded_noises_vx. rowwise () * vx_T). rowwise (). sum ()). eval () ;
563+ costs_ += gamma_vx * cross_vx ;
545564
546565 if (s.sampling_std .wz > 0 .0f ) {
547- auto wz_T = control_sequence_.wz .transpose ();
548- auto bounded_noises_wz = state_.cwz .rowwise () - wz_T;
566+ // auto wz_T = control_sequence_.wz.transpose(); // 1xN
567+ // auto bounded_noises_wz = state_.cwz.rowwise() - wz_T; // KxN
549568 const float gamma_wz = s.gamma / (s.sampling_std .wz * s.sampling_std .wz );
550- costs_ += (gamma_wz * (bounded_noises_wz.rowwise () * wz_T).rowwise ().sum ()).eval ();
569+ // costs_ += (gamma_wz * (bounded_noises_wz.rowwise() * wz_T).rowwise().sum()).eval();
570+
571+ auto wz = state_.cwz ;
572+ auto uz = control_sequence_.wz ; // Nx1
573+ Eigen::ArrayXf cross_wz = (wz.rowwise () * uz.transpose ()) // (K×N), broadcast ux over rows
574+ .rowwise ()
575+ .sum ();
576+ costs_ += gamma_wz * cross_wz;
551577 }
552578
553579 if (is_holo) {
0 commit comments