Skip to content

Commit

Permalink
feat(traffic_light_classifier): update classifier model (autowarefoun…
Browse files Browse the repository at this point in the history
…dation#2820)

* feat: update traffic light classifier

Signed-off-by: wep21 <border_goldenmarket@yahoo.co.jp>

* remove unused file

Signed-off-by: Daisuke Nishimatsu <border_goldenmarket@yahoo.co.jp>

* remove unused header

Signed-off-by: Daisuke Nishimatsu <border_goldenmarket@yahoo.co.jp>

* add variable to apply softmax

Signed-off-by: Daisuke Nishimatsu <border_goldenmarket@yahoo.co.jp>

* update visualization

Signed-off-by: Daisuke Nishimatsu <border_goldenmarket@yahoo.co.jp>

* apply pre-commit

Signed-off-by: Daisuke Nishimatsu <border_goldenmarket@yahoo.co.jp>

* add debug term

Signed-off-by: Daisuke Nishimatsu <border_goldenmarket@yahoo.co.jp>

* change default parameter

Signed-off-by: wep21 <border_goldenmarket@yahoo.co.jp>

* fix debug node

Signed-off-by: wep21 <border_goldenmarket@yahoo.co.jp>

* change default node name

Signed-off-by: wep21 <border_goldenmarket@yahoo.co.jp>

* change default

Signed-off-by: Daisuke Nishimatsu <border_goldenmarket@yahoo.co.jp>

---------

Signed-off-by: wep21 <border_goldenmarket@yahoo.co.jp>
Signed-off-by: Daisuke Nishimatsu <border_goldenmarket@yahoo.co.jp>
  • Loading branch information
wep21 authored Feb 16, 2023
1 parent 476afcb commit 0dc2247
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 41 deletions.
16 changes: 16 additions & 0 deletions perception/traffic_light_classifier/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,22 @@ if(TRT_AVAIL AND CUDA_AVAIL AND CUDNN_AVAIL)
EXECUTABLE traffic_light_classifier_node
)

ament_auto_add_library(single_image_debug_inference_node SHARED
src/cnn_classifier.cpp
src/color_classifier.cpp
src/nodelet.cpp
src/single_image_debug_inference_node.cpp
)
target_link_libraries(single_image_debug_inference_node
libutils
opencv_core
opencv_highgui
)
rclcpp_components_register_node(single_image_debug_inference_node
PLUGIN "traffic_light::SingleImageDebugInferenceNode"
EXECUTABLE single_image_debug_inference
)

ament_auto_package(INSTALL_TO_SHARE
data
launch
Expand Down
29 changes: 0 additions & 29 deletions perception/traffic_light_classifier/cfg/HSVFilter.cfg

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class CNNClassifier : public ClassifierInterface
void preProcess(cv::Mat & image, std::vector<float> & tensor, bool normalize = true);
bool postProcess(
std::vector<float> & output_data_host,
autoware_auto_perception_msgs::msg::TrafficSignal & traffic_signal);
autoware_auto_perception_msgs::msg::TrafficSignal & traffic_signal, bool apply_softmax = false);
bool readLabelfile(std::string filepath, std::vector<std::string> & labels);
bool isColorLabel(const std::string label);
void calcSoftmax(std::vector<float> & data, std::vector<float> & probs, int num_output);
Expand Down Expand Up @@ -106,6 +106,7 @@ class CNNClassifier : public ClassifierInterface
int input_c_;
int input_h_;
int input_w_;
bool apply_softmax_;
};

} // namespace traffic_light
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,12 @@ class TrafficLightClassifierNodelet : public rclcpp::Node
const autoware_auto_perception_msgs::msg::TrafficLightRoiArray::ConstSharedPtr &
input_rois_msg);

private:
enum ClassifierType {
HSVFilter = 0,
CNN = 1,
};

private:
void connectCb();

rclcpp::TimerBase::SharedPtr timer_;
Expand Down
18 changes: 12 additions & 6 deletions perception/traffic_light_classifier/src/cnn_classifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,13 @@ CNNClassifier::CNNClassifier(rclcpp::Node * node_ptr) : node_ptr_(node_ptr)
input_c_ = node_ptr_->declare_parameter("input_c", 3);
input_h_ = node_ptr_->declare_parameter("input_h", 224);
input_w_ = node_ptr_->declare_parameter("input_w", 224);
auto input_name = node_ptr_->declare_parameter("input_name", "input_0");
auto output_name = node_ptr_->declare_parameter("output_name", "output_0");
apply_softmax_ = node_ptr_->declare_parameter("apply_softmax", true);

