Skip to content

Commit 46bab6e

Browse files
committed
fix cpp interface with new library
1 parent aed6865 commit 46bab6e

File tree

2 files changed

+60
-41
lines changed

2 files changed

+60
-41
lines changed

CMakeLists.txt

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ find_package(tf2 REQUIRED)
2626
find_package(tf2_ros REQUIRED)
2727
find_package(turtlebot3_msgs REQUIRED)
2828
find_package(gmm_msgs REQUIRED)
29-
# find_package(GaussianMixtureModel REQUIRED)
3029
find_package(gaussian_mixture_model REQUIRED)
3130
find_package(visualization_msgs REQUIRED)
3231
find_package(rosidl_default_generators REQUIRED)
@@ -41,7 +40,7 @@ add_library(${PROJECT_NAME}_lib
4140
"src/centralized_gmm.cpp"
4241
"src/gmm_visualizer.cpp"
4342
"src/gmm_test_node.cpp"
44-
# "src/hs_interface.cpp"
43+
"src/hs_interface.cpp"
4544
)
4645

4746

@@ -55,7 +54,6 @@ set(dependencies
5554
"sensor_msgs"
5655
"turtlebot3_msgs"
5756
"visualization_msgs"
58-
# "GaussianMixtureModel"
5957
"gaussian_mixture_model"
6058
"gmm_msgs"
6159
)
@@ -69,9 +67,9 @@ set(dependencies
6967
target_link_libraries(${PROJECT_NAME}_lib)
7068
ament_target_dependencies(${PROJECT_NAME}_lib ${dependencies})
7169

72-
# add_executable(hs_interface src/hs_interface.cpp )
73-
# target_link_libraries(hs_interface ${SFML_libraries})
74-
# ament_target_dependencies(hs_interface ${dependencies})
70+
add_executable(hs_interface src/hs_interface.cpp )
71+
target_link_libraries(hs_interface ${SFML_libraries})
72+
ament_target_dependencies(hs_interface ${dependencies})
7573

7674
add_executable(centralized_gmm src/centralized_gmm.cpp )
7775
target_link_libraries(centralized_gmm ${SFML_libraries})
@@ -104,7 +102,7 @@ install(PROGRAMS
104102
)
105103

106104

107-
install(TARGETS centralized_gmm distributed_gmm gmm_visualizer supervisor_gmm gmm_test
105+
install(TARGETS centralized_gmm distributed_gmm gmm_visualizer supervisor_gmm gmm_test hs_interface
108106
DESTINATION lib/${PROJECT_NAME}
109107
)
110108

@@ -141,7 +139,6 @@ ament_export_dependencies(turtlebot3_msgs)
141139
ament_export_dependencies(gmm_msgs)
142140
ament_export_dependencies(visualization_msgs)
143141
ament_export_dependencies(gaussian_mixture_model)
144-
# ament_export_dependencies(GaussianMixtureModel)
145142
ament_export_dependencies(rosidl_default_runtime)
146143

147144
ament_package()

src/hs_interface.cpp

Lines changed: 55 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,11 @@
4545
#include <nav_msgs/msg/odometry.hpp>
4646
#include "turtlebot3_msgs/msg/gaussian.hpp"
4747
#include "turtlebot3_msgs/msg/gmm.hpp"
48-
#include <GaussianMixtureModel/GaussianMixtureModel.h>
49-
#include <GaussianMixtureModel/ExpectationMaximization.h>
50-
#include <GaussianMixtureModel/TrainSet.h>
48+
// #include <GaussianMixtureModel/GaussianMixtureModel.h>
49+
// #include <GaussianMixtureModel/ExpectationMaximization.h>
50+
// #include <GaussianMixtureModel/TrainSet.h>
51+
52+
#include "gaussian_mixture_model/gaussian_mixture_model.h"
5153

5254
#define M_PI 3.14159265358979323846 /*pi*/
5355

@@ -61,7 +63,7 @@ class Interface : public rclcpp::Node
6163
{
6264

6365
public:
64-
Interface() : Node("human_swarm_interface")
66+
Interface() : Node("human_swarm_interface"), gmm(4)
6567
{
6668
// --------------------------------------------------------- ROS parameters ----------------------------------------------------------
6769
// Area parameters
@@ -76,6 +78,8 @@ class Interface : public rclcpp::Node
7678

7779
this->declare_parameter<int>("CLUSTERS_NUM", 4);
7880
this->get_parameter("CLUSTERS_NUM", CLUSTERS_NUM);
81+
this->declare_parameter<int>("PARTICLES_NUM", 200);
82+
this->get_parameter("PARTICLES_NUM", PARTICLES_NUM);
7983

8084
// --------------------------------------------------------- GMM ROS publisher -------------------------------------------------------
8185
publisher = this->create_publisher<turtlebot3_msgs::msg::GMM>("/gaussian_mixture_model", 1);
@@ -89,38 +93,49 @@ class Interface : public rclcpp::Node
8993

9094

9195
drawPolygon(); // draw ROI and save vertices
92-
std::vector<Eigen::VectorXd> samples = generateSamples(2000); // generate desired number of samples inside ROI
93-
gauss::TrainSet samples_set(samples); // create train set from samples
94-
std::vector<gauss::gmm::Cluster> clusters = gauss::gmm::ExpectationMaximization(samples_set, CLUSTERS_NUM); // run EM algorithm to get GMM
95-
gauss::gmm::GaussianMixtureModel gmm_(clusters); // create GMM from clusters
96+
samples.resize(2, PARTICLES_NUM);
97+
samples = generateSamples(PARTICLES_NUM); // generate desired number of samples inside ROI
98+
std::cout << "Samples generated with shape : " << samples.size() << "\n";
99+
gmm.fitgmm(samples, CLUSTERS_NUM, 1000, 1e-3, false); // fit GMM to samples
100+
std::cout << "Fitting completed...\n";
101+
mean_points = gmm.getMeans();
102+
covariances = gmm.getCovariances();
103+
weights = gmm.getWeights();
96104
std::cout << "GMM initialized\n";
97-
for (int i=0; i<gmm_.getClusters().size(); i++)
105+
for(int i = 0; i < mean_points.size(); i++)
98106
{
99-
std::cout << "Cluster " << i << ": weight = " << gmm_.getClusters()[i].weight << std::endl;
100-
std::cout << "Mean: " << gmm_.getClusters()[i].distribution->getMean().transpose() << std::endl;
101-
std::cout << "Covariance matrix: \n" << gmm_.getClusters()[i].distribution->getCovariance().transpose() << std::endl;
107+
std::cout << mean_points[i] << std::endl;
102108
}
103109

104-
// Create GMM ROS msg
105-
for (int i=0; i<gmm_.getClusters().size(); i++)
110+
std::cout << "Covariances: " << std::endl;
111+
for(int i = 0; i < covariances.size(); i++)
106112
{
107-
turtlebot3_msgs::msg::Gaussian gaussian_msg;
108-
geometry_msgs::msg::Point mean_pt;
109-
110-
mean_pt.x = gmm_.getClusters()[i].distribution->getMean()[0];
111-
mean_pt.y = gmm_.getClusters()[i].distribution->getMean()[1];
112-
mean_pt.z = 0.0;
113-
gaussian_msg.mean_point = mean_pt;
114-
for (int j=0; j<gmm_.getClusters()[i].distribution->getCovariance().rows(); j++)
115-
{
116-
gaussian_msg.covariance.push_back(gmm_.getClusters()[i].distribution->getCovariance()(j,0));
117-
gaussian_msg.covariance.push_back(gmm_.getClusters()[i].distribution->getCovariance()(j,1));
118-
}
113+
std::cout << covariances[i] << std::endl;
114+
}
119115

120-
gmm_msg.gaussians.push_back(gaussian_msg);
121-
gmm_msg.weights.push_back(gmm_.getClusters()[i].weight);
116+
std::cout << "Weights: " << std::endl;
117+
for(int i = 0; i < weights.size(); i++)
118+
{
119+
std::cout << weights[i] << std::endl;
122120
}
123121

122+
// Create ROS msg
123+
gmm_msg.weights = weights;
124+
125+
for(int i = 0; i < mean_points.size(); i++)
126+
{
127+
turtlebot3_msgs::msg::Gaussian gaussian;
128+
gaussian.mean_point.x = mean_points[i](0);
129+
gaussian.mean_point.y = mean_points[i](1);
130+
gaussian.covariance.push_back(covariances[i](0,0));
131+
gaussian.covariance.push_back(covariances[i](0,1));
132+
gaussian.covariance.push_back(covariances[i](1,0));
133+
gaussian.covariance.push_back(covariances[i](1,1));
134+
gmm_msg.gaussians.push_back(gaussian);
135+
}
136+
137+
std::cout << "ROS MSG Initialized" << std::endl;
138+
124139
}
125140
~Interface()
126141
{
@@ -130,7 +145,8 @@ class Interface : public rclcpp::Node
130145

131146

132147
void drawPolygon();
133-
std::vector<Eigen::VectorXd> generateSamples(int n_samples);
148+
// void generateSamples(Eigen::MatrixXd& samples_matrix, int n_samples);
149+
Eigen::MatrixXd generateSamples(int n_samples);
134150
bool insideROI(Eigen::VectorXd q, std::vector<Eigen::VectorXd> verts);
135151

136152

@@ -142,7 +158,11 @@ class Interface : public rclcpp::Node
142158
double AREA_LEFT;
143159
double AREA_BOTTOM;
144160
//-----------------------------------------------------------------------------------
145-
161+
GaussianMixtureModel gmm;
162+
std::vector<Eigen::MatrixXd> covariances;
163+
std::vector<Eigen::VectorXd> mean_points;
164+
std::vector<double> weights;
165+
Eigen::MatrixXd samples;
146166
//------------------------- Publishers and subscribers ------------------------------
147167
rclcpp::TimerBase::SharedPtr timer_;
148168
rclcpp::Publisher<turtlebot3_msgs::msg::GMM>::SharedPtr publisher;
@@ -160,6 +180,7 @@ class Interface : public rclcpp::Node
160180

161181
std::vector<Eigen::VectorXd> vertices;
162182
int CLUSTERS_NUM;
183+
int PARTICLES_NUM;
163184

164185

165186
void timer_callback()
@@ -230,12 +251,13 @@ bool Interface::insideROI(Eigen::VectorXd q, std::vector<Eigen::VectorXd> verts)
230251
return c;
231252
}
232253

233-
std::vector<Eigen::VectorXd> Interface::generateSamples(int n_samples)
254+
Eigen::MatrixXd Interface::generateSamples(int n_samples)
234255
{
235256
std::vector<Eigen::VectorXd> samples;
236257

237258
// Get min and max values of x and y
238-
double x_min, x_max, y_min, y_max = -100;
259+
double x_min, y_min = 100;
260+
double x_max, y_max = -100;
239261
for (int i=0; i<this->vertices.size(); i++)
240262
{
241263
double x = this->vertices[i](0);
@@ -275,7 +297,7 @@ std::vector<Eigen::VectorXd> Interface::generateSamples(int n_samples)
275297
this->app_gui->drawParticles(displayMatrix);
276298
this->app_gui->display();
277299

278-
return samples;
300+
return displayMatrix;
279301
}
280302

281303

0 commit comments

Comments
 (0)