Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 193 additions & 10 deletions vmas/scenarios/balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# ProrokLab (https://www.proroklab.org/)
# All rights reserved.

import numpy as np
import random
import torch

from vmas import render_interactively
Expand All @@ -12,27 +14,110 @@


class Scenario(BaseScenario):
def get_rng_state(self, device):
"""
Returns a tuple of the form
(numpy random state, python's random state, torch's random state, torch.cuda's random state)
"""
np_rng_state = np.random.get_state()
py_rng_state = random.getstate()
torch_rng_state = torch.get_rng_state()
torch_cuda_rng_state = torch.cuda.get_rng_state(device)

return (np_rng_state, py_rng_state, torch_rng_state, torch_cuda_rng_state)

def set_eval_seed(self, eval_seed):
"""
Set a new seed for numpy, python.random, torch.random, and torch.cuda.random.

Intended to be used only with eval_seed + wrapped by get/set_rng_state().
"""
torch.manual_seed(self.eval_seed)
torch.cuda.manual_seed_all(self.eval_seed)
random.seed(self.eval_seed)
np.random.seed(self.eval_seed)

def set_rng_state(self, old_rng_state, device):
"""
Restore the prior RNG state (based on the return value of get_rng_state).
"""
assert old_rng_state is not None, "set_rng_state() must be called with the return value of get_rng_state()!"

np_rng_state, py_rng_state, torch_rng_state, torch_cuda_rng_state = old_rng_state

np.random.set_state(np_rng_state)
random.setstate(py_rng_state)
torch.set_rng_state(torch_rng_state)
torch.cuda.set_rng_state(torch_cuda_rng_state, device)

def make_world(self, batch_dim: int, device: torch.device, **kwargs):
self.n_agents = kwargs.get("n_agents", 3)
self.package_mass = kwargs.get("package_mass", 5)
self.random_package_pos_on_line = kwargs.get("random_package_pos_on_line", True)
self.world_semidim = kwargs.get("world_semidim", 1.0)
self.gravity = kwargs.get("gravity", -0.05)
self.eval_seed = kwargs.get("eval_seed", None)

# capabilities
self.capability_mult_range = kwargs.get("capability_mult_range", [0.75, 1.25])
self.multiple_ranges = kwargs.get("multiple_ranges", False)
if not self.multiple_ranges:
self.capability_mult_min = self.capability_mult_range[0]
self.capability_mult_max = self.capability_mult_range[1]
self.capability_representation = kwargs.get("capability_representation", "raw")
self.default_u_multiplier = kwargs.get("default_u_multiplier", 0.7)
self.default_agent_radius = kwargs.get("default_agent_radius", 0.03)
self.default_agent_mass = kwargs.get("default_agent_mass", 1)

# metrics
self.success_rate = None

# rng
rng_state = None
if self.eval_seed:
rng_state = self.get_rng_state(device)
self.set_eval_seed(self.eval_seed)

assert self.n_agents > 1

self.line_length = 0.8
self.agent_radius = 0.03

self.shaping_factor = 100
self.fall_reward = -10
self.shaping_factor = 1
self.fall_reward = -0.1

# Make world
world = World(batch_dim, device, gravity=(0.0, -0.05), y_semidim=1)
world = World(batch_dim, device, gravity=(0.0, self.gravity), y_semidim=self.world_semidim)
# Add agents
capabilities = [] # save capabilities for relative capabilities later
for i in range(self.n_agents):
if self.multiple_ranges:
cap_idx = int(random.choice(np.arange(len(self.capability_mult_range))))
self.capability_mult_min = self.capability_mult_range[cap_idx][0]
self.capability_mult_max = self.capability_mult_range[cap_idx][1]
print("MADE IT HERE")
max_u = self.default_u_multiplier * random.uniform(self.capability_mult_min, self.capability_mult_max)
if self.multiple_ranges:
cap_idx = int(random.choice(np.arange(len(self.capability_mult_range))))
self.capability_mult_min = self.capability_mult_range[cap_idx][0]
self.capability_mult_max = self.capability_mult_range[cap_idx][1]
radius = self.default_agent_radius * random.uniform(self.capability_mult_min, self.capability_mult_max)
if self.multiple_ranges:
cap_idx = int(random.choice(np.arange(len(self.capability_mult_range))))
self.capability_mult_min = self.capability_mult_range[cap_idx][0]
self.capability_mult_max = self.capability_mult_range[cap_idx][1]
mass = self.default_agent_mass * random.uniform(self.capability_mult_min, self.capability_mult_max)

agent = Agent(
name=f"agent_{i}", shape=Sphere(self.agent_radius), u_multiplier=0.7
name=f"agent_{i}",
shape=Sphere(radius),
u_multiplier=max_u,
mass=mass,
render_action=True,
)
capabilities.append([max_u, agent.shape.radius, agent.mass])
world.add_agent(agent)
self.capabilities = torch.tensor(capabilities)
print(self.capabilities)

