Skip to content

Commit 7686169

Browse files
committed
push sampling to datasets directory when done and option selected in config
1 parent 3d9080e commit 7686169

File tree

2 files changed

+64
-46
lines changed

2 files changed

+64
-46
lines changed

configs/sampling.yaml

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ obs_type: states # [states, pixels]
1010
frame_stack: 1 # only works if obs_type=pixels
1111
action_repeat: 1 # set to 2 for pixels
1212
discount: 0.99
13-
skill_dim: 500
13+
skill_dim: 10
1414
# train settings
1515
num_train_frames: 2000010
1616
num_seed_frames: 4000
@@ -32,8 +32,9 @@ update_encoder: false # can be either true or false depending if we want to fine
3232
# misc
3333
seed: 1
3434
device: cuda
35-
save_video: true
35+
save_video: false
3636
save_train_video: false
37+
save_to_data: true
3738
use_tb: false
3839
use_wandb: false
3940
# experiment

sampling.py

+61-44
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
warnings.filterwarnings('ignore', category=DeprecationWarning)
44

55
import os
6+
import shutil
67
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
78
os.environ['MUJOCO_GL'] = 'egl'
89

@@ -134,50 +135,50 @@ def sample(self):
134135
meta = self.agent.init_meta()
135136

136137
self.replay_storage.add(time_step, meta)
137-
while train_until_step(self.global_step):
138-
while sample_until_step(episode):
139-
time_step = self.sample_env.reset()
140-
if self.cfg.agent.name not in self.prior_encoded_agents:
141-
# Update agent if not in the reward
142-
meta = self.agent.update_meta(meta, self.global_step, time_step)
143-
self.sample_env._env._env._env._env.environment.reset(random_start=False)
144-
else:
145-
meta = self.agent.init_meta()
146-
self.replay_storage.add(time_step, meta)
147-
self.video_recorder.init(self.sample_env, enabled=True)
148-
trajectory = []
149-
while not time_step.last():
150-
with torch.no_grad(), utils.eval_mode(self.agent):
151-
action = self.agent.act(time_step.observation,
152-
meta,
153-
self.global_step,
154-
eval_mode=True)
155-
time_step = self.sample_env.step(action)
156-
self.video_recorder.record(self.sample_env)
157-
total_reward += time_step.reward
158-
trajectory.append(time_step)
159-
step += 1
160-
self._global_step += 1
161-
162-
if self.cfg.data_type == 'unsupervised':
163-
# TODO: Provide a less hacky way of accessing info from environment
164-
info = self.sample_env._env._env._env._env.get_info()
165-
if self.meta_encoded:
166-
unsupervised_data = {'meta': meta, 'constraint': info['constraint'], 'done': info['done']}
167-
else:
168-
unsupervised_data = {'constraint': info['constraint'], 'done': info['done']}
169-
self.replay_storage.add(time_step, unsupervised_data)
138+
while sample_until_step(episode):
139+
time_step = self.sample_env.reset()
140+
if self.cfg.agent.name not in self.prior_encoded_agents:
141+
# Update agent if not in the reward
142+
meta = self.agent.update_meta(meta, self.global_step, time_step)
143+
self.sample_env._env._env._env._env.environment.reset(random_start=False)
144+
else:
145+
meta = self.agent.init_meta()
146+
self.replay_storage.add(time_step, meta)
147+
self.video_recorder.init(self.sample_env, enabled=True)
148+
trajectory = []
149+
while not time_step.last():
150+
with torch.no_grad(), utils.eval_mode(self.agent):
151+
action = self.agent.act(time_step.observation,
152+
meta,
153+
self.global_step,
154+
eval_mode=True)
155+
time_step = self.sample_env.step(action)
156+
self.video_recorder.record(self.sample_env)
157+
total_reward += time_step.reward
158+
trajectory.append(time_step)
159+
step += 1
160+
self._global_step += 1
161+
162+
if self.cfg.data_type == 'unsupervised':
163+
# TODO: Provide a less hacky way of accessing info from environment
164+
info = self.sample_env._env._env._env._env.get_info()
165+
if self.meta_encoded:
166+
unsupervised_data = {'meta': meta, 'constraint': info['constraint'], 'done': info['done']}
170167
else:
171-
self.replay_storage.add(time_step, meta)
172-
173-
episode += 1
174-
# skill_index = str(meta['skill'])
175-
if not seed_until_step(self.global_step):
176-
if self.cfg.agent.name not in self.prior_encoded_agents:
177-
batch = next(self.replay_iter)
178-
algo_batch = batch[0:5]
179-
algo_iter = iter([algo_batch])
180-
self.agent.update(algo_iter, self.global_step)
168+
unsupervised_data = {'constraint': info['constraint'], 'done': info['done']}
169+
self.replay_storage.add(time_step, unsupervised_data)
170+
else:
171+
self.replay_storage.add(time_step, meta)
172+
173+
episode += 1
174+
# skill_index = str(meta['skill'])
175+
if not seed_until_step(self.global_step):
176+
if self.cfg.agent.name not in self.prior_encoded_agents:
177+
batch = next(self.replay_iter)
178+
algo_batch = batch[0:5]
179+
algo_iter = iter([algo_batch])
180+
self.agent.update(algo_iter, self.global_step)
181+
if self.cfg.save_video:
181182
self.video_recorder.save(f'{episode}.mp4')
182183

183184
with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log:
@@ -189,14 +190,30 @@ def sample(self):
189190
# Store data in values
190191
buffer_path = os.path.join(self.work_dir, 'buffer')
191192
os.rename(buffer_path, f'{self.cfg.agent.name}_{self.cfg.snapshot_ts}')
193+
if self.cfg.save_to_data:
194+
domain, _ = self.cfg.task.split('_', 1)
195+
if self.cfg.agent.name == 'diayn':
196+
target_path = f'./../../../data/datasets/{self.cfg.obs_type}/{domain}/{self.cfg.agent.name}/{self.cfg.skill_dim}/{self.cfg.seed}/{self.cfg.snapshot_ts}'
197+
else:
198+
target_path = f'./../../../data/datasets/{self.cfg.obs_type}/{domain}/{self.agent.name}/{self.cfg.seed}/{self.cfg.snapsot_ts}'
199+
if not os.path.exists(target_path):
200+
os.makedirs(target_path)
201+
source_path = os.path.join(self.work_dir, f'{self.cfg.agent.name}_{self.cfg.snapshot_ts}')
202+
print(f'beginning to move: {source_path} -> {target_path}')
203+
for file_name in os.listdir(source_path):
204+
source = os.path.join(source_path, file_name)
205+
target = os.path.join(target_path, file_name)
206+
if os.path.isfile(source):
207+
shutil.move(source, target)
208+
192209

193210
def load_snapshot(self):
194211
snapshot_base_dir = Path(self.cfg.snapshot_base_dir)
195212
domain, _ = self.cfg.task.split('_', 1)
196213
snapshot_dir = snapshot_base_dir / self.cfg.obs_type / domain / self.cfg.agent.name
197214

198215
def try_load(seed):
199-
if self.cfg.agent == 'diayn':
216+
if self.cfg.agent.name == 'diayn':
200217
snapshot = snapshot_dir / f'{self.cfg.skill_dim}' / str(seed) / f'snapshot_{self.cfg.snapshot_ts}.pt'
201218
else:
202219
snapshot = snapshot_dir / str(seed) / f'snapshot_{self.cfg.snapshot_ts}.pt'

0 commit comments

Comments
 (0)