Skip to content

Minor fixes to wandb logger #2999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 2 additions & 34 deletions torchrl/record/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ class WandbLogger(Logger):

@classmethod
def __new__(cls, *args, **kwargs):
cls._prev_video_step = -1
return super().__new__(cls)

def __init__(
Expand Down Expand Up @@ -95,7 +94,7 @@ def __init__(

self.video_log_counter = 0

def _create_experiment(self) -> WandbLogger:
def _create_experiment(self) -> "wandb.Experiment":
"""Creates a wandb experiment.

Args:
Expand All @@ -122,10 +121,7 @@ def log_scalar(self, name: str, value: float, step: int | None = None) -> None:
step (int, optional): The step at which the scalar is logged.
Defaults to None.
"""
if step is not None:
self.experiment.log({name: value, "trainer/step": step})
else:
self.experiment.log({name: value})
self.experiment.log({name: value}, step=step)

def log_video(self, name: str, video: Tensor, **kwargs) -> None:
"""Log videos inputs to wandb.
Expand All @@ -139,39 +135,11 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None:
passed as-is to the :obj:`experiment.log` method.
"""
import wandb

# check for correct format of the video tensor ((N), T, C, H, W)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this was mistakenly copied from tensorbaord logger. I dont see why its needed here.

# check that the color channel (C) is either 1 or 3
if video.dim() != 5 or video.size(dim=2) not in {1, 3}:
raise Exception(
"Wrong format of the video tensor. Should be ((N), T, C, H, W)"
)
if not self._has_imported_moviepy:
try:
import moviepy # noqa

self._has_imported_moviepy = True
except ImportError:
raise Exception(
"moviepy not found, videos cannot be logged with TensorboardLogger"
)
self.video_log_counter += 1
fps = kwargs.pop("fps", self.video_fps)
step = kwargs.pop("step", None)
format = kwargs.pop("format", "mp4")
if step not in (None, self._prev_video_step, self._prev_video_step + 1):
warnings.warn(
"when using step with wandb_logger.log_video, it is expected "
"that the step is equal to the previous step or that value incremented "
f"by one. Got step={step} but previous value was {self._prev_video_step}. "
f"The step value will be set to {self._prev_video_step+1}. This warning will "
f"be silenced from now on but the values will keep being incremented."
)
step = self._prev_video_step + 1
self._prev_video_step = step if step is not None else self._prev_video_step + 1
self.experiment.log(
{name: wandb.Video(video, fps=fps, format=format)},
# step=step,
**kwargs,
)

Expand Down