Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] MDQN #74

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
244 changes: 244 additions & 0 deletions examples/atari/train_mdqn_ale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import argparse

import torch.nn as nn
import torch.optim as optim
import numpy as np

import pfrl
from pfrl.q_functions import DiscreteActionValueHead
from pfrl.agents import MDQN
from pfrl import experiments
from pfrl import explorers
from pfrl import nn as pnn
from pfrl import utils
from pfrl import replay_buffers

from pfrl.wrappers import atari_wrappers
from pfrl.initializers import init_chainer_default


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--env",
type=str,
default="BreakoutNoFrameskip-v4",
help="OpenAI Atari domain to perform algorithm on.",
)
parser.add_argument(
"--outdir",
type=str,
default="results",
help=(
"Directory path to save output files."
" If it does not exist, it will be created."
),
)
parser.add_argument("--seed", type=int, default=0, help="Random seed [0, 2 ** 31)")
parser.add_argument(
"--gpu", type=int, default=0, help="GPU to use, set to -1 if no GPU."
)
parser.add_argument("--demo", action="store_true", default=False)
parser.add_argument("--load", type=str, default=None)
parser.add_argument(
"--final-exploration-frames",
type=int,
default=250000,
help="Timesteps after which we stop " + "annealing exploration rate",
)
parser.add_argument(
"--final-epsilon",
type=float,
default=0.01,
help="Final value of epsilon during training.",
)
parser.add_argument(
"--eval-epsilon",
type=float,
default=0.001,
help="Exploration epsilon used during eval episodes.",
)
parser.add_argument(
"--arch",
type=str,
default="doubledqn",
choices=["nature", "nips", "dueling", "doubledqn"],
help="Network architecture to use.",
)
parser.add_argument(
"--steps",
type=int,
default=5 * 10 ** 7,
help="Total number of timesteps to train the agent.",
)
parser.add_argument(
"--max-frames",
type=int,
default=30 * 60 * 60, # 30 minutes with 60 fps
help="Maximum number of frames for each episode.",
)
parser.add_argument(
"--replay-start-size",
type=int,
default=5 * 10 ** 4,
help="Minimum replay buffer size before " + "performing gradient updates.",
)
parser.add_argument(
"--target-update-interval",
type=int,
default=8 * 10 ** 3,
help="Frequency (in timesteps) at which " + "the target network is updated.",
)
parser.add_argument(
"--eval-interval",
type=int,
default=10 ** 5,
help="Frequency (in timesteps) of evaluation phase.",
)
parser.add_argument(
"--update-interval",
type=int,
default=4,
help="Frequency (in timesteps) of network updates.",
)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--eval-n-runs", type=int, default=10)
parser.add_argument("--no-clip-delta", dest="clip_delta", action="store_false")
parser.add_argument("--num-step-return", type=int, default=1)
parser.set_defaults(clip_delta=True)
parser.add_argument(
"--log-level",
type=int,
default=20,
help="Logging level. 10:DEBUG, 20:INFO etc.",
)
parser.add_argument(
"--render",
action="store_true",
default=False,
help="Render env states in a GUI window.",
)
parser.add_argument(
"--monitor",
action="store_true",
default=False,
help=(
"Monitor env. Videos and additional information are saved as output files."
),
)
parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate.")
parser.add_argument(
"--checkpoint-frequency",
type=int,
default=None,
help="Frequency at which agents are stored.",
)
args = parser.parse_args()

import logging

logging.basicConfig(level=args.log_level)

# Set a random seed used in PFRL.
utils.set_random_seed(args.seed)

# Set different random seeds for train and test envs.
train_seed = args.seed
test_seed = 2 ** 31 - 1 - args.seed

args.outdir = experiments.prepare_output_dir(args, args.outdir)
print("Output files are saved in {}".format(args.outdir))

def make_env(test):
# Use different random seeds for train and test envs
env_seed = test_seed if test else train_seed
env = atari_wrappers.wrap_deepmind(
atari_wrappers.make_atari(args.env, max_frames=args.max_frames),
episode_life=not test,
clip_rewards=not test,
)
env.seed(int(env_seed))
if test:
# Randomize actions like epsilon-greedy in evaluation as well
env = pfrl.wrappers.RandomizeAction(env, args.eval_epsilon)
if args.monitor:
env = pfrl.wrappers.Monitor(
env, args.outdir, mode="evaluation" if test else "training"
)
if args.render:
env = pfrl.wrappers.Render(env)
return env

env = make_env(test=False)
eval_env = make_env(test=True)

n_actions = env.action_space.n
q_func = nn.Sequential(
pnn.LargeAtariCNN(),
init_chainer_default(nn.Linear(512, n_actions)),
DiscreteActionValueHead(),
)

explorer = explorers.LinearDecayEpsilonGreedy(
1.0,
args.final_epsilon,
args.final_exploration_frames,
lambda: np.random.randint(n_actions),
)

opt = optim.Adam(q_func.parameters(), lr=args.lr, eps=1e-2 / args.batch_size)

rbuf = replay_buffers.ReplayBuffer(10 ** 6, args.num_step_return)

def phi(x):
# Feature extractor
return np.asarray(x, dtype=np.float32) / 255

agent = MDQN(
q_func,
opt,
rbuf,
gpu=args.gpu,
gamma=0.99,
explorer=explorer,
replay_start_size=args.replay_start_size,
minibatch_size=args.batch_size,
update_interval=args.update_interval,
target_update_interval=args.target_update_interval,
clip_delta=args.clip_delta,
batch_accumulator="sum",
phi=phi,
)

if args.load:
agent.load(args.load)

if args.demo:
eval_stats = experiments.eval_performance(
env=eval_env, agent=agent, n_steps=None, n_episodes=args.eval_n_runs
)
print(
"n_runs: {} mean: {} median: {} stdev {}".format(
args.eval_n_runs,
eval_stats["mean"],
eval_stats["median"],
eval_stats["stdev"],
)
)
else:
experiments.train_agent_with_evaluation(
agent=agent,
env=env,
steps=args.steps,
eval_n_steps=None,
checkpoint_freq=args.checkpoint_frequency,
eval_n_episodes=args.eval_n_runs,
eval_interval=args.eval_interval,
outdir=args.outdir,
save_best_so_far_agent=False,
eval_env=eval_env,
)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions pfrl/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pfrl.agents.dpp import DPP # NOQA
from pfrl.agents.dqn import DQN # NOQA
from pfrl.agents.iqn import IQN # NOQA
from pfrl.agents.mdqn import MDQN # NOQA
from pfrl.agents.pal import PAL # NOQA
from pfrl.agents.ppo import PPO # NOQA
from pfrl.agents.reinforce import REINFORCE # NOQA
Expand Down
Loading