Skip to content

Commit

Permalink
Prediction: implemented extrapolate by lane
Browse files Browse the repository at this point in the history
  • Loading branch information
kechxu committed Nov 2, 2019
1 parent 4824c4b commit f35a41e
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 14 deletions.
6 changes: 3 additions & 3 deletions modules/prediction/container/obstacles/obstacle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,7 @@ void Obstacle::BuildLaneGraph() {
for (auto& lane : feature->lane().current_lane_feature()) {
std::shared_ptr<const LaneInfo> lane_info =
PredictionMap::LaneById(lane.lane_id());
const LaneGraph& lane_graph = ObstacleClusters::GetLaneGraph(
LaneGraph lane_graph = ObstacleClusters::GetLaneGraph(
lane.lane_s(), road_graph_search_distance, true, lane_info);
if (lane_graph.lane_sequence_size() > 0) {
++curr_lane_count;
Expand Down Expand Up @@ -1074,7 +1074,7 @@ void Obstacle::BuildLaneGraph() {
for (auto& lane : feature->lane().nearby_lane_feature()) {
std::shared_ptr<const LaneInfo> lane_info =
PredictionMap::LaneById(lane.lane_id());
const LaneGraph& lane_graph = ObstacleClusters::GetLaneGraph(
LaneGraph lane_graph = ObstacleClusters::GetLaneGraph(
lane.lane_s(), road_graph_search_distance, false, lane_info);
if (lane_graph.lane_sequence_size() > 0) {
++nearby_lane_count;
Expand Down Expand Up @@ -1241,7 +1241,7 @@ void Obstacle::BuildLaneGraphFromLeftToRight() {
bool vehicle_is_on_lane = (lane_id == center_lane_info->lane().id().id());
std::shared_ptr<const LaneInfo> curr_lane_info =
PredictionMap::LaneById(lane_id);
const LaneGraph& local_lane_graph =
LaneGraph local_lane_graph =
ObstacleClusters::GetLaneGraphWithoutMemorizing(
feature->lane().lane_feature().lane_s(), road_graph_search_distance,
true, curr_lane_info);
Expand Down
7 changes: 2 additions & 5 deletions modules/prediction/container/obstacles/obstacle_clusters.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,26 @@ namespace prediction {

using ::apollo::hdmap::LaneInfo;

std::unordered_map<std::string, LaneGraph> ObstacleClusters::lane_graphs_;
std::unordered_map<std::string, std::vector<LaneObstacle>>
ObstacleClusters::lane_obstacles_;
std::unordered_map<std::string, StopSign>
ObstacleClusters::lane_id_stop_sign_map_;

void ObstacleClusters::Clear() {
lane_graphs_.clear();
lane_obstacles_.clear();
lane_id_stop_sign_map_.clear();
}

void ObstacleClusters::Init() { Clear(); }

const LaneGraph& ObstacleClusters::GetLaneGraph(
LaneGraph ObstacleClusters::GetLaneGraph(
const double start_s, const double length, const bool is_on_lane,
std::shared_ptr<const LaneInfo> lane_info_ptr) {
std::string lane_id = lane_info_ptr->id().id();
RoadGraph road_graph(start_s, length, is_on_lane, lane_info_ptr);
LaneGraph lane_graph;
road_graph.BuildLaneGraph(&lane_graph);
lane_graphs_[lane_id] = std::move(lane_graph);
return lane_graphs_[lane_id];
return lane_graph;
}

LaneGraph ObstacleClusters::GetLaneGraphWithoutMemorizing(
Expand Down
3 changes: 1 addition & 2 deletions modules/prediction/container/obstacles/obstacle_clusters.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class ObstacleClusters {
* @param lane info
* @return a corresponding lane graph
*/
static const LaneGraph& GetLaneGraph(
static LaneGraph GetLaneGraph(
const double start_s, const double length, const bool is_on_lane,
std::shared_ptr<const apollo::hdmap::LaneInfo> lane_info_ptr);

Expand Down Expand Up @@ -131,7 +131,6 @@ class ObstacleClusters {
static void Clear();

private:
static std::unordered_map<std::string, LaneGraph> lane_graphs_;
static std::unordered_map<std::string, std::vector<LaneObstacle>>
lane_obstacles_;
static std::unordered_map<std::string, StopSign> lane_id_stop_sign_map_;
Expand Down
3 changes: 3 additions & 0 deletions modules/prediction/predictor/extrapolation/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@ cc_library(
"-DMODULE_NAME=\\\"prediction\\\"",
],
deps = [
"//modules/common/math:geometry",
"//modules/prediction/common:prediction_gflags",
"//modules/prediction/common:prediction_map",
"//modules/prediction/common:prediction_util",
"//modules/prediction/container/obstacles:obstacle_clusters",
"//modules/prediction/container/obstacles:obstacles_container",
"//modules/prediction/predictor/sequence:sequence_predictor",
"//modules/prediction/proto:lane_graph_proto",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@
*****************************************************************************/

#include "modules/prediction/predictor/extrapolation/extrapolation_predictor.h"

#include "modules/common/math/vec2d.h"
#include "modules/prediction/common/prediction_gflags.h"
#include "modules/prediction/common/prediction_map.h"
#include "modules/prediction/container/obstacles/obstacle_clusters.h"
#include "modules/prediction/proto/lane_graph.pb.h"

namespace apollo {
namespace prediction {

using apollo::common::PathPoint;
using apollo::common::TrajectoryPoint;
using apollo::hdmap::LaneInfo;

ExtrapolationPredictor::ExtrapolationPredictor() {
predictor_type_ = ObstacleConf::EXTRAPOLATION_PREDICTOR;
Expand Down Expand Up @@ -56,7 +61,7 @@ void ExtrapolationPredictor::PostProcess(Trajectory* trajectory_ptr) {
ExtrapolationPredictor::LaneSearchResult
lane_search_result = SearchExtrapolationLane(*trajectory_ptr, kNumTailPoint);
if (lane_search_result.found) {
ExtrapolateByLane(lane_search_result.lane_id, trajectory_ptr);
ExtrapolateByLane(lane_search_result, trajectory_ptr);
} else {
ExtrapolateByFreeMove(kNumTailPoint, trajectory_ptr);
}
Expand Down Expand Up @@ -91,8 +96,77 @@ ExtrapolationPredictor::SearchExtrapolationLane(
}

void ExtrapolationPredictor::ExtrapolateByLane(
const std::string& lane_id, Trajectory* trajectory_ptr) {
// TODO(kechxu) implement
const LaneSearchResult& lane_search_result, Trajectory* trajectory_ptr) {
std::string start_lane_id = lane_search_result.lane_id;
int point_index = lane_search_result.point_index;
while (trajectory_ptr->trajectory_point_size() > point_index + 1) {
trajectory_ptr->mutable_trajectory_point()->RemoveLast();
}
auto lane_info_ptr = PredictionMap::LaneById(start_lane_id);
int num_trajectory_point = trajectory_ptr->trajectory_point_size();
const TrajectoryPoint& last_point =
trajectory_ptr->trajectory_point(num_trajectory_point);

Eigen::Vector2d position(last_point.path_point().x(),
last_point.path_point().y());
double lane_s = 0.0;
double lane_l = 0.0;
bool projected = PredictionMap::GetProjection(
position, lane_info_ptr, &lane_s, &lane_l);
if (!projected) {
AERROR << "Position (" << position.x() << ", " << position.y() << ") "
<< "cannot be projected onto lane [" << start_lane_id << "]";
return;
}

double last_relative_time = last_point.relative_time();
double speed = last_point.v();
double time_range = FLAGS_prediction_trajectory_time_length -
last_relative_time;
double time_resolution = FLAGS_prediction_trajectory_time_resolution;
double length = speed * time_range;

LaneGraph lane_graph = ObstacleClusters::GetLaneGraph(
lane_s, length, false, lane_info_ptr);
CHECK_EQ(lane_graph.lane_sequence_size(), 1);
const LaneSequence& lane_sequence = lane_graph.lane_sequence(0);
int lane_segment_index = 0;
std::string lane_id =
lane_sequence.lane_segment(lane_segment_index).lane_id();

int num_point_remained = static_cast<int>(time_range / time_resolution);
for (int i = 0; i < num_point_remained; ++i) {
double relative_time = last_relative_time +
static_cast<double>(i) * time_resolution;
Eigen::Vector2d point;
double theta = M_PI;
if (!PredictionMap::SmoothPointFromLane(lane_id, lane_s, lane_l, &point,
&theta)) {
AERROR << "Unable to get smooth point from lane [" << lane_id
<< "] with s [" << lane_s << "] and l [" << lane_l << "]";
break;
}
TrajectoryPoint* trajectory_point = trajectory_ptr->add_trajectory_point();
PathPoint* path_point = trajectory_point->mutable_path_point();
path_point->set_x(point.x());
path_point->set_y(point.y());
path_point->set_z(0.0);
path_point->set_theta(theta);
path_point->set_lane_id(lane_id);
trajectory_point->set_v(speed);
trajectory_point->set_a(0.0);
trajectory_point->set_relative_time(relative_time);

lane_s += speed * time_resolution;
while (lane_s > PredictionMap::LaneById(lane_id)->total_length() &&
lane_segment_index + 1 < lane_sequence.lane_segment_size()) {
lane_segment_index += 1;
lane_s = lane_s - PredictionMap::LaneById(lane_id)->total_length();
lane_id = lane_sequence.lane_segment(lane_segment_index).lane_id();
}

lane_l *= FLAGS_go_approach_rate;
}
}

void ExtrapolationPredictor::ExtrapolateByFreeMove(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ExtrapolationPredictor : public SequencePredictor {
LaneSearchResult SearchExtrapolationLane(const Trajectory& trajectory,
const int num_tail_point);

void ExtrapolateByLane(const std::string& lane_id,
void ExtrapolateByLane(const LaneSearchResult& lane_search_result,
Trajectory* trajectory_ptr);

void ExtrapolateByFreeMove(const int num_tail_point,
Expand Down

0 comments on commit f35a41e

Please sign in to comment.