Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 4e62017

Browse files
cclaussCopybara-Service
authored and
Copybara-Service
committed
internal merge of PR #1343
PiperOrigin-RevId: 228365154
1 parent c54464a commit 4e62017

File tree

9 files changed

+76
-164
lines changed

9 files changed

+76
-164
lines changed

tensor2tensor/data_generators/gym_env.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import itertools
2424
import os
2525
import random
26+
import re
2627

2728
from gym.spaces import Box
2829
import numpy as np
@@ -641,6 +642,58 @@ def base_env_name(self):
641642
def num_channels(self):
642643
return self.observation_space.shape[2]
643644

645+
@staticmethod
646+
def infer_last_epoch_num(data_dir):
647+
"""Infer highest epoch number from file names in data_dir."""
648+
names = os.listdir(data_dir)
649+
epochs_str = [re.findall(pattern=r".*\.(-?\d+)$", string=name)
650+
for name in names]
651+
epochs_str = sum(epochs_str, [])
652+
return max([int(epoch_str) for epoch_str in epochs_str])
653+
654+
@staticmethod
655+
def setup_env_from_hparams(hparams, batch_size, max_num_noops):
656+
game_mode = "NoFrameskip-v4"
657+
camel_game_name = misc_utils.snakecase_to_camelcase(hparams.game)
658+
camel_game_name += game_mode
659+
env_name = camel_game_name
660+
661+
env = T2TGymEnv(base_env_name=env_name,
662+
batch_size=batch_size,
663+
grayscale=hparams.grayscale,
664+
resize_width_factor=hparams.resize_width_factor,
665+
resize_height_factor=hparams.resize_height_factor,
666+
rl_env_max_episode_steps=hparams.rl_env_max_episode_steps,
667+
max_num_noops=max_num_noops, maxskip_envs=True)
668+
return env
669+
670+
@staticmethod
671+
def setup_and_load_epoch(hparams, data_dir, which_epoch_data=None):
672+
"""Load T2TBatchGymEnv with data from one epoch.
673+
674+
Args:
675+
hparams: hparams.
676+
data_dir: data directory.
677+
which_epoch_data: data from which epoch to load.
678+
679+
Returns:
680+
env.
681+
"""
682+
t2t_env = T2TGymEnv.setup_env_from_hparams(
683+
hparams, batch_size=hparams.real_batch_size,
684+
max_num_noops=hparams.max_num_noops
685+
)
686+
# Load data.
687+
if which_epoch_data is not None:
688+
if which_epoch_data == "last":
689+
which_epoch_data = T2TGymEnv.infer_last_epoch_num(data_dir)
690+
assert isinstance(which_epoch_data, int), \
691+
"{}".format(type(which_epoch_data))
692+
t2t_env.start_new_epoch(which_epoch_data, data_dir)
693+
else:
694+
t2t_env.start_new_epoch(-999)
695+
return t2t_env
696+
644697
def _derive_observation_space(self, orig_observ_space):
645698
height, width, channels = orig_observ_space.shape
646699
if self.grayscale:

tensor2tensor/models/research/rl.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,6 @@ def rlmf_original():
335335
frame_stack_size=4,
336336
eval_sampling_temps=[0.0, 0.2, 0.5, 0.8, 1.0, 2.0],
337337
eval_max_num_noops=8,
338-
eval_rl_env_max_episode_steps=1000,
339338
resize_height_factor=2,
340339
resize_width_factor=2,
341340
grayscale=0,

tensor2tensor/rl/evaluator.py

Lines changed: 0 additions & 83 deletions
This file was deleted.

tensor2tensor/rl/player.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import six
5757

