Skip to content

Commit

Permalink
update tests for existing functions
Browse files Browse the repository at this point in the history
Signed-off-by: Takayuki Murooka <takayuki5168@gmail.com>
  • Loading branch information
takayuki5168 committed Feb 7, 2022
1 parent c86840c commit 5de79ba
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 94 deletions.
60 changes: 0 additions & 60 deletions common/interpolation/include/interpolation/interpolation_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,66 +96,6 @@ inline void validateKeysAndValues(
throw std::invalid_argument("The size of base_keys and base_values are not the same.");
}
}

// solve Ax = d
// where A is tridiagonal matrix
// [b_0 c_0 ... ]
// [a_0 b_1 c_1 ... O ]
// A = [ ... ]
// [ O ... a_N-3 b_N-2 c_N-2]
// [ ... a_N-2 b_N-1]
struct TDMACoef
{
explicit TDMACoef(const size_t num_row)
{
a.resize(num_row - 1);
b.resize(num_row);
c.resize(num_row - 1);
d.resize(num_row);
}

std::vector<double> a;
std::vector<double> b;
std::vector<double> c;
std::vector<double> d;
};

inline std::vector<double> solveTridiagonalMatrixAlgorithm(const TDMACoef & tdma_coef)
{
const auto & a = tdma_coef.a;
const auto & b = tdma_coef.b;
const auto & c = tdma_coef.c;
const auto & d = tdma_coef.d;

const size_t num_row = b.size();

std::vector<double> x(num_row);
if (num_row != 1) {
// calculate p and q
std::vector<double> p;
std::vector<double> q;
p.push_back(-c[0] / b[0]);
q.push_back(d[0] / b[0]);

for (size_t i = 1; i < num_row; ++i) {
const double den = b[i] + a[i - 1] * p[i - 1];
p.push_back(-c[i - 1] / den);
q.push_back((d[i] - a[i - 1] * q[i - 1]) / den);
}

// calculate solution
x[num_row - 1] = q[num_row - 1];

for (size_t i = 1; i < num_row; ++i) {
const size_t j = num_row - 1 - i;
x[j] = p[j] * x[j + 1] + q[j];
}
} else {
x.push_back(d[0] / b[0]);
}

return x;
}
} // namespace interpolation_utils

