Skip to content

Commit

Permalink
feature(zlx): add envpool support (#228)
Browse files Browse the repository at this point in the history
* feature(nyz): init envpool(ci skip)

* feature(nyz): add naive envpool atari demo

* fix(nyz): fix typo and unused import

* feature(zlx): Modify envpool config

* fixbug(zlx): Fix comments by nyz

* fixbug(zlx): add unittest for envpool env manager

* style(zlx)

* fixbug(zlx): Add runtime installation of envpool

* fixbug(zlx): Add pip install envpool in github unittest workflows

* fixbug(zlx): Add pip install envpool in github unittest workflows

* test(zlx): Seperate envpool test out of unit/platform test

* fixbug(zlx): Withdraw changes in sample serial collector; In PoolEnvManager, reset returns until all envs are reset successfully

Co-authored-by: niuyazhe <niuyazhe@sensetime.com>
  • Loading branch information
LuciusMos and PaParaZz1 authored Mar 17, 2022
1 parent c02d048 commit 8a108c7
Show file tree
Hide file tree
Showing 16 changed files with 411 additions and 4 deletions.
28 changes: 28 additions & 0 deletions .github/workflows/envpool_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# This workflow will check pytest
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: envpool_test

on: [push, pull_request]

jobs:
test_envpooltest:
runs-on: ubuntu-latest
if: "!contains(github.event.head_commit.message, 'ci skip')"
strategy:
matrix:
python-version: [3.7, 3.8] # Envpool only supports python>=3.7

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: do_envpool_test
run: |
python -m pip install .
python -m pip install ".[test,k8s]"
python -m pip install ".[envpool]"
./ding/scripts/install-k8s-tools.sh
make envpooltest
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ cudatest:
pytest ${TEST_DIR} \
-sv -m cudatest

envpooltest:
pytest ${TEST_DIR} \
-sv -m envpooltest

dockertest:
./ding/scripts/docker-test-entry.sh

Expand Down
1 change: 1 addition & 0 deletions ding/envs/env_manager/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .base_env_manager import BaseEnvManager, create_env_manager, get_env_manager_cls
from .subprocess_env_manager import AsyncSubprocessEnvManager, SyncSubprocessEnvManager
from .gym_vector_env_manager import GymVectorEnvManager
# Do not import PoolEnvManager, because it depends on installation of `envpool`
105 changes: 105 additions & 0 deletions ding/envs/env_manager/envpool_env_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import gym
from easydict import EasyDict
from copy import deepcopy
import numpy as np
from collections import namedtuple
from typing import Any, Union, List, Tuple, Dict, Callable, Optional
import logging
try:
import envpool
except ImportError:
import sys
logging.warning("Please install envpool first, use 'pip install envpool'")
envpool = None

from ding.envs import BaseEnvTimestep
from ding.utils import ENV_MANAGER_REGISTRY, deep_merge_dicts


@ENV_MANAGER_REGISTRY.register('env_pool')
class PoolEnvManager:
'''
Overview:
Envpool now supports Atari, Classic Control, Toy Text, ViZDoom.
Here we list some commonly used env_ids as follows.
For more examples, you can refer to <https://envpool.readthedocs.io/en/latest/api/atari.html>.
- Atari: "Pong-v5", "SpaceInvaders-v5", "Qbert-v5"
- Classic Control: "CartPole-v0", "CartPole-v1", "Pendulum-v0"
'''

@classmethod
def default_config(cls) -> EasyDict:
return EasyDict(deepcopy(cls.config))

config = dict(
type='envpool',
# Sync mode: batch_size == env_num
# Async mode: batch_size < env_num
env_num=8,
batch_size=8,
# Unlike other env managers, envpool's seed should be specified in config.
seed=0,
)

def __init__(self, cfg: EasyDict) -> None:
self._cfg = cfg
self._env_num = cfg.env_num
self._batch_size = cfg.batch_size
self._seed = cfg.seed
self._ready_obs = {}
self._closed = True

def launch(self) -> None:
assert self._closed, "Please first close the env manager"
self._envs = envpool.make(
self._cfg.env_id, env_type="gym", num_envs=self._env_num, batch_size=self._batch_size, seed=self._seed
)
self._closed = False
self.reset()

def reset(self) -> None:
self._envs.async_reset()
obs, _, _, info = self._envs.recv()
env_id = info['env_id']
print(env_id)
obs = obs.astype(np.float32)
self._ready_obs = {i: o for i, o in zip(env_id, obs)}

def step(self, action) -> Dict[int, namedtuple]:
env_id = np.array(list(action.keys()))
action = np.array(list(action.values()))
if len(action.shape) == 2:
action = action.squeeze(1)
self._envs.send(action, env_id)

obs, rew, done, info = self._envs.recv()
obs = obs.astype(np.float32)
rew = rew.astype(np.float32)
env_id = info['env_id']
timesteps = {}
self._ready_obs = {}
for i in range(len(env_id)):
d = bool(done[i])
r = rew[i:i + 1]
timesteps[env_id[i]] = BaseEnvTimestep(obs[i], r, d, info={'env_id': i, 'final_eval_reward': 0.})
self._ready_obs[env_id[i]] = obs[i]
return timesteps

def close(self) -> None:
if self._closed:
return
# Envpool has no `close` API
self._closed = True

def seed(self, seed, dynamic_seed=False) -> None:
# Envpool's seed is set in `envpool.make`. This method is preserved for compatibility.
logging.warning("envpool doesn't support dynamic_seed in different episode")

@property
def env_num(self) -> int:
return self._env_num

@property
def ready_obs(self) -> Dict[int, Any]:
return self._ready_obs
40 changes: 40 additions & 0 deletions ding/envs/env_manager/tests/test_envpool_env_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import time
import pytest
import numpy as np
from easydict import EasyDict

from ..envpool_env_manager import PoolEnvManager

env_num_args = [[16, 8]]


@pytest.mark.envpooltest
@pytest.mark.parametrize('env_num, batch_size', env_num_args)
class TestPoolEnvManager:

def test_naive(self, env_num, batch_size):
env_manager_cfg = EasyDict({
'env_id': 'Pong-v5',
'env_num': env_num,
'batch_size': batch_size,
'seed': 3,
})
env_manager = PoolEnvManager(env_manager_cfg)
assert env_manager._closed
env_manager.launch()
# Test step
start_time = time.time()
for count in range(20):
env_id = env_manager.ready_obs.keys()
action = {i: np.random.randint(4) for i in env_id}
timestep = env_manager.step(action)
if count > 10:
assert len(timestep) == env_manager_cfg.batch_size
print('Count {}'.format(count))
print([v.info for v in timestep.values()])
print([v.done for v in timestep.values()])
end_time = time.time()
print('total step time: {}'.format(end_time - start_time))
# Test close
env_manager.close()
assert env_manager._closed
2 changes: 1 addition & 1 deletion ding/worker/collector/base_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __getitem__(self, idx: int) -> Any:
- item (:obj:`Any`): The item we get.
"""
data = self._pool[idx]
if len(data) == 1:
if data is not None and len(data) == 1:
data = data[0]
return data

Expand Down
5 changes: 5 additions & 0 deletions dizoo/atari/config/serial/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from dizoo.atari.config.serial.enduro import *
from dizoo.atari.config.serial.montezuma import *
from dizoo.atari.config.serial.pong import *
from dizoo.atari.config.serial.qbert import *
from dizoo.atari.config.serial.spaceinvaders import *
1 change: 1 addition & 0 deletions dizoo/atari/config/serial/enduro/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .enduro_dqn_config import enduro_dqn_config, enduro_dqn_create_config
Empty file.
3 changes: 2 additions & 1 deletion dizoo/atari/config/serial/pong/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .pong_dqn_config import pong_dqn_config, pong_dqn_create_config
from .pong_dqn_config import pong_dqn_config, pong_dqn_create_config
from .pong_dqn_envpool_config import pong_dqn_envpool_config, pong_dqn_envpool_create_config
64 changes: 64 additions & 0 deletions dizoo/atari/config/serial/pong/pong_dqn_envpool_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from easydict import EasyDict

pong_dqn_envpool_config = dict(
exp_name='pong_dqn_envpool',
env=dict(
collector_env_num=8,
collector_batch_size=8,
evaluator_env_num=8,
evaluator_batch_size=8,
n_evaluator_episode=8,
stop_value=20,
env_id='Pong-v5',
frame_stack=4,
manager=dict(shared_memory=False, )
),
policy=dict(
cuda=True,
priority=False,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
),
collect=dict(n_sample=96, ),
eval=dict(evaluator=dict(eval_freq=4000, )),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=250000,
),
replay_buffer=dict(replay_buffer_size=100000, ),
),
),
)
pong_dqn_envpool_config = EasyDict(pong_dqn_envpool_config)
main_config = pong_dqn_envpool_config
pong_dqn_envpool_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='env_pool'),
policy=dict(type='dqn'),
# replay_buffer=dict(type='deque'),
)
pong_dqn_envpool_create_config = EasyDict(pong_dqn_envpool_create_config)
create_config = pong_dqn_envpool_create_config

if __name__ == '__main__':
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0)

# Alternatively, one can be opt to run the following command to directly execute this config file
# ding -m serial -c pong_dqn_envpool_config.py -s 0
2 changes: 1 addition & 1 deletion dizoo/atari/config/serial/spaceinvaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from spaceinvaders_dqn_config import spaceinvaders_dqn_config, spaceinvaders_dqn_create_config
from .spaceinvaders_dqn_config import spaceinvaders_dqn_config, spaceinvaders_dqn_create_config
75 changes: 75 additions & 0 deletions dizoo/atari/entry/atari_dqn_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import os
import gym
from tensorboardX import SummaryWriter
from easydict import EasyDict
from copy import deepcopy
from functools import partial

from ding.config import compile_config
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
from ding.envs import SyncSubprocessEnvManager
from ding.policy import DQNPolicy
from ding.model import DQN
from ding.utils import set_pkg_seed, deep_merge_dicts
from ding.rl_utils import get_epsilon_greedy_fn
from dizoo.atari.envs import AtariEnv
from dizoo.atari.config.serial.pong.pong_dqn_config import pong_dqn_config


def main(cfg, seed=0, max_iterations=int(1e10)):
cfg = compile_config(
cfg,
SyncSubprocessEnvManager,
DQNPolicy,
BaseLearner,
SampleSerialCollector,
InteractionSerialEvaluator,
AdvancedReplayBuffer,
save_cfg=True
)
collector_env_cfg = AtariEnv.create_collector_env_cfg(cfg.env)
evaluator_env_cfg = AtariEnv.create_evaluator_env_cfg(cfg.env)
collector_env = SyncSubprocessEnvManager(
env_fn=[partial(AtariEnv, cfg=c) for c in collector_env_cfg], cfg=cfg.env.manager
)
evaluator_env = SyncSubprocessEnvManager(
env_fn=[partial(AtariEnv, cfg=c) for c in evaluator_env_cfg], cfg=cfg.env.manager
)

collector_env.seed(seed)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)

model = DQN(**cfg.policy.model)
policy = DQNPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = AdvancedReplayBuffer(
cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name, instance_name='replay_buffer'
)
eps_cfg = cfg.policy.other.eps
epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)

while True:
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
eps = epsilon_greedy(collector.envstep)
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
for i in range(cfg.policy.learn.update_per_collect):
batch_size = learner.policy.get_attribute('batch_size')
train_data = replay_buffer.sample(batch_size, learner.train_iter)
if train_data is not None:
learner.train(train_data, collector.envstep)


if __name__ == "__main__":
main(EasyDict(pong_dqn_config))
Loading

0 comments on commit 8a108c7

Please sign in to comment.