3
3
warnings .filterwarnings ('ignore' , category = DeprecationWarning )
4
4
5
5
import os
6
+ import shutil
6
7
os .environ ['MKL_SERVICE_FORCE_INTEL' ] = '1'
7
8
os .environ ['MUJOCO_GL' ] = 'egl'
8
9
@@ -134,50 +135,50 @@ def sample(self):
134
135
meta = self .agent .init_meta ()
135
136
136
137
self .replay_storage .add (time_step , meta )
137
- while train_until_step (self .global_step ):
138
- while sample_until_step (episode ):
139
- time_step = self .sample_env .reset ()
140
- if self .cfg .agent .name not in self .prior_encoded_agents :
141
- # Update agent if not in the reward
142
- meta = self .agent .update_meta (meta , self .global_step , time_step )
143
- self .sample_env ._env ._env ._env ._env .environment .reset (random_start = False )
144
- else :
145
- meta = self .agent .init_meta ()
146
- self .replay_storage .add (time_step , meta )
147
- self .video_recorder .init (self .sample_env , enabled = True )
148
- trajectory = []
149
- while not time_step .last ():
150
- with torch .no_grad (), utils .eval_mode (self .agent ):
151
- action = self .agent .act (time_step .observation ,
152
- meta ,
153
- self .global_step ,
154
- eval_mode = True )
155
- time_step = self .sample_env .step (action )
156
- self .video_recorder .record (self .sample_env )
157
- total_reward += time_step .reward
158
- trajectory .append (time_step )
159
- step += 1
160
- self ._global_step += 1
161
-
162
- if self .cfg .data_type == 'unsupervised' :
163
- # TODO: Provide a less hacky way of accessing info from environment
164
- info = self .sample_env ._env ._env ._env ._env .get_info ()
165
- if self .meta_encoded :
166
- unsupervised_data = {'meta' : meta , 'constraint' : info ['constraint' ], 'done' : info ['done' ]}
167
- else :
168
- unsupervised_data = {'constraint' : info ['constraint' ], 'done' : info ['done' ]}
169
- self .replay_storage .add (time_step , unsupervised_data )
138
+ while sample_until_step (episode ):
139
+ time_step = self .sample_env .reset ()
140
+ if self .cfg .agent .name not in self .prior_encoded_agents :
141
+ # Update agent if not in the reward
142
+ meta = self .agent .update_meta (meta , self .global_step , time_step )
143
+ self .sample_env ._env ._env ._env ._env .environment .reset (random_start = False )
144
+ else :
145
+ meta = self .agent .init_meta ()
146
+ self .replay_storage .add (time_step , meta )
147
+ self .video_recorder .init (self .sample_env , enabled = True )
148
+ trajectory = []
149
+ while not time_step .last ():
150
+ with torch .no_grad (), utils .eval_mode (self .agent ):
151
+ action = self .agent .act (time_step .observation ,
152
+ meta ,
153
+ self .global_step ,
154
+ eval_mode = True )
155
+ time_step = self .sample_env .step (action )
156
+ self .video_recorder .record (self .sample_env )
157
+ total_reward += time_step .reward
158
+ trajectory .append (time_step )
159
+ step += 1
160
+ self ._global_step += 1
161
+
162
+ if self .cfg .data_type == 'unsupervised' :
163
+ # TODO: Provide a less hacky way of accessing info from environment
164
+ info = self .sample_env ._env ._env ._env ._env .get_info ()
165
+ if self .meta_encoded :
166
+ unsupervised_data = {'meta' : meta , 'constraint' : info ['constraint' ], 'done' : info ['done' ]}
170
167
else :
171
- self .replay_storage .add (time_step , meta )
172
-
173
- episode += 1
174
- # skill_index = str(meta['skill'])
175
- if not seed_until_step (self .global_step ):
176
- if self .cfg .agent .name not in self .prior_encoded_agents :
177
- batch = next (self .replay_iter )
178
- algo_batch = batch [0 :5 ]
179
- algo_iter = iter ([algo_batch ])
180
- self .agent .update (algo_iter , self .global_step )
168
+ unsupervised_data = {'constraint' : info ['constraint' ], 'done' : info ['done' ]}
169
+ self .replay_storage .add (time_step , unsupervised_data )
170
+ else :
171
+ self .replay_storage .add (time_step , meta )
172
+
173
+ episode += 1
174
+ # skill_index = str(meta['skill'])
175
+ if not seed_until_step (self .global_step ):
176
+ if self .cfg .agent .name not in self .prior_encoded_agents :
177
+ batch = next (self .replay_iter )
178
+ algo_batch = batch [0 :5 ]
179
+ algo_iter = iter ([algo_batch ])
180
+ self .agent .update (algo_iter , self .global_step )
181
+ if self .cfg .save_video :
181
182
self .video_recorder .save (f'{ episode } .mp4' )
182
183
183
184
with self .logger .log_and_dump_ctx (self .global_frame , ty = 'eval' ) as log :
@@ -189,14 +190,30 @@ def sample(self):
189
190
# Store data in values
190
191
buffer_path = os .path .join (self .work_dir , 'buffer' )
191
192
os .rename (buffer_path , f'{ self .cfg .agent .name } _{ self .cfg .snapshot_ts } ' )
193
+ if self .cfg .save_to_data :
194
+ domain , _ = self .cfg .task .split ('_' , 1 )
195
+ if self .cfg .agent .name == 'diayn' :
196
+ target_path = f'./../../../data/datasets/{ self .cfg .obs_type } /{ domain } /{ self .cfg .agent .name } /{ self .cfg .skill_dim } /{ self .cfg .seed } /{ self .cfg .snapshot_ts } '
197
+ else :
198
+ target_path = f'./../../../data/datasets/{ self .cfg .obs_type } /{ domain } /{ self .agent .name } /{ self .cfg .seed } /{ self .cfg .snapsot_ts } '
199
+ if not os .path .exists (target_path ):
200
+ os .makedirs (target_path )
201
+ source_path = os .path .join (self .work_dir , f'{ self .cfg .agent .name } _{ self .cfg .snapshot_ts } ' )
202
+ print (f'beginning to move: { source_path } -> { target_path } ' )
203
+ for file_name in os .listdir (source_path ):
204
+ source = os .path .join (source_path , file_name )
205
+ target = os .path .join (target_path , file_name )
206
+ if os .path .isfile (source ):
207
+ shutil .move (source , target )
208
+
192
209
193
210
def load_snapshot (self ):
194
211
snapshot_base_dir = Path (self .cfg .snapshot_base_dir )
195
212
domain , _ = self .cfg .task .split ('_' , 1 )
196
213
snapshot_dir = snapshot_base_dir / self .cfg .obs_type / domain / self .cfg .agent .name
197
214
198
215
def try_load (seed ):
199
- if self .cfg .agent == 'diayn' :
216
+ if self .cfg .agent . name == 'diayn' :
200
217
snapshot = snapshot_dir / f'{ self .cfg .skill_dim } ' / str (seed ) / f'snapshot_{ self .cfg .snapshot_ts } .pt'
201
218
else :
202
219
snapshot = snapshot_dir / str (seed ) / f'snapshot_{ self .cfg .snapshot_ts } .pt'
0 commit comments