-
Notifications
You must be signed in to change notification settings - Fork 25
/
train_finetuning.py
229 lines (188 loc) · 7.45 KB
/
train_finetuning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
#! /usr/bin/env python
import os
import pickle
import d4rl
import d4rl.gym_mujoco
import d4rl.locomotion
import dmcgym
import gym
import numpy as np
import tqdm
from absl import app, flags
try:
from flax.training import checkpoints
except:
print("Not loading checkpointing functionality.")
from ml_collections import config_flags
import wandb
from rlpd.agents import SACLearner
from rlpd.data import ReplayBuffer
from rlpd.data.d4rl_datasets import D4RLDataset
try:
from rlpd.data.binary_datasets import BinaryDataset
except:
print("not importing binary dataset")
from rlpd.evaluation import evaluate
from rlpd.wrappers import wrap_gym
FLAGS = flags.FLAGS
flags.DEFINE_string("project_name", "rlpd", "wandb project name.")
flags.DEFINE_string("env_name", "halfcheetah-expert-v2", "D4rl dataset name.")
flags.DEFINE_float("offline_ratio", 0.5, "Offline ratio.")
flags.DEFINE_integer("seed", 42, "Random seed.")
flags.DEFINE_integer("eval_episodes", 10, "Number of episodes used for evaluation.")
flags.DEFINE_integer("log_interval", 1000, "Logging interval.")
flags.DEFINE_integer("eval_interval", 5000, "Eval interval.")
flags.DEFINE_integer("batch_size", 256, "Mini batch size.")
flags.DEFINE_integer("max_steps", int(1e6), "Number of training steps.")
flags.DEFINE_integer(
"start_training", int(1e4), "Number of training steps to start training."
)
flags.DEFINE_integer("pretrain_steps", 0, "Number of offline updates.")
flags.DEFINE_boolean("tqdm", True, "Use tqdm progress bar.")
flags.DEFINE_boolean("save_video", False, "Save videos during evaluation.")
flags.DEFINE_boolean("checkpoint_model", False, "Save agent checkpoint on evaluation.")
flags.DEFINE_boolean(
"checkpoint_buffer", False, "Save agent replay buffer on evaluation."
)
flags.DEFINE_integer("utd_ratio", 1, "Update to data ratio.")
flags.DEFINE_boolean(
"binary_include_bc", True, "Whether to include BC data in the binary datasets."
)
config_flags.DEFINE_config_file(
"config",
"configs/sac_config.py",
"File path to the training hyperparameter configuration.",
lock_config=False,
)
def combine(one_dict, other_dict):
combined = {}
for k, v in one_dict.items():
if isinstance(v, dict):
combined[k] = combine(v, other_dict[k])
else:
tmp = np.empty(
(v.shape[0] + other_dict[k].shape[0], *v.shape[1:]), dtype=v.dtype
)
tmp[0::2] = v
tmp[1::2] = other_dict[k]
combined[k] = tmp
return combined
def main(_):
assert FLAGS.offline_ratio >= 0.0 and FLAGS.offline_ratio <= 1.0
wandb.init(project=FLAGS.project_name)
wandb.config.update(FLAGS)
exp_prefix = f"s{FLAGS.seed}_{FLAGS.pretrain_steps}pretrain"
if hasattr(FLAGS.config, "critic_layer_norm") and FLAGS.config.critic_layer_norm:
exp_prefix += "_LN"
log_dir = os.path.join(FLAGS.log_dir, exp_prefix)
if FLAGS.checkpoint_model:
chkpt_dir = os.path.join(log_dir, "checkpoints")
os.makedirs(chkpt_dir, exist_ok=True)
if FLAGS.checkpoint_buffer:
buffer_dir = os.path.join(log_dir, "buffers")
os.makedirs(buffer_dir, exist_ok=True)
env = gym.make(FLAGS.env_name)
env = wrap_gym(env, rescale_actions=True)
env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=1)
env.seed(FLAGS.seed)
# not ideal, but works for now:
if "binary" in FLAGS.env_name:
ds = BinaryDataset(env, include_bc_data=FLAGS.binary_include_bc)
else:
ds = D4RLDataset(env)
eval_env = gym.make(FLAGS.env_name)
eval_env = wrap_gym(eval_env, rescale_actions=True)
eval_env.seed(FLAGS.seed + 42)
kwargs = dict(FLAGS.config)
model_cls = kwargs.pop("model_cls")
agent = globals()[model_cls].create(
FLAGS.seed, env.observation_space, env.action_space, **kwargs
)
replay_buffer = ReplayBuffer(
env.observation_space, env.action_space, FLAGS.max_steps
)
replay_buffer.seed(FLAGS.seed)
for i in tqdm.tqdm(
range(0, FLAGS.pretrain_steps), smoothing=0.1, disable=not FLAGS.tqdm
):
offline_batch = ds.sample(FLAGS.batch_size * FLAGS.utd_ratio)
batch = {}
for k, v in offline_batch.items():
batch[k] = v
if "antmaze" in FLAGS.env_name and k == "rewards":
batch[k] -= 1
agent, update_info = agent.update(batch, FLAGS.utd_ratio)
if i % FLAGS.log_interval == 0:
for k, v in update_info.items():
wandb.log({f"offline-training/{k}": v}, step=i)
if i % FLAGS.eval_interval == 0:
eval_info = evaluate(agent, eval_env, num_episodes=FLAGS.eval_episodes)
for k, v in eval_info.items():
wandb.log({f"offline-evaluation/{k}": v}, step=i)
observation, done = env.reset(), False
for i in tqdm.tqdm(
range(0, FLAGS.max_steps + 1), smoothing=0.1, disable=not FLAGS.tqdm
):
if i < FLAGS.start_training:
action = env.action_space.sample()
else:
action, agent = agent.sample_actions(observation)
next_observation, reward, done, info = env.step(action)
if not done or "TimeLimit.truncated" in info:
mask = 1.0
else:
mask = 0.0
replay_buffer.insert(
dict(
observations=observation,
actions=action,
rewards=reward,
masks=mask,
dones=done,
next_observations=next_observation,
)
)
observation = next_observation
if done:
observation, done = env.reset(), False
for k, v in info["episode"].items():
decode = {"r": "return", "l": "length", "t": "time"}
wandb.log({f"training/{decode[k]}": v}, step=i + FLAGS.pretrain_steps)
if i >= FLAGS.start_training:
online_batch = replay_buffer.sample(
int(FLAGS.batch_size * FLAGS.utd_ratio * (1 - FLAGS.offline_ratio))
)
offline_batch = ds.sample(
int(FLAGS.batch_size * FLAGS.utd_ratio * FLAGS.offline_ratio)
)
batch = combine(offline_batch, online_batch)
if "antmaze" in FLAGS.env_name:
batch["rewards"] -= 1
agent, update_info = agent.update(batch, FLAGS.utd_ratio)
if i % FLAGS.log_interval == 0:
for k, v in update_info.items():
wandb.log({f"training/{k}": v}, step=i + FLAGS.pretrain_steps)
if i % FLAGS.eval_interval == 0:
eval_info = evaluate(
agent,
eval_env,
num_episodes=FLAGS.eval_episodes,
save_video=FLAGS.save_video,
)
for k, v in eval_info.items():
wandb.log({f"evaluation/{k}": v}, step=i + FLAGS.pretrain_steps)
if FLAGS.checkpoint_model:
try:
checkpoints.save_checkpoint(
chkpt_dir, agent, step=i, keep=20, overwrite=True
)
except:
print("Could not save model checkpoint.")
if FLAGS.checkpoint_buffer:
try:
with open(os.path.join(buffer_dir, f"buffer"), "wb") as f:
pickle.dump(replay_buffer, f, pickle.HIGHEST_PROTOCOL)
except:
print("Could not save agent buffer.")
if __name__ == "__main__":
app.run(main)