-
Notifications
You must be signed in to change notification settings - Fork 0
/
learn_and_save_model.py
58 lines (47 loc) · 1.57 KB
/
learn_and_save_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gym
import tensorflow as tf
import os
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from baselines.run import get_learn_function, get_env_type, get_learn_function_defaults, build_env
from types import SimpleNamespace
# Write all the arguments into a dictionary that we can references e.g. args.env
args_dict={
'alg': 'ppo2',
'total_timesteps': 10000,
'seed': 0,
'env': 'BipedalWalker-v2',
'network': 'mlp',
'num_env': 1,
'reward_scale': 1,
'flatten_dict_observations': True,
'save_interval': 1,
'num_epochs': 1000,
'steps_per_update': 1000,
'log_interval': 1,
'save_path':'/Users/thomascartwright/Documents/Development/mlp/mlpgroup009/'
}
args = SimpleNamespace(**args_dict)
env_type, env_id = get_env_type(args.env)
learn = get_learn_function(args.alg)
alg_kwargs = get_learn_function_defaults(args.alg, env_type)
env = build_env(args)
alg_kwargs['network'] = args.network
# The path we will store the results of this experiment
full_path = args.save_path + '/' + args.env + '-' + args.alg
# Make folders that we will store the checkpoints, models and epoch results
if not os.path.exists(full_path):
os.makedirs(full_path)
os.makedirs(full_path + '/checkpoints')
model = learn(
env=env,
seed=args.seed,
total_timesteps=args.total_timesteps,
save_interval=args.save_interval,
noptepochs = args.num_epochs,
nsteps = args.steps_per_update,
log_interval = args.log_interval,
save_path = full_path,
**alg_kwargs
)
# Save the model and variables
model.save(full_path + '/final')