Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions source/isaaclab_rl/isaaclab_rl/rsl_rl/symmetry_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

from dataclasses import MISSING

import torch
from isaaclab.managers.action_manager import ActionManager
from isaaclab.managers.observation_manager import ObservationManager
from isaaclab.utils import configclass
from isaaclab_rl.rsl_rl.vecenv_wrapper import RslRlVecEnvWrapper


@configclass
Expand Down Expand Up @@ -51,3 +55,204 @@ class RslRlSymmetryCfg:

mirror_loss_coeff: float = 0.0
"""The weight for the symmetry-mirror loss. Default is 0.0."""


@configclass
class SymmetryTermCfg:
"""
Configuration for a single symmetry term.

For each possible symmetry configuration, specify the observation / action term names
and the indices of elements inside the term to perform the operation on.

- `swap_terms` swaps the elements at the specified indices, i.e. A_sym, B_sym = B, A
- `swap_negate_terms` swaps and negates the elements at the specified indices, i.e. A_sym, B_sym = -B, -A
- `negate_terms` negates the elements at the specified indices, i.e. A_sym = -A
"""

# A_sym, B_sym = B, A
swap_terms: dict[str, list[tuple[int, int]]] = MISSING

# A_sym, B_sym = -B, -A
swap_negate_terms: dict[str, list[tuple[int, int]]] = MISSING

# A_sym = -A
negate_terms: dict[str, list[int]] = MISSING


@configclass
class SymmetryCfg:
"""
Configuration for the symmetry of the environment.

For each possible symmetry configuration, specify the observation / action term names
and the indices of elements inside the term to perform the operation on.
"""

# Symmetry terms for actor observations
actor_observations: SymmetryTermCfg = MISSING

# Symmetry terms for critic observations
critic_observations: SymmetryTermCfg = MISSING

# Symmetry terms for policy actions
actions: SymmetryTermCfg = MISSING


def __get_observation_term_index(observation_manager: ObservationManager, term_name: str) -> int:
"""
Get the index of the first element of the specified observation term in the observation tensor.

Args:
observation_manager (ObservationManager): The observation manager.
term_name (str): The name of the term to get the index of.

Returns:
int: The index of the first element of the specified observation term in the observation tensor.
"""
term_index = 0

# HACK: currently only works for policy group
for name, dims in zip(
observation_manager._group_obs_term_names["policy"],
observation_manager._group_obs_term_dim["policy"],
):
if name == term_name:
break
term_index += dims[-1]
return term_index


def __get_action_term_index(action_manager: ActionManager, term_name: str) -> int:
"""
Get the index of the first element of the specified action term in the action tensor.

Args:
action_manager (ActionManager): The action manager.
term_name (str): The name of the term to get the index of.

Returns:
int: The index of the first element of the specified action term in the action tensor.
"""
term_index = 0
for (name, term) in action_manager._terms.items():
dim = term.action_dim
if name == term_name:
break
term_index += dim
return term_index


