Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(interpolation): add functions for flexible usage (non-static spline interpolation) #352

Closed
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 81 additions & 7 deletions common/interpolation/include/interpolation/interpolation_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,18 @@ inline bool isNotDecreasing(const std::vector<double> & x)
return true;
}

inline void validateInput(
const std::vector<double> & base_keys, const std::vector<double> & base_values,
const std::vector<double> & query_keys)
inline void validateKeys(
const std::vector<double> & base_keys, const std::vector<double> & query_keys)
{
// when vectors are empty
if (base_keys.empty() || base_values.empty() || query_keys.empty()) {
if (base_keys.empty() || query_keys.empty()) {
throw std::invalid_argument("Points is empty.");
}

// when size of vectors are less than 2
if (base_keys.size() < 2 || base_values.size() < 2) {
if (base_keys.size() < 2) {
throw std::invalid_argument(
"The size of points is less than 2. base_keys.size() = " + std::to_string(base_keys.size()) +
", base_values.size() = " + std::to_string(base_values.size()));
"The size of points is less than 2. base_keys.size() = " + std::to_string(base_keys.size()));
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactor: validateInput -> validateKeys and validateKeysAndValues

// when indices are not sorted
Expand All @@ -76,12 +74,88 @@ inline void validateInput(
if (query_keys.front() < base_keys.front() || base_keys.back() < query_keys.back()) {
throw std::invalid_argument("query_keys is out of base_keys");
}
}

inline void validateKeysAndValues(
const std::vector<double> & base_keys, const std::vector<double> & base_values)
{
// when vectors are empty
if (base_keys.empty() || base_values.empty()) {
throw std::invalid_argument("Points is empty.");
}

// when size of vectors are less than 2
if (base_keys.size() < 2 || base_values.size() < 2) {
throw std::invalid_argument(
"The size of points is less than 2. base_keys.size() = " + std::to_string(base_keys.size()) +
", base_values.size() = " + std::to_string(base_values.size()));
}

// when sizes of indices and values are not same
if (base_keys.size() != base_values.size()) {
throw std::invalid_argument("The size of base_keys and base_values are not the same.");
}
}

// solve Ax = d
// where A is tridiagonal matrix
// [b_0 c_0 ... ]
// [a_0 b_1 c_1 ... O ]
// A = [ ... ]
// [ O ... a_N-3 b_N-2 c_N-2]
// [ ... a_N-2 b_N-1]
struct TDMACoef
{
explicit TDMACoef(const size_t num_row)
{
a.resize(num_row - 1);
b.resize(num_row);
c.resize(num_row - 1);
d.resize(num_row);
}

std::vector<double> a;
std::vector<double> b;
std::vector<double> c;
std::vector<double> d;
};

inline std::vector<double> solveTridiagonalMatrixAlgorithm(const TDMACoef & tdma_coef)
{
const auto & a = tdma_coef.a;
const auto & b = tdma_coef.b;
const auto & c = tdma_coef.c;
const auto & d = tdma_coef.d;

const size_t num_row = b.size();

std::vector<double> x(num_row);
if (num_row != 1) {
// calculate p and q
std::vector<double> p;
std::vector<double> q;
p.push_back(-c[0] / b[0]);
q.push_back(d[0] / b[0]);

for (size_t i = 1; i < num_row; ++i) {
const double den = b[i] + a[i - 1] * p[i - 1];
p.push_back(-c[i - 1] / den);
q.push_back((d[i] - a[i - 1] * q[i - 1]) / den);
}

// calculate solution
x[num_row - 1] = q[num_row - 1];

for (size_t i = 1; i < num_row; ++i) {
const size_t j = num_row - 1 - i;
x[j] = p[j] * x[j + 1] + q[j];
}
} else {
x.push_back(d[0] / b[0]);
}

return x;
}
} // namespace interpolation_utils

#endif // INTERPOLATION__INTERPOLATION_UTILS_HPP_
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
#ifndef INTERPOLATION__SPLINE_INTERPOLATION_HPP_
#define INTERPOLATION__SPLINE_INTERPOLATION_HPP_

#include "interpolation/interpolation_utils.hpp"
#include "tier4_autoware_utils/geometry/geometry.hpp"

#include <algorithm>
#include <cmath>
#include <iostream>
Expand All @@ -26,6 +29,8 @@ namespace interpolation
// NOTE: X(s) = a_i (s - s_i)^3 + b_i (s - s_i)^2 + c_i (s - s_i) + d_i : (i = 0, 1, ... N-1)
struct MultiSplineCoef
{
MultiSplineCoef() = default;

explicit MultiSplineCoef(const size_t num_spline)
{
a.resize(num_spline);
Expand All @@ -40,9 +45,81 @@ struct MultiSplineCoef
std::vector<double> d;
};

// static spline interpolation functions
std::vector<double> slerp(
const std::vector<double> & base_keys, const std::vector<double> & base_values,
const std::vector<double> & query_keys);

// std::vector<double> slerpDiff(
// const std::vector<double> & base_keys, const std::vector<double> & base_values,
// const std::vector<double> & query_keys);

// TODO(murooka) use template
// template <typename T>
// std::vector<double> slerpYawFromPoints(const std::vector<T> & points);
std::vector<double> slerpYawFromPoints(const std::vector<geometry_msgs::msg::Point> & points);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(@rej55 Let's discuss in a thread)
image

Thank you. It's a good point.

For static spline interpolation (this function slerpYawFromPoints), I'm ok to move elsewhere like tier4_autoware_utils as you pointed.

For non-static spline interpolation, I implemented SplineInterpolation1d and SplineInterpolationPoint2d classes.
The second class includes yaw calculation (getSplineInterpolatedYaw). This member function is an option in this class. I mean, SplineInterpolationPoint2d is mainly for interpolating and calculating x and y, and as an option, you can get yaw angle as well (it is calculated from derivatives of x and y).
So it's not weird that SplineInterpolationPoint2d, which has getSplineInterpolatedYaw, is in interpolation package, IMO.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it is a little weird having SplineInterpolation"Point2d" here as well since it depends on the geometric configuration.

It is better to create SplineInterpolation"Path2d" instead and manage it in a separate file. (The interface could be the same as the "Point2d".)

} // namespace interpolation

// non-static 1-dimensional spline interpolation
//
// Usage:
// ```
// SplineInterpolation1d spline;
// spline.calcSplineCoefficients(base_keys, base_values); // memorize pre-interpolation result
// internally const auto interpolation_result1 = spline.getSplineInterpolatedValues(base_keys,
// query_keys1); const auto interpolation_result2 = spline.getSplineInterpolatedValues(base_keys,
// query_keys2);
// ```
class SplineInterpolation1d
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1-dimensional spline interpolation

{
public:
SplineInterpolation1d() = default;

void calcSplineCoefficients(
const std::vector<double> & base_keys, const std::vector<double> & base_values);

std::vector<double> getSplineInterpolatedValues(
const std::vector<double> & base_keys, const std::vector<double> & query_keys) const;

private:
interpolation::MultiSplineCoef multi_spline_coef_;
};

// non-static points spline interpolation
// NOTE: We can calculate yaw from the x and y by interpolation derivatives.
//
// Usage:
// ```
// SplineInterpolationPoint spline;
// spline.calcSplineCoefficients(base_keys, base_values); // memorize pre-interpolation result
// internally const auto interpolation_result1 = spline.getSplineInterpolatedPoint(base_keys,
// query_keys1); const auto interpolation_result2 = spline.getSplineInterpolatedPoint(base_keys,
// query_keys2); const auto yaw_interpolation_result = spline.getSplineInterpolatedValues(base_keys,
// query_keys1);
// ```
class SplineInterpolationPoint
{
public:
SplineInterpolationPoint() = default;

// TODO(murooka) use template
// template <typename T>
// void calcSplineCoefficients(const std::vector<T> & points);
void calcSplineCoefficients(const std::vector<geometry_msgs::msg::Point> & points);

// TODO(murooka) implement these functions
// std::vector<geometry_msgs::msg::Point> getSplineInterpolatedPoints(const double width);
// std::vector<geometry_msgs::msg::Pose> getSplineInterpolatedPoses(const double width);

geometry_msgs::msg::Point getSplineInterpolatedPoint(const size_t idx, const double s) const;
double getSplineInterpolatedYaw(const size_t idx, const double s) const;

double getAccumulatedDistance(const size_t idx) const;

private:
interpolation::MultiSplineCoef multi_spline_coef_x_;
interpolation::MultiSplineCoef multi_spline_coef_y_;
std::vector<double> base_s_vec_;
};

#endif // INTERPOLATION__SPLINE_INTERPOLATION_HPP_
2 changes: 2 additions & 0 deletions common/interpolation/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
<license>Apache License 2.0</license>
<buildtool_depend>ament_cmake_auto</buildtool_depend>

<depend>tier4_autoware_utils</depend>

<test_depend>ament_lint_auto</test_depend>
<test_depend>autoware_lint_common</test_depend>

Expand Down
3 changes: 2 additions & 1 deletion common/interpolation/src/linear_interpolation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ std::vector<double> lerp(
const std::vector<double> & query_keys)
{
// throw exception for invalid arguments
interpolation_utils::validateInput(base_keys, base_values, query_keys);
interpolation_utils::validateKeys(base_keys, query_keys);
interpolation_utils::validateKeysAndValues(base_keys, base_values);

// calculate linear interpolation
std::vector<double> query_values;
Expand Down
Loading