Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docopt #38

Merged
merged 29 commits into from
Jan 7, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
final changes
  • Loading branch information
jtatusko committed Jan 7, 2019
commit 65393c847c869f07a36665f35b452bf77fd7e2da
3 changes: 1 addition & 2 deletions adept/scripts/render_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def main(
engine = env_registry.lookup_engine(train_args.env)
assert engine == Engines.GYM, "render_atari.py is only for Atari."

train_args.nb_env = 1
env = SimpleEnvManager.from_args(
train_args, seed=args.seed, nb_env=1, registry=env_registry
)
Expand Down Expand Up @@ -120,7 +119,7 @@ def main(
env.gpu_preprocessor,
env.engine,
env.action_space,
nb_env=args.nb_episode
nb_env=1
)

# create a rendering container
Expand Down
133 changes: 83 additions & 50 deletions adept/scripts/replay_gen_sc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,86 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
__ __
____ _____/ /__ ____ / /_
/ __ `/ __ / _ \/ __ \/ __/
/ /_/ / /_/ / __/ /_/ / /_
\__,_/\__,_/\___/ .___/\__/
/_/

Replay Gen SC2

Generates StarCraft 2 Replay files of an agent interacting with the environment.

Usage:
replay_gen_sc2 (--log-id-dir <path> --epoch <int>) [options]
replay_gen_sc2 (-h | --help)

Required:
--log-id-dir <path> Path to train logs (.../logs/<env-id>/<log-id>)
--epoch <int> Epoch number to load

Options:
--gpu-id <int> CUDA device ID of GPU [default: 0]
--seed <int> Seed for random variables [default: 512]
--render Render environment
"""
import json
import os

import torch
from absl import flags

from adept.agents.agent_registry import AgentRegistry
from adept.containers import ReplayGenerator
from adept.environments import SubProcEnvManager
from adept.environments.env_registry import EnvModuleRegistry, Engines
from adept.utils.logging import print_ascii_logo
from adept.utils.script_helpers import make_agent, make_network, \
get_head_shapes, parse_bool
from adept.utils.script_helpers import (
make_network, LogDirHelper
)
from adept.utils.util import DotDict

# hack to use argparse for SC2
FLAGS = flags.FLAGS
FLAGS(['local.py'])


def main(args, env_registry=EnvModuleRegistry()):
def parse_args():
from docopt import docopt
args = docopt(__doc__)
args = {k.strip('--').replace('-', '_'): v for k, v in args.items()}
del args['h']
del args['help']
args = DotDict(args)
args.epoch = int(float(args.epoch))
args.gpu_id = int(args.gpu_id)
args.seed = int(args.seed)
return args


def main(
args,
agent_registry=AgentRegistry(),
env_registry=EnvModuleRegistry()
):
"""
Run an evaluation.

:param args: Dict[str, Any]
:param agent_registry: AgentRegistry
:param env_registry: EnvModuleRegistry
:return:
"""

print_ascii_logo()
print('Saving replays... Press Ctrl+C to stop.')

with open(args.args_file, 'r') as args_file:
log_dir_helper = LogDirHelper(args.log_id_dir)

with open(log_dir_helper.args_file_path(), 'r') as args_file:
train_args = DotDict(json.load(args_file))
train_args.nb_env = 1 # TODO remove

engine = env_registry.lookup_engine(train_args.env)
assert engine == Engines.SC2, "replay_gen_sc2.py is only for SC2."
Expand All @@ -50,24 +103,39 @@ def main(args, env_registry=EnvModuleRegistry()):
seed=args.seed,
nb_env=1,
registry=env_registry,
sc2_replay_dir=os.path.split(args.network_file)[0],
sc2_replay_dir=log_dir_helper.epoch_path_at_epoch(args.epoch),
sc2_render=args.render
)

# construct network
network_head_shapes = get_head_shapes(env.action_space, train_args.agent)
network = make_network(
env.observation_space, network_head_shapes, train_args
env.observation_space,
agent_registry.lookup_output_shape(train_args.agent, env.action_space),
train_args
)
network.load_state_dict(
torch.load(
log_dir_helper.network_path_at_epoch(args.epoch),
map_location=lambda storage, loc: storage
)
)
network.load_state_dict(torch.load(args.network_file))

# create an agent (add act_eval method)
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device(
"cuda:{}".format(args.gpu_id)
if (torch.cuda.is_available() and args.gpu_id >= 0)
else "cpu"
)
torch.backends.cudnn.benchmark = True
agent = make_agent(
network, device, env.gpu_preprocessor, env.engine, env.action_space,
train_args
agent = agent_registry.lookup_agent(train_args.agent).from_args(
train_args,
network,
device,
env_registry.lookup_reward_normalizer(train_args.env),
env.gpu_preprocessor,
env.engine,
env.action_space,
nb_env=1
)

# create a rendering container
Expand All @@ -80,39 +148,4 @@ def main(args, env_registry=EnvModuleRegistry()):


if __name__ == '__main__':
import argparse

parser = argparse.ArgumentParser(description='AdeptRL Renderer')
parser.add_argument(
'--network-file',
help='path to args file (.../logs/<env-id>/<log-id>/<epoch>/model.pth)'
)
parser.add_argument(
'--args-file',
help='path to args file (.../logs/<env-id>/<log-id>/args.json)'
)
parser.add_argument(
'-s',
'--seed',
type=int,
default=32,
metavar='S',
help='random seed (default: 32)'
)
parser.add_argument(
'-r',
'--render',
type=parse_bool,
nargs='?',
const=True,
default=False,
help='render the environment during eval. (default: False)'
)
parser.add_argument(
'--gpu-id',
type=int,
default=0,
help='Which GPU to use for training (default: 0)'
)
args = parser.parse_args()
main(args)
main(parse_args())