Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions vmas/scenarios/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
self.package_mass = kwargs.get("package_mass", 3)

# partial obs
self.partial_observations = kwargs.get("partial_observations", False)
self.package_observation_radius = kwargs.get("package_observation_radius", 0.35)
self.partial_observations = kwargs.get("partial_observations", True)
self.package_observation_dist = kwargs.get("package_observation_dist", 0.35)

# realism
self.linear_friction = kwargs.get("linear_friction", 0.01)
Expand Down Expand Up @@ -131,6 +131,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):

# Add agents
capabilities = [] # save capabilities for relative capabilities later
self.observation_sensors = [] # for partial observability
for i in range(n_agents):
max_linear_vel = self.default_agent_max_linear_vel * random.uniform(self.capability_mult_min, self.capability_mult_max)
max_angular_vel = self.default_agent_max_angular_vel * random.uniform(self.capability_mult_min, self.capability_mult_max)
Expand All @@ -152,6 +153,20 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):

world.add_agent(agent)

# add the observation sensor if partial observability is turned on.
if self.partial_observations:
self.observation_sensors.append(
Landmark(
name=f'obs_sensor_agent_{i}',
collide=False,
shape=Sphere(radius=self.package_observation_dist+radius),
color=(0.827, 0.827, 0.827, 0.65),
movable=False,
)
)
world.add_landmark(self.observation_sensors[-1])


self.capabilities = torch.tensor(capabilities)

# Add landmarks
Expand Down Expand Up @@ -191,7 +206,7 @@ def reset_world_at(self, env_index: int = None):
# only do this during batched resets!
if not env_index:
capabilities = [] # save capabilities for relative capabilities later
for agent in self.world.agents:
for i, agent in enumerate(self.world.agents):
max_linear_vel = self.default_agent_max_linear_vel * random.uniform(self.capability_mult_min, self.capability_mult_max)
max_angular_vel = self.default_agent_max_angular_vel * random.uniform(self.capability_mult_min, self.capability_mult_max)
radius = self.default_agent_radius * random.uniform(self.capability_mult_min, self.capability_mult_max)
Expand All @@ -203,6 +218,12 @@ def reset_world_at(self, env_index: int = None):
agent.shape=Sphere(radius)
agent.mass=mass

# spawn the sensor radius for each agent
if self.partial_observations:
self.observation_sensors[i].set_pos(self.world.agents[i].state.pos, env_index)
self.observation_sensors[i].shape = Sphere(self.package_observation_dist+radius)


self.capabilities = torch.tensor(capabilities)

# spawn goal at origin
Expand Down Expand Up @@ -256,7 +277,7 @@ def reset_world_at(self, env_index: int = None):
),
occupied_positions=package_occupied_pos,
)

self.package_starting_dists = []
self.og_package_positions = []
for i, package in enumerate(self.packages):
Expand Down Expand Up @@ -444,12 +465,17 @@ def partial_observation(self, agent: Agent):
# get positions of all entities in this agent's reference frame
package_obs = []
out_of_obs_val = -0.0001 # default value used for out-of-observation data in the observation vector

# spawn the sensor radius for each agent
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the sensor always attached to the agents?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is reset in the observation function for partial observability.

for i, agent_i_sensor in enumerate(self.observation_sensors):
agent_i_sensor.set_pos(self.world.agents[i].state.pos, None)

for i, package in enumerate(self.packages):
# box starting position and goal position alway part of the observation
package_obs.append(self.og_package_positions[i])
package_obs.append(package.on_goal.unsqueeze(-1))

mask = (torch.linalg.vector_norm(package.state.pos - agent.state.pos, dim=-1) < self.package_observation_radius)
mask = self.world.is_overlapping(self.observation_sensors[i], package)
pkg_state_vec = package.state.pos.clone()
pkg_rot_vec = package.state.rot.clone()
pkg_vel_vec = package.state.vel.clone()
Expand Down Expand Up @@ -606,15 +632,6 @@ def extra_render(self, env_index: int = 0) -> "List[Geom]":
geoms: List[Geom] = []
if not self.partial_observations:
return geoms

for i, agent in enumerate(self.world.agents):

obs_circle = rendering.make_circle(self.package_observation_radius, filled=True)
xform = rendering.Transform()
xform.set_translation(*agent.state.pos[env_index])
obs_circle.add_attr(xform)
obs_circle.set_color(*(0.827, 0.827, 0.827, 0.65))
geoms.append(obs_circle)

return geoms

Expand Down