Skip to content

Commit

Permalink
Removed tensorboard dependencies from pytorch_sac. (facebookresearch#154
Browse files Browse the repository at this point in the history
)
  • Loading branch information
luisenp authored Jun 8, 2022
1 parent 9b7534d commit de84a6b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 35 deletions.
30 changes: 0 additions & 30 deletions mbrl/third_party/pytorch_sac/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np
import torch
from termcolor import colored
from torch.utils.tensorboard import SummaryWriter

COMMON_TRAIN_FORMAT = [
("episode", "E", "int"),
Expand Down Expand Up @@ -118,25 +117,13 @@ class Logger(object):
def __init__(
self,
log_dir,
save_tb=False,
log_frequency=10000,
agent="sac",
train_format=None,
eval_format=None,
):
self._log_dir = log_dir
self._log_frequency = log_frequency
if save_tb:
tb_dir = os.path.join(log_dir, "tb")
if os.path.exists(tb_dir):
try:
shutil.rmtree(tb_dir)
except:
print("logger.py warning: Unable to remove tb directory")
pass
self._sw = SummaryWriter(tb_dir)
else:
self._sw = None
if not train_format:
# each agent has specific output format for training
assert agent in AGENT_TRAIN_FORMAT
Expand All @@ -153,27 +140,12 @@ def _should_log(self, step, log_frequency):
log_frequency = log_frequency or self._log_frequency
return step % log_frequency == 0

def _try_sw_log(self, key, value, step):
if self._sw is not None:
self._sw.add_scalar(key, value, step)

def _try_sw_log_video(self, key, frames, step):
if self._sw is not None:
frames = torch.from_numpy(np.array(frames))
frames = frames.unsqueeze(0)
self._sw.add_video(key, frames, step, fps=30)

def _try_sw_log_histogram(self, key, histogram, step):
if self._sw is not None:
self._sw.add_histogram(key, histogram, step)

def log(self, key, value, step, n=1, log_frequency=1):
if not self._should_log(step, log_frequency):
return
assert key.startswith("train") or key.startswith("eval")
if type(value) == torch.Tensor:
value = value.item()
self._try_sw_log(key, value / n, step)
mg = self._train_mg if key.startswith("train") else self._eval_mg
mg.log(key, value, n)

Expand All @@ -192,13 +164,11 @@ def log_video(self, key, frames, step, log_frequency=None):
if not self._should_log(step, log_frequency):
return
assert key.startswith("train") or key.startswith("eval")
self._try_sw_log_video(key, frames, step)

def log_histogram(self, key, histogram, step, log_frequency=None):
if not self._should_log(step, log_frequency):
return
assert key.startswith("train") or key.startswith("eval")
self._try_sw_log_histogram(key, histogram, step)

def dump(self, step, save=True, ty=None):
if ty is None:
Expand Down
10 changes: 5 additions & 5 deletions mbrl/third_party/pytorch_sac/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import hydra
import numpy as np
import torch
from pytorch_sac import utils
from pytorch_sac.logger import Logger
from pytorch_sac.replay_buffer import ReplayBuffer
from pytorch_sac.video import VideoRecorder

from mbrl.third_party.pytorch_sac import utils
from mbrl.third_party.pytorch_sac.logger import Logger
from mbrl.third_party.pytorch_sac.replay_buffer import ReplayBuffer
from mbrl.third_party.pytorch_sac.video import VideoRecorder


class Workspace(object):
Expand All @@ -21,7 +22,6 @@ def __init__(self, cfg):

self.logger = Logger(
self.work_dir,
save_tb=cfg.log_save_tb,
log_frequency=cfg.log_frequency,
agent="sac",
)
Expand Down

0 comments on commit de84a6b

Please sign in to comment.