-
Notifications
You must be signed in to change notification settings - Fork 370
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(zlx): add envpool support (#228)
* feature(nyz): init envpool(ci skip) * feature(nyz): add naive envpool atari demo * fix(nyz): fix typo and unused import * feature(zlx): Modify envpool config * fixbug(zlx): Fix comments by nyz * fixbug(zlx): add unittest for envpool env manager * style(zlx) * fixbug(zlx): Add runtime installation of envpool * fixbug(zlx): Add pip install envpool in github unittest workflows * fixbug(zlx): Add pip install envpool in github unittest workflows * test(zlx): Seperate envpool test out of unit/platform test * fixbug(zlx): Withdraw changes in sample serial collector; In PoolEnvManager, reset returns until all envs are reset successfully Co-authored-by: niuyazhe <niuyazhe@sensetime.com>
- Loading branch information
Showing
16 changed files
with
411 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# This workflow will check pytest | ||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions | ||
|
||
name: envpool_test | ||
|
||
on: [push, pull_request] | ||
|
||
jobs: | ||
test_envpooltest: | ||
runs-on: ubuntu-latest | ||
if: "!contains(github.event.head_commit.message, 'ci skip')" | ||
strategy: | ||
matrix: | ||
python-version: [3.7, 3.8] # Envpool only supports python>=3.7 | ||
|
||
steps: | ||
- uses: actions/checkout@v2 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: do_envpool_test | ||
run: | | ||
python -m pip install . | ||
python -m pip install ".[test,k8s]" | ||
python -m pip install ".[envpool]" | ||
./ding/scripts/install-k8s-tools.sh | ||
make envpooltest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .base_env_manager import BaseEnvManager, create_env_manager, get_env_manager_cls | ||
from .subprocess_env_manager import AsyncSubprocessEnvManager, SyncSubprocessEnvManager | ||
from .gym_vector_env_manager import GymVectorEnvManager | ||
# Do not import PoolEnvManager, because it depends on installation of `envpool` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import gym | ||
from easydict import EasyDict | ||
from copy import deepcopy | ||
import numpy as np | ||
from collections import namedtuple | ||
from typing import Any, Union, List, Tuple, Dict, Callable, Optional | ||
import logging | ||
try: | ||
import envpool | ||
except ImportError: | ||
import sys | ||
logging.warning("Please install envpool first, use 'pip install envpool'") | ||
envpool = None | ||
|
||
from ding.envs import BaseEnvTimestep | ||
from ding.utils import ENV_MANAGER_REGISTRY, deep_merge_dicts | ||
|
||
|
||
@ENV_MANAGER_REGISTRY.register('env_pool') | ||
class PoolEnvManager: | ||
''' | ||
Overview: | ||
Envpool now supports Atari, Classic Control, Toy Text, ViZDoom. | ||
Here we list some commonly used env_ids as follows. | ||
For more examples, you can refer to <https://envpool.readthedocs.io/en/latest/api/atari.html>. | ||
- Atari: "Pong-v5", "SpaceInvaders-v5", "Qbert-v5" | ||
- Classic Control: "CartPole-v0", "CartPole-v1", "Pendulum-v0" | ||
''' | ||
|
||
@classmethod | ||
def default_config(cls) -> EasyDict: | ||
return EasyDict(deepcopy(cls.config)) | ||
|
||
config = dict( | ||
type='envpool', | ||
# Sync mode: batch_size == env_num | ||
# Async mode: batch_size < env_num | ||
env_num=8, | ||
batch_size=8, | ||
# Unlike other env managers, envpool's seed should be specified in config. | ||
seed=0, | ||
) | ||
|
||
def __init__(self, cfg: EasyDict) -> None: | ||
self._cfg = cfg | ||
self._env_num = cfg.env_num | ||
self._batch_size = cfg.batch_size | ||
self._seed = cfg.seed | ||
self._ready_obs = {} | ||
self._closed = True | ||
|
||
def launch(self) -> None: | ||
assert self._closed, "Please first close the env manager" | ||
self._envs = envpool.make( | ||
self._cfg.env_id, env_type="gym", num_envs=self._env_num, batch_size=self._batch_size, seed=self._seed | ||
) | ||
self._closed = False | ||
self.reset() | ||
|
||
def reset(self) -> None: | ||
self._envs.async_reset() | ||
obs, _, _, info = self._envs.recv() | ||
env_id = info['env_id'] | ||
print(env_id) | ||
obs = obs.astype(np.float32) | ||
self._ready_obs = {i: o for i, o in zip(env_id, obs)} | ||
|
||
def step(self, action) -> Dict[int, namedtuple]: | ||
env_id = np.array(list(action.keys())) | ||
action = np.array(list(action.values())) | ||
if len(action.shape) == 2: | ||
action = action.squeeze(1) | ||
self._envs.send(action, env_id) | ||
|
||
obs, rew, done, info = self._envs.recv() | ||
obs = obs.astype(np.float32) | ||
rew = rew.astype(np.float32) | ||
env_id = info['env_id'] | ||
timesteps = {} | ||
self._ready_obs = {} | ||
for i in range(len(env_id)): | ||
d = bool(done[i]) | ||
r = rew[i:i + 1] | ||
timesteps[env_id[i]] = BaseEnvTimestep(obs[i], r, d, info={'env_id': i, 'final_eval_reward': 0.}) | ||
self._ready_obs[env_id[i]] = obs[i] | ||
return timesteps | ||
|
||
def close(self) -> None: | ||
if self._closed: | ||
return | ||
# Envpool has no `close` API | ||
self._closed = True | ||
|
||
def seed(self, seed, dynamic_seed=False) -> None: | ||
# Envpool's seed is set in `envpool.make`. This method is preserved for compatibility. | ||
logging.warning("envpool doesn't support dynamic_seed in different episode") | ||
|
||
@property | ||
def env_num(self) -> int: | ||
return self._env_num | ||
|
||
@property | ||
def ready_obs(self) -> Dict[int, Any]: | ||
return self._ready_obs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import time | ||
import pytest | ||
import numpy as np | ||
from easydict import EasyDict | ||
|
||
from ..envpool_env_manager import PoolEnvManager | ||
|
||
env_num_args = [[16, 8]] | ||
|
||
|
||
@pytest.mark.envpooltest | ||
@pytest.mark.parametrize('env_num, batch_size', env_num_args) | ||
class TestPoolEnvManager: | ||
|
||
def test_naive(self, env_num, batch_size): | ||
env_manager_cfg = EasyDict({ | ||
'env_id': 'Pong-v5', | ||
'env_num': env_num, | ||
'batch_size': batch_size, | ||
'seed': 3, | ||
}) | ||
env_manager = PoolEnvManager(env_manager_cfg) | ||
assert env_manager._closed | ||
env_manager.launch() | ||
# Test step | ||
start_time = time.time() | ||
for count in range(20): | ||
env_id = env_manager.ready_obs.keys() | ||
action = {i: np.random.randint(4) for i in env_id} | ||
timestep = env_manager.step(action) | ||
if count > 10: | ||
assert len(timestep) == env_manager_cfg.batch_size | ||
print('Count {}'.format(count)) | ||
print([v.info for v in timestep.values()]) | ||
print([v.done for v in timestep.values()]) | ||
end_time = time.time() | ||
print('total step time: {}'.format(end_time - start_time)) | ||
# Test close | ||
env_manager.close() | ||
assert env_manager._closed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from dizoo.atari.config.serial.enduro import * | ||
from dizoo.atari.config.serial.montezuma import * | ||
from dizoo.atari.config.serial.pong import * | ||
from dizoo.atari.config.serial.qbert import * | ||
from dizoo.atari.config.serial.spaceinvaders import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .enduro_dqn_config import enduro_dqn_config, enduro_dqn_create_config |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .pong_dqn_config import pong_dqn_config, pong_dqn_create_config | ||
from .pong_dqn_config import pong_dqn_config, pong_dqn_create_config | ||
from .pong_dqn_envpool_config import pong_dqn_envpool_config, pong_dqn_envpool_create_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from easydict import EasyDict | ||
|
||
pong_dqn_envpool_config = dict( | ||
exp_name='pong_dqn_envpool', | ||
env=dict( | ||
collector_env_num=8, | ||
collector_batch_size=8, | ||
evaluator_env_num=8, | ||
evaluator_batch_size=8, | ||
n_evaluator_episode=8, | ||
stop_value=20, | ||
env_id='Pong-v5', | ||
frame_stack=4, | ||
manager=dict(shared_memory=False, ) | ||
), | ||
policy=dict( | ||
cuda=True, | ||
priority=False, | ||
model=dict( | ||
obs_shape=[4, 84, 84], | ||
action_shape=6, | ||
encoder_hidden_size_list=[128, 128, 512], | ||
), | ||
nstep=3, | ||
discount_factor=0.99, | ||
learn=dict( | ||
update_per_collect=10, | ||
batch_size=32, | ||
learning_rate=0.0001, | ||
target_update_freq=500, | ||
), | ||
collect=dict(n_sample=96, ), | ||
eval=dict(evaluator=dict(eval_freq=4000, )), | ||
other=dict( | ||
eps=dict( | ||
type='exp', | ||
start=1., | ||
end=0.05, | ||
decay=250000, | ||
), | ||
replay_buffer=dict(replay_buffer_size=100000, ), | ||
), | ||
), | ||
) | ||
pong_dqn_envpool_config = EasyDict(pong_dqn_envpool_config) | ||
main_config = pong_dqn_envpool_config | ||
pong_dqn_envpool_create_config = dict( | ||
env=dict( | ||
type='atari', | ||
import_names=['dizoo.atari.envs.atari_env'], | ||
), | ||
env_manager=dict(type='env_pool'), | ||
policy=dict(type='dqn'), | ||
# replay_buffer=dict(type='deque'), | ||
) | ||
pong_dqn_envpool_create_config = EasyDict(pong_dqn_envpool_create_config) | ||
create_config = pong_dqn_envpool_create_config | ||
|
||
if __name__ == '__main__': | ||
from ding.entry import serial_pipeline | ||
serial_pipeline((main_config, create_config), seed=0) | ||
|
||
# Alternatively, one can be opt to run the following command to directly execute this config file | ||
# ding -m serial -c pong_dqn_envpool_config.py -s 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from spaceinvaders_dqn_config import spaceinvaders_dqn_config, spaceinvaders_dqn_create_config | ||
from .spaceinvaders_dqn_config import spaceinvaders_dqn_config, spaceinvaders_dqn_create_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import os | ||
import gym | ||
from tensorboardX import SummaryWriter | ||
from easydict import EasyDict | ||
from copy import deepcopy | ||
from functools import partial | ||
|
||
from ding.config import compile_config | ||
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer | ||
from ding.envs import SyncSubprocessEnvManager | ||
from ding.policy import DQNPolicy | ||
from ding.model import DQN | ||
from ding.utils import set_pkg_seed, deep_merge_dicts | ||
from ding.rl_utils import get_epsilon_greedy_fn | ||
from dizoo.atari.envs import AtariEnv | ||
from dizoo.atari.config.serial.pong.pong_dqn_config import pong_dqn_config | ||
|
||
|
||
def main(cfg, seed=0, max_iterations=int(1e10)): | ||
cfg = compile_config( | ||
cfg, | ||
SyncSubprocessEnvManager, | ||
DQNPolicy, | ||
BaseLearner, | ||
SampleSerialCollector, | ||
InteractionSerialEvaluator, | ||
AdvancedReplayBuffer, | ||
save_cfg=True | ||
) | ||
collector_env_cfg = AtariEnv.create_collector_env_cfg(cfg.env) | ||
evaluator_env_cfg = AtariEnv.create_evaluator_env_cfg(cfg.env) | ||
collector_env = SyncSubprocessEnvManager( | ||
env_fn=[partial(AtariEnv, cfg=c) for c in collector_env_cfg], cfg=cfg.env.manager | ||
) | ||
evaluator_env = SyncSubprocessEnvManager( | ||
env_fn=[partial(AtariEnv, cfg=c) for c in evaluator_env_cfg], cfg=cfg.env.manager | ||
) | ||
|
||
collector_env.seed(seed) | ||
evaluator_env.seed(seed, dynamic_seed=False) | ||
set_pkg_seed(seed, use_cuda=cfg.policy.cuda) | ||
|
||
model = DQN(**cfg.policy.model) | ||
policy = DQNPolicy(cfg.policy, model=model) | ||
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) | ||
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) | ||
collector = SampleSerialCollector( | ||
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name | ||
) | ||
evaluator = InteractionSerialEvaluator( | ||
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name | ||
) | ||
replay_buffer = AdvancedReplayBuffer( | ||
cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name, instance_name='replay_buffer' | ||
) | ||
eps_cfg = cfg.policy.other.eps | ||
epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) | ||
|
||
while True: | ||
if evaluator.should_eval(learner.train_iter): | ||
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) | ||
if stop: | ||
break | ||
eps = epsilon_greedy(collector.envstep) | ||
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps}) | ||
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) | ||
for i in range(cfg.policy.learn.update_per_collect): | ||
batch_size = learner.policy.get_attribute('batch_size') | ||
train_data = replay_buffer.sample(batch_size, learner.train_iter) | ||
if train_data is not None: | ||
learner.train(train_data, collector.envstep) | ||
|
||
|
||
if __name__ == "__main__": | ||
main(EasyDict(pong_dqn_config)) |
Oops, something went wrong.