ViSkill: Value-Informed Skill Chaining for Policy Learning of Long-Horizon Tasks with Surgical Robot
This is the official PyTorch implementation of the paper "Value-Informed Skill Chaining for Policy Learning of Long-Horizon Tasks with Surgical Robot" (IROS 2023).
- Ubuntu 18.04
- Python 3.7+
- Clone this repository.
git clone --recursive https://github.com/med-air/ViSkill.git
cd ViSkill
- Create a virtual environment
conda create -n viskill python=3.8
conda activate viskill
- Install packages
pip3 install -e SurRoL/ # install surrol environments
pip3 install -r requirements.txt
pip3 install -e .
- Then add one line of code at the top of
gym/gym/envs/__init__.py
to register SurRoL tasks:
# directory: anaconda3/envs/dex/lib/python3.8/site-packages/
import surrol.gym
Commands for ViSkill. Results will be logged to WandB. Before running the commands below, please change the wandb entity in skill_learning.yaml
and skill_chaining.yaml
to match your account.
We collect demonstration data for each subtask via the scripted controllers provided by SurRoL. Take the BiPegTransfer task as example:
mkdir SurRoL/surrol/data/demo
python SurRoL/surrol/data/data_generation_bipegtransfer.py --env BiPegTransfer-v0 --subtask grasp
python SurRoL/surrol/data/data_generation_bipegtransfer.py --env BiPegTransfer-v0 --subtask handover
python SurRoL/surrol/data/data_generation_bipegtransfer.py --env BiPegTransfer-v0 --subtask release
- Train subtask policies:
mpirun -np 8 python -m train_sl seed=1 subtask=grasp
mpirun -np 8 python -m train_sl seed=1 subtask=handover
mpirun -np 8 python -m train_sl seed=1 subtask=release
- Train chaining policies:
mpirun -np 8 python -m train_sc.py model_seed=1 task=BiPegTransfer-v0
The default hyperparameters are defined in viskill/configs
, where skill_learning.yaml
and skill_chaining.yaml
define the experiment settings of learning subtask policies and chaining policies, respectively, and YAML file in the directory sl_agent
and sc_agent
define the hyperparameters of each method. Modifications to these parameters can be directly defined in the experiment or agent config files, or passed through the terminal command.
The core RL algorithms are implemented within the BaseAgent
class. For adding a new skill chaining algorithm, a new file needs to be created in
viskill/agents
and BaseAgent
needs to be subclassed. In particular, any required
networks (actor, critic etc) need to be constructed and the update(...)
function and get_action(...)
needs to be overwritten. When implementation is done, a registration is needed in factory.py
and a config file should also be made in sc_agent
to specify the model parameters.
viskill
|- agents # implements core algorithms in agent classes
|- components # reusable infrastructure for model training
| |- checkpointer.py # handles saving + loading of model checkpoints
| |- environment.py # environment wrappers for SurRoL environments
| |- normalizer.py # normalizer for vectorized input
| |- logger.py # implements core logging functionality using wandB
|
|- configs # experiment configs
| |- skill_learning.yaml # configs for subtask policy learning
| |- skill_chaining.yaml # configs for chaining policy learning
| |- sl_agent # configs for demonstration-guided RL algorithm
| |- sc_agent # configs for skill chaining algorithm
| |- hier_agent # configs for (synthetic) hierachical agent
|
|- modules # reusable architecture components
| |- critic.py # basic critic implementations (eg MLP-based critic)
| |- distributions.py # pytorch distribution utils for density model
| |- policy.py # basic actor implementations
| |- replay_buffer.py # her replay buffer with future sampling strategy
| |- sampler.py # rollout sampler for collecting experience
| |- subnetworks.py # basic networks
|
|- trainers # main model training script, builds all components + runs training loop and logging
|
|- utils # general and rl utilities, pytorch / visualization utilities etc
|- train_sl.py # experiment launcher for skill learning
|- train_sc.py # experiment launcher for skill chaining
For any questions, please feel free to email taou.cs13@gmail.com