Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chainerrl/misc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
from chainerrl.misc.draw_computational_graph import is_graphviz_available # NOQA
from chainerrl.misc import env_modifiers # NOQA
from chainerrl.misc.is_return_code_zero import is_return_code_zero # NOQA
from chainerrl.misc.random_seed import sample_from_space # NOQA
from chainerrl.misc.random_seed import set_random_seed # NOQA
30 changes: 30 additions & 0 deletions chainerrl/misc/random_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from future import standard_library
standard_library.install_aliases() # NOQA

import contextlib
import os
import random

Expand Down Expand Up @@ -39,3 +40,32 @@ def set_random_seed(seed, gpus=()):
chainer.cuda.cupy.random.seed(seed)
# chainer.functions.n_step_rnn directly depends on CHAINER_SEED
os.environ['CHAINER_SEED'] = str(seed)


@contextlib.contextmanager
def using_numpy_random_for_gym_spaces():
from gym import spaces
gym_spaces_random_state = spaces.prng.np_random
spaces.prng.np_random = np.random.rand.__self__
yield
spaces.prng.np_random = gym_spaces_random_state


def sample_from_space(space):
"""Sample from gym.spaces.Space.

Unlike gym.spaces.Space.sample, this function use numpy's global random
state.

Users should use this function instead of gym.spaces.Space.sample because
it is not recommended to use gym.space.Space.sample in algorithms.
See https://github.com/openai/gym/blob/master/gym/spaces/prng.py

Args:
space (gym.spaces.Space): Space.

Returns:
object: Sample from the given space.
"""
with using_numpy_random_for_gym_spaces():
return space.sample()
8 changes: 6 additions & 2 deletions examples/ale/train_nsq_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import gym.wrappers
import numpy as np

import chainerrl
from chainerrl.action_value import DiscreteActionValue
from chainerrl.agents import nsq
from chainerrl import experiments
Expand Down Expand Up @@ -107,6 +108,9 @@ def make_env(process_idx, test):
opt = rmsprop_async.RMSpropAsync(lr=args.lr, eps=1e-1, alpha=0.99)
opt.setup(q_func)

def action_sampler():
return chainerrl.misc.sample_from_space(action_space)

def phi(x):
# Feature extractor
return np.asarray(x, dtype=np.float32) / 255
Expand All @@ -123,7 +127,7 @@ def make_agent(process_idx):
epsilon_target = 0.5
explorer = explorers.LinearDecayEpsilonGreedy(
1, epsilon_target, args.final_exploration_frames,
action_space.sample)
action_sampler)
# Suppress the explorer logger
explorer.logger.setLevel(logging.INFO)
return nsq.NSQ(q_func, opt, t_max=5, gamma=0.99,
Expand All @@ -141,7 +145,7 @@ def make_agent(process_idx):
args.eval_n_runs, eval_stats['mean'], eval_stats['median'],
eval_stats['stdev']))
else:
explorer = explorers.ConstantEpsilonGreedy(0.05, action_space.sample)
explorer = explorers.ConstantEpsilonGreedy(0.05, action_sampler)

# Linearly decay the learning rate to zero
def lr_setter(env, agent, value):
Expand Down
7 changes: 2 additions & 5 deletions examples/gym/train_ddpg_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,8 @@ def make_env(test):

rbuf = replay_buffer.ReplayBuffer(5 * 10 ** 5)

def random_action():
a = action_space.sample()
if isinstance(a, np.ndarray):
a = a.astype(np.float32)
return a
def phi(obs):
return obs.astype(np.float32)

ou_sigma = (action_space.high - action_space.low) * 0.2
explorer = explorers.AdditiveOU(sigma=ou_sigma)
Expand Down
6 changes: 5 additions & 1 deletion examples/gym/train_dqn_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,14 @@ def make_env(test):
obs_size, n_actions,
n_hidden_channels=args.n_hidden_channels,
n_hidden_layers=args.n_hidden_layers)

def action_sampler():
return chainerrl.misc.sample_from_space(action_space)

# Use epsilon-greedy for exploration
explorer = explorers.LinearDecayEpsilonGreedy(
args.start_epsilon, args.end_epsilon, args.final_exploration_steps,
action_space.sample)
action_sampler)

if args.noisy_net_sigma is not None:
links.to_factorized_noisy(q_func)
Expand Down
3 changes: 2 additions & 1 deletion examples/quickstart/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@
"\n",
"# Use epsilon-greedy for exploration\n",
"explorer = chainerrl.explorers.ConstantEpsilonGreedy(\n",
" epsilon=0.3, random_action_func=env.action_space.sample)\n",
" epsilon=0.3,\n",
" random_action_func=lambda: np.random.randint(env.action_space.n))\n",
"\n",
"# DQN uses Experience Replay.\n",
"# Specify a replay buffer and its capacity.\n",
Expand Down
3 changes: 2 additions & 1 deletion tests/agents_tests/basetest_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from chainer import optimizers
import numpy as np