readLabelfile(label_file_path, labels_);

trt_ = std::make_shared<Tn::TrtCommon>(model_file_path, precision);
trt_ = std::make_shared<Tn::TrtCommon>(model_file_path, precision, input_name, output_name);
trt_->setup();
}

Expand Down Expand Up @@ -79,7 +82,7 @@ bool CNNClassifier::getTrafficSignal(
output_data_host.data(), output_data_device.get(), num_output * sizeof(float),
cudaMemcpyDeviceToHost);

postProcess(output_data_host, traffic_signal);
postProcess(output_data_host, traffic_signal, apply_softmax_);

/* debug */
if (0 < image_pub_.getNumSubscribers()) {
Expand Down Expand Up @@ -153,19 +156,22 @@ void CNNClassifier::preProcess(cv::Mat & image, std::vector<float> & input_tenso

bool CNNClassifier::postProcess(
std::vector<float> & output_tensor,
autoware_auto_perception_msgs::msg::TrafficSignal & traffic_signal)
autoware_auto_perception_msgs::msg::TrafficSignal & traffic_signal, bool apply_softmax)
{
std::vector<float> probs;
int num_output = trt_->getNumOutput();
calcSoftmax(output_tensor, probs, num_output);
if (apply_softmax) {
calcSoftmax(output_tensor, probs, num_output);
}
std::vector<size_t> sorted_indices = argsort(output_tensor, num_output);

// ROS_INFO("label: %s, score: %.2f\%",
// labels_[sorted_indices[0]].c_str(),
// probs[sorted_indices[0]] * 100);

std::string match_label = labels_[sorted_indices[0]];
float probability = probs[sorted_indices[0]];
size_t max_indice = sorted_indices.front();
std::string match_label = labels_[max_indice];
float probability = apply_softmax ? probs[max_indice] : output_tensor[max_indice];

// label names are assumed to be comma-separated to represent each lamp
// e.g.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// Copyright 2023 Tier IV, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <rclcpp/rclcpp.hpp>

#if ENABLE_GPU
#include <traffic_light_classifier/cnn_classifier.hpp>
#endif

#include <traffic_light_classifier/color_classifier.hpp>
#include <traffic_light_classifier/nodelet.hpp>

#include <memory>
#include <string>

namespace
{
std::string toString(const uint8_t state)
{
if (state == autoware_auto_perception_msgs::msg::TrafficLight::RED) {
return "red";
} else if (state == autoware_auto_perception_msgs::msg::TrafficLight::AMBER) {
return "yellow";
} else if (state == autoware_auto_perception_msgs::msg::TrafficLight::GREEN) {
return "green";
} else if (state == autoware_auto_perception_msgs::msg::TrafficLight::WHITE) {
return "white";
} else if (state == autoware_auto_perception_msgs::msg::TrafficLight::CIRCLE) {
return "circle";
} else if (state == autoware_auto_perception_msgs::msg::TrafficLight::LEFT_ARROW) {
return "left";
} else if (state == autoware_auto_perception_msgs::msg::TrafficLight::RIGHT_ARROW) {
return "right";
} else if (state == autoware_auto_perception_msgs::msg::TrafficLight::UP_ARROW) {
return "straight";
} else if (state == autoware_auto_perception_msgs::msg::TrafficLight::DOWN_ARROW) {
return "down";
} else if (state == autoware_auto_perception_msgs::msg::TrafficLight::DOWN_LEFT_ARROW) {
return "down_left";
} else if (state == autoware_auto_perception_msgs::msg::TrafficLight::DOWN_RIGHT_ARROW) {
return "down_right";
} else if (state == autoware_auto_perception_msgs::msg::TrafficLight::CROSS) {
return "cross";
} else if (state == autoware_auto_perception_msgs::msg::TrafficLight::UNKNOWN) {
return "unknown";
} else {
return "";
}
}
} // namespace

namespace traffic_light
{
class SingleImageDebugInferenceNode : public rclcpp::Node
{
public:
explicit SingleImageDebugInferenceNode(const rclcpp::NodeOptions & node_options)
: Node("single_image_debug_inference", node_options)
{
const auto image_path = declare_parameter("image_path", "");

int classifier_type = this->declare_parameter(
"classifier_type",
static_cast<int>(TrafficLightClassifierNodelet::ClassifierType::HSVFilter));
if (classifier_type == TrafficLightClassifierNodelet::ClassifierType::HSVFilter) {
classifier_ptr_ = std::make_unique<ColorClassifier>(this);
} else if (classifier_type == TrafficLightClassifierNodelet::ClassifierType::CNN) {
#if ENABLE_GPU
classifier_ptr_ = std::make_unique<CNNClassifier>(this);
#else
RCLCPP_ERROR(get_logger(), "please install CUDA, CUDNN and TensorRT to use cnn classifier");
#endif
}

image_ = cv::imread(image_path);
if (image_.empty()) {
RCLCPP_ERROR(get_logger(), "image is empty");
return;
}
cv::namedWindow("inference image", cv::WINDOW_NORMAL);
cv::setMouseCallback("inference image", SingleImageDebugInferenceNode::onMouse, this);

cv::imshow("inference image", image_);

// loop until q character is pressed
while (cv::waitKey(0) != 113) {
}
cv::destroyAllWindows();
rclcpp::shutdown();
}

private:
static void onMouse(int event, int x, int y, int flags, void * param)
{
SingleImageDebugInferenceNode * node = static_cast<SingleImageDebugInferenceNode *>(param);
if (node) {
node->inferWithCrop(event, x, y, flags);
}
}

void inferWithCrop(int action, int x, int y, [[maybe_unused]] int flags)
{
if (action == cv::EVENT_LBUTTONDOWN) {
top_left_corner_ = cv::Point(x, y);
} else if (action == cv::EVENT_LBUTTONUP) {
bottom_right_corner_ = cv::Point(x, y);
cv::Mat tmp = image_.clone();
cv::Mat crop = image_(cv::Rect{top_left_corner_, bottom_right_corner_}).clone();
if (crop.empty()) {
RCLCPP_ERROR(get_logger(), "crop image is empty");
return;
}
cv::cvtColor(crop, crop, cv::COLOR_BGR2RGB);
autoware_auto_perception_msgs::msg::TrafficSignal traffic_signal;
if (!classifier_ptr_->getTrafficSignal(crop, traffic_signal)) {
RCLCPP_ERROR(get_logger(), "failed to classify image");
return;
}
cv::Scalar color;
cv::Scalar text_color;
for (const auto & light : traffic_signal.lights) {
auto color_str = toString(light.color);
auto shape_str = toString(light.shape);
auto confidence_str = std::to_string(light.confidence);
if (shape_str == "circle") {
if (color_str == "red") {
color = cv::Scalar(0, 0, 255);
} else if (color_str == "green") {
color = cv::Scalar(0, 255, 0);
} else if (color_str == "yellow") {
color = cv::Scalar(0, 255, 255);
} else if (color_str == "white") {
color = cv::Scalar(0, 0, 0);
} else {
color = cv::Scalar(255, 255, 255);
}
}
RCLCPP_INFO_STREAM(get_logger(), color_str << " " << shape_str << " " << confidence_str);
}
cv::rectangle(tmp, top_left_corner_, bottom_right_corner_, color, 2, 8);
cv::imshow("inference image", tmp);
}
}

cv::Point top_left_corner_;
cv::Point bottom_right_corner_;
cv::Mat image_;
std::unique_ptr<ClassifierInterface> classifier_ptr_;
};
} // namespace traffic_light

#include "rclcpp_components/register_node_macro.hpp"
RCLCPP_COMPONENTS_REGISTER_NODE(traffic_light::SingleImageDebugInferenceNode)
7 changes: 4 additions & 3 deletions perception/traffic_light_classifier/utils/trt_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ void check_error(const ::cudaError_t e, decltype(__FILE__) f, decltype(__LINE__)
}
}

TrtCommon::TrtCommon(std::string model_path, std::string precision)
TrtCommon::TrtCommon(
std::string model_path, std::string precision, std::string input_name, std::string output_name)
: model_file_path_(model_path),
precision_(precision),
input_name_("input_0"),
output_name_("output_0"),
input_name_(input_name),
output_name_(output_name),
is_initialized_(false)
{
runtime_ = UniquePtr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(logger_));
Expand Down
3 changes: 2 additions & 1 deletion perception/traffic_light_classifier/utils/trt_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ Tn::UniquePtr<T> make_unique()
class TrtCommon
{
public:
TrtCommon(std::string model_path, std::string precision);
TrtCommon(
std::string model_path, std::string precision, std::string input_name, std::string output_name);
~TrtCommon() {}

bool loadEngine(std::string engine_file_path);
Expand Down

0 comments on commit 0dc2247

Please sign in to comment.