goal = Landmark(
name="goal",
Expand Down Expand Up @@ -75,9 +160,46 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
self.pos_rew = torch.zeros(batch_dim, device=device, dtype=torch.float32)
self.ground_rew = self.pos_rew.clone()

if self.eval_seed:
self.set_rng_state(rng_state, device)

return world

def reset_world_at(self, env_index: int = None):
rng_state = None
if self.eval_seed:
rng_state = self.get_rng_state(self.world.device)
self.set_eval_seed(self.eval_seed)

# reset capabilities, only do this during batched resets!
if not env_index:
capabilities = [] # save capabilities for relative capabilities later
for agent in self.world.agents:
if self.multiple_ranges:
cap_idx = int(random.choice(np.arange(len(self.capability_mult_range))))
self.capability_mult_min = self.capability_mult_range[cap_idx][0]
self.capability_mult_max = self.capability_mult_range[cap_idx][1]
max_u = self.default_u_multiplier * random.uniform(self.capability_mult_min, self.capability_mult_max)
if self.multiple_ranges:
cap_idx = int(random.choice(np.arange(len(self.capability_mult_range))))
self.capability_mult_min = self.capability_mult_range[cap_idx][0]
self.capability_mult_max = self.capability_mult_range[cap_idx][1]
radius = self.default_agent_radius * random.uniform(self.capability_mult_min, self.capability_mult_max)
if self.multiple_ranges:
cap_idx = int(random.choice(np.arange(len(self.capability_mult_range))))
self.capability_mult_min = self.capability_mult_range[cap_idx][0]
self.capability_mult_max = self.capability_mult_range[cap_idx][1]
mass = self.default_agent_mass * random.uniform(self.capability_mult_min, self.capability_mult_max)

# capabilities.append([max_u, agent.shape.radius, agent.mass])
capabilities.append([max_u, radius, mass])

agent.u_multiplier=max_u
agent.shape=Sphere(radius)
agent.mass=mass

self.capabilities = torch.tensor(capabilities)

goal_pos = torch.cat(
[
torch.zeros(
Expand Down Expand Up @@ -111,7 +233,8 @@ def reset_world_at(self, env_index: int = None):
),
torch.full(
(1, 1) if env_index is not None else (self.world.batch_dim, 1),
-self.world.y_semidim + self.agent_radius * 2,
-self.world.y_semidim + self.default_agent_radius * self.capability_mult_max * 2 if not self.multiple_ranges else \
-self.world.y_semidim + self.default_agent_radius * self.capability_mult_range[-1][1] * 2,
device=self.world.device,
dtype=torch.float32,
),
Expand Down Expand Up @@ -151,7 +274,7 @@ def reset_world_at(self, env_index: int = None):
+ i
* (self.line_length - agent.shape.radius)
/ (self.n_agents - 1),
-self.agent_radius * 2,
-agent.shape.radius * 2,
],
device=self.world.device,
dtype=torch.float32,
Expand Down Expand Up @@ -182,7 +305,10 @@ def reset_world_at(self, env_index: int = None):
0,
-self.world.y_semidim
- self.floor.shape.width / 2
- self.agent_radius,
- (
self.default_agent_radius * self.capability_mult_max if not self.multiple_ranges else \
self.default_agent_radius * self.capability_mult_range[-1][1]
),
],
device=self.world.device,
),
Expand All @@ -205,6 +331,9 @@ def reset_world_at(self, env_index: int = None):
* self.shaping_factor
)

if self.eval_seed:
self.set_rng_state(rng_state, self.world.device)

def compute_on_the_ground(self):
self.on_the_ground = self.world.is_overlapping(
self.line, self.floor
Expand All @@ -230,8 +359,59 @@ def reward(self, agent: Agent):

return self.ground_rew + self.pos_rew

def get_capability_repr(self, agent: Agent):
"""
Get capability representation:
raw = raw multiplier values
relative = zero-meaned (taking mean of team into account)
mixed = raw + relative (concatenated)
"""
# agent's normal capabilities
max_u = agent.u_multiplier
radius = agent.shape.radius
mass = agent.mass

# compute the mean capabilities across the team's agents
# then compute "relative capability" of this agent by subtracting the mean
team_mean = list(torch.mean(self.capabilities, dim=0))
rel_max_u = max_u - team_mean[0].item()
rel_radius = radius - team_mean[1].item()
rel_mass = mass - team_mean[2].item()

raw_capability_repr = [
torch.tensor(
max_u, device=self.world.device
).repeat(self.world.batch_dim, 1),
torch.tensor(
radius, device=self.world.device
).repeat(self.world.batch_dim, 1),
torch.tensor(
mass, device=self.world.device
).repeat(self.world.batch_dim, 1),
]

rel_capability_repr = [
torch.tensor(
rel_max_u, device=self.world.device
).repeat(self.world.batch_dim, 1),
torch.tensor(
rel_radius, device=self.world.device
).repeat(self.world.batch_dim, 1),
torch.tensor(
rel_mass, device=self.world.device
).repeat(self.world.batch_dim, 1),
]

if self.capability_representation == "raw":
return raw_capability_repr
elif self.capability_representation == "relative":
return rel_capability_repr
elif self.capability_representation == "mixed":
return raw_capability_repr + rel_capability_repr

def observation(self, agent: Agent):
# get positions of all entities in this agent's reference frame
capability_repr = self.get_capability_repr(agent)
return torch.cat(
[
agent.state.pos,
Expand All @@ -243,7 +423,7 @@ def observation(self, agent: Agent):
self.line.state.vel,
self.line.state.ang_vel,
self.line.state.rot % torch.pi,
],
] + capability_repr,
dim=-1,
)

Expand All @@ -253,7 +433,10 @@ def done(self):
)

def info(self, agent: Agent):
info = {"pos_rew": self.pos_rew, "ground_rew": self.ground_rew}
self.success_rate = self.world.is_overlapping(
self.package, self.package.goal
)
info = {"pos_rew": self.pos_rew, "ground_rew": self.ground_rew, "success_rate": self.success_rate}
return info


Expand Down