#endif // INTERPOLATION__INTERPOLATION_UTILS_HPP_
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,11 @@ std::vector<double> slerpYawFromPoints(const std::vector<geometry_msgs::msg::Poi
// ```
// SplineInterpolation1d spline;
// spline.calcSplineCoefficients(base_keys, base_values); // memorize pre-interpolation result
// internally const auto interpolation_result1 = spline.getSplineInterpolatedValues(base_keys,
// query_keys1); const auto interpolation_result2 = spline.getSplineInterpolatedValues(base_keys,
// query_keys2);
// internally
// const auto interpolation_result1 = spline.getSplineInterpolatedValues(base_keys,
// query_keys1);
// const auto interpolation_result2 = spline.getSplineInterpolatedValues(base_keys,
// query_keys2);
// ```
class SplineInterpolation1d
{
Expand All @@ -92,10 +94,13 @@ class SplineInterpolation1d
// ```
// SplineInterpolationPoint spline;
// spline.calcSplineCoefficients(base_keys, base_values); // memorize pre-interpolation result
// internally const auto interpolation_result1 = spline.getSplineInterpolatedPoint(base_keys,
// query_keys1); const auto interpolation_result2 = spline.getSplineInterpolatedPoint(base_keys,
// query_keys2); const auto yaw_interpolation_result = spline.getSplineInterpolatedValues(base_keys,
// query_keys1);
// internally
// const auto interpolation_result1 = spline.getSplineInterpolatedPoints(base_keys,
// query_keys1);
// const auto interpolation_result2 = spline.getSplineInterpolatedPoints(base_keys,
// query_keys2);
// const auto yaw_interpolation_result = spline.getSplineInterpolatedYaws(base_keys,
// query_keys1);
// ```
class SplineInterpolationPoint
{
Expand All @@ -112,7 +117,7 @@ class SplineInterpolationPoint
// std::vector<geometry_msgs::msg::Pose> getSplineInterpolatedPoses(const double width);

geometry_msgs::msg::Point getSplineInterpolatedPoint(const size_t idx, const double s) const;
double getSplineInterpolatedYaw(const size_t idx, const double s) const;
double getSplineInterpolatedYaws(const size_t idx, const double s) const;

double getAccumulatedDistance(const size_t idx) const;

Expand Down
67 changes: 63 additions & 4 deletions common/interpolation/src/spline_interpolation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,66 @@

namespace
{
// solve Ax = d
// where A is tridiagonal matrix
// [b_0 c_0 ... ]
// [a_0 b_1 c_1 ... O ]
// A = [ ... ]
// [ O ... a_N-3 b_N-2 c_N-2]
// [ ... a_N-2 b_N-1]
struct TDMACoef
{
explicit TDMACoef(const size_t num_row)
{
a.resize(num_row - 1);
b.resize(num_row);
c.resize(num_row - 1);
d.resize(num_row);
}

std::vector<double> a;
std::vector<double> b;
std::vector<double> c;
std::vector<double> d;
};

inline std::vector<double> solveTridiagonalMatrixAlgorithm(const TDMACoef & tdma_coef)
{
const auto & a = tdma_coef.a;
const auto & b = tdma_coef.b;
const auto & c = tdma_coef.c;
const auto & d = tdma_coef.d;

const size_t num_row = b.size();

std::vector<double> x(num_row);
if (num_row != 1) {
// calculate p and q
std::vector<double> p;
std::vector<double> q;
p.push_back(-c[0] / b[0]);
q.push_back(d[0] / b[0]);

for (size_t i = 1; i < num_row; ++i) {
const double den = b[i] + a[i - 1] * p[i - 1];
p.push_back(-c[i - 1] / den);
q.push_back((d[i] - a[i - 1] * q[i - 1]) / den);
}

// calculate solution
x[num_row - 1] = q[num_row - 1];

for (size_t i = 1; i < num_row; ++i) {
const size_t j = num_row - 1 - i;
x[j] = p[j] * x[j + 1] + q[j];
}
} else {
x.push_back(d[0] / b[0]);
}

return x;
}

interpolation::MultiSplineCoef getSplineCoefficients(
const std::vector<double> & base_keys, const std::vector<double> & base_values)
{
Expand All @@ -36,7 +96,7 @@ interpolation::MultiSplineCoef getSplineCoefficients(
std::vector<double> v = {0.0};
if (num_base > 2) {
// solve tridiagonal matrix algorithm
interpolation_utils::TDMACoef tdma_coef(num_base - 2); // N-1
TDMACoef tdma_coef(num_base - 2); // N-1

for (size_t i = 0; i < num_base - 2; ++i) {
tdma_coef.b[i] = 2 * (diff_keys[i] + diff_keys[i + 1]);
Expand All @@ -48,8 +108,7 @@ interpolation::MultiSplineCoef getSplineCoefficients(
6.0 * (diff_values[i + 1] / diff_keys[i + 1] - diff_values[i] / diff_keys[i]);
}

const std::vector<double> tdma_res =
interpolation_utils::solveTridiagonalMatrixAlgorithm(tdma_coef);
const std::vector<double> tdma_res = solveTridiagonalMatrixAlgorithm(tdma_coef);

// calculate v
v.insert(v.end(), tdma_res.begin(), tdma_res.end());
Expand Down Expand Up @@ -281,7 +340,7 @@ geometry_msgs::msg::Point SplineInterpolationPoint::getSplineInterpolatedPoint(
return geom_point;
}

double SplineInterpolationPoint::getSplineInterpolatedYaw(const size_t idx, const double s) const
double SplineInterpolationPoint::getSplineInterpolatedYaws(const size_t idx, const double s) const
{
double whole_s = base_s_vec_.at(idx) + s;

Expand Down
56 changes: 34 additions & 22 deletions common/interpolation/test/src/test_interpolation_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,54 +58,66 @@ TEST(interpolation_utils, isNotDecreasing)
EXPECT_EQ(interpolation_utils::isNotDecreasing(decreasing_vec), false);
}

// TODO(murooka) implement this test
/*
TEST(interpolation_utils, validateInput)
TEST(interpolation_utils, validateKeys)
{
using interpolation_utils::validateInput;
using interpolation_utils::validateKeys;

const std::vector<double> base_keys{0.0, 1.0, 2.0, 3.0};
const std::vector<double> base_values{0.0, 1.0, 2.0, 3.0};
const std::vector<double> query_keys{0.0, 1.0, 2.0, 3.0};

// valid
EXPECT_NO_THROW(validateInput(base_keys, base_values, query_keys));
EXPECT_NO_THROW(validateKeys(base_keys, query_keys));

// empty
const std::vector<double> empty_vec;
EXPECT_THROW(validateInput(empty_vec, base_values, query_keys), std::invalid_argument);
EXPECT_THROW(validateInput(base_keys, empty_vec, query_keys), std::invalid_argument);
EXPECT_THROW(validateInput(base_keys, base_values, empty_vec), std::invalid_argument);
EXPECT_THROW(validateKeys(empty_vec, query_keys), std::invalid_argument);
EXPECT_THROW(validateKeys(base_keys, empty_vec), std::invalid_argument);

// size is less than 2
const std::vector<double> short_vec{0.0};
EXPECT_THROW(validateInput(short_vec, base_values, query_keys), std::invalid_argument);
EXPECT_THROW(validateInput(base_keys, short_vec, query_keys), std::invalid_argument);
EXPECT_THROW(validateInput(short_vec, short_vec, query_keys), std::invalid_argument);
EXPECT_THROW(validateKeys(short_vec, query_keys), std::invalid_argument);

// partly not increase
const std::vector<double> partly_not_increasing_vec{0.0, 0.0, 2.0, 3.0};
// NOTE: base_keys must be strictly monotonous increasing vector
EXPECT_THROW(
validateInput(partly_not_increasing_vec, base_values, query_keys), std::invalid_argument);
EXPECT_THROW(validateKeys(partly_not_increasing_vec, query_keys), std::invalid_argument);
// NOTE: query_keys is allowed to be monotonous non-decreasing vector
EXPECT_NO_THROW(validateInput(base_keys, base_values, partly_not_increasing_vec));
EXPECT_NO_THROW(validateKeys(base_keys, partly_not_increasing_vec));

// decrease
const std::vector<double> decreasing_vec{0.0, -1.0, 2.0, 3.0};
EXPECT_THROW(validateInput(decreasing_vec, base_values, query_keys), std::invalid_argument);
EXPECT_THROW(validateInput(base_keys, base_values, decreasing_vec), std::invalid_argument);
EXPECT_THROW(validateKeys(decreasing_vec, query_keys), std::invalid_argument);
EXPECT_THROW(validateKeys(base_keys, decreasing_vec), std::invalid_argument);

// out of range
const std::vector<double> front_out_query_keys{-1.0, 1.0, 2.0, 3.0};
EXPECT_THROW(validateInput(base_keys, base_values, front_out_query_keys), std::invalid_argument);
EXPECT_THROW(validateKeys(base_keys, front_out_query_keys), std::invalid_argument);

const std::vector<double> back_out_query_keys{0.0, 1.0, 2.0, 4.0};
EXPECT_THROW(validateInput(base_keys, base_values, back_out_query_keys), std::invalid_argument);
EXPECT_THROW(validateKeys(base_keys, back_out_query_keys), std::invalid_argument);
}

TEST(interpolation_utils, validateKeysAndValues)
{
using interpolation_utils::validateKeysAndValues;

const std::vector<double> base_keys{0.0, 1.0, 2.0, 3.0};
const std::vector<double> base_values{0.0, 1.0, 2.0, 3.0};

// valid
EXPECT_NO_THROW(validateKeysAndValues(base_keys, base_values));

// empty
const std::vector<double> empty_vec;
EXPECT_THROW(validateKeysAndValues(empty_vec, base_values), std::invalid_argument);
EXPECT_THROW(validateKeysAndValues(base_keys, empty_vec), std::invalid_argument);

// size is less than 2
const std::vector<double> short_vec{0.0};
EXPECT_THROW(validateKeysAndValues(short_vec, base_values), std::invalid_argument);
EXPECT_THROW(validateKeysAndValues(base_keys, short_vec), std::invalid_argument);

// size is different
const std::vector<double> different_size_base_values{0.0, 1.0, 2.0};
EXPECT_THROW(
validateInput(base_keys, different_size_base_values, query_keys), std::invalid_argument);
EXPECT_THROW(validateKeysAndValues(base_keys, different_size_base_values), std::invalid_argument);
}
*/

0 comments on commit 5de79ba

Please sign in to comment.