def symmetry_data_augmentation_function(
env: RslRlVecEnvWrapper,
obs: torch.Tensor | None,
actions: torch.Tensor | None,
obs_type: str = "policy",
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
"""
The symmetry data augmentation function.

This function implements symmetry-based data augmentation for the G1 robot walking task.
It swaps and negates certain actions components to create symmetric training samples.

Args:
env (VecEnv): The environment object. This is used to access the environment's properties.
obs (torch.Tensor | None): The observation tensor. If None, the observation is not used.
actions (torch.Tensor | None): The actions tensor. If None, the actions is not used.
obs_type (str): The name of the observation type. Defaults to "policy".
This is useful when handling augmentation for different observation groups.

Returns:
tuple[torch.Tensor | None, torch.Tensor | None]: A tuple containing the augmented observation and actions tensors.
"""
# TODO: change to search by class type, instead of attribute name
symmetry_cfg: SymmetryCfg | None = getattr(getattr(env, "cfg", None), "symmetry", None)

if symmetry_cfg is None:
print("WARNING: No symmetry configuration found")
return obs, actions

# Augment the observation
if obs is not None:
symmetry_obs = obs.clone()
manager = env.unwrapped.observation_manager

# Swap pairs
if obs_type == "policy":
obs_cfg = symmetry_cfg.actor_observations
elif obs_type == "critic":
obs_cfg = symmetry_cfg.critic_observations

for term_name in obs_cfg.swap_terms:
# calculate the location of the term in the observation
term_index = __get_observation_term_index(manager, term_name)

for id_pair in obs_cfg.swap_terms[term_name]:
symmetry_obs[..., term_index + id_pair[0]], symmetry_obs[..., term_index + id_pair[1]] = (
symmetry_obs[..., term_index + id_pair[1]],
symmetry_obs[..., term_index + id_pair[0]],
)

# Swap and negate pairs
for term_name in obs_cfg.swap_negate_terms:
# calculate the location of the term in the observation
term_index = __get_observation_term_index(manager, term_name)

for id_pair in obs_cfg.swap_negate_terms[term_name]:
symmetry_obs[..., term_index + id_pair[0]], symmetry_obs[..., term_index + id_pair[1]] = (
-symmetry_obs[..., term_index + id_pair[1]],
-symmetry_obs[..., term_index + id_pair[0]],
)

# Negate
for term_name in obs_cfg.negate_terms:
# calculate the location of the term in the observation
term_index = __get_observation_term_index(manager, term_name)

for id in obs_cfg.negate_terms[term_name]:
symmetry_obs[..., term_index + id] = -symmetry_obs[..., term_index + id]

augmented_obs = torch.cat([obs, symmetry_obs], dim=0)
else:
augmented_obs = None

# Augment the actions
if actions is not None:
symmetry_actions = actions.clone()
manager = env.unwrapped.action_manager

# Swap pairs
for term_name in symmetry_cfg.actions.swap_terms:
# calculate the location of the term in the action
term_index = __get_action_term_index(manager, term_name)

for id_pair in symmetry_cfg.actions.swap_terms[term_name]:
symmetry_actions[..., term_index + id_pair[0]], symmetry_actions[..., term_index + id_pair[1]] = (
symmetry_actions[..., term_index + id_pair[1]],
symmetry_actions[..., term_index + id_pair[0]],
)

# Swap and negate pairs
for term_name in symmetry_cfg.actions.swap_negate_terms:
# calculate the location of the term in the action
term_index = __get_action_term_index(manager, term_name)

for id_pair in symmetry_cfg.actions.swap_negate_terms[term_name]:
symmetry_actions[..., term_index + id_pair[0]], symmetry_actions[..., term_index + id_pair[1]] = (
-symmetry_actions[..., term_index + id_pair[1]],
-symmetry_actions[..., term_index + id_pair[0]],
)

# Negate
for term_name in symmetry_cfg.actions.negate_terms:
# calculate the location of the term in the action
term_index = __get_action_term_index(manager, term_name)

for id in symmetry_cfg.actions.negate_terms[term_name]:
symmetry_actions[..., term_index + id] = -symmetry_actions[..., term_index + id]

augmented_actions = torch.cat([actions, symmetry_actions], dim=0)
else:
augmented_actions = None

return augmented_obs, augmented_actions
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from isaaclab.utils import configclass

from isaaclab_rl.rsl_rl import RslRlOnPolicyRunnerCfg, RslRlPpoActorCriticCfg, RslRlPpoAlgorithmCfg
from isaaclab_rl.rsl_rl.symmetry_cfg import RslRlSymmetryCfg, symmetry_data_augmentation_function


@configclass
Expand Down Expand Up @@ -46,3 +47,17 @@ def __post_init__(self):
self.experiment_name = "g1_flat"
self.policy.actor_hidden_dims = [256, 128, 128]
self.policy.critic_hidden_dims = [256, 128, 128]


@configclass
class G1FlatSymmetryPPORunnerCfg(G1FlatPPORunnerCfg):
def __post_init__(self):
super().__post_init__()

self.experiment_name = "velocity_g1_symmetry"
self.algorithm.symmetry_cfg = RslRlSymmetryCfg(
use_data_augmentation=True,
use_mirror_loss=True,
data_augmentation_func=symmetry_data_augmentation_function,
mirror_loss_coeff=0.1,
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from isaaclab.utils import configclass

from .rough_env_cfg import G1RoughEnvCfg
from isaaclab_rl.rsl_rl.symmetry_cfg import SymmetryCfg, SymmetryTermCfg


@configclass
Expand Down Expand Up @@ -41,6 +42,116 @@ def __post_init__(self):
self.commands.base_velocity.ranges.ang_vel_z = (-1.0, 1.0)


g1_joints_names = [
"left_shoulder_pitch_joint",
"left_shoulder_roll_joint",
"left_shoulder_yaw_joint",
"left_elbow_pitch_joint",
"left_elbow_roll_joint",
"right_shoulder_pitch_joint",
"right_shoulder_roll_joint",
"right_shoulder_yaw_joint",
"right_elbow_pitch_joint",
"right_elbow_roll_joint",
"left_hip_pitch_joint",
"left_hip_roll_joint",
"left_hip_yaw_joint",
"left_knee_joint",
"left_ankle_pitch_joint",
"left_ankle_roll_joint",
"right_hip_pitch_joint",
"right_hip_roll_joint",
"right_hip_yaw_joint",
"right_knee_joint",
"right_ankle_pitch_joint",
"right_ankle_roll_joint",
"torso_joint",
]


# Symmetry joint configurations
# all pitch joints needs to be swapped
symmetry_joint_swap_ids = [
(g1_joints_names.index("left_hip_pitch_joint"), g1_joints_names.index("right_hip_pitch_joint")),
(g1_joints_names.index("left_knee_joint"), g1_joints_names.index("right_knee_joint")),
(g1_joints_names.index("left_shoulder_pitch_joint"), g1_joints_names.index("right_shoulder_pitch_joint")),
(g1_joints_names.index("left_ankle_pitch_joint"), g1_joints_names.index("right_ankle_pitch_joint")),
(g1_joints_names.index("left_elbow_pitch_joint"), g1_joints_names.index("right_elbow_pitch_joint")),
]
# all roll and yaw joints needs to be swapped and negated
symmetry_joint_swap_negate_ids = [
(g1_joints_names.index("left_hip_roll_joint"), g1_joints_names.index("right_hip_roll_joint")),
(g1_joints_names.index("left_hip_yaw_joint"), g1_joints_names.index("right_hip_yaw_joint")),
(g1_joints_names.index("left_shoulder_roll_joint"), g1_joints_names.index("right_shoulder_roll_joint")),
(g1_joints_names.index("left_ankle_roll_joint"), g1_joints_names.index("right_ankle_roll_joint")),
(g1_joints_names.index("left_shoulder_yaw_joint"), g1_joints_names.index("right_shoulder_yaw_joint")),
]
# non-symmetric (at center) yaw joints needs to be negated
symmetry_joint_negate_ids = [
g1_joints_names.index("waist_yaw_joint"),
]


@configclass
class G1Symmetry(SymmetryCfg):
"""Configuration for the symmetry of the environment."""

@configclass
class ActionsCfg(SymmetryTermCfg):
swap_terms = {
"joint_pos": symmetry_joint_swap_ids,
}
swap_negate_terms = {
"joint_pos": symmetry_joint_swap_negate_ids,
}
negate_terms = {
"joint_pos": symmetry_joint_negate_ids,
}

@configclass
class ActorObservationsCfg(SymmetryTermCfg):
swap_terms = {
"joint_pos": symmetry_joint_swap_ids,
"joint_vel": symmetry_joint_swap_ids,
"actions": symmetry_joint_swap_ids,
}
swap_negate_terms = {
"joint_pos": symmetry_joint_swap_negate_ids,
"joint_vel": symmetry_joint_swap_negate_ids,
"actions": symmetry_joint_swap_negate_ids,
}
negate_terms = {
"velocity_commands": [1, 2],
"base_ang_vel": [0, 2],
"projected_gravity": [1],
"joint_pos": symmetry_joint_negate_ids,
"joint_vel": symmetry_joint_negate_ids,
"actions": symmetry_joint_negate_ids,
}

@configclass
class CriticObservationsCfg(ActorObservationsCfg):
negate_terms = {
"velocity_commands": [1, 2],
"base_lin_vel": [1],
"base_ang_vel": [0, 2],
"projected_gravity": [1],
"joint_pos": symmetry_joint_negate_ids,
"joint_vel": symmetry_joint_negate_ids,
"actions": symmetry_joint_negate_ids,
}

actions: ActionsCfg = ActionsCfg()
actor_observations: ActorObservationsCfg = ActorObservationsCfg()
critic_observations: CriticObservationsCfg = CriticObservationsCfg()


@configclass
class G1FlatSymmetryEnvCfg(G1FlatEnvCfg):
# Symmetry loss configuration
symmetry: G1Symmetry = G1Symmetry()


class G1FlatEnvCfg_PLAY(G1FlatEnvCfg):
def __post_init__(self) -> None:
# post init of parent
Expand Down