forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsb2rllib_sb_example.py
40 lines (34 loc) · 1.13 KB
/
sb2rllib_sb_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""
Example script on how to train, save, load, and test a stable baselines 2 agent
Code taken and adjusted from SB2 docs:
https://stable-baselines.readthedocs.io/en/master/guide/quickstart.html
Equivalent script with RLlib: sb2rllib_rllib_example.py
"""
import gym
from stable_baselines.common.policies import MlpPolicy
from stable_baselines import PPO2
# settings used for both stable baselines and rllib
env_name = "CartPole-v1"
train_steps = 10000
learning_rate = 1e-3
save_dir = "saved_models"
save_path = f"{save_dir}/sb_model_{train_steps}steps"
env = gym.make(env_name)
# training and saving
model = PPO2(MlpPolicy, env, learning_rate=learning_rate, verbose=1)
model.learn(total_timesteps=train_steps)
model.save(save_path)
print(f"Trained model saved at {save_path}")
# delete and load model (just for illustration)
del model
model = PPO2.load(save_path)
print(f"Agent loaded from saved model at {save_path}")
# inference
obs = env.reset()
for i in range(1000):
action, _states = model.predict(obs)
obs, reward, done, info = env.step(action)
env.render()
if done:
print(f"Cart pole dropped after {i} steps.")
break