Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions baselines/ppo/config/ppo_base_puffer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ environment: # Overrides default environment configs (see pygpudrive/env/config.
name: "gpudrive"
num_worlds: 75 # Number of parallel environments
k_unique_scenes: 75 # Number of unique scenes to sample from
max_controlled_agents: 64 # Maximum number of agents controlled by the model. Make sure this aligns with the variable kMaxAgentCount in src/consts.hpp
max_controlled_agents: 32 # arbitrary number between 1 and src/consts.hpp::kMaxAgentCount
ego_state: true
road_map_obs: true
partner_obs: true
Expand Down Expand Up @@ -77,7 +77,7 @@ train:
vf_coef: 0.3
max_grad_norm: 0.5
target_kl: null
log_window: 1000
log_window: 10

# # # Network # # #
network:
Expand Down
220 changes: 142 additions & 78 deletions gpudrive/env/env_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
render_format="mp4",
render_fps=15,
zoom_radius=50,
minimum_frames_to_log=50,
buf=None,
**kwargs,
):
Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(
self.render_format = render_format
self.render_fps = render_fps
self.zoom_radius = zoom_radius
self.minimum_frames_to_log = minimum_frames_to_log

# VBD
self.vbd_model_path = vbd_model_path
Expand Down Expand Up @@ -169,8 +171,11 @@ def __init__(
self.observations = self.env.reset(self.controlled_agent_mask)

self.masks = torch.ones(self.num_agents, dtype=bool)
self.world_size = self.controlled_agent_mask.shape[1]
# Action tensor must match simulator's expected shape: (num_worlds, max_num_agents_in_scene)
# The simulator will only use actions for agents marked as controlled in cont_agent_mask
self.actions = torch.zeros(
(self.num_worlds, self.max_cont_agents_per_env), dtype=torch.int64
(self.num_worlds, self.world_size), dtype=torch.int64
).to(self.device)

# Setup rendering storage
Expand Down Expand Up @@ -211,37 +216,79 @@ def reset(self, seed=None):
self.num_agents, dtype=torch.float32
).to(self.device)
self.agent_episode_returns = torch.zeros(
(self.num_worlds, self.max_cont_agents_per_env),
(self.num_worlds, self.world_size),
dtype=torch.float32,
).to(self.device)
self.episode_lengths = torch.zeros(
(self.num_worlds, self.max_cont_agents_per_env),
(self.num_worlds, self.world_size),
dtype=torch.float32,
).to(self.device)
self.live_agent_mask = torch.ones(
(self.num_worlds, self.max_cont_agents_per_env), dtype=bool
(self.num_worlds, self.world_size), dtype=bool
).to(self.device)
self.collided_in_episode = torch.zeros(
(self.num_worlds, self.max_cont_agents_per_env),
(self.num_worlds, self.world_size),
dtype=torch.float32,
).to(self.device)
self.offroad_in_episode = torch.zeros(
(self.num_worlds, self.max_cont_agents_per_env),
(self.num_worlds, self.world_size),
dtype=torch.float32,
).to(self.device)

self.initialize_tracking()

return self.observations, []

def initialize_tracking(self):
self.done_or_truncated_worlds = torch.zeros(self.num_worlds, dtype=torch.int32).to(self.device)
self.goal_achieved_mask = torch.zeros(
(self.num_worlds, self.world_size),
dtype=torch.int32
).to(self.device)
self.collided_mask = torch.zeros(
(self.num_worlds, self.world_size),
dtype=torch.int32
).to(self.device)
self.offroad_mask = torch.zeros(
(self.num_worlds, self.world_size),
dtype=torch.int32
).to(self.device)
self.truncated_mask = torch.zeros(
(self.num_worlds, self.world_size),
dtype=torch.int32
).to(self.device)
self.reward_agent = torch.zeros(
(self.num_worlds, self.world_size),
dtype=torch.float32
).to(self.device)
self.episode_length_agent = torch.zeros(
(self.num_worlds, self.world_size),
dtype=torch.float32
).to(self.device)
self.total_offroad_count = torch.zeros(
(self.num_worlds, self.world_size),
dtype=torch.int32
).to(self.device)
self.total_collided_count = torch.zeros(
(self.num_worlds, self.world_size),
dtype=torch.int32
).to(self.device)


def step(self, action):
"""
Step the environment with the given actions. Note that we reset worlds
asynchronously when they are done.
Args:
action: A numpy array of actions for the controlled agents. Shape:
(num_worlds, max_cont_agents_per_env)
(total_controlled_agents,) - will be mapped to controlled positions
in the (num_worlds, max_num_agents_in_scene) action tensor
"""

# Set the action for the controlled agents
# print(f"action shape: {action.shape}")
# print(f"self.controlled_agent_mask shape: {self.controlled_agent_mask.shape}")
# print(f"total controlled agents: {self.controlled_agent_mask.sum().item()}")
self.actions[self.controlled_agent_mask] = action

# Step the simulator with controlled agents actions
Expand All @@ -262,10 +309,28 @@ def step(self, action):

# Check if any worlds are done (terminal or truncated)
controlled_per_world = self.controlled_agent_mask.sum(dim=1)
done_worlds = torch.where(

# Worlds where all controlled agents are terminal
terminal_done_worlds = torch.where(
(terminal * self.controlled_agent_mask).sum(dim=1)
== controlled_per_world
)[0]

# Worlds where episodes have reached maximum length (truncated)
max_episode_length = self.env.episode_len
truncated_done_worlds = torch.where(
self.episode_lengths[:, 0] >= max_episode_length
)[0]

# Combine both types of done worlds
if len(terminal_done_worlds) > 0 and len(truncated_done_worlds) > 0:
done_worlds = torch.unique(torch.cat([terminal_done_worlds, truncated_done_worlds]))
elif len(terminal_done_worlds) > 0:
done_worlds = terminal_done_worlds
elif len(truncated_done_worlds) > 0:
done_worlds = truncated_done_worlds
else:
done_worlds = torch.tensor([], dtype=torch.long, device=self.device)
done_worlds_cpu = done_worlds.cpu().numpy()

# Add rewards for living agents
Expand All @@ -284,7 +349,7 @@ def step(self, action):
self.masks = self.live_agent_mask[self.controlled_agent_mask]

# Set the mask to False for _agents_ that are terminated for the next step
# Shape: (num_worlds, max_cont_agents_per_env)
# Shape: (num_worlds, world_size)
self.live_agent_mask[terminal] = 0

# Truncated is defined as not crashed nor goal achieved
Expand All @@ -300,75 +365,41 @@ def step(self, action):
terminal = terminal[self.controlled_agent_mask]

info_lst = []
if len(done_worlds) > 0:

if self.render:
for render_env_idx in range(self.render_k_scenarios):
self.log_video_to_wandb(render_env_idx, done_worlds)

# Log episode statistics
controlled_mask = self.controlled_agent_mask[
done_worlds, :
].clone()

num_finished_agents = controlled_mask.sum().item()

# Collision rates are summed across all agents in the episode
off_road_rate = (
torch.where(
self.offroad_in_episode[done_worlds, :][controlled_mask]
> 0,
1,
0,
).sum()
/ num_finished_agents
)
collision_rate = (
torch.where(
self.collided_in_episode[done_worlds, :][controlled_mask]
> 0,
1,
0,
).sum()
/ num_finished_agents

if self.render:
for render_env_idx in range(self.render_k_scenarios):
self.log_video_to_wandb(render_env_idx, done_worlds)

if(len(done_worlds) > 0):
self.done_or_truncated_worlds[done_worlds] = 1
done_world_mask = torch.zeros_like(self.controlled_agent_mask, dtype=torch.bool)
done_world_mask[done_worlds, :] = True
combined_mask = done_world_mask & self.controlled_agent_mask

# Now use the combined mask for proper assignment
self.goal_achieved_mask[combined_mask] = torch.where(
self.env.get_infos().goal_achieved[combined_mask].to(torch.int32) > 0,
torch.tensor(1, dtype=torch.int32),
torch.tensor(0, dtype=torch.int32),
)
goal_achieved_rate = (
self.env.get_infos()
.goal_achieved[done_worlds, :][controlled_mask]
.sum()
/ num_finished_agents
self.collided_mask[combined_mask] = torch.where(
self.collided_in_episode[combined_mask].to(torch.int32) > 0,
torch.tensor(1, dtype=torch.int32),
torch.tensor(0, dtype=torch.int32),
)

total_collisions = self.collided_in_episode[done_worlds, :].sum()
total_off_road = self.offroad_in_episode[done_worlds, :].sum()

agent_episode_returns = self.agent_episode_returns[done_worlds, :][
controlled_mask
]

num_truncated = (
truncated[done_worlds, :][controlled_mask].sum().item()
self.offroad_mask[combined_mask] = torch.where(
self.offroad_in_episode[combined_mask].to(torch.int32) > 0,
torch.tensor(1, dtype=torch.int32),
torch.tensor(0, dtype=torch.int32),
)

if num_finished_agents > 0:
# fmt: off
info_lst.append(
{
"mean_episode_reward_per_agent": agent_episode_returns.mean().item(),
"perc_goal_achieved": goal_achieved_rate.item(),
"perc_off_road": off_road_rate.item(),
"perc_veh_collisions": collision_rate.item(),
"total_controlled_agents": self.num_agents,
"control_density": self.num_agents / self.controlled_agent_mask.numel(),
"episode_length": self.episode_lengths[done_worlds, :].mean().item(),
"perc_truncated": num_truncated / num_finished_agents,
"num_completed_episodes": len(done_worlds),
"total_collisions": total_collisions.item(),
"total_off_road": total_off_road.item(),
}
)
# fmt: on

self.total_collided_count[combined_mask] = self.collided_in_episode[combined_mask].sum().to(torch.int32)
self.total_offroad_count[combined_mask] = self.offroad_in_episode[combined_mask].sum().to(torch.int32)

self.truncated_mask[combined_mask] = truncated[combined_mask].to(torch.int32)
self.reward_agent[combined_mask] = self.agent_episode_returns[combined_mask]
self.episode_length_agent[combined_mask] = self.episode_lengths[combined_mask]

# reset the done_worlds
# Get obs for the last terminal step (before reset)
self.last_obs = self.env.get_obs(self.controlled_agent_mask)

Expand All @@ -384,6 +415,36 @@ def step(self, action):
]
self.offroad_in_episode[done_worlds, :] = 0
self.collided_in_episode[done_worlds, :] = 0

if(self.done_or_truncated_worlds.sum().item() == self.num_worlds):
# we have finished all synced worlds, now we can log the data
goal_achieved_rate = self.goal_achieved_mask.sum() / self.num_agents
off_road_rate = self.offroad_mask.sum() / self.num_agents
collision_rate = self.collided_mask.sum() / self.num_agents
truncated_rate = self.truncated_mask.sum() / self.num_agents
crashed = self.collided_mask | self.offroad_mask
crashed_rate = crashed.sum() / self.num_agents
mean_episode_reward = self.reward_agent.sum() / self.num_agents

# print(f"mean episode reward per agent: {mean_episode_reward.item()}")
# print(f"goal_achieved_rate: {goal_achieved_rate.item()}, off_road_rate: {off_road_rate.item()}, collision_rate: {collision_rate.item()}, truncated_rate: {truncated_rate.item()}, PercentCrashedorGoalAchievedorTruncated: {goal_achieved_rate.item() + crashed_rate.item() + truncated_rate.item()}")

info_lst.append(
{
"perc_goal_achieved": goal_achieved_rate.item(),
"perc_crashed(collided or offroad)": crashed_rate.item(),
"perc_off_road": off_road_rate.item(),
"perc_veh_collisions": collision_rate.item(),
"perc_truncated": truncated_rate.item(),
"mean_episode_reward_per_agent": mean_episode_reward.item(),
"episode_length": self.episode_length_agent.mean().item(),
"total_offroad_count": self.total_offroad_count.sum().item(),
"total_collided_count": self.total_collided_count.sum().item(),
}
)

# reset the tracking variables
self.initialize_tracking()

# Get the next observations. Note that we do this after resetting
# the worlds so that we always return a fresh observation
Expand Down Expand Up @@ -462,11 +523,14 @@ def clear_render_storage(self):

def log_video_to_wandb(self, render_env_idx, done_worlds):
"""Log arrays as videos to wandb."""
# if(len(self.frames[render_env_idx]) > 0):
# print(f"iter: {self.iters}, render_env_idx: {render_env_idx}, frames length: {len(self.frames[render_env_idx])}, done_worlds: {done_worlds}")
if (
render_env_idx in done_worlds
and len(self.frames[render_env_idx]) > 0
(render_env_idx in done_worlds and len(self.frames[render_env_idx]) > 0)
or len(self.frames[render_env_idx]) > self.minimum_frames_to_log
):
frames_array = np.array(self.frames[render_env_idx])
# print(f"frames shape: {frames_array.shape}")
self.wandb_obj.log(
{
f"vis/state/env_{render_env_idx}": wandb.Video(
Expand Down Expand Up @@ -511,4 +575,4 @@ def log_data_coverage(self):
* 100,
},
step=self.global_step,
)
)
9 changes: 2 additions & 7 deletions gpudrive/integrations/puffer/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,17 +205,12 @@ def evaluate(data):

with profile.eval_misc:
data.stats = {}

# Store the average across K done worlds across last N rollouts
# ensure we are logging an unbiased estimate of the performance
if sum(data.infos["num_completed_episodes"]) > data.config.log_window:
if len(data.infos["perc_goal_achieved"]) > data.config.log_window:
for k, v in data.infos.items():
try:
if "num_completed_episodes" in k:
data.stats[k] = np.sum(v)
else:
data.stats[k] = np.mean(v)

data.stats[k] = np.mean(v)
# Log variance for goal and collision metrics
if "goal" in k:
data.stats[f"std_{k}"] = np.std(v)
Expand Down