import chainerrl
from chainerrl.agents.ddpg import DDPGModel
from chainerrl.envs.abc import ABC
from chainerrl.explorers.epsilon_greedy import LinearDecayEpsilonGreedy
Expand Down Expand Up @@ -46,7 +47,7 @@ def make_ddpg_agent(self, env, model, actor_opt, critic_opt, explorer,

def make_explorer(self, env):
def random_action_func():
a = env.action_space.sample()
a = chainerrl.misc.sample_from_space(env.action_space)
if isinstance(a, np.ndarray):
return a.astype(np.float32)
else:
Expand Down
3 changes: 2 additions & 1 deletion tests/agents_tests/basetest_dqn_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from chainer import optimizers
import numpy as np

import chainerrl
from chainerrl.envs.abc import ABC
from chainerrl.explorers.epsilon_greedy import LinearDecayEpsilonGreedy
from chainerrl import q_functions
Expand Down Expand Up @@ -60,7 +61,7 @@ def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu):

def make_explorer(self, env):
def random_action_func():
a = env.action_space.sample()
a = chainerrl.misc.sample_from_space(env.action_space)
if isinstance(a, np.ndarray):
return a.astype(np.float32)
else:
Expand Down
3 changes: 2 additions & 1 deletion tests/agents_tests/basetest_pgt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from chainer import optimizers
import numpy as np

import chainerrl
from chainerrl.envs.abc import ABC
from chainerrl.explorers.epsilon_greedy import LinearDecayEpsilonGreedy
from chainerrl.links import Sequence
Expand Down Expand Up @@ -47,7 +48,7 @@ def make_pgt_agent(self, env, model, actor_opt, critic_opt, explorer,

def make_explorer(self, env):
def random_action_func():
a = env.action_space.sample()
a = chainerrl.misc.sample_from_space(env.action_space)
if isinstance(a, np.ndarray):
return a.astype(np.float32)
else:
Expand Down
19 changes: 15 additions & 4 deletions tests/agents_tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
import gym
import gym.spaces

import basetest_agents as base
import chainerrl
from chainerrl import agents
from chainerrl import explorers
from chainerrl import policies
from chainerrl import q_functions
from chainerrl import replay_buffer
from chainerrl import v_function

import basetest_agents as base


def create_stochastic_policy_for_env(env):
assert isinstance(env.observation_space, gym.spaces.Box)
Expand Down Expand Up @@ -127,8 +129,11 @@ def create_agent(self, env):
rbuf = replay_buffer.ReplayBuffer(10 ** 5)
opt = optimizers.Adam()
opt.setup(model)

def action_sampler():
return chainerrl.misc.sample_from_space(env.action_space)
explorer = explorers.ConstantEpsilonGreedy(
0.2, random_action_func=lambda: env.action_space.sample())
0.2, random_action_func=action_sampler)
return agents.DQN(model, opt, rbuf, gamma=0.99, explorer=explorer)


Expand All @@ -144,8 +149,11 @@ def create_agent(self, env):
rbuf = replay_buffer.ReplayBuffer(10 ** 5)
opt = optimizers.Adam()
opt.setup(model)

def action_sampler():
return chainerrl.misc.sample_from_space(env.action_space)
explorer = explorers.ConstantEpsilonGreedy(
0.2, random_action_func=lambda: env.action_space.sample())
0.2, random_action_func=action_sampler)
return agents.DoubleDQN(
model, opt, rbuf, gamma=0.99, explorer=explorer)

Expand All @@ -161,8 +169,11 @@ def create_agent(self, env):
model = create_state_q_function_for_env(env)
opt = optimizers.Adam()
opt.setup(model)

def action_sampler():
return chainerrl.misc.sample_from_space(env.action_space)
explorer = explorers.ConstantEpsilonGreedy(
0.2, random_action_func=lambda: env.action_space.sample())
0.2, random_action_func=action_sampler)
return agents.NSQ(
q_function=model,
optimizer=opt,
Expand Down
33 changes: 33 additions & 0 deletions tests/misc_tests/test_random_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,36 @@ def test_numpy_random(self):
@attr.gpu
def test_cupy_random(self):
self._test_xp_random(chainer.cuda.cupy, gpus=(0,))


class TestSampleFromSpace(unittest.TestCase):

def test_discrete(self):
from gym import spaces
space = spaces.Discrete(10000)
np.random.seed(0)
a = chainerrl.misc.sample_from_space(space)
np.random.seed(0)
b = chainerrl.misc.sample_from_space(space)
self.assertTrue(space.contains(a))
self.assertTrue(space.contains(b))
self.assertEqual(a, b)

def test_box(self):
from gym import spaces
space = spaces.Box(low=-1, high=1, shape=10)
np.random.seed(0)
a0 = chainerrl.misc.sample_from_space(space)
np.random.seed(1)
a1 = chainerrl.misc.sample_from_space(space)
np.random.seed(0)
b0 = chainerrl.misc.sample_from_space(space)
np.random.seed(1)
b1 = chainerrl.misc.sample_from_space(space)
self.assertTrue(space.contains(a0))
self.assertTrue(space.contains(b0))
self.assertTrue(space.contains(a1))
self.assertTrue(space.contains(b1))
self.assertTrue((a0 == b0).all())
self.assertTrue((a1 == b1).all())
self.assertTrue((a0 != a1).all())