-
Notifications
You must be signed in to change notification settings - Fork 0
/
base-bip.py
76 lines (64 loc) · 2.09 KB
/
base-bip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gym
import tensorflow as tf
import os
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from baselines.run import get_learn_function, get_env_type, get_learn_function_defaults, build_env
import time
from types import SimpleNamespace
# Set the save path for this model
save_path = os.path.basename(__file__) + '.' + str(time.time()).replace('.', '')[-6:]
# The model path we want to load from
model_load_path = ''
# Write all the arguments into a dictionary that we can references e.g. args.env
args_dict={
'alg': 'ppo2',
'env': 'BipedalWalker-v2',
'network': 'mlp',
'learning_rate': 0.001,
'discount_factor':0.99,
'nminibatches': 64,
'cliprange': 0.2,
'total_timesteps': 1e6,
'num_env': 1,
'nsteps': 2048,
'noptepochs': 10,
'save_interval': 20,
'log_interval': 1,
'save_path': save_path,
'model_load_path': model_load_path,
'seed': 0,
'reward_scale': 1,
'flatten_dict_observations': True,
'transfer_weights': False
}
args = SimpleNamespace(**args_dict)
# Prepare the environment and learning algorithm
env_type, env_id = get_env_type(args.env)
learn = get_learn_function(args.alg)
alg_kwargs = get_learn_function_defaults(args.alg, env_type)
env = build_env(args)
alg_kwargs['network'] = args.network
# The path we will store the results of this experiment
full_path = args.save_path + '/' + args.env + '-' + args.alg
# Make folders that we will store the checkpoints, models and epoch results
if not os.path.exists(full_path):
os.makedirs(full_path)
os.makedirs(full_path + '/checkpoints')
print("About to start learning model")
model = learn(
env=env,
seed=args.seed,
total_timesteps=args.total_timesteps,
save_interval=args.save_interval,
lr=args.learning_rate,
noptepochs = args.noptepochs,
nsteps = args.nsteps,
log_interval = args.log_interval,
save_path = full_path,
model_load_path = args.model_load_path,
transfer_weights = args.transfer_weights,
**alg_kwargs
)
# Save the model and variables
print("Attempting to save model")
model.save(full_path + '/final')