From 39341928a818ea2af67975590a97df24be256558 Mon Sep 17 00:00:00 2001 From: Gaiejj Date: Wed, 19 Jun 2024 20:16:35 +0800 Subject: [PATCH] fix: fix compensator saving --- omnisafe/algorithms/off_policy/ddpg_cbf.py | 13 +++++++++--- .../on_policy/barrier_function/trpo_cbf.py | 21 +++++++++++++++++-- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/omnisafe/algorithms/off_policy/ddpg_cbf.py b/omnisafe/algorithms/off_policy/ddpg_cbf.py index 17dcacda3..f69310fff 100644 --- a/omnisafe/algorithms/off_policy/ddpg_cbf.py +++ b/omnisafe/algorithms/off_policy/ddpg_cbf.py @@ -51,7 +51,7 @@ def _init_env(self) -> None: self._seed, self._cfgs, ) - solver = PendulumSolver(device=self._cfgs.train_cfgs.device) + solver = PendulumSolver(device=self._device) compensator = BarrierCompensator( obs_dim=self._env.observation_space.shape[0], act_dim=self._env.action_space.shape[0], @@ -120,11 +120,18 @@ def _specific_save(self) -> None: os.makedirs(os.path.dirname(path), exist_ok=True) joblib.dump(self._env.gp_models, path) - def _log_what_to_save(self) -> dict[str, Any]: - """Define what need to be saved below.""" + def _setup_torch_saver(self) -> None: + """Define what need to be saved below. + + OmniSafe's main storage interface is based on PyTorch. If you need to save models in other + formats, please use :meth:`_specific_save`. + """ what_to_save: dict[str, Any] = {} what_to_save['pi'] = self._actor_critic.actor what_to_save['compensator'] = self._env.compensator + if self._cfgs.algo_cfgs.obs_normalize: + obs_normalizer = self._env.save()['obs_normalizer'] + what_to_save['obs_normalizer'] = obs_normalizer self._logger.setup_torch_saver(what_to_save) diff --git a/omnisafe/algorithms/on_policy/barrier_function/trpo_cbf.py b/omnisafe/algorithms/on_policy/barrier_function/trpo_cbf.py index 8125151d6..b0b64f892 100644 --- a/omnisafe/algorithms/on_policy/barrier_function/trpo_cbf.py +++ b/omnisafe/algorithms/on_policy/barrier_function/trpo_cbf.py @@ -28,6 +28,7 @@ from omnisafe.algorithms.on_policy.base.trpo import TRPO from omnisafe.common.barrier_comp import BarrierCompensator from omnisafe.common.barrier_solver import PendulumSolver +from omnisafe.typing import Any from omnisafe.utils import distributed from omnisafe.utils.distributed import get_rank @@ -70,12 +71,12 @@ def _init_env(self) -> None: // distributed.world_size() // self._cfgs.train_cfgs.vector_env_nums ) - self.solver = PendulumSolver(device=self._cfgs.train_cfgs.device) + self.solver = PendulumSolver(device=self._device) self.compensator = BarrierCompensator( obs_dim=self._env.observation_space.shape[0], act_dim=self._env.action_space.shape[0], cfgs=self._cfgs.compensator_cfgs, - ) + ).to(self._device) self._env.set_solver(solver=self.solver) self._env.set_compensator(compensator=self.compensator) @@ -165,3 +166,19 @@ def _specific_save(self) -> None: ) os.makedirs(os.path.dirname(path), exist_ok=True) joblib.dump(self._env.gp_models, path) + + def _setup_torch_saver(self) -> None: + """Define what need to be saved below. + + OmniSafe's main storage interface is based on PyTorch. If you need to save models in other + formats, please use :meth:`_specific_save`. + """ + what_to_save: dict[str, Any] = {} + + what_to_save['pi'] = self._actor_critic.actor + what_to_save['compensator'] = self._env.compensator + if self._cfgs.algo_cfgs.obs_normalize: + obs_normalizer = self._env.save()['obs_normalizer'] + what_to_save['obs_normalizer'] = obs_normalizer + + self._logger.setup_torch_saver(what_to_save)