Skip to content

Commit

Permalink
Curriculum manager clean-up (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
David Hoeller authored Jul 10, 2023
1 parent 5567f0a commit 4d18699
Showing 1 changed file with 76 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,91 @@
# SPDX-License-Identifier: BSD-3-Clause


"""Curriculum manager for computing done signals for a given world."""
"""Curriculum manager for updating environment quantities subject to a training curriculum."""

import copy
import inspect
import torch
from prettytable import PrettyTable
from typing import TYPE_CHECKING, Dict, List
from typing import Dict, List, Optional, Sequence

from omni.isaac.orbit.utils.dict import class_to_dict, string_to_callable

if TYPE_CHECKING:
from omni.isaac.orbit_envs.isaac_env import IsaacEnv


class CurriculumManager:
"""Manager for computing done signals for a given world.
"""Manager to implement and execute specific curricula.
The curriculum manager updates various quantities of the environment subject to a training curriculum by calling a
list of terms. These help stabilize learning by progressively making the learning tasks harder as the agent
improves.
The curriculum terms are parsed from a nested dictionary containing the manger's settings and each term's
parameters. Each curriculum term dictionary contains the following keys:
- ``func``: The name of the function to be called. This function should take the environment object, environment
indices, and any other parameters as input and return the curriculum state as a 1-dimensional float tensor.
- ``**params``: The parameters to be passed to the function.
There are special keys in the curriculum term dictionary that are resolved by the curriculum manager:
- ``sensor_name``: The name of the sensor required by the curriculum term. If the sensor is not enabled, it will be
enabled by the curriculum manager on initialization.
- ``dofs``: The names of the degrees of freedom (dofs) required by the curriculum term. These are converted to dof
indices on initialization and passed to the curriculum term as a list of dof indices.
- ``bodies``: The names of the bodies required by the curriculum term. These are converted to body indices on
initialization and passed to the curriculum term as a list of body indices.
Usage:
.. code-block:: python
from collections import namedtuple
The curriculum manager updates various quantities of the environment subject to a training curriculum by calling a list of terms.
Each curriculum term is a function which takes the environment as an
argument.
def curriculum_1(env, env_ids, param_1, param_2):
# Use the environment data to update the curriculum
pass
The terms are parsed from a nested dictionary containing the curriculum manger's settings and
curriculum terms configuration.
def curriculum_2(env, env_ids, param_1):
# Use the environment data to update the curriculum
pass
# TODO: Add example of config. Also add logging of each term's value similar to reward manager.
# dummy environment with 20 instances
env = namedtuple("IsaacEnv", ["num_envs", "dt"])(20, 0.01)
# dummy device
device = "cpu"
# create curriculum manager
cfg = {
"curriculum_term_1": {"func": "curriculum_1", "param_1": 1, "param_2": 2},
"curriculum_term_2": {"func": "curriculum_2", "param_1": 1},
}
curriculum_man = CurriculumManager(cfg, env)
# print curriculum manager
# shows active curricula
print(curriculum_man)
# check number of active terms
assert len(curriculum_man.active_terms) == 2
# update the curriculum
curriculum_man.compute()
"""

def __init__(self, cfg, env: "IsaacEnv"):
def __init__(self, cfg, env: object):
"""Construct a list of curriculum functions which are used to compute the done signal.
Args:
cfg (Dict[str, Dict[str, Any]]): Configuration for curriculum terms.
env (IsaacEnv): A world instance used for accessing state.
num_envs (int): Number of environment instances.
dt (float): The time-stepping for the environment.
device (int): The device on which create buffers.
env (object): An environment object.
"""
# store input
if not isinstance(cfg, dict):
cfg = class_to_dict(cfg)
self._cfg = copy.deepcopy(cfg)
self._env = env
self._num_envs = env.num_envs # We can get this from env?
self._num_envs = env.num_envs
self._device = env.device
# parse config to create curriculum terms information
self._prepare_curriculum_terms()
Expand Down Expand Up @@ -97,24 +138,35 @@ def active_terms(self) -> List[str]:
Operations.
"""

def log_extra_info(self, env_ids=...) -> Dict[str, torch.Tensor]:
def log_extra_info(self, env_ids: Optional[Sequence[int]] = None) -> Dict[str, torch.Tensor]:
"""Returns the current state of individual curriculum terms.
Note:
This function does not use the environment indices :attr:`env_ids`
and logs the state of all the terms. The argument is only present
to maintain consistency with other classes.
Returns:
Dict[str, torch.Tensor]: Dictionary of curriculum terms.
Dict[str, torch.Tensor]: Dictionary of curriculum terms and their states.
"""
extras = {}
for key, value in self._curriculum_state.items():
if value is not None:
extras["Curriculum/" + key] = value
return extras

def compute(self, env_ids=...) -> torch.Tensor:
"""Computes the curriculum signal as union of individual terms.
def compute(self, env_ids: Optional[Sequence[int]] = None):
"""Update the curriculum terms.
This function calls each curriculum term managed by the class and performs a logical OR operation
to compute the net curriculum signal.
This function calls each curriculum term managed by the class.
Args:
env_ids (Optional[Sequence[int]]): The list of environment IDs to update.
If None, all the environments are updated. Defaults to None.
"""
# resolve environment indices
if env_ids is None:
env_ids = ...
# iterate over all the curriculum terms
for name, params, func in zip(
self._curriculum_term_names, self._curriculum_term_params, self._curriculum_term_functions
Expand Down

0 comments on commit 4d18699

Please sign in to comment.