Skip to content

Commit

Permalink
refactor(mpf::predictor) resampling interval control in out of resamp…
Browse files Browse the repository at this point in the history
…ler (autowarefoundation#20)

* resampling interval management should be done out of resample()

Signed-off-by: Kento Yabuuchi <kento.yabuuchi.2@tier4.jp>

* resampler class throw exeption rather than optional

Signed-off-by: Kento Yabuuchi <kento.yabuuchi.2@tier4.jp>

* split files for resampling_history

Signed-off-by: Kento Yabuuchi <kento.yabuuchi.2@tier4.jp>

* split files for experimental/suspention_adaptor

Signed-off-by: Kento Yabuuchi <kento.yabuuchi.2@tier4.jp>

---------

Signed-off-by: Kento Yabuuchi <kento.yabuuchi.2@tier4.jp>
  • Loading branch information
KYabuuchi committed Jun 6, 2023
1 parent 966aedc commit db7768a
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 235 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright 2023 TIER IV, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <rclcpp/rclcpp.hpp>

#include <geometry_msgs/msg/pose_with_covariance_stamped.hpp>
#include <sensor_msgs/msg/image.hpp>

#include <optional>

namespace pcdless::modularized_particle_filter
{
struct SwapModeAdaptor
{
using Image = sensor_msgs::msg::Image;
using PoseCovStamped = geometry_msgs::msg::PoseWithCovarianceStamped;

SwapModeAdaptor(rclcpp::Node * node) : logger_(rclcpp::get_logger("swap_adaptor"))
{
auto on_ekf_pose = [this](const PoseCovStamped & pose) -> void { init_pose_opt_ = pose; };
auto on_image = [this](const Image & msg) -> void {
stamp_opt_ = rclcpp::Time(msg.header.stamp);
};

sub_image_ = node->create_subscription<Image>("image", 1, on_image);
sub_pose_ = node->create_subscription<PoseCovStamped>("pose_cov", 1, on_ekf_pose);
clock_ = node->get_clock();

state_is_active = false;
state_is_activating = false;
}

std::optional<rclcpp::Time> stamp_opt_{std::nullopt};
std::optional<PoseCovStamped> init_pose_opt_{std::nullopt};
rclcpp::Subscription<PoseCovStamped>::SharedPtr sub_pose_;
rclcpp::Subscription<Image>::SharedPtr sub_image_;
rclcpp::Clock::SharedPtr clock_;
rclcpp::Logger logger_;

bool state_is_active;
bool state_is_activating;

PoseCovStamped init_pose() { return init_pose_opt_.value(); }

bool should_call_initialize()
{
if (state_is_activating && init_pose_opt_.has_value()) {
if (!state_is_active) {
state_is_active = true;
state_is_activating = false;
return true;
}
}
return false;
}

bool should_keep_update()
{
if (!stamp_opt_.has_value()) {
RCLCPP_INFO_STREAM_THROTTLE(logger_, *clock_, 1000, "yabloc should stop");
state_is_active = false;
return false;
}

const double dt = (clock_->now() - stamp_opt_.value()).seconds();
if (dt > 3) {
RCLCPP_INFO_STREAM_THROTTLE(logger_, *clock_, 1000, "yabloc should stop");
state_is_active = false;
return false;
} else {
RCLCPP_INFO_STREAM_THROTTLE(logger_, *clock_, 1000, "yabloc should keep estimation");
state_is_activating = true;
return true;
}
}
};

} // namespace pcdless::modularized_particle_filter
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#define MODULARIZED_PARTICLE_FILTER__PREDICTION__PREDICTOR_HPP_

#include "modularized_particle_filter/common/visualize.hpp"
#include "modularized_particle_filter/prediction/experimental/suspension_adaptor.hpp"
#include "modularized_particle_filter/prediction/resampler.hpp"

#include <rclcpp/rclcpp.hpp>
Expand All @@ -25,7 +26,6 @@
#include <geometry_msgs/msg/twist_stamped.hpp>
#include <geometry_msgs/msg/twist_with_covariance_stamped.hpp>
#include <modularized_particle_filter_msgs/msg/particle_array.hpp>
#include <sensor_msgs/msg/image.hpp>
#include <std_msgs/msg/float32.hpp>

#include <tf2_ros/transform_broadcaster.h>
Expand All @@ -43,7 +43,6 @@ class Predictor : public rclcpp::Node
using TwistCovStamped = geometry_msgs::msg::TwistWithCovarianceStamped;
using TwistStamped = geometry_msgs::msg::TwistStamped;
using OptParticleArray = std::optional<ParticleArray>;
using Image = sensor_msgs::msg::Image;

Predictor();

Expand All @@ -66,6 +65,7 @@ class Predictor : public rclcpp::Node

const bool visualize_;
const int number_of_particles_;
// The minimum resampling interval is longer than this.
const float resampling_interval_seconds_;
const float static_linear_covariance_;
const float static_angular_covariance_;
Expand All @@ -76,6 +76,7 @@ class Predictor : public rclcpp::Node
std::shared_ptr<RetroactiveResampler> resampler_ptr_{nullptr};
std::optional<ParticleArray> particle_array_opt_{std::nullopt};
std::optional<TwistCovStamped> twist_opt_{std::nullopt};
std::optional<double> previous_resampling_time_opt_{std::nullopt};

// Callback
void on_gnss_pose(const PoseStamped::ConstSharedPtr initialpose);
Expand All @@ -92,23 +93,6 @@ class Predictor : public rclcpp::Node

void publish_mean_pose(const geometry_msgs::msg::Pose & mean_pose, const rclcpp::Time & stamp);

struct SwapModeAdaptor
{
SwapModeAdaptor(rclcpp::Node * node);
std::optional<rclcpp::Time> stamp_opt_{std::nullopt};
std::optional<PoseCovStamped> init_pose_opt_{std::nullopt};
rclcpp::Subscription<PoseCovStamped>::SharedPtr sub_pose_;
rclcpp::Subscription<Image>::SharedPtr sub_image_;
rclcpp::Clock::SharedPtr clock_;

bool state_is_active;
bool state_is_activating;

bool should_keep_update();
bool should_call_initialize();
PoseCovStamped init_pose() { return init_pose_opt_.value(); }
};

std::unique_ptr<SwapModeAdaptor> swap_mode_adaptor_{nullptr};
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,97 +12,41 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef MODULARIZED_PARTICLE_FILTER__CORRECTION__RETROACTIVE_RESAMPLER_HPP_
#define MODULARIZED_PARTICLE_FILTER__CORRECTION__RETROACTIVE_RESAMPLER_HPP_
#ifndef MODULARIZED_PARTICLE_FILTER__PREDICTION__RESAMPLER_HPP_
#define MODULARIZED_PARTICLE_FILTER__PREDICTION__RESAMPLER_HPP_

#include "modularized_particle_filter/prediction/resampling_history.hpp"

#include <rclcpp/logger.hpp>

#include "modularized_particle_filter_msgs/msg/particle_array.hpp"

#include <algorithm>
#include <iostream>
#include <numeric>
#include <optional>
namespace pcdless::modularized_particle_filter
{

class History
{
public:
History(int max_history_num, int number_of_particles)
: max_history_num_(max_history_num), number_of_particles_(number_of_particles)
{
resampling_history_.resize(max_history_num);

for (auto & generation : resampling_history_) {
generation.resize(number_of_particles);
std::iota(generation.begin(), generation.end(), 0);
}
}

bool check_history_validity() const
{
for (auto & generation : resampling_history_) {
bool result = std::any_of(generation.begin(), generation.end(), [this](int x) {
return x < 0 || x >= number_of_particles_;
});

if (result) {
return false;
}
}
return true;
}

std::vector<int> & operator[](int generation_id)
{
return resampling_history_.at(generation_id % max_history_num_);
}

const std::vector<int> & operator[](int generation_id) const
{
return resampling_history_.at(generation_id % max_history_num_);
}

private:
// Number of updates to keep resampling history.
// Resampling records prior to this will not be kept.
const int max_history_num_;
const int number_of_particles_;
std::vector<std::vector<int>> resampling_history_;
};

class RetroactiveResampler
{
public:
using Particle = modularized_particle_filter_msgs::msg::Particle;
using ParticleArray = modularized_particle_filter_msgs::msg::ParticleArray;
using OptParticleArray = std::optional<ParticleArray>;

RetroactiveResampler(
float resampling_interval_seconds, int number_of_particles, int max_history_num);
RetroactiveResampler(int number_of_particles, int max_history_num);

OptParticleArray add_weight_retroactively(
ParticleArray add_weight_retroactively(
const ParticleArray & predicted_particles, const ParticleArray & weighted_particles);

std::optional<ParticleArray> resample(const ParticleArray & predicted_particles);
ParticleArray resample(const ParticleArray & predicted_particles);

private:
// The minimum resampling interval is longer than this.
// It is assumed that users will call the resampling() function frequently.
const float resampling_interval_seconds_;
// Number of updates to keep resampling history.
// Resampling records prior to this will not be kept.
const int max_history_num_;
// Number of particles to be managed.
const int number_of_particles_;
//
// ROS logger
rclcpp::Logger logger_;
// Previous resampling time
std::optional<double> previous_resampling_time_opt_{std::nullopt};
// This is handled like ring buffer.
// It keeps track of which particles each particle has transformed into at each resampling.
History resampling_history_;
ResamplingHistory resampling_history_;
// Indicates how many times the particles were resampled.
int latest_resampling_generation_;

Expand All @@ -113,4 +57,4 @@ class RetroactiveResampler
};
} // namespace pcdless::modularized_particle_filter

#endif // MODULARIZED_PARTICLE_FILTER__CORRECTION__RETROACTIVE_RESAMPLER_HPP_
#endif // MODULARIZED_PARTICLE_FILTER__PREDICTION__RESAMPLER_HPP_
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright 2023 TIER IV, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef MODULARIZED_PARTICLE_FILTER__PREDICTION__RESAMPLING_HISTORY_HPP_
#define MODULARIZED_PARTICLE_FILTER__PREDICTION__RESAMPLING_HISTORY_HPP_

#include <algorithm>
#include <numeric>
#include <vector>

namespace pcdless::modularized_particle_filter
{
class ResamplingHistory
{
public:
ResamplingHistory(int max_history_num, int number_of_particles)
: max_history_num_(max_history_num), number_of_particles_(number_of_particles)
{
resampling_history_.resize(max_history_num);

for (auto & generation : resampling_history_) {
generation.resize(number_of_particles);
std::iota(generation.begin(), generation.end(), 0);
}
}

bool check_history_validity() const
{
for (auto & generation : resampling_history_) {
bool result = std::any_of(generation.begin(), generation.end(), [this](int x) {
return x < 0 || x >= number_of_particles_;
});

if (result) {
return false;
}
}
return true;
}

std::vector<int> & operator[](int generation_id)
{
return resampling_history_.at(generation_id % max_history_num_);
}

const std::vector<int> & operator[](int generation_id) const
{
return resampling_history_.at(generation_id % max_history_num_);
}

private:
// Number of updates to keep resampling history.
// Resampling records prior to this will not be kept.
const int max_history_num_;
const int number_of_particles_;
std::vector<std::vector<int>> resampling_history_;
};

} // namespace pcdless::modularized_particle_filter

#endif // MODULARIZED_PARTICLE_FILTER__PREDICTION__RESAMPLING_HISTORY_HPP_
Loading

0 comments on commit db7768a

Please sign in to comment.