Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit e14a535

Browse files
koz4kCopybara-Service
authored and
Copybara-Service
committed
internal merge of PR #1277
PiperOrigin-RevId: 224250704
1 parent 2663805 commit e14a535

File tree

3 files changed

+12
-20
lines changed

3 files changed

+12
-20
lines changed

tensor2tensor/models/research/rl.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,6 @@ def ppo_base_v1():
6161
return hparams
6262

6363

64-
#@registry.register_hparams
65-
#def ppo_continuous_action_base():
66-
# hparams = ppo_base_v1()
67-
# hparams.add_hparam("policy_network", feed_forward_gaussian_fun)
68-
# hparams.add_hparam("policy_network_params", "basic_policy_parameters")
69-
# return hparams
70-
71-
7264
@registry.register_hparams
7365
def basic_policy_parameters():
7466
wrappers = None
@@ -158,7 +150,7 @@ def get_policy(observations, hparams, action_space):
158150
"""Get a policy network.
159151
160152
Args:
161-
observations
153+
observations: observations
162154
hparams: parameters
163155
action_space: action space
164156

tensor2tensor/rl/dopamine_connector.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22-
from copy import copy
22+
import copy
23+
2324
from dopamine.agents.dqn import dqn_agent
24-
from dopamine.agents.dqn.dqn_agent import NATURE_DQN_OBSERVATION_SHAPE
25-
from dopamine.agents.dqn.dqn_agent import NATURE_DQN_STACK_SIZE
2625
from dopamine.atari import run_experiment
2726
from dopamine.replay_memory import circular_replay_buffer
2827
from dopamine.replay_memory.circular_replay_buffer import OutOfGraphReplayBuffer
@@ -45,7 +44,7 @@
4544

4645

4746
class ResizeObservation(gym.ObservationWrapper):
48-
""" TODO(konradczechowski): Add doc-string."""
47+
"""TODO(konradczechowski): Add doc-string."""
4948

5049
def __init__(self, env, size=84):
5150
"""Based on WarpFrame from openai baselines atari_wrappers.py.
@@ -91,7 +90,7 @@ def step(self, action):
9190

9291

9392
class _DQNAgent(dqn_agent.DQNAgent):
94-
""" Modify dopamine DQNAgent to match our needs.
93+
"""Modify dopamine DQNAgent to match our needs.
9594
9695
Allow passing batch_size and replay_capacity to ReplayBuffer, allow not using
9796
(some of) terminal episode transitions in training.
@@ -107,8 +106,8 @@ def __init__(self, replay_capacity, batch_size, generates_trainable_dones,
107106
def _build_replay_buffer(self, use_staging):
108107
"""Build WrappedReplayBuffer with custom OutOfGraphReplayBuffer."""
109108
replay_buffer_kwargs = dict(
110-
observation_shape=NATURE_DQN_OBSERVATION_SHAPE,
111-
stack_size=NATURE_DQN_STACK_SIZE,
109+
observation_shape=dqn_agent.NATURE_DQN_OBSERVATION_SHAPE,
110+
stack_size=dqn_agent.NATURE_DQN_STACK_SIZE,
112111
replay_capacity=self._replay_capacity,
113112
batch_size=self._batch_size,
114113
update_horizon=self.update_horizon,
@@ -127,7 +126,7 @@ def _build_replay_buffer(self, use_staging):
127126

128127

129128
class _OutOfGraphReplayBuffer(OutOfGraphReplayBuffer):
130-
""" Replay not sampling artificial_terminal transition.
129+
"""Replay not sampling artificial_terminal transition.
131130
132131
Adds to stored tuples 'artificial_done' field (as last ReplayElement).
133132
When sampling, ignores tuples for which artificial_done is True.
@@ -238,7 +237,7 @@ def _get_optimizer(params):
238237

239238

240239
class DQNLearner(PolicyLearner):
241-
""" Interface for learning dqn implemented in dopamine."""
240+
"""Interface for learning dqn implemented in dopamine."""
242241

243242
def __init__(self, frame_stack_size, base_event_dir, agent_model_dir):
244243
super(DQNLearner, self).__init__(frame_stack_size, base_event_dir,
@@ -296,7 +295,7 @@ def train(self,
296295
if num_env_steps is None:
297296
num_env_steps = hparams.num_frames
298297

299-
hparams = copy(hparams)
298+
hparams = copy.copy(hparams)
300299
hparams.set_hparam(
301300
"agent_epsilon_eval", min(hparams.agent_epsilon_eval * sampling_temp, 1)
302301
)
@@ -318,7 +317,7 @@ def evaluate(self, env_fn, hparams, sampling_temp):
318317
target_iterations = 0
319318
training_steps_per_iteration = 0
320319

321-
hparams = copy(hparams)
320+
hparams = copy.copy(hparams)
322321
hparams.set_hparam(
323322
"agent_epsilon_eval", min(hparams.agent_epsilon_eval * sampling_temp, 1)
324323
)

tensor2tensor/rl/policy_learner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def train(
4141
eval_env_fn=None,
4242
report_fn=None
4343
):
44+
"""Train."""
4445
# TODO(konradczechowski): pass name_scope instead of epoch?
4546
# TODO(konradczechowski): move 'simulated' to batch_env
4647
raise NotImplementedError()

0 commit comments

Comments
 (0)