Skip to content

Add examples to use custom env, agent and learner #323

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ format:
test:
black . --check
isort . --check --diff --skip checkpoint --skip wandb --skip data
env PYTHONPATH=. pytest --pylint --flake8 --cov=tests --ignore=checkpoint --ignore=data --ignore=wandb --ignore tests/integration
env PYTHONPATH=. pytest --pylint --flake8 --cov=tests --ignore=checkpoint --ignore=data --ignore=wandb --ignore tests/integration --ignore example

integration-test:
env PYTHONPATH=. pytest tests/integration --cov=tests
Expand Down
Empty file added example/__init__.py
Empty file.
34 changes: 34 additions & 0 deletions example/custom_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
"""Custom Agent for DQN.
This is example for using custom agent.
In this example, custom agent use state as exponential.
You can customize any function e.g) select_aciton, train ... etc.

To use custom agent just decorate class to build and import in main function.

- Author: Jiseong Han
- Contact: jisung.han@medipixel.io
"""

import numpy as np
import torch

from rl_algorithms.common.helper_functions import numpy2floattensor
from rl_algorithms.dqn.agent import DQNAgent
from rl_algorithms.registry import AGENTS


@AGENTS.register_module
class CustomDQN(DQNAgent):
"""Example Custom Agent for DQN"""

# pylint: disable=no-self-use
def _preprocess_state(self, state: np.ndarray) -> torch.Tensor:
"""Preprocess state so that actor selects an action."""
state = np.exp(state)
state = numpy2floattensor(state, self.learner.device)
return state

def train(self):
"""Custom train."""
pass
38 changes: 38 additions & 0 deletions example/custom_dqn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
type: "CustomDQN" #Custom Agent name
hyper_params:
gamma: 0.99
tau: 0.005
buffer_size: 10000 # openai baselines: 10000
batch_size: 64 # openai baselines: 32
update_starts_from: 100 # openai baselines: 10000
multiple_update: 1 # multiple learning updates
train_freq: 1 # in openai baselines, train_freq = 4
gradient_clip: 10.0 # dueling: 10.0
n_step: 3
w_n_step: 1.0
w_q_reg: 0.0000001
per_alpha: 0.6 # openai baselines: 0.6
per_beta: 0.4
per_eps: 0.000001
max_epsilon: 1.0
min_epsilon: 0.01 # openai baselines: 0.01
epsilon_decay: 0.00001 # openai baselines: 1e-7 / 1e-1

learner_cfg:
type: "CustomDQNLearner" #Custom Learner name
loss_type:
type: "C51Loss"
backbone:
head:
type: "C51DuelingMLP"
configs:
hidden_sizes: [128, 64]
v_min: -300
v_max: 300
atom_size: 1530
output_activation: "identity"
use_noisy_net: False
optim_cfg:
lr_dqn: 0.0001
weight_decay: 0.0000001
adam_eps: 0.00000001
34 changes: 34 additions & 0 deletions example/custom_learner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""This is example to use custom learner that inherit DQNLearner.
You need to decorate class to register your own Learner to build.
And import custom learner on main file.

If you want to make custom learner, you can inherit BaseLeaner or Learner.
If you make your own learner, you need to change config file to build.

- Author: Jiseong Han
- Contact: jisung.han@medipixel.io
"""
from typing import Tuple, Union

import numpy as np
import torch

from rl_algorithms.common.abstract.learner import TensorTuple
from rl_algorithms.dqn.learner import DQNLearner
from rl_algorithms.registry import LEARNERS


@LEARNERS.register_module
class CustomDQNLearner(DQNLearner):
"""Example of Custom DQN learner."""

def _init_network(self):
return super()._init_network()

def update_model(
self, experience: Union[TensorTuple, Tuple[TensorTuple]]
) -> Tuple[torch.Tensor, torch.Tensor, list, np.ndarray]: # type: ignore
"""
Custom Update model with experience.
"""
pass
179 changes: 179 additions & 0 deletions example/run_custom_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# -*- coding: utf-8 -*-
"""Train or test algorithms on Custom Environment.

