Skip to content

Commit

Permalink
Prediction: reduce the complexity of collision cost in interaction pr…
Browse files Browse the repository at this point in the history
…edictor
  • Loading branch information
kechxu committed Apr 4, 2019
1 parent 787141d commit a31545e
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 45 deletions.
52 changes: 41 additions & 11 deletions modules/prediction/predictor/interaction/interaction_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,21 +259,51 @@ double InteractionPredictor::CentripetalAccelerationCost(
double InteractionPredictor::CollisionWithEgoVehicleCost(
const LaneSequence& lane_sequence,
const double speed, const double acceleration) {
CHECK_GT(lane_sequence.lane_segment_size(), 0);
double cost_abs_sum = 0.0;
double cost_sqr_sum = 0.0;
int num_lane_segment = lane_sequence.lane_segment_size();

double remained_s = lane_sequence.lane_segment(0).start_s();
int lane_seg_idx = 0;
double prev_s = 0.0;

for (const TrajectoryPoint& adc_trajectory_point : adc_trajectory_) {
double relative_time = adc_trajectory_point.relative_time();
double s = GetSByConstantAcceleration(speed, acceleration, relative_time);
Point3D position = GetPositionByLaneSequenceS(lane_sequence, s);
double pos_x = position.x();
double pos_y = position.y();
double adc_x = adc_trajectory_point.path_point().x();
double adc_y = adc_trajectory_point.path_point().y();
double distance = std::hypot(adc_x - pos_x, adc_y - pos_y);
double cost =
std::exp(-FLAGS_collision_cost_exp_coefficient * distance * distance);
cost_abs_sum += std::abs(cost);
cost_sqr_sum += cost * cost;
double curr_s = GetSByConstantAcceleration(
speed, acceleration, relative_time);
double delta_s = curr_s - prev_s;
remained_s += delta_s;
while (lane_seg_idx < num_lane_segment) {
const LaneSegment& lane_segment =
lane_sequence.lane_segment(lane_seg_idx);
const std::string& lane_id = lane_segment.lane_id();
std::shared_ptr<const LaneInfo> lane_info_ptr =
PredictionMap::LaneById(lane_id);
if (lane_info_ptr == nullptr) {
AERROR << "Null lane info ptr found with lane ID [" << lane_id << "]";
continue;
}
double lane_length = lane_info_ptr->total_length();
if (remained_s < lane_length) {
apollo::common::PointENU point_enu =
lane_info_ptr->GetSmoothPoint(remained_s);
double obs_x = point_enu.x();
double obs_y = point_enu.y();
double adc_x = adc_trajectory_point.path_point().x();
double adc_y = adc_trajectory_point.path_point().y();
double distance = std::hypot(adc_x - obs_x, adc_y - obs_y);
double cost = std::exp(-FLAGS_collision_cost_exp_coefficient *
distance * distance);
cost_abs_sum += std::abs(cost);
cost_sqr_sum += cost * cost;
} else {
++lane_seg_idx;
remained_s -= lane_length;
}
}
// Out of the while loop
prev_s = curr_s;
}
return cost_sqr_sum / (cost_abs_sum + FLAGS_double_precision);
}
Expand Down
25 changes: 0 additions & 25 deletions modules/prediction/predictor/sequence/sequence_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -342,31 +342,6 @@ double SequencePredictor::GetLaneSequenceCurvatureByS(
return 0.0;
}

Point3D SequencePredictor::GetPositionByLaneSequenceS(
const LaneSequence& lane_sequence, const double s) {
CHECK_GT(lane_sequence.lane_segment_size(), 0);
Point3D position;
double lane_s = s + lane_sequence.lane_segment(0).start_s();
for (const LaneSegment& lane_segment : lane_sequence.lane_segment()) {
std::string lane_id = lane_segment.lane_id();
std::shared_ptr<const LaneInfo> lane_info_ptr =
PredictionMap::LaneById(lane_id);
double lane_length = lane_info_ptr->total_length();
if (lane_s > lane_length + FLAGS_double_precision) {
lane_s -= lane_length;
} else {
apollo::common::PointENU point_enu =
lane_info_ptr->GetSmoothPoint(lane_s);
position.set_x(point_enu.x());
position.set_y(point_enu.y());
position.set_z(point_enu.z());
return position;
}
}
AERROR << "Cannot find position by lane s";
return position;
}

bool SequencePredictor::GetLongitudinalPolynomial(
const Obstacle& obstacle, const LaneSequence& lane_sequence,
const std::pair<double, double>& lon_end_vt,
Expand Down
9 changes: 0 additions & 9 deletions modules/prediction/predictor/sequence/sequence_predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,6 @@ class SequencePredictor : public Predictor {
double GetLaneSequenceCurvatureByS(const LaneSequence& lane_sequence,
const double s);

/**
* @brief Get position by s
* @param lane sequence
* @param s
* @return the position
*/
apollo::common::Point3D GetPositionByLaneSequenceS(
const LaneSequence& lane_sequence, const double s);

/**
* @brief Clear private members
*/
Expand Down

0 comments on commit a31545e

Please sign in to comment.