Skip to content

Commit

Permalink
Allow specifying full Atari environment name.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 260580590
Change-Id: Ibc896555e2dbcaf7a7e74ed3259fc2c603f97f02
  • Loading branch information
sguada authored and copybara-github committed Jul 29, 2019
1 parent 27de6da commit 3072d16
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tf_agents/agents/dqn/examples/v1/train_eval_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@

flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),
'Root directory for writing logs/summaries/checkpoints.')
flags.DEFINE_string('environment_name', None,
'Full name of Atari game to run, ex. PongNoFrameskip-v4.')
flags.DEFINE_string('game_name', 'Pong', 'Name of Atari game to run.')

flags.DEFINE_integer('num_iterations', None,
'Number of train/eval iterations to run.')
flags.DEFINE_integer('initial_collect_steps', None,
Expand Down Expand Up @@ -627,8 +630,10 @@ def get_run_args():
def main(_):
logging.set_verbosity(logging.INFO)
tf.enable_resource_variables()
TrainEval(FLAGS.root_dir, suite_atari.game(name=FLAGS.game_name),
**get_run_args()).run()
environment_name = FLAGS.environment_name
if environment_name is None:
environment_name = suite_atari.game(name=FLAGS.game_name)
TrainEval(FLAGS.root_dir, environment_name, **get_run_args()).run()


if __name__ == '__main__':
Expand Down

0 comments on commit 3072d16

Please sign in to comment.