5858
from tensor2tensor.bin import t2t_trainer # pylint: disable=unused-import
59+
from tensor2tensor.data_generators.gym_env import T2TGymEnv
5960
from tensor2tensor.rl import player_utils
6061
from tensor2tensor.rl.envs.simulated_batch_env import PIL_Image
6162
from tensor2tensor.rl.envs.simulated_batch_env import PIL_ImageDraw
@@ -228,7 +229,7 @@ def main(_):
228229
directories["data"], directories["world_model"],
229230
hparams, which_epoch_data=epoch)
230231
else:
231-
env = player_utils.setup_and_load_epoch(
232+
env = T2TGymEnv.setup_and_load_epoch(
232233
hparams, data_dir=directories["data"],
233234
which_epoch_data=epoch)
234235
env = FlatBatchEnv(env)

tensor2tensor/rl/player_utils.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
import copy
2323
import os
24-
import re
2524

2625
import gym
2726
import numpy as np
@@ -40,42 +39,6 @@
4039
FLAGS = flags.FLAGS
4140

4241

43-
def infer_last_epoch_num(data_dir):
44-
"""Infer highest epoch number from file names in data_dir."""
45-
names = os.listdir(data_dir)
46-
epochs_str = [re.findall(pattern=r".*\.(-?\d+)$", string=name)
47-
for name in names]
48-
epochs_str = sum(epochs_str, [])
49-
return max([int(epoch_str) for epoch_str in epochs_str])
50-
51-
52-
def setup_and_load_epoch(hparams, data_dir, which_epoch_data=None):
53-
"""Load T2TGymEnv with data from one epoch.
54-
55-
Args:
56-
hparams: hparams.
57-
data_dir: data directory.
58-
which_epoch_data: data from which epoch to load.
59-
60-
Returns:
61-
env.
62-
"""
63-
t2t_env = rl_utils.setup_env(
64-
hparams, batch_size=hparams.real_batch_size,
65-
max_num_noops=hparams.max_num_noops
66-
)
67-
# Load data.
68-
if which_epoch_data is not None:
69-
if which_epoch_data == "last":
70-
which_epoch_data = infer_last_epoch_num(data_dir)
71-
assert isinstance(which_epoch_data, int), \
72-
"{}".format(type(which_epoch_data))
73-
t2t_env.start_new_epoch(which_epoch_data, data_dir)
74-
else:
75-
t2t_env.start_new_epoch(-999)
76-
return t2t_env
77-
78-
7942
def make_simulated_gym_env(real_env, world_model_dir, hparams, random_starts):
8043
"""Gym environment with world model."""
8144
initial_frame_chooser = rl_utils.make_initial_frame_chooser(
@@ -98,7 +61,7 @@ def load_data_and_make_simulated_env(
9861
data_dir, wm_dir, hparams, which_epoch_data="last", random_starts=True
9962
):
10063
hparams = copy.deepcopy(hparams)
101-
t2t_env = setup_and_load_epoch(
64+
t2t_env = T2TGymEnv.setup_and_load_epoch(
10265
hparams, data_dir=data_dir,
10366
which_epoch_data=which_epoch_data)
10467
return make_simulated_gym_env(

tensor2tensor/rl/rl_utils.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from tensor2tensor.models.research import rl
2929
from tensor2tensor.rl.dopamine_connector import DQNLearner
3030
from tensor2tensor.rl.ppo_learner import PPOLearner
31-
from tensor2tensor.utils import misc_utils
3231
from tensor2tensor.utils import trainer_lib
3332

3433
import tensorflow as tf
@@ -64,9 +63,8 @@ def evaluate_single_config(
6463
):
6564
"""Evaluate the PPO agent in the real environment."""
6665
eval_hparams = trainer_lib.create_hparams(hparams.base_algo_params)
67-
env = setup_env(
68-
hparams, batch_size=hparams.eval_batch_size, max_num_noops=max_num_noops,
69-
rl_env_max_episode_steps=hparams.eval_rl_env_max_episode_steps
66+
env = T2TGymEnv.setup_env_from_hparams(
67+
hparams, batch_size=hparams.eval_batch_size, max_num_noops=max_num_noops
7068
)
7169
env.start_new_epoch(0)
7270
env_fn = rl.make_real_env_fn(env)
@@ -100,38 +98,12 @@ def evaluate_all_configs(hparams, agent_model_dir):
10098
return metrics
10199

102100

103-
def summarize_metrics(eval_metrics_writer, metrics, epoch):
104-
"""Write metrics to summary."""
105-
for (name, value) in six.iteritems(metrics):
106-
summary = tf.Summary()
107-
summary.value.add(tag=name, simple_value=value)
108-
eval_metrics_writer.add_summary(summary, epoch)
109-
eval_metrics_writer.flush()
110-
111-
112101
LEARNERS = {
113102
"ppo": PPOLearner,
114103
"dqn": DQNLearner,
115104
}
116105

117106

118-
def setup_env(hparams, batch_size, max_num_noops, rl_env_max_episode_steps=-1):
119-
"""Setup."""
120-
game_mode = "NoFrameskip-v4"
121-
camel_game_name = misc_utils.snakecase_to_camelcase(hparams.game)
122-
camel_game_name += game_mode
123-
env_name = camel_game_name
124-
125-
env = T2TGymEnv(base_env_name=env_name,
126-
batch_size=batch_size,
127-
grayscale=hparams.grayscale,
128-
resize_width_factor=hparams.resize_width_factor,
129-
resize_height_factor=hparams.resize_height_factor,
130-
rl_env_max_episode_steps=rl_env_max_episode_steps,
131-
max_num_noops=max_num_noops, maxskip_envs=True)
132-
return env
133-
134-
135107
def update_hparams_from_hparams(target_hparams, source_hparams, prefix):
136108
"""Copy a subset of hparams to target_hparams."""
137109
for (param_name, param_value) in six.iteritems(source_hparams.values()):

tensor2tensor/rl/trainer_model_based.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import six
3838

3939
from tensor2tensor.bin import t2t_trainer # pylint: disable=unused-import
40+
from tensor2tensor.data_generators.gym_env import T2TGymEnv
4041
from tensor2tensor.layers import common_video
4142
from tensor2tensor.models.research import rl
4243
from tensor2tensor.models.research.rl import make_simulated_env_fn_from_hparams
@@ -377,6 +378,15 @@ def load_metrics(event_dir, epoch):
377378
return metrics
378379

379380

381+
def summarize_metrics(eval_metrics_writer, metrics, epoch):
382+
"""Write metrics to summary."""
383+
for (name, value) in six.iteritems(metrics):
384+
summary = tf.Summary()
385+
summary.value.add(tag=name, simple_value=value)
386+
eval_metrics_writer.add_summary(summary, epoch)
387+
eval_metrics_writer.flush()
388+
389+
380390
def training_loop(hparams, output_dir, report_fn=None, report_metric=None):
381391
"""Run the main training loop."""
382392
if report_fn:
@@ -391,10 +401,9 @@ def training_loop(hparams, output_dir, report_fn=None, report_metric=None):
391401

392402
epoch = -1
393403
data_dir = directories["data"]
394-
env = rl_utils.setup_env(
404+
env = T2TGymEnv.setup_env_from_hparams(
395405
hparams, batch_size=hparams.real_batch_size,
396-
max_num_noops=hparams.max_num_noops,
397-
rl_env_max_episode_steps=hparams.rl_env_max_episode_steps
406+
max_num_noops=hparams.max_num_noops
398407
)
399408
env.start_new_epoch(epoch, data_dir)
400409

@@ -484,7 +493,7 @@ def training_loop(hparams, output_dir, report_fn=None, report_metric=None):
484493
log("World model eval metrics:\n{}".format(pprint.pformat(wm_metrics)))
485494
metrics.update(wm_metrics)
486495

487-
rl_utils.summarize_metrics(eval_metrics_writer, metrics, epoch)
496+
summarize_metrics(eval_metrics_writer, metrics, epoch)
488497

489498
# Report metrics
490499
if report_fn:

tensor2tensor/rl/trainer_model_based_params.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,6 @@ def _rlmb_base():
8484
# Sampling temperatures to try during eval.
8585
eval_sampling_temps=[0.5, 0.0, 1.0],
8686
eval_max_num_noops=8,
87-
# To speed up the pipeline. Some games want to run forever.
88-
eval_rl_env_max_episode_steps=1000,
8987

9088
game="pong",
9189
# Whether to evaluate the world model in each iteration of the loop to get
@@ -508,7 +506,6 @@ def _rlmb_tiny_overrides():
508506
resize_width_factor=2,
509507
wm_eval_rollout_ratios=[1],
510508
rl_env_max_episode_steps=7,
511-
eval_rl_env_max_episode_steps=7,
512509
simulated_rollout_length=2,
513510
eval_sampling_temps=[0.0, 1.0],
514511
)

tensor2tensor/rl/trainer_model_free.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
python -m tensor2tensor.rl.trainer_model_free \
2121
--output_dir=$HOME/t2t/rl_v1 \
2222
--hparams_set=pong_model_free \
23-
--hparams='batch_size=15'
23+
--loop_hparams='batch_size=15'
2424
"""
2525

2626
from __future__ import absolute_import
@@ -29,6 +29,7 @@
2929

3030
import pprint
3131

32+
from tensor2tensor.data_generators.gym_env import T2TGymEnv
3233
from tensor2tensor.models.research import rl
3334
from tensor2tensor.rl import rl_utils
3435
from tensor2tensor.utils import flags as t2t_flags # pylint: disable=unused-import
@@ -52,9 +53,9 @@
5253

5354
def initialize_env_specs(hparams):
5455
"""Initializes env_specs using T2TGymEnvs."""
55-
env = rl_utils.setup_env(hparams, hparams.batch_size,
56-
hparams.eval_max_num_noops,
57-
hparams.rl_env_max_episode_steps)
56+
env = T2TGymEnv.setup_env_from_hparams(
57+
hparams, hparams.batch_size, hparams.eval_max_num_noops
58+
)
5859
env.start_new_epoch(0)
5960

6061
# TODO(afrozm): Decouple env_fn from hparams and return both, is there

0 commit comments

Comments
 (0)