Skip to content

Commit 1faeea4

Browse files
Testing agent type in obs
1 parent 56529c2 commit 1faeea4

File tree

5 files changed

+25
-10
lines changed

5 files changed

+25
-10
lines changed

examples/eval/notebooks/obtain_guidance_data.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
from gpudrive.env.dataset import SceneDataLoader
88
import torch
99
import numpy as np
10-
10+
from gpudrive.datatypes.observation import LocalEgoState
1111

1212
if __name__ == "__main__":
1313

14-
GUIDANCE_MODE = "vbd_online"
15-
DATASET = "data/processed/wosac/validation_json_100" # Ensure VBD trajectory structures are in here
14+
GUIDANCE_MODE = "log_replay"
15+
DATASET = "data/processed/wosac/validation_interactive/json" # Ensure VBD trajectory structures are in here
1616
SAVE_PATH = "examples/eval/figures_data/guidance/"
1717

1818
env_config = EnvConfig(
@@ -30,7 +30,7 @@
3030

3131
train_loader = SceneDataLoader(
3232
root=DATASET,
33-
batch_size=10,
33+
batch_size=1,
3434
dataset_size=100,
3535
sample_with_replacement=False,
3636
shuffle=False,
@@ -41,11 +41,20 @@
4141
config=env_config,
4242
data_loader=train_loader,
4343
max_cont_agents=32,
44-
device="cuda",
44+
device="cpu",
4545
)
4646

4747
obs = env.reset(env.cont_agent_mask)
4848

49+
50+
# Get agent types
51+
agent_types = LocalEgoState.from_tensor(
52+
self_obs_tensor=env.sim.self_observation_tensor(),
53+
backend='torch',
54+
device="cpu",
55+
#mask=mask,
56+
).agent_type
57+
4958
# Save for analysis
5059
reference_traj = torch.cat(
5160
[

gpudrive/datatypes/observation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(self, self_obs_tensor: torch.Tensor, mask=None):
4141
self.is_goal_reached = self_obs_tensor[:, 7]
4242
self.id = self_obs_tensor[:, 8]
4343
self.steer_angle = self_obs_tensor[:, 9]
44+
self.agent_type = self_obs_tensor[:, 10].long()
4445
else:
4546
self.speed = self_obs_tensor[:, :, 0]
4647
self.vehicle_length = self_obs_tensor[:, :, 1] * AGENT_SCALE
@@ -52,6 +53,7 @@ def __init__(self, self_obs_tensor: torch.Tensor, mask=None):
5253
self.is_goal_reached = self_obs_tensor[:, :, 7]
5354
self.id = self_obs_tensor[:, :, 8]
5455
self.steer_angle = self_obs_tensor[:, :, 9]
56+
self.agent_type = self_obs_tensor[:, :, 10].long()
5557

5658
@classmethod
5759
def from_tensor(
@@ -89,7 +91,7 @@ def normalize(self):
8991
min_val=constants.MIN_REL_GOAL_COORD,
9092
max_val=constants.MAX_REL_GOAL_COORD,
9193
)
92-
self.steer_angle /= (torch.pi / 3)
94+
self.steer_angle /= torch.pi / 3
9395

9496
@property
9597
def shape(self) -> tuple[int, ...]:

gpudrive/env/env_torch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1753,7 +1753,7 @@ def swap_data_batch(self, data_batch=None):
17531753
controlled_agent_mask=self.cont_agent_mask,
17541754
reference_trajectory=self.reference_trajectory,
17551755
)
1756-
1756+
17571757
self.guidance_dropout_mask = self.create_guidance_dropout_mask()
17581758

17591759
def get_expert_actions(self):
@@ -1898,7 +1898,7 @@ def render(self, focus_env_idx=0, focus_agent_idx=[0, 1]):
18981898

18991899
# Create data loader
19001900
train_loader = SceneDataLoader(
1901-
root="data/processed/validation",
1901+
root="data/processed/wosac/validation_interactive",
19021902
batch_size=1,
19031903
dataset_size=100,
19041904
sample_with_replacement=False,

src/sim.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ inline void collectSelfObsSystem(Engine &ctx,
194194
const Info& info = ctx.get<Info>(agent_iface.e);
195195
self_obs.goalState = info.reachedGoal ? 1.f : 0.f;
196196
self_obs.steerAngle = vel.angular.z;
197+
self_obs.type = (float)ctx.get<EntityType>(agent_iface.e);
198+
std::cout << "Self Obs Type: " << self_obs.type << std::endl;
197199
}
198200

199201
inline void collectPartnerObsSystem(Engine &ctx,

src/types.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ namespace madrona_gpudrive
194194
float goalState;
195195
float id;
196196
float steerAngle;
197+
float type;
197198
static inline SelfObservation zero()
198199
{
199200
return SelfObservation{
@@ -203,11 +204,12 @@ namespace madrona_gpudrive
203204
.collisionState = 0,
204205
.goalState = 0,
205206
.id = -1,
206-
.steerAngle = 0};
207+
.steerAngle = 0,
208+
.type = static_cast<float>(EntityType::Padding)};
207209
}
208210
};
209211

210-
const size_t SelfObservationExportSize = 10; // 1 + 3 + 2 + 1 + 1 + 1
212+
const size_t SelfObservationExportSize = 11; // 1 + 3 + 2 + 1 + 1 + 1
211213

212214
static_assert(sizeof(SelfObservation) == sizeof(float) * SelfObservationExportSize);
213215

0 commit comments

Comments
 (0)