Skip to content

Commit

Permalink
When simulation is reset, make sure spline plan is shifted to time 0.
Browse files Browse the repository at this point in the history
Make sure that whenever the policy is mutated, it's done with a writer lock (unique_lock) and not a reader lock (shared_lock).

PiperOrigin-RevId: 644688932
Change-Id: I890ccfa590feb7fa320695ae0d8222f935021014
  • Loading branch information
nimrod-gileadi authored and copybara-github committed Jun 19, 2024
1 parent d7caeb5 commit 3b29b1e
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 6 deletions.
25 changes: 19 additions & 6 deletions mjpc/planners/sampling/planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <algorithm>
#include <chrono>
#include <mutex>
#include <shared_mutex>

#include <absl/random/random.h>
Expand Down Expand Up @@ -112,8 +113,11 @@ void SamplingPlanner::Reset(int horizon,
time = 0.0;

// policy parameters
policy.Reset(horizon, initial_repeated_action);
previous_policy.Reset(horizon, initial_repeated_action);
{
const std::unique_lock<std::shared_mutex> lock(mtx_);
policy.Reset(horizon, initial_repeated_action);
previous_policy.Reset(horizon, initial_repeated_action);
}

// scratch
plan_scratch.Clear();
Expand Down Expand Up @@ -266,10 +270,19 @@ void SamplingPlanner::UpdateNominalPolicy(int horizon) {
time_shift = time_horizon;
}

const std::shared_lock<std::shared_mutex> lock(mtx_);
const std::unique_lock<std::shared_mutex> lock(mtx_);

// special case for when simulation time is reset (which doesn't cause
// Planner::Reset)
if (policy.plan.Size() && policy.plan.begin()->time() > nominal_time) {
// time went backwards. keep the nominal plan, but start at the new time
policy.plan.ShiftTime(nominal_time);
previous_policy.plan.ShiftTime(nominal_time);
}

policy.plan.DiscardBefore(nominal_time);
if (policy.plan.Size() == 0) {
policy.plan.AddNode(time);
policy.plan.AddNode(nominal_time);
}
while (policy.plan.Size() < num_spline_points) {
// duplicate the last node, with a time further in the future.
Expand Down Expand Up @@ -303,7 +316,7 @@ void SamplingPlanner::UpdateNominalPolicy(int horizon) {

// copy scratch into plan
{
const std::shared_lock<std::shared_mutex> lock(mtx_);
const std::unique_lock<std::shared_mutex> lock(mtx_);
policy.plan = plan_scratch;
}
}
Expand Down Expand Up @@ -527,7 +540,7 @@ void SamplingPlanner::CopyCandidateToPolicy(int candidate) {
winner = trajectory_order[candidate];

{
const std::shared_lock<std::shared_mutex> lock(mtx_);
const std::unique_lock<std::shared_mutex> lock(mtx_);
previous_policy = policy;
policy = candidate_policy[winner];
}
Expand Down
10 changes: 10 additions & 0 deletions mjpc/spline/spline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,16 @@ int TimeSpline::DiscardBefore(double time) {
return nodes_to_remove;
}

void TimeSpline::ShiftTime(double start_time) {
if (times_.empty()) {
return;
}
double shift = start_time - times_[0];
for (int i = 0; i < times_.size(); i++) {
times_[i] += shift;
}
}

void TimeSpline::Clear() {
times_.clear();
values_begin_ = 0;
Expand Down
5 changes: 5 additions & 0 deletions mjpc/spline/spline.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ class TimeSpline {
// Returns the number of nodes removed.
int DiscardBefore(double time);

// Keeps all existing nodes, but shifts the time of the first node to be
// `start_time`, and all other times are shifted accordingly. No resampling
// is performed.
void ShiftTime(double start_time);

// Removes all existing nodes.
void Clear();

Expand Down
20 changes: 20 additions & 0 deletions mjpc/test/spline/spline_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,26 @@ TEST(TimeSplineTest, Cubic) {
}
}

TEST(TimeSplineAllInterpolationsTest, ShiftTime) {
TimeSpline spline(/*dim=*/2);
spline.SetInterpolation(mjpc::spline::kLinearSpline);

spline.AddNode(1.0, {1.0, 2.0});
spline.AddNode(2.0, {2.0, 3.0});
spline.AddNode(3.0, {3.0, 4.0});
spline.AddNode(4.0, {4.0, 5.0});
EXPECT_EQ(spline.Size(), 4);

EXPECT_THAT(spline.Sample(1.0), ElementsAre(1.0, 2.0));
EXPECT_THAT(spline.Sample(1.5), ElementsAre(1.5, 2.5));

// Shift the spline so that the first node is at 1.5.
spline.ShiftTime(1.5);
EXPECT_EQ(spline.Size(), 4);
EXPECT_THAT(spline.Sample(1.5), ElementsAre(1.0, 2.0));
EXPECT_THAT(spline.Sample(2.0), ElementsAre(1.5, 2.5));
}

TEST_P(TimeSplineAllInterpolationsTest, DiscardBefore) {
const TimeSplineTestCase& test_case = GetParam();
TimeSpline spline(/*dim=*/2);
Expand Down

0 comments on commit 3b29b1e

Please sign in to comment.