- Author: Jiseong Han
- Contact: jisung.han@medipixel.io
"""

import argparse
import datetime

import gym
import numpy as np

from rl_algorithms import build_agent
import rl_algorithms.common.env.utils as env_utils
import rl_algorithms.common.helper_functions as common_utils
from rl_algorithms.utils import YamlConfig


def parse_args() -> argparse.Namespace:
# configurations
parser = argparse.ArgumentParser(description="Pytorch RL algorithms")
parser.add_argument(
"--seed", type=int, default=777, help="random seed for reproducibility"
)
parser.add_argument(
"--integration-test",
dest="integration_test",
action="store_true",
help="for integration test",
)
parser.add_argument(
"--cfg-path",
type=str,
default="rl_algorithms/example/custom_dqn.yaml",
help="config path",
)
parser.add_argument(
"--test", dest="test", action="store_true", help="test mode (no training)"
)
parser.add_argument(
"--load-from",
type=str,
default=None,
help="load the saved model and optimizer at the beginning",
)
parser.add_argument(
"--off-render", dest="render", action="store_false", help="turn off rendering"
)
parser.add_argument(
"--render-after",
type=int,
default=0,
help="start rendering after the input number of episode",
)
parser.add_argument(
"--log", dest="log", action="store_true", help="turn on logging"
)
parser.add_argument(
"--save-period", type=int, default=100, help="save model period"
)
parser.add_argument(
"--episode-num", type=int, default=1500, help="total episode num"
)
parser.add_argument(
"--max-episode-steps", type=int, default=300, help="max episode step"
)
parser.add_argument(
"--interim-test-num",
type=int,
default=10,
help="number of test during training",
)

return parser.parse_args()


class CustomEnv(gym.Env):
"""Custom Environment for example."""

metadata = {"render.modes": ["human"]}

def __init__(self):
super(CustomEnv, self).__init__()
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Box(low=-3, high=3, shape=(1,))
self.pos = 0

def step(self, action):
"""
Reach Position as 3 get +1 reward.
else if Position is lower than then -1 reward.
else get -0.1.
"""
action = -1 if action == 0 else 1
self.pos += action
if self.pos <= -3:
reward = -1
elif self.pos >= 3:
reward = 1
else:
reward = -0.1
done = abs(self.pos) >= 3

return np.array([self.pos]), reward, done, {}

def reset(self):
self.pos = 0
return np.array([self.pos])

def render(self, mode="human"):
render_state = [[] for _ in range(7)]
render_state[self.pos + 3] = [0]
print(
"################################\n",
render_state,
"\n################################",
)


def main(custom_env):
"""Main."""
args = parse_args()

env_name = type(custom_env).__name__
env, max_episode_steps = env_utils.set_env(custom_env, args.max_episode_steps)

# set a random seed
common_utils.set_random_seed(args.seed, env)

# run
NOWTIMES = datetime.datetime.now()
curr_time = NOWTIMES.strftime("%y%m%d_%H%M%S")

cfg = YamlConfig(dict(agent=args.cfg_path)).get_config_dict()

# If running integration test, simplify experiment
if args.integration_test:
cfg = common_utils.set_cfg_for_intergration_test(cfg)

env_info = dict(
name=env_name,
observation_space=env.observation_space,
action_space=env.action_space,
is_atari=False,
)
log_cfg = dict(agent=cfg.agent.type, curr_time=curr_time, cfg_path=args.cfg_path)
build_args = dict(
env=env,
env_info=env_info,
log_cfg=log_cfg,
is_test=args.test,
load_from=args.load_from,
is_render=args.render,
render_after=args.render_after,
is_log=args.log,
save_period=args.save_period,
episode_num=args.episode_num,
max_episode_steps=max_episode_steps,
interim_test_num=args.interim_test_num,
)
agent = build_agent(cfg.agent, build_args)

if not args.test:
agent.train()
else:
agent.test()


if __name__ == "__main__":
###################################################################################
# To use custom agent and learner, import custom agent and learner.
from custom_agent import CustomDQN # noqa: F401
from custom_learner import CustomDQNLearner # noqa: F401

# Declare custom environment here.
env = CustomEnv()
###################################################################################
main(env)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy==1.18.0
torch==1.6.0
torch>=1.6.0
gym==0.17.3
atari-py==0.2.6
box2d-py==2.3.8
Expand Down
2 changes: 1 addition & 1 deletion rl_algorithms/common/abstract/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _test(self, interim_test: bool = False):
step += 1

print(
"[INFO] test %d\tstep: %d\ttotal score: %d" % (i_episode, step, score)
"[INFO] test %d\tstep: %d\ttotal score: %.2f" % (i_episode, step, score)
)
score_list.append(score)

Expand Down