-
Notifications
You must be signed in to change notification settings - Fork 9
/
train.py
85 lines (63 loc) · 2.21 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#!/usr/bin/env python
# encoding: utf-8
from SupplyChain_gym.envs.InventoryEnvFile import InventoryEnv
from ray.tune.registry import register_env
import gym
import os
import ray
import ray.rllib.agents.ppo as ppo
import shutil
def main ():
# init directory in which to save checkpoints
chkpt_root = "tmp/exa"
shutil.rmtree(chkpt_root, ignore_errors=True, onerror=None)
# init directory in which to log results
ray_results = "{}/ray_results/".format(os.getenv("HOME"))
shutil.rmtree(ray_results, ignore_errors=True, onerror=None)
# start Ray -- add `local_mode=True` here for debugging
ray.init(ignore_reinit_error=True)
# register the custom environment
select_env = "example-v0"
#select_env = "fail-v1"
register_env(select_env, lambda config: Example_v0())
#register_env(select_env, lambda config: Fail_v1())
# configure the environment and create agent
config = ppo.DEFAULT_CONFIG.copy()
config["log_level"] = "WARN"
agent = ppo.PPOTrainer(config, env=select_env)
status = "{:2d} reward {:6.2f}/{:6.2f}/{:6.2f} len {:4.2f} saved {}"
n_iter = 5
# train a policy with RLlib using PPO
for n in range(n_iter):
result = agent.train()
chkpt_file = agent.save(chkpt_root)
print(status.format(
n + 1,
result["episode_reward_min"],
result["episode_reward_mean"],
result["episode_reward_max"],
result["episode_len_mean"],
chkpt_file
))
# examine the trained policy
policy = agent.get_policy()
model = policy.model
print(model.base_model.summary())
# apply the trained policy in a rollout
agent.restore(chkpt_file)
env = gym.make(select_env)
state = env.reset()
sum_reward = 0
n_step = 20
for step in range(n_step):
action = agent.compute_action(state)
state, reward, done, info = env.step(action)
sum_reward += reward
env.render()
if done == 1:
# report at the end of each episode
print("cumulative reward", sum_reward)
state = env.reset()
sum_reward = 0
if __name__ == "__main__":
main()