Skip to content

Commit

Permalink
correct for state representation
Browse files Browse the repository at this point in the history
  • Loading branch information
AOS55 committed Oct 26, 2022
1 parent a29bf56 commit c2bfb72
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion libraries/latentsafesets/rl_trainers/mpc_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def __init__(self, env, cfg, modules):
self.logdir = cfg.log_dir

loss_plotter = LossPlotter(os.path.join(self.logdir, 'loss_plots'))
self.encoder_data_loader = EncoderDataLoader(cfg.env, frame_stack=cfg.frame_stack)

self.trainers = []

Expand All @@ -33,6 +32,8 @@ def initial_train(self, replay_buffer):
os.makedirs(update_dir, exist_ok=True)
for trainer in self.trainers:
if type(trainer) == VAETrainer:

self.encoder_data_loader = EncoderDataLoader(self.cfg.env, frame_stack=self.cfg.frame_stack)
trainer.initial_train(self.encoder_data_loader, update_dir)
else:
trainer.initial_train(replay_buffer, update_dir)
Expand Down

0 comments on commit c2bfb72

Please sign in to comment.