19
19
from __future__ import division
20
20
from __future__ import print_function
21
21
22
- from copy import copy
22
+ import copy
23
+
23
24
from dopamine .agents .dqn import dqn_agent
24
- from dopamine .agents .dqn .dqn_agent import NATURE_DQN_OBSERVATION_SHAPE
25
- from dopamine .agents .dqn .dqn_agent import NATURE_DQN_STACK_SIZE
26
25
from dopamine .atari import run_experiment
27
26
from dopamine .replay_memory import circular_replay_buffer
28
27
from dopamine .replay_memory .circular_replay_buffer import OutOfGraphReplayBuffer
45
44
46
45
47
46
class ResizeObservation (gym .ObservationWrapper ):
48
- """ TODO(konradczechowski): Add doc-string."""
47
+ """TODO(konradczechowski): Add doc-string."""
49
48
50
49
def __init__ (self , env , size = 84 ):
51
50
"""Based on WarpFrame from openai baselines atari_wrappers.py.
@@ -91,7 +90,7 @@ def step(self, action):
91
90
92
91
93
92
class _DQNAgent (dqn_agent .DQNAgent ):
94
- """ Modify dopamine DQNAgent to match our needs.
93
+ """Modify dopamine DQNAgent to match our needs.
95
94
96
95
Allow passing batch_size and replay_capacity to ReplayBuffer, allow not using
97
96
(some of) terminal episode transitions in training.
@@ -107,8 +106,8 @@ def __init__(self, replay_capacity, batch_size, generates_trainable_dones,
107
106
def _build_replay_buffer (self , use_staging ):
108
107
"""Build WrappedReplayBuffer with custom OutOfGraphReplayBuffer."""
109
108
replay_buffer_kwargs = dict (
110
- observation_shape = NATURE_DQN_OBSERVATION_SHAPE ,
111
- stack_size = NATURE_DQN_STACK_SIZE ,
109
+ observation_shape = dqn_agent . NATURE_DQN_OBSERVATION_SHAPE ,
110
+ stack_size = dqn_agent . NATURE_DQN_STACK_SIZE ,
112
111
replay_capacity = self ._replay_capacity ,
113
112
batch_size = self ._batch_size ,
114
113
update_horizon = self .update_horizon ,
@@ -127,7 +126,7 @@ def _build_replay_buffer(self, use_staging):
127
126
128
127
129
128
class _OutOfGraphReplayBuffer (OutOfGraphReplayBuffer ):
130
- """ Replay not sampling artificial_terminal transition.
129
+ """Replay not sampling artificial_terminal transition.
131
130
132
131
Adds to stored tuples 'artificial_done' field (as last ReplayElement).
133
132
When sampling, ignores tuples for which artificial_done is True.
@@ -238,7 +237,7 @@ def _get_optimizer(params):
238
237
239
238
240
239
class DQNLearner (PolicyLearner ):
241
- """ Interface for learning dqn implemented in dopamine."""
240
+ """Interface for learning dqn implemented in dopamine."""
242
241
243
242
def __init__ (self , frame_stack_size , base_event_dir , agent_model_dir ):
244
243
super (DQNLearner , self ).__init__ (frame_stack_size , base_event_dir ,
@@ -296,7 +295,7 @@ def train(self,
296
295
if num_env_steps is None :
297
296
num_env_steps = hparams .num_frames
298
297
299
- hparams = copy (hparams )
298
+ hparams = copy . copy (hparams )
300
299
hparams .set_hparam (
301
300
"agent_epsilon_eval" , min (hparams .agent_epsilon_eval * sampling_temp , 1 )
302
301
)
@@ -318,7 +317,7 @@ def evaluate(self, env_fn, hparams, sampling_temp):
318
317
target_iterations = 0
319
318
training_steps_per_iteration = 0
320
319
321
- hparams = copy (hparams )
320
+ hparams = copy . copy (hparams )
322
321
hparams .set_hparam (
323
322
"agent_epsilon_eval" , min (hparams .agent_epsilon_eval * sampling_temp , 1 )
324
323
)
0 commit comments