From 68ec284dd57ba00230c6c6199cbb71d53e0696fc Mon Sep 17 00:00:00 2001 From: Maxime Chevalier-Boisvert Date: Wed, 25 Dec 2019 20:27:44 -0500 Subject: [PATCH 1/3] Update visualize.py --- scripts/visualize.py | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/scripts/visualize.py b/scripts/visualize.py index b5717ceb8..1c8159e75 100644 --- a/scripts/visualize.py +++ b/scripts/visualize.py @@ -2,7 +2,6 @@ import time import numpy import torch - import utils @@ -23,6 +22,9 @@ help="pause duration between two consequent actions of the agent (default: 0.1)") parser.add_argument("--gif", type=str, default=None, help="store output as gif with the given filename") +parser.add_argument("--episodes", type=int, default=1000000, + help="number of episodes to visualize") + args = parser.parse_args() # Set seed for all randomness sources @@ -55,22 +57,24 @@ done = True -while True: - if done: - obs = env.reset() +for episode in range(args.episodes): + obs = env.reset() - time.sleep(args.pause) - renderer = env.render() - if args.gif: - frames.append(numpy.moveaxis(env.render("rgb_array"), 2, 0)) + while True: + time.sleep(args.pause) + env.render('human') + if args.gif: + frames.append(numpy.moveaxis(env.render("rgb_array"), 2, 0)) - action = agent.get_action(obs) - obs, reward, done, _ = env.step(action) - agent.analyze_feedback(reward, done) + action = agent.get_action(obs) + obs, reward, done, _ = env.step(action) + agent.analyze_feedback(reward, done) - if renderer.window is None: - if args.gif: - print("Saving gif... ", end="") - write_gif(numpy.array(frames), args.gif+".gif", fps=1/args.pause) - print("Done.") - break + if done: + break + +if args.gif: + print("Saving gif... ", end="") + write_gif(numpy.array(frames), args.gif+".gif", fps=1/args.pause) + +print("Done.") From cb107324d04d947761d0212ee7963990eef46c30 Mon Sep 17 00:00:00 2001 From: Maxime Chevalier-Boisvert Date: Sun, 5 Jan 2020 00:11:44 -0500 Subject: [PATCH 2/3] Detect window being closed in visualize.py --- scripts/visualize.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/scripts/visualize.py b/scripts/visualize.py index 1c8159e75..82a329abc 100644 --- a/scripts/visualize.py +++ b/scripts/visualize.py @@ -54,15 +54,16 @@ if args.gif: from array2gif import write_gif frames = [] - -done = True + +# Create a window to view the environment +env.render('human') for episode in range(args.episodes): obs = env.reset() while True: - time.sleep(args.pause) env.render('human') + if args.gif: frames.append(numpy.moveaxis(env.render("rgb_array"), 2, 0)) @@ -70,9 +71,12 @@ obs, reward, done, _ = env.step(action) agent.analyze_feedback(reward, done) - if done: + if done or env.window.closed: break + if env.window.closed: + break + if args.gif: print("Saving gif... ", end="") write_gif(numpy.array(frames), args.gif+".gif", fps=1/args.pause) From 1eeca1242d31fd43611b02153f093c937e965fa1 Mon Sep 17 00:00:00 2001 From: Lucas Willems Date: Mon, 6 Jan 2020 23:14:23 +0100 Subject: [PATCH 3/3] Small fixes --- scripts/visualize.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/scripts/visualize.py b/scripts/visualize.py index 82a329abc..75d36a514 100644 --- a/scripts/visualize.py +++ b/scripts/visualize.py @@ -2,6 +2,7 @@ import time import numpy import torch + import utils @@ -54,7 +55,7 @@ if args.gif: from array2gif import write_gif frames = [] - + # Create a window to view the environment env.render('human') @@ -63,7 +64,6 @@ while True: env.render('human') - if args.gif: frames.append(numpy.moveaxis(env.render("rgb_array"), 2, 0)) @@ -80,5 +80,4 @@ if args.gif: print("Saving gif... ", end="") write_gif(numpy.array(frames), args.gif+".gif", fps=1/args.pause) - -print("Done.") + print("Done.")