Skip to content

[BugFix] Fix optional imports #535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 9, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .circleci/unittest/linux_optdeps/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,3 @@ dependencies:
- expecttest
- pyyaml
- scipy
- hydra-core
15 changes: 13 additions & 2 deletions test/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@
import pytest
import torch
from _utils_internal import get_available_devices, generate_seeds
from hydra import initialize, compose
from hydra.core.config_store import ConfigStore

try:
from hydra import initialize, compose
from hydra.core.config_store import ConfigStore

_has_hydra = True
except ImportError:
_has_hydra = False
from mocking_classes import (
ContinuousActionConvMockEnvNumpy,
ContinuousActionVecMockEnv,
Expand Down Expand Up @@ -49,6 +55,7 @@ def _assert_keys_match(td, expeceted_keys):


@pytest.mark.skipif(not _has_gym, reason="No gym library found")
@pytest.mark.skipif(not _has_hydra, reason="No hydra library found")
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("noisy", [tuple(), ("noisy=True",)])
@pytest.mark.parametrize("distributional", [tuple(), ("distributional=True",)])
Expand Down Expand Up @@ -99,6 +106,7 @@ def test_dqn_maker(device, noisy, distributional, from_pixels):
proof_environment.close()


@pytest.mark.skipif(not _has_hydra, reason="No hydra library found")
@pytest.mark.skipif(not _has_gym, reason="No gym library found")
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("from_pixels", [("from_pixels=True", "catframes=4"), tuple()])
Expand Down Expand Up @@ -173,6 +181,7 @@ def test_ddpg_maker(device, from_pixels, gsde, exploration):
del proof_environment


@pytest.mark.skipif(not _has_hydra, reason="No hydra library found")
@pytest.mark.skipif(not _has_gym, reason="No gym library found")
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("from_pixels", [tuple(), ("from_pixels=True", "catframes=4")])
Expand Down Expand Up @@ -287,6 +296,7 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration):
del proof_environment


@pytest.mark.skipif(not _has_hydra, reason="No hydra library found")
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("gsde", [tuple(), ("gSDE=True",)])
@pytest.mark.parametrize("from_pixels", [tuple()])
Expand Down Expand Up @@ -402,6 +412,7 @@ def test_sac_make(device, gsde, tanh_loc, from_pixels, exploration):
del proof_environment


@pytest.mark.skipif(not _has_hydra, reason="No hydra library found")
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("from_pixels", [tuple(), ("from_pixels=True", "catframes=4")])
@pytest.mark.parametrize("gsde", [tuple(), ("gSDE=True",)])
Expand Down
20 changes: 16 additions & 4 deletions test/test_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,21 @@

import pytest
import torch
import torchvision

try:
import torchvision

_has_tv = True
except ImportError:
_has_tv = False

try:
import mlflow

_has_mlfow = True
except ImportError:
_has_mlfow = False

from torchrl.trainers.loggers.csv import CSVLogger
from torchrl.trainers.loggers.mlflow import MLFlowLogger, _has_mlflow
from torchrl.trainers.loggers.tensorboard import TensorboardLogger, _has_tb
Expand Down Expand Up @@ -225,7 +239,6 @@ def test_log_video(self):
@pytest.fixture
def mlflow_fixture():
torch.manual_seed(0)
import mlflow

with tempfile.TemporaryDirectory() as log_dir:
exp_name = "ramala"
Expand All @@ -240,7 +253,6 @@ def mlflow_fixture():
class TestMLFlowLogger:
@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
def test_log_scalar(self, steps, mlflow_fixture):
import mlflow

logger, client = mlflow_fixture
values = torch.rand(3)
Expand All @@ -259,8 +271,8 @@ def test_log_scalar(self, steps, mlflow_fixture):
assert metric.value == values[i].item()

@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
@pytest.mark.skipif(not _has_tv, reason="torchvision not installed")
def test_log_video(self, steps, mlflow_fixture):
import mlflow

logger, client = mlflow_fixture
videos = torch.cat(
Expand Down