Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 3dce919

Browse files
koz4klukaszkaiser
authored andcommitted
RL fixes (#1505)
1 parent 3669442 commit 3dce919

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

tensor2tensor/rl/evaluator.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def make_env(env_type, real_env, sim_env_kwargs):
252252

253253
def make_agent(
254254
agent_type, env, policy_hparams, policy_dir, sampling_temp,
255-
sim_env_kwargs=None, frame_stack_size=None, rollout_agent_type=None,
255+
sim_env_kwargs_fn=None, frame_stack_size=None, rollout_agent_type=None,
256256
batch_size=None, inner_batch_size=None, env_type=None, **planner_kwargs
257257
):
258258
"""Factory function for Agents."""
@@ -270,7 +270,7 @@ def make_agent(
270270
batch_size, make_agent(
271271
rollout_agent_type, env, policy_hparams, policy_dir,
272272
sampling_temp, batch_size=inner_batch_size
273-
), make_env(env_type, env.env, sim_env_kwargs),
273+
), make_env(env_type, env.env, sim_env_kwargs_fn()),
274274
lambda env: rl_utils.BatchStackWrapper(env, frame_stack_size),
275275
discount_factor=policy_hparams.gae_gamma, **planner_kwargs
276276
),
@@ -302,17 +302,18 @@ def make_agent_from_hparams(
302302
planner_hparams, model_dir, policy_dir, sampling_temp, video_writers=()
303303
):
304304
"""Creates an Agent from hparams."""
305-
sim_env_kwargs = rl.make_simulated_env_kwargs(
306-
base_env, loop_hparams, batch_size=planner_hparams.batch_size,
307-
model_dir=model_dir
308-
)
305+
def sim_env_kwargs_fn():
306+
return rl.make_simulated_env_kwargs(
307+
base_env, loop_hparams, batch_size=planner_hparams.batch_size,
308+
model_dir=model_dir
309+
)
309310
planner_kwargs = planner_hparams.values()
310311
planner_kwargs.pop("batch_size")
311312
planner_kwargs.pop("rollout_agent_type")
312313
planner_kwargs.pop("env_type")
313314
return make_agent(
314315
agent_type, stacked_env, policy_hparams, policy_dir, sampling_temp,
315-
sim_env_kwargs, loop_hparams.frame_stack_size,
316+
sim_env_kwargs_fn, loop_hparams.frame_stack_size,
316317
planner_hparams.rollout_agent_type,
317318
inner_batch_size=planner_hparams.batch_size,
318319
env_type=planner_hparams.env_type,

tensor2tensor/rl/rl_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,9 +414,8 @@ def augment_observation(
414414
(1, 15), "f:{:3}".format(int(frame_index)),
415415
fill=(255, 0, 0)
416416
)
417-
header = np.asarray(img)
417+
header = np.copy(np.asarray(img))
418418
del img
419-
header.setflags(write=1)
420419
if bar_color is not None:
421420
header[0, :, :] = bar_color
422421
return np.concatenate([header, observation], axis=0)

0 commit comments

Comments
 (0)