Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Way to record the replays of the training #29

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,14 @@ After Reaver has finished training, you can look at how it performs by appending
python -m reaver.run --env MoveToBeacon --agent a2c --test --render 2> stderr.log
```

To record replays of the training you can use `--replay` and `--replay_dir`. The `replay` flag needs an integger `n`, after `n` episodes a replay will be saved. The `replay_dir` flag will need a `path` where do you want to save the replay, `Linux OS` will save on ~/StarCraftII/Replays/ + `path(replay_dir)` and `Windows OS` will save on `path(replay_dir)`.

```bash
python -m reaver.run --env MoveToBeacon --agent a2c --test --render --replay 10 --replay_dir a2c_agent_mtb_replay 2> stderr.log
```
**NB!** If you want to use the flag `replay` you will need to install manually it from source as described above.


### Google Colab

A companion [Google Colab notebook](https://colab.research.google.com/drive/1DvyCUdymqgjk85FB5DrTtAwTFbI494x7)
Expand Down
9 changes: 8 additions & 1 deletion reaver/envs/sc2.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,17 @@ def __init__(
spatial_dim=16,
step_mul=8,
obs_features=None,
save_replay_episodes=0,
replay_dir=None,
action_ids=ACTIONS_MINIGAMES
):
super().__init__(map_name, render, reset_done, max_ep_len)

self.step_mul = step_mul
self.spatial_dim = spatial_dim
self._env = None
self.save_replay_episodes = save_replay_episodes
self.replay_dir = replay_dir

# sensible action set for all minigames
if not action_ids or action_ids in [ACTIONS_MINIGAMES, ACTIONS_MINIGAMES_ALL]:
Expand Down Expand Up @@ -79,7 +83,10 @@ def start(self):
rgb_screen=None,
rgb_minimap=None
)],
step_mul=self.step_mul,)
step_mul=self.step_mul,
save_replay_episodes=self.save_replay_episodes,
replay_dir=self.replay_dir
)

def step(self, action):
try:
Expand Down
12 changes: 11 additions & 1 deletion reaver/run.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
'Run an agent in test mode: restore flag is set to true and number of envs set to 1'
'Loss is calculated, but gradients are not applied.'
'Checkpoints, summaries, log files are not updated, but console logger is enabled.')
flags.DEFINE_integer('replay', 0, "Save a replay after this many episodes. Default of 0 means don't save replays.")
flags.DEFINE_string('replay_dir', None, 'Directory to save replays. '
'Linux distros will save on ~/StarCraftII/Replays/ + path(replay_dir)'
'Windows distros will save on path(replay_dir)')


flags.DEFINE_alias('e', 'env')
flags.DEFINE_alias('a', 'agent')
Expand All @@ -42,6 +47,8 @@
flags.DEFINE_alias('la', 'log_eps_avg')
flags.DEFINE_alias('n', 'experiment')
flags.DEFINE_alias('g', 'gin_bindings')
flags.DEFINE_alias('r', 'replay')
flags.DEFINE_alias('rd', 'replay_dir')


def main(argv):
Expand Down Expand Up @@ -75,7 +82,10 @@ def main(argv):
sess_mgr = rvr.utils.tensorflow.SessionManager(sess, expt.path, args.ckpt_freq, training_enabled=not args.test)

env_cls = rvr.envs.GymEnv if '-v' in args.env else rvr.envs.SC2Env
env = env_cls(args.env, args.render, max_ep_len=args.max_ep_len)
env = env_cls(args.env, args.render,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would break support for non-sc2 environments.
Can you add a general-purpose flag similar to render to base env class (e.g. call it monitor)? Simply ignore it in other envs for now.

max_ep_len=args.max_ep_len,
save_replay_episodes=args.replay,
replay_dir=args.replay_dir)

agent = rvr.agents.registry[args.agent](env.obs_spec(), env.act_spec(), sess_mgr=sess_mgr, n_envs=args.n_envs)
agent.logger = rvr.utils.StreamLogger(args.n_envs, args.log_freq, args.log_eps_avg, sess_mgr, expt.log_path)
Expand Down