From 0dc2247f40f21ae9d987a83b250314604b03b7d3 Mon Sep 17 00:00:00 2001 From: Daisuke Nishimatsu <42202095+wep21@users.noreply.github.com> Date: Thu, 16 Feb 2023 23:22:18 +0900 Subject: [PATCH] feat(traffic_light_classifier): update classifier model (#2820) * feat: update traffic light classifier Signed-off-by: wep21 * remove unused file Signed-off-by: Daisuke Nishimatsu * remove unused header Signed-off-by: Daisuke Nishimatsu * add variable to apply softmax Signed-off-by: Daisuke Nishimatsu * update visualization Signed-off-by: Daisuke Nishimatsu * apply pre-commit Signed-off-by: Daisuke Nishimatsu * add debug term Signed-off-by: Daisuke Nishimatsu * change default parameter Signed-off-by: wep21 * fix debug node Signed-off-by: wep21 * change default node name Signed-off-by: wep21 * change default Signed-off-by: Daisuke Nishimatsu --------- Signed-off-by: wep21 Signed-off-by: Daisuke Nishimatsu --- .../traffic_light_classifier/CMakeLists.txt | 16 ++ .../cfg/HSVFilter.cfg | 29 ---- .../cnn_classifier.hpp | 3 +- .../traffic_light_classifier/nodelet.hpp | 3 +- .../src/cnn_classifier.cpp | 18 +- .../src/single_image_debug_inference_node.cpp | 164 ++++++++++++++++++ .../utils/trt_common.cpp | 7 +- .../utils/trt_common.hpp | 3 +- 8 files changed, 202 insertions(+), 41 deletions(-) delete mode 100755 perception/traffic_light_classifier/cfg/HSVFilter.cfg create mode 100644 perception/traffic_light_classifier/src/single_image_debug_inference_node.cpp diff --git a/perception/traffic_light_classifier/CMakeLists.txt b/perception/traffic_light_classifier/CMakeLists.txt index 1b61e755567a2..e125abc7b435a 100644 --- a/perception/traffic_light_classifier/CMakeLists.txt +++ b/perception/traffic_light_classifier/CMakeLists.txt @@ -146,6 +146,22 @@ if(TRT_AVAIL AND CUDA_AVAIL AND CUDNN_AVAIL) EXECUTABLE traffic_light_classifier_node ) + ament_auto_add_library(single_image_debug_inference_node SHARED + src/cnn_classifier.cpp + src/color_classifier.cpp + src/nodelet.cpp + src/single_image_debug_inference_node.cpp + ) + target_link_libraries(single_image_debug_inference_node + libutils + opencv_core + opencv_highgui + ) + rclcpp_components_register_node(single_image_debug_inference_node + PLUGIN "traffic_light::SingleImageDebugInferenceNode" + EXECUTABLE single_image_debug_inference + ) + ament_auto_package(INSTALL_TO_SHARE data launch diff --git a/perception/traffic_light_classifier/cfg/HSVFilter.cfg b/perception/traffic_light_classifier/cfg/HSVFilter.cfg deleted file mode 100755 index a662704741d2a..0000000000000 --- a/perception/traffic_light_classifier/cfg/HSVFilter.cfg +++ /dev/null @@ -1,29 +0,0 @@ -#! /usr/bin/env python - -# set up parameters that we care about -PACKAGE = 'traffic_light_classifier' - -from dynamic_reconfigure.parameter_generator_catkin import * - -gen = ParameterGenerator () -# def add (self, name, paramtype, level, description, default = None, min = None, max = None, edit_method = ""): -gen.add ("green_min_h", int_t, 0, "min h green", 50, 0, 180) -gen.add ("green_max_h", int_t, 0, "max h green", 120, 0, 180) -gen.add ("green_min_s", int_t, 0, "min s green", 100, 0, 255) -gen.add ("green_max_s", int_t, 0, "max s green", 200, 0, 255) -gen.add ("green_min_v", int_t, 0, "min v green", 150, 0, 255) -gen.add ("green_max_v", int_t, 0, "max v green", 255, 0, 255) -gen.add ("yellow_min_h", int_t, 0, "min h yellow", 0, 0, 180) -gen.add ("yellow_max_h", int_t, 0, "max h yellow", 50, 0, 180) -gen.add ("yellow_min_s", int_t, 0, "min s yellow", 80, 0, 255) -gen.add ("yellow_max_s", int_t, 0, "max s yellow", 200, 0, 255) -gen.add ("yellow_min_v", int_t, 0, "min v yellow", 150, 0, 255) -gen.add ("yellow_max_v", int_t, 0, "max v yellow", 255, 0, 255) -gen.add ("red_min_h", int_t, 0, "min h red", 160, 0, 180) -gen.add ("red_max_h", int_t, 0, "max h red", 180, 0, 180) -gen.add ("red_min_s", int_t, 0, "min s red", 100, 0, 255) -gen.add ("red_max_s", int_t, 0, "max s red", 255, 0, 255) -gen.add ("red_min_v", int_t, 0, "min v red", 150, 0, 255) -gen.add ("red_max_v", int_t, 0, "max v red", 255, 0, 255) - -exit (gen.generate (PACKAGE, "traffic_light_classifier", "HSVFilter")) diff --git a/perception/traffic_light_classifier/include/traffic_light_classifier/cnn_classifier.hpp b/perception/traffic_light_classifier/include/traffic_light_classifier/cnn_classifier.hpp index ef91f58310caa..04914889f0756 100644 --- a/perception/traffic_light_classifier/include/traffic_light_classifier/cnn_classifier.hpp +++ b/perception/traffic_light_classifier/include/traffic_light_classifier/cnn_classifier.hpp @@ -48,7 +48,7 @@ class CNNClassifier : public ClassifierInterface void preProcess(cv::Mat & image, std::vector & tensor, bool normalize = true); bool postProcess( std::vector & output_data_host, - autoware_auto_perception_msgs::msg::TrafficSignal & traffic_signal); + autoware_auto_perception_msgs::msg::TrafficSignal & traffic_signal, bool apply_softmax = false); bool readLabelfile(std::string filepath, std::vector & labels); bool isColorLabel(const std::string label); void calcSoftmax(std::vector & data, std::vector & probs, int num_output); @@ -106,6 +106,7 @@ class CNNClassifier : public ClassifierInterface int input_c_; int input_h_; int input_w_; + bool apply_softmax_; }; } // namespace traffic_light diff --git a/perception/traffic_light_classifier/include/traffic_light_classifier/nodelet.hpp b/perception/traffic_light_classifier/include/traffic_light_classifier/nodelet.hpp index c04e610eb939b..8216d1fd2edab 100644 --- a/perception/traffic_light_classifier/include/traffic_light_classifier/nodelet.hpp +++ b/perception/traffic_light_classifier/include/traffic_light_classifier/nodelet.hpp @@ -58,11 +58,12 @@ class TrafficLightClassifierNodelet : public rclcpp::Node const autoware_auto_perception_msgs::msg::TrafficLightRoiArray::ConstSharedPtr & input_rois_msg); -private: enum ClassifierType { HSVFilter = 0, CNN = 1, }; + +private: void connectCb(); rclcpp::TimerBase::SharedPtr timer_; diff --git a/perception/traffic_light_classifier/src/cnn_classifier.cpp b/perception/traffic_light_classifier/src/cnn_classifier.cpp index 9147fbb8773a4..cb09b675ff010 100644 --- a/perception/traffic_light_classifier/src/cnn_classifier.cpp +++ b/perception/traffic_light_classifier/src/cnn_classifier.cpp @@ -39,10 +39,13 @@ CNNClassifier::CNNClassifier(rclcpp::Node * node_ptr) : node_ptr_(node_ptr) input_c_ = node_ptr_->declare_parameter("input_c", 3); input_h_ = node_ptr_->declare_parameter("input_h", 224); input_w_ = node_ptr_->declare_parameter("input_w", 224); + auto input_name = node_ptr_->declare_parameter("input_name", "input_0"); + auto output_name = node_ptr_->declare_parameter("output_name", "output_0"); + apply_softmax_ = node_ptr_->declare_parameter("apply_softmax", true); readLabelfile(label_file_path, labels_); - trt_ = std::make_shared(model_file_path, precision); + trt_ = std::make_shared(model_file_path, precision, input_name, output_name); trt_->setup(); } @@ -79,7 +82,7 @@ bool CNNClassifier::getTrafficSignal( output_data_host.data(), output_data_device.get(), num_output * sizeof(float), cudaMemcpyDeviceToHost); - postProcess(output_data_host, traffic_signal); + postProcess(output_data_host, traffic_signal, apply_softmax_); /* debug */ if (0 < image_pub_.getNumSubscribers()) { @@ -153,19 +156,22 @@ void CNNClassifier::preProcess(cv::Mat & image, std::vector & input_tenso bool CNNClassifier::postProcess( std::vector & output_tensor, - autoware_auto_perception_msgs::msg::TrafficSignal & traffic_signal) + autoware_auto_perception_msgs::msg::TrafficSignal & traffic_signal, bool apply_softmax) { std::vector probs; int num_output = trt_->getNumOutput(); - calcSoftmax(output_tensor, probs, num_output); + if (apply_softmax) { + calcSoftmax(output_tensor, probs, num_output); + } std::vector sorted_indices = argsort(output_tensor, num_output); // ROS_INFO("label: %s, score: %.2f\%", // labels_[sorted_indices[0]].c_str(), // probs[sorted_indices[0]] * 100); - std::string match_label = labels_[sorted_indices[0]]; - float probability = probs[sorted_indices[0]]; + size_t max_indice = sorted_indices.front(); + std::string match_label = labels_[max_indice]; + float probability = apply_softmax ? probs[max_indice] : output_tensor[max_indice]; // label names are assumed to be comma-separated to represent each lamp // e.g. diff --git a/perception/traffic_light_classifier/src/single_image_debug_inference_node.cpp b/perception/traffic_light_classifier/src/single_image_debug_inference_node.cpp new file mode 100644 index 0000000000000..f324b04f1628c --- /dev/null +++ b/perception/traffic_light_classifier/src/single_image_debug_inference_node.cpp @@ -0,0 +1,164 @@ +// Copyright 2023 Tier IV, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#if ENABLE_GPU +#include +#endif + +#include +#include + +#include +#include + +namespace +{ +std::string toString(const uint8_t state) +{ + if (state == autoware_auto_perception_msgs::msg::TrafficLight::RED) { + return "red"; + } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::AMBER) { + return "yellow"; + } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::GREEN) { + return "green"; + } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::WHITE) { + return "white"; + } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::CIRCLE) { + return "circle"; + } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::LEFT_ARROW) { + return "left"; + } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::RIGHT_ARROW) { + return "right"; + } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::UP_ARROW) { + return "straight"; + } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::DOWN_ARROW) { + return "down"; + } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::DOWN_LEFT_ARROW) { + return "down_left"; + } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::DOWN_RIGHT_ARROW) { + return "down_right"; + } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::CROSS) { + return "cross"; + } else if (state == autoware_auto_perception_msgs::msg::TrafficLight::UNKNOWN) { + return "unknown"; + } else { + return ""; + } +} +} // namespace + +namespace traffic_light +{ +class SingleImageDebugInferenceNode : public rclcpp::Node +{ +public: + explicit SingleImageDebugInferenceNode(const rclcpp::NodeOptions & node_options) + : Node("single_image_debug_inference", node_options) + { + const auto image_path = declare_parameter("image_path", ""); + + int classifier_type = this->declare_parameter( + "classifier_type", + static_cast(TrafficLightClassifierNodelet::ClassifierType::HSVFilter)); + if (classifier_type == TrafficLightClassifierNodelet::ClassifierType::HSVFilter) { + classifier_ptr_ = std::make_unique(this); + } else if (classifier_type == TrafficLightClassifierNodelet::ClassifierType::CNN) { +#if ENABLE_GPU + classifier_ptr_ = std::make_unique(this); +#else + RCLCPP_ERROR(get_logger(), "please install CUDA, CUDNN and TensorRT to use cnn classifier"); +#endif + } + + image_ = cv::imread(image_path); + if (image_.empty()) { + RCLCPP_ERROR(get_logger(), "image is empty"); + return; + } + cv::namedWindow("inference image", cv::WINDOW_NORMAL); + cv::setMouseCallback("inference image", SingleImageDebugInferenceNode::onMouse, this); + + cv::imshow("inference image", image_); + + // loop until q character is pressed + while (cv::waitKey(0) != 113) { + } + cv::destroyAllWindows(); + rclcpp::shutdown(); + } + +private: + static void onMouse(int event, int x, int y, int flags, void * param) + { + SingleImageDebugInferenceNode * node = static_cast(param); + if (node) { + node->inferWithCrop(event, x, y, flags); + } + } + + void inferWithCrop(int action, int x, int y, [[maybe_unused]] int flags) + { + if (action == cv::EVENT_LBUTTONDOWN) { + top_left_corner_ = cv::Point(x, y); + } else if (action == cv::EVENT_LBUTTONUP) { + bottom_right_corner_ = cv::Point(x, y); + cv::Mat tmp = image_.clone(); + cv::Mat crop = image_(cv::Rect{top_left_corner_, bottom_right_corner_}).clone(); + if (crop.empty()) { + RCLCPP_ERROR(get_logger(), "crop image is empty"); + return; + } + cv::cvtColor(crop, crop, cv::COLOR_BGR2RGB); + autoware_auto_perception_msgs::msg::TrafficSignal traffic_signal; + if (!classifier_ptr_->getTrafficSignal(crop, traffic_signal)) { + RCLCPP_ERROR(get_logger(), "failed to classify image"); + return; + } + cv::Scalar color; + cv::Scalar text_color; + for (const auto & light : traffic_signal.lights) { + auto color_str = toString(light.color); + auto shape_str = toString(light.shape); + auto confidence_str = std::to_string(light.confidence); + if (shape_str == "circle") { + if (color_str == "red") { + color = cv::Scalar(0, 0, 255); + } else if (color_str == "green") { + color = cv::Scalar(0, 255, 0); + } else if (color_str == "yellow") { + color = cv::Scalar(0, 255, 255); + } else if (color_str == "white") { + color = cv::Scalar(0, 0, 0); + } else { + color = cv::Scalar(255, 255, 255); + } + } + RCLCPP_INFO_STREAM(get_logger(), color_str << " " << shape_str << " " << confidence_str); + } + cv::rectangle(tmp, top_left_corner_, bottom_right_corner_, color, 2, 8); + cv::imshow("inference image", tmp); + } + } + + cv::Point top_left_corner_; + cv::Point bottom_right_corner_; + cv::Mat image_; + std::unique_ptr classifier_ptr_; +}; +} // namespace traffic_light + +#include "rclcpp_components/register_node_macro.hpp" +RCLCPP_COMPONENTS_REGISTER_NODE(traffic_light::SingleImageDebugInferenceNode) diff --git a/perception/traffic_light_classifier/utils/trt_common.cpp b/perception/traffic_light_classifier/utils/trt_common.cpp index adb2fbe037a31..aa87a974fe542 100644 --- a/perception/traffic_light_classifier/utils/trt_common.cpp +++ b/perception/traffic_light_classifier/utils/trt_common.cpp @@ -37,11 +37,12 @@ void check_error(const ::cudaError_t e, decltype(__FILE__) f, decltype(__LINE__) } } -TrtCommon::TrtCommon(std::string model_path, std::string precision) +TrtCommon::TrtCommon( + std::string model_path, std::string precision, std::string input_name, std::string output_name) : model_file_path_(model_path), precision_(precision), - input_name_("input_0"), - output_name_("output_0"), + input_name_(input_name), + output_name_(output_name), is_initialized_(false) { runtime_ = UniquePtr(nvinfer1::createInferRuntime(logger_)); diff --git a/perception/traffic_light_classifier/utils/trt_common.hpp b/perception/traffic_light_classifier/utils/trt_common.hpp index 7fc3d3b3e46d9..d7e314a3b4705 100644 --- a/perception/traffic_light_classifier/utils/trt_common.hpp +++ b/perception/traffic_light_classifier/utils/trt_common.hpp @@ -108,7 +108,8 @@ Tn::UniquePtr make_unique() class TrtCommon { public: - TrtCommon(std::string model_path, std::string precision); + TrtCommon( + std::string model_path, std::string precision, std::string input_name, std::string output_name); ~TrtCommon() {} bool loadEngine(std::string engine_file_path);