@@ -163,8 +163,8 @@ inline void collectSelfObsSystem(Engine &ctx,
163163 auto &self_obs = ctx.get <SelfObservation>(agent_iface.e );
164164 self_obs.speed = vel.linear .length ();
165165 self_obs.vehicle_size = size;
166- auto goalPos = goal.position - pos. xy () ;
167- self_obs.goal .position = rot.inv ().rotateVec ({goalPos.x , goalPos.y , 0 }). xy ( );
166+ auto goalPos = goal.position - pos;
167+ self_obs.goal .position = rot.inv ().rotateVec ({goalPos.x , goalPos.y , goalPos. z } );
168168
169169 auto hasCollided = collisionEvent.hasCollided .load_relaxed ();
170170 self_obs.collisionState = hasCollided ? 1 .f : 0 .f ;
@@ -231,13 +231,13 @@ inline void collectMapObservationsSystem(Engine &ctx,
231231 const auto alg = ctx.data ().params .roadObservationAlgorithm ;
232232 if (alg == FindRoadObservationsWith::KNearestEntitiesWithRadiusFiltering) {
233233 selectKNearestRoadEntities<consts::kMaxAgentMapObservationsCount >(
234- ctx, rot, pos. xy () , map_obs.obs );
234+ ctx, rot, pos, map_obs.obs );
235235 return ;
236236 }
237237
238238 assert (alg == FindRoadObservationsWith::AllEntitiesWithRadiusFiltering);
239239
240- utils::ReferenceFrame referenceFrame (pos. xy () , rot);
240+ utils::ReferenceFrame referenceFrame (pos, rot);
241241 CountT arrIndex = 0 ; CountT roadIdx = 0 ;
242242 while (roadIdx < ctx.data ().numRoads && arrIndex < consts::kMaxAgentMapObservationsCount ) {
243243 Entity road = ctx.data ().roads [roadIdx++];
@@ -452,13 +452,13 @@ inline void rewardSystem(Engine &ctx,
452452 const auto &rewardType = ctx.data ().params .rewardParams .rewardType ;
453453 if (rewardType == RewardType::DistanceBased)
454454 {
455- float dist = (position. xy () - goal.position ).length ();
455+ float dist = (position - goal.position ).length ();
456456 float reward = -dist;
457457 out_reward.v = reward;
458458 }
459459 else if (rewardType == RewardType::OnGoalAchieved)
460460 {
461- float dist = (position. xy () - goal.position ).length ();
461+ float dist = (position - goal.position ).length ();
462462 float reward = (dist < ctx.data ().params .rewardParams .distanceToGoalThreshold ) ? 1 .f : 0 .f ;
463463 out_reward.v = reward;
464464 }
@@ -502,7 +502,7 @@ inline void doneSystem(Engine &ctx,
502502 // An agent can be done early if it reaches the goal
503503 if (done.v != 1 || info.reachedGoal != 1 )
504504 {
505- float dist = (position. xy () - goal.position ).length ();
505+ float dist = (position - goal.position ).length ();
506506 if (dist < ctx.data ().params .rewardParams .distanceToGoalThreshold )
507507 {
508508 done.v = 1 ;
0 commit comments