Skip to content

Commit

Permalink
further speedup (#209)
Browse files Browse the repository at this point in the history
Cache spec attr to speedup, profile result says reading `self._spec.xxx`
takes a lot of time.

After fix, envpool's speedup increased from ~0.9x to ~1.1x:

```
Namespace(domain='cheetah', seed=0, task='run', total_step=200000)
100%|███████████████████████████████████████████████████████████████████████████████████████████████| 200000/200000 [00:16<00:00, 12291.98it/s]
FPS(dmc) = 12289.97
100%|███████████████████████████████████████████████████████████████████████████████████████████████| 200000/200000 [00:14<00:00, 14146.43it/s]
FPS(envpool) = 14145.82
EnvPool Speedup: 1.15x
```
  • Loading branch information
wangsiping97 authored Oct 26, 2022
1 parent 93474cf commit b494580
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
2 changes: 1 addition & 1 deletion envpool/python/dm_envpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _to_dm(
reset: bool,
return_info: bool,
) -> TimeStep:
values = [state_values[i] for i in state_idx]
values = map(lambda i: state_values[i], state_idx)
state = treevalue.unflatten(
[(path, vi) for (path, _), vi in zip(tree_pairs, values)]
)
Expand Down
19 changes: 15 additions & 4 deletions envpool/python/envpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,27 @@ def _from(
alist = treevalue.flatten(atree)
adict = {".".join(k): v for k, v in alist}
else: # only 3 keys in action_keys
if not hasattr(self, "_last_action_type"):
self._last_action_type = self._spec._action_spec[-1][0]
if not hasattr(self, "_last_action_name"):
self._last_action_name = self._spec._action_keys[-1]
if isinstance(action, np.ndarray):
# else it could be a jax array, when using xla
action = action.astype(self._spec._action_spec[-1][0], order='C')
adict = {self._spec._action_keys[-1]: action}
action = action.astype(
self._last_action_type, # type: ignore
order='C',
)
adict = {self._last_action_name: action} # type: ignore
if env_id is None:
if "env_id" not in adict:
adict["env_id"] = self.all_env_ids
else:
adict["env_id"] = env_id.astype(np.int32)
if "players.env_id" not in adict:
adict["players.env_id"] = adict["env_id"]
return list(map(lambda k: adict[k], self._spec._action_keys))
if not hasattr(self, "_action_names"):
self._action_names = self._spec._action_keys
return list(map(lambda k: adict[k], self._action_names)) # type: ignore

def __len__(self: EnvPool) -> int:
"""Return the number of environments."""
Expand All @@ -83,7 +92,9 @@ def __len__(self: EnvPool) -> int:
@property
def all_env_ids(self: EnvPool) -> np.ndarray:
"""All env_id in numpy ndarray with dtype=np.int32."""
return np.arange(self.config["num_envs"], dtype=np.int32)
if not hasattr(self, "_all_env_ids"):
self._all_env_ids = np.arange(self.config["num_envs"], dtype=np.int32)
return self._all_env_ids # type: ignore

@property
def is_async(self: EnvPool) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion envpool/python/gym_envpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _to_gym(
self: Any, state_values: List[np.ndarray], reset: bool, return_info: bool
) -> Union[Any, Tuple[Any, Any], Tuple[Any, np.ndarray, np.ndarray, Any],
Tuple[Any, np.ndarray, np.ndarray, np.ndarray, Any]]:
values = [state_values[i] for i in state_idx]
values = map(lambda i: state_values[i], state_idx)
state = treevalue.unflatten(
[(path, vi) for (path, _), vi in zip(tree_pairs, values)]
)
Expand Down

0 comments on commit b494580

Please sign in to comment.