Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve gym speedup #210

Merged
merged 16 commits into from
Nov 1, 2022
199 changes: 100 additions & 99 deletions docs/content/new_env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ state space, and action space. Create a class ``CartPoleEnvFns``:
class CartPoleEnvFns {
public:
static decltype(auto) DefaultConfig() {
return MakeDict("max_episode_steps"_.Bind(200),
"reward_threshold"_.Bind(195.0));
return MakeDict("reward_threshold"_.Bind(195.0));
}
template <typename Config>
static decltype(auto) StateSpec(const Config& conf) {
Expand Down Expand Up @@ -114,12 +113,12 @@ available to see on the python side:
>>> import envpool
>>> spec = envpool.make_spec("CartPole-v0")
>>> spec
CartPoleEnvSpec(num_envs=1, batch_size=0, num_threads=0, max_num_players=1, thread_affinity_offset=-1, base_path='envpool', seed=42, max_episode_steps=200, reward_threshold=195.0)
CartPoleEnvSpec(num_envs=1, batch_size=1, num_threads=0, max_num_players=1, thread_affinity_offset=-1, base_path='envpool', seed=42, gym_reset_return_info=False, max_episode_steps=200, reward_threshold=195.0)

>>> # if we change a config value
>>> env = envpool.make_gym("CartPole-v0", reward_threshold=666)
>>> env
CartPoleGymEnvPool(num_envs=1, batch_size=0, num_threads=0, max_num_players=1, thread_affinity_offset=-1, base_path='envpool', seed=42, max_episode_steps=200, reward_threshold=666.0)
CartPoleGymEnvPool(num_envs=1, batch_size=1, num_threads=0, max_num_players=1, thread_affinity_offset=-1, base_path='envpool', seed=42, gym_reset_return_info=True, max_episode_steps=200, reward_threshold=666.0)

>>> # observation space and action space
>>> env.observation_space
Expand Down Expand Up @@ -421,7 +420,8 @@ Miscellaneous
.. code-block:: c++

#ifdef ENVPOOL_TEST
fprintf(stderr, "here");
fprintf(stderr, "here\n");
LOG(INFO) << "another error log print method.";
#endif


Expand Down Expand Up @@ -472,6 +472,48 @@ instantiate ``CartPoleEnvSpec``, ``CartPoleDMEnvPool``, and
]


Register CartPole-v0/1 in EnvPool
---------------------------------

To register a task in EnvPool, you need to call ``register`` function in
``envpool.registration``. Here is ``registration.py``:
::

from envpool.registration import register

register(
task_id="CartPole-v0",
import_path="envpool.classic_control",
spec_cls="CartPoleEnvSpec",
dm_cls="CartPoleDMEnvPool",
gym_cls="CartPoleGymEnvPool",
max_episode_steps=200,
reward_threshold=195.0,
)

register(
task_id="CartPole-v1",
import_path="envpool.classic_control",
spec_cls="CartPoleEnvSpec",
dm_cls="CartPoleDMEnvPool",
gym_cls="CartPoleGymEnvPool",
max_episode_steps=500,
reward_threshold=475.0,
)

``task_id``, ``import_path``, ``spec_cls``, ``dm_cls``, and ``gym_cls`` are
required arguments. Other arguments such as ``max_episode_steps`` and
``reward_threshold`` are env-specific. For example, if someone use
``envpool.make("CartPole-v1")``, the ``reward_threshold`` will be set to 475.0
at ``CartPoleEnvPool`` initialization.

Finally, it is crucial to let the top-level module import this file. In
``envpool/entry.py``, add the following line:
::

import envpool.classic_control.registration # noqa: F401


Write Bazel BUILD File
----------------------

Expand Down Expand Up @@ -556,25 +598,25 @@ Let's first take a look at ``BUILD`` file in ``classic_control``:
deps = ["//envpool/python:api"],
)

py_library(
name = "classic_control_registration",
srcs = ["registration.py"],
deps = [
"//envpool:registration",
],
)

py_test(
name = "classic_control_test",
srcs = ["classic_control_test.py"],
deps = [
":classic_control",
":classic_control_registration",
requirement("numpy"),
requirement("absl-py"),
],
)

py_library(
name = "classic_control_registration",
srcs = ["registration.py"],
deps = [
"//envpool:registration",
],
)


We have several ways for dependency declaration:

1. use relative path: ``:cartpole`` points to first item (cartpole cc_library);
Expand All @@ -584,6 +626,50 @@ We have several ways for dependency declaration:
runtime dependencies;
4. third-party dependency (not shown above): will explain in the next section.

And don't forget to modify the top-level Bazel BUILD dependency:

.. code-block:: diff

py_library(
name = "entry",
srcs = ["entry.py"],
deps = [
"//envpool/atari:atari_registration",
+ "//envpool/classic_control:classic_control_registration",
],
)

py_library(
name = "envpool",
srcs = ["__init__.py"],
deps = [
":entry",
":registration",
"//envpool/atari",
+ "//envpool/classic_control",
"//envpool/python",
],
)

Also, pay attention to check if ``.so`` file is packed into ``.whl``
successfully. In ``setup.cfg``:

.. code-block:: diff

[options.package_data]
envpool = atari/*.so
atari/roms/*.bin
+ classic_control/*.so

Now you can run ``envpool.make("CartPole-v0")`` by re-installing EnvPool:

.. code-block:: bash

# generate .whl file
make bazel-build
# install .whl
pip install dist/envpool-<version>-*.whl


Testing
~~~~~~~
Expand Down Expand Up @@ -701,91 +787,6 @@ documentation
`Atari BUILD example <https://github.com/sail-sg/envpool/blob/v0.6.1.post1/envpool/atari/BUILD>`_.


Register CartPole-v0/1 in EnvPool
---------------------------------

To register a task in EnvPool, you need to call ``register`` function in
``envpool.registration``. Here is ``registration.py``:
::

from envpool.registration import register

register(
task_id="CartPole-v0",
import_path="envpool.classic_control",
spec_cls="CartPoleEnvSpec",
dm_cls="CartPoleDMEnvPool",
gym_cls="CartPoleGymEnvPool",
max_episode_steps=200,
reward_threshold=195.0,
)

register(
task_id="CartPole-v1",
import_path="envpool.classic_control",
spec_cls="CartPoleEnvSpec",
dm_cls="CartPoleDMEnvPool",
gym_cls="CartPoleGymEnvPool",
max_episode_steps=500,
reward_threshold=475.0,
)

``task_id``, ``import_path``, ``spec_cls``, ``dm_cls``, and ``gym_cls`` are
required arguments. Other arguments such as ``max_episode_steps`` and
``reward_threshold`` are env-specific. For example, if someone use
``envpool.make("CartPole-v1")``, the ``reward_threshold`` will be set to 475.0
at ``CartPoleEnvPool`` initialization.

Finally, it is crucial to let the top-level module import this file. In
``envpool/entry.py``, add the following line:
::

import envpool.classic_control.registration

And don't forget to modify the Bazel BUILD dependency:

.. code-block:: diff

py_library(
name = "entry",
srcs = ["entry.py"],
deps = [
"//envpool/atari:atari_registration",
+ "//envpool/classic_control:classic_control_registration",
],
)

py_library(
name = "envpool",
srcs = ["__init__.py"],
deps = [
":entry",
":registration",
"//envpool/atari",
+ "//envpool/classic_control",
"//envpool/python",
],
)

Also, pay attention to check if ``.so`` file is packed into ``.whl``
successfully. In ``setup.cfg``:

.. code-block:: diff

[options.package_data]
envpool = atari/*.so
atari/roms/*.bin
+ classic_control/*.so

Now you can run ``envpool.make("CartPole-v0")`` by re-installing EnvPool:

.. code-block:: bash

# generate .whl file
make bazel-build
# install .whl
pip install dist/envpool-<version>-*.whl


Add Unit Test for CartPoleEnv
-----------------------------
Expand Down
4 changes: 3 additions & 1 deletion docs/content/xla_interface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ We can now write the actor loop as:
def actor_step(iter, loop_var):
handle0, states = loop_var
action = policy(states)
# for gym
# for gym < 0.26
handle1, (new_states, rew, done, info) = step(handle0, action)
# for gym >= 0.26
# handle1, (new_states, rew, term, trunc, info) = step(handle0, action)
# for dm
# handle1, new_states = step(handle0, action)
return (handle1, new_states)
Expand Down
19 changes: 11 additions & 8 deletions envpool/atari/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,20 @@ pybind_extension(
],
)

py_library(
name = "atari_registration",
srcs = ["registration.py"],
deps = [
"//envpool:registration",
],
)

py_test(
name = "api_test",
srcs = ["api_test.py"],
deps = [
":atari",
":atari_registration",
"//envpool/python",
requirement("numpy"),
],
Expand All @@ -91,6 +100,7 @@ py_test(
srcs = ["atari_envpool_test.py"],
deps = [
":atari",
":atari_registration",
requirement("numpy"),
requirement("jax"),
requirement("dm-env"),
Expand All @@ -109,18 +119,11 @@ py_test(
deps = [
":atari",
":atari_network",
":atari_registration",
requirement("numpy"),
requirement("absl-py"),
requirement("tianshou"),
requirement("tqdm"),
requirement("torch"),
],
)

py_library(
name = "atari_registration",
srcs = ["registration.py"],
deps = [
"//envpool:registration",
],
)
Loading