Skip to content

Commit

Permalink
fixbug(zlx): Withdraw changes in sample serial collector; In PoolEnvM…
Browse files Browse the repository at this point in the history
…anager, reset returns until all envs are reset successfully
  • Loading branch information
zlx-sensetime committed Mar 17, 2022
1 parent 9cc47df commit 35fd537
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 15 deletions.
6 changes: 2 additions & 4 deletions .github/workflows/envpool_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@ name: envpool_test
on: [push, pull_request]

jobs:
test_unittest:
runs-on: ${{ matrix.os }}
test_envpooltest:
runs-on: ubuntu-latest
if: "!contains(github.event.head_commit.message, 'ci skip')"
strategy:
matrix:
# os: [windows-latest]
python-version: [3.7, 3.8] # Envpool only supports python>=3.7

steps:
Expand All @@ -25,6 +24,5 @@ jobs:
python -m pip install .
python -m pip install ".[test,k8s]"
python -m pip install ".[envpool]"
# python -m pip uninstall pytest-timeouts -y
./ding/scripts/install-k8s-tools.sh
make envpooltest
16 changes: 9 additions & 7 deletions ding/envs/env_manager/envpool_env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
envpool = None

from ding.envs import BaseEnvTimestep
from ding.utils import ENV_MANAGER_REGISTRY
from ding.utils import ENV_MANAGER_REGISTRY, deep_merge_dicts


@ENV_MANAGER_REGISTRY.register('env_pool')
Expand Down Expand Up @@ -47,7 +47,7 @@ def __init__(self, cfg: EasyDict) -> None:
self._env_num = cfg.env_num
self._batch_size = cfg.batch_size
self._seed = cfg.seed
self._ready_obs = None
self._ready_obs = {}
self._closed = True

def launch(self) -> None:
Expand All @@ -59,11 +59,13 @@ def launch(self) -> None:
self.reset()

def reset(self) -> None:
self._envs.async_reset()
obs, _, _, _ = self._envs.recv()
# obs = self._envs.reset()
obs = obs.astype(np.float32)
self._ready_obs = {i: o for i, o in enumerate(obs)}
while True:
self._envs.async_reset()
obs, _, _, _ = self._envs.recv()
obs = obs.astype(np.float32)
self._ready_obs = deep_merge_dicts({i: o for i, o in enumerate(obs)}, self._ready_obs)
if len(self._ready_obs) == self._env_num:
break

def step(self, action) -> Dict[int, namedtuple]:
env_id = np.array(list(action.keys()))
Expand Down
4 changes: 0 additions & 4 deletions ding/worker/collector/sample_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,6 @@ def collect(self,
# TODO(nyz) vectorize this for loop
for env_id, timestep in timesteps.items():
with self._timer:
# In async mode, this env is reset successfully just now.
if self._obs_pool[env_id] is None:
self._obs_pool.update({env_id: timestep.obs})
continue
if timestep.info.get('abnormal', False):
# If there is an abnormal timestep, reset all the related variables(including this env).
# suppose there is no reset param, just reset this env
Expand Down

0 comments on commit 35fd537

Please sign in to comment.