Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vmas/scenarios/debug/waterfall.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
anchor_a=(1, 0),
anchor_b=(-1, 0),
dist=self.agent_dist,
rotate_a=True,
rotate_b=True,
rotate_a=False,
rotate_b=False,
collidable=True,
width=0,
mass=1,
Expand Down
56 changes: 47 additions & 9 deletions vmas/simulator/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
Observable,
override,
TorchUtils,
TORQUE_CONSTRAINT_FORCE,
X,
Y,
)
Expand Down Expand Up @@ -1079,6 +1080,7 @@ def __init__(
dim_c: int = 0,
collision_force: float = COLLISION_FORCE,
joint_force: float = JOINT_FORCE,
torque_constraint_force: float = TORQUE_CONSTRAINT_FORCE,
contact_margin: float = 1e-3,
gravity: Tuple[float, float] = (0.0, 0.0),
):
Expand Down Expand Up @@ -1110,6 +1112,7 @@ def __init__(
self._collision_force = collision_force
self._joint_force = joint_force
self._contact_margin = contact_margin
self._torque_constraint_force = torque_constraint_force
# joints
self._joints = {}
# Pairs of collidable shapes
Expand Down Expand Up @@ -1597,8 +1600,6 @@ def step(self):
# apply gravity
self._apply_gravity(entity)

# self._apply_environment_force(entity, i)

self._apply_vectorized_enviornment_force()

for entity in self.entities:
Expand Down Expand Up @@ -1802,17 +1803,24 @@ def _vectorized_joint_constraints(self, joints):
pos_joint_b = []
dist = []
rotate = []
rot_a = []
rot_b = []

for entity_a, entity_b, joint in joints:
pos_joint_a.append(joint.pos_point(entity_a))
pos_joint_b.append(joint.pos_point(entity_b))
pos_a.append(entity_a.state.pos)
pos_b.append(entity_b.state.pos)
dist.append(torch.tensor(joint.dist, device=self.device))
rotate.append(torch.tensor(joint.rotate, device=self.device))
rot_a.append(entity_a.state.rot)
rot_b.append(entity_b.state.rot)
pos_a = torch.stack(pos_a, dim=-2)
pos_b = torch.stack(pos_b, dim=-2)
pos_joint_a = torch.stack(pos_joint_a, dim=-2)
pos_joint_b = torch.stack(pos_joint_b, dim=-2)
rot_a = torch.stack(rot_a, dim=-2)
rot_b = torch.stack(rot_b, dim=-2)
dist = (
torch.stack(
dist,
Expand Down Expand Up @@ -1846,13 +1854,19 @@ def _vectorized_joint_constraints(self, joints):
r_a = pos_joint_a - pos_a
r_b = pos_joint_b - pos_b

torque_a = torch.zeros_like(rotate, device=self.device, dtype=torch.float)
torque_b = torch.zeros_like(rotate, device=self.device, dtype=torch.float)
if rotate_prior.any():
torque_a_rotate = TorchUtils.compute_torque(force_a, r_a)
torque_b_rotate = TorchUtils.compute_torque(force_b, r_b)
torque_a = torch.where(rotate, torque_a_rotate, 0)
torque_b = torch.where(rotate, torque_b_rotate, 0)
torque_a_rotate = TorchUtils.compute_torque(force_a, r_a)
torque_b_rotate = TorchUtils.compute_torque(force_b, r_b)

torque_a_fixed, torque_b_fixed = self._get_constraint_torques(
rot_a, rot_b, force_multiplier=self._torque_constraint_force
)

torque_a = torch.where(
rotate, torque_a_rotate, torque_a_rotate + torque_a_fixed
)
torque_b = torch.where(
rotate, torque_b_rotate, torque_b_rotate + torque_b_fixed
)

for i, (entity_a, entity_b, _) in enumerate(joints):
self.update_env_forces(
Expand Down Expand Up @@ -2411,6 +2425,30 @@ def _get_constraint_forces(
force = torch.where((dist < dist_min).unsqueeze(-1), 0.0, force)
return force, -force

def _get_constraint_torques(
self,
rot_a: Tensor,
rot_b: Tensor,
force_multiplier: float = TORQUE_CONSTRAINT_FORCE,
) -> Tensor:
min_delta_rot = 1e-9
delta_rot = rot_a - rot_b
abs_delta_rot = torch.linalg.vector_norm(delta_rot, dim=-1).unsqueeze(-1)

# softmax penetration
k = 1
penetration = k * (torch.exp(abs_delta_rot / k) - 1)

torque = (
force_multiplier
* delta_rot
/ torch.where(abs_delta_rot > 0, abs_delta_rot, 1e-8)
* penetration
)
torque = torch.where((abs_delta_rot < min_delta_rot), 0.0, torque)

return -torque, torque

# integrate physical state
# uses semi-implicit euler with sub-stepping
def _integrate_state(self, entity: Entity, substep: int):
Expand Down
2 changes: 1 addition & 1 deletion vmas/simulator/joints.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
name=f"joint {entity_a.name} {entity_b.name}",
collide=collidable,
movable=True,
rotatable=rotate_a and rotate_b,
rotatable=True,
mass=mass,
shape=(
vmas.simulator.core.Box(length=dist, width=width)
Expand Down
1 change: 1 addition & 0 deletions vmas/simulator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
LINE_MIN_DIST = 4 / 6e2
COLLISION_FORCE = 100
JOINT_FORCE = 130
TORQUE_CONSTRAINT_FORCE = 1

DRAG = 0.25
LINEAR_FRICTION = 0.0
Expand Down