Skip to content

Commit

Permalink
fix: fix compensator saving
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaiejj committed Jun 19, 2024
1 parent 7423bc1 commit 3934192
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
13 changes: 10 additions & 3 deletions omnisafe/algorithms/off_policy/ddpg_cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _init_env(self) -> None:
self._seed,
self._cfgs,
)
solver = PendulumSolver(device=self._cfgs.train_cfgs.device)
solver = PendulumSolver(device=self._device)
compensator = BarrierCompensator(
obs_dim=self._env.observation_space.shape[0],
act_dim=self._env.action_space.shape[0],
Expand Down Expand Up @@ -120,11 +120,18 @@ def _specific_save(self) -> None:
os.makedirs(os.path.dirname(path), exist_ok=True)
joblib.dump(self._env.gp_models, path)

def _log_what_to_save(self) -> dict[str, Any]:
"""Define what need to be saved below."""
def _setup_torch_saver(self) -> None:
"""Define what need to be saved below.
OmniSafe's main storage interface is based on PyTorch. If you need to save models in other
formats, please use :meth:`_specific_save`.
"""
what_to_save: dict[str, Any] = {}

what_to_save['pi'] = self._actor_critic.actor
what_to_save['compensator'] = self._env.compensator
if self._cfgs.algo_cfgs.obs_normalize:
obs_normalizer = self._env.save()['obs_normalizer']
what_to_save['obs_normalizer'] = obs_normalizer

self._logger.setup_torch_saver(what_to_save)
21 changes: 19 additions & 2 deletions omnisafe/algorithms/on_policy/barrier_function/trpo_cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from omnisafe.algorithms.on_policy.base.trpo import TRPO
from omnisafe.common.barrier_comp import BarrierCompensator
from omnisafe.common.barrier_solver import PendulumSolver
from omnisafe.typing import Any
from omnisafe.utils import distributed
from omnisafe.utils.distributed import get_rank

Expand Down Expand Up @@ -70,12 +71,12 @@ def _init_env(self) -> None:
// distributed.world_size()
// self._cfgs.train_cfgs.vector_env_nums
)
self.solver = PendulumSolver(device=self._cfgs.train_cfgs.device)
self.solver = PendulumSolver(device=self._device)
self.compensator = BarrierCompensator(
obs_dim=self._env.observation_space.shape[0],
act_dim=self._env.action_space.shape[0],
cfgs=self._cfgs.compensator_cfgs,
)
).to(self._device)
self._env.set_solver(solver=self.solver)
self._env.set_compensator(compensator=self.compensator)

Expand Down Expand Up @@ -165,3 +166,19 @@ def _specific_save(self) -> None:
)
os.makedirs(os.path.dirname(path), exist_ok=True)
joblib.dump(self._env.gp_models, path)

def _setup_torch_saver(self) -> None:
"""Define what need to be saved below.
OmniSafe's main storage interface is based on PyTorch. If you need to save models in other
formats, please use :meth:`_specific_save`.
"""
what_to_save: dict[str, Any] = {}

what_to_save['pi'] = self._actor_critic.actor
what_to_save['compensator'] = self._env.compensator
if self._cfgs.algo_cfgs.obs_normalize:
obs_normalizer = self._env.save()['obs_normalizer']
what_to_save['obs_normalizer'] = obs_normalizer

self._logger.setup_torch_saver(what_to_save)

0 comments on commit 3934192

Please sign in to comment.