Skip to content

Commit bd0120e

Browse files
authored
[BugFix] Fix optional imports (#535)
1 parent 9eeafbf commit bd0120e

File tree

3 files changed

+28
-16
lines changed

3 files changed

+28
-16
lines changed

.circleci/unittest/linux_optdeps/scripts/environment.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,4 @@ dependencies:
1414
- expecttest
1515
- pyyaml
1616
- scipy
17-
- hydra-core
1817
- coverage

test/test_helpers.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,14 @@
99
import pytest
1010
import torch
1111
from _utils_internal import get_available_devices, generate_seeds
12-
from hydra import initialize, compose
13-
from hydra.core.config_store import ConfigStore
12+
13+
try:
14+
from hydra import initialize, compose
15+
from hydra.core.config_store import ConfigStore
16+
17+
_has_hydra = True
18+
except ImportError:
19+
_has_hydra = False
1420
from mocking_classes import (
1521
ContinuousActionConvMockEnvNumpy,
1622
ContinuousActionVecMockEnv,
@@ -49,6 +55,7 @@ def _assert_keys_match(td, expeceted_keys):
4955

5056

5157
@pytest.mark.skipif(not _has_gym, reason="No gym library found")
58+
@pytest.mark.skipif(not _has_hydra, reason="No hydra library found")
5259
@pytest.mark.parametrize("device", get_available_devices())
5360
@pytest.mark.parametrize("noisy", [tuple(), ("noisy=True",)])
5461
@pytest.mark.parametrize("distributional", [tuple(), ("distributional=True",)])
@@ -99,6 +106,7 @@ def test_dqn_maker(device, noisy, distributional, from_pixels):
99106
proof_environment.close()
100107

101108

109+
@pytest.mark.skipif(not _has_hydra, reason="No hydra library found")
102110
@pytest.mark.skipif(not _has_gym, reason="No gym library found")
103111
@pytest.mark.parametrize("device", get_available_devices())
104112
@pytest.mark.parametrize("from_pixels", [("from_pixels=True", "catframes=4"), tuple()])
@@ -173,6 +181,7 @@ def test_ddpg_maker(device, from_pixels, gsde, exploration):
173181
del proof_environment
174182

175183

184+
@pytest.mark.skipif(not _has_hydra, reason="No hydra library found")
176185
@pytest.mark.skipif(not _has_gym, reason="No gym library found")
177186
@pytest.mark.parametrize("device", get_available_devices())
178187
@pytest.mark.parametrize("from_pixels", [tuple(), ("from_pixels=True", "catframes=4")])
@@ -287,11 +296,12 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration):
287296
del proof_environment
288297

289298

299+
@pytest.mark.skipif(not _has_hydra, reason="No hydra library found")
300+
@pytest.mark.skipif(not _has_gym, reason="No gym library found")
290301
@pytest.mark.parametrize("device", get_available_devices())
291302
@pytest.mark.parametrize("gsde", [tuple(), ("gSDE=True",)])
292303
@pytest.mark.parametrize("from_pixels", [tuple()])
293304
@pytest.mark.parametrize("tanh_loc", [tuple(), ("tanh_loc=True",)])
294-
@pytest.mark.skipif(not _has_gym, reason="No gym library found")
295305
@pytest.mark.parametrize("exploration", ["random", "mode"])
296306
def test_sac_make(device, gsde, tanh_loc, from_pixels, exploration):
297307
if not gsde and exploration != "random":
@@ -402,10 +412,11 @@ def test_sac_make(device, gsde, tanh_loc, from_pixels, exploration):
402412
del proof_environment
403413

404414

415+
@pytest.mark.skipif(not _has_hydra, reason="No hydra library found")
416+
@pytest.mark.skipif(not _has_gym, reason="No gym library found")
405417
@pytest.mark.parametrize("device", get_available_devices())
406418
@pytest.mark.parametrize("from_pixels", [tuple(), ("from_pixels=True", "catframes=4")])
407419
@pytest.mark.parametrize("gsde", [tuple(), ("gSDE=True",)])
408-
@pytest.mark.skipif(not _has_gym, reason="No gym library found")
409420
@pytest.mark.parametrize("exploration", ["random", "mode"])
410421
def test_redq_make(device, from_pixels, gsde, exploration):
411422
if not gsde and exploration != "random":

test/test_loggers.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,22 @@
77

88
import pytest
99
import torch
10-
import torchvision
1110
from torchrl.trainers.loggers.csv import CSVLogger
12-
from torchrl.trainers.loggers.mlflow import MLFlowLogger, _has_mlflow
11+
from torchrl.trainers.loggers.mlflow import MLFlowLogger, _has_mlflow, _has_tv
1312
from torchrl.trainers.loggers.tensorboard import TensorboardLogger, _has_tb
1413
from torchrl.trainers.loggers.wandb import WandbLogger, _has_wandb
1514

15+
if _has_tv:
16+
import torchvision
17+
18+
if _has_tb:
19+
from tensorboard.backend.event_processing.event_accumulator import (
20+
EventAccumulator,
21+
)
22+
23+
if _has_mlflow:
24+
import mlflow
25+
1626

1727
@pytest.mark.skipif(not _has_tb, reason="TensorBoard not installed")
1828
class TestTensorboard:
@@ -34,9 +44,6 @@ def test_log_scalar(self, steps):
3444
)
3545

3646
sleep(0.01) # wait until events are registered
37-
from tensorboard.backend.event_processing.event_accumulator import (
38-
EventAccumulator,
39-
)
4047

4148
event_acc = EventAccumulator(logger.experiment.get_logdir())
4249
event_acc.Reload()
@@ -69,9 +76,6 @@ def test_log_video(self, steps):
6976
)
7077

7178
sleep(0.01) # wait until events are registered
72-
from tensorboard.backend.event_processing.event_accumulator import (
73-
EventAccumulator,
74-
)
7579

7680
event_acc = EventAccumulator(logger.experiment.get_logdir())
7781
event_acc.Reload()
@@ -225,7 +229,6 @@ def test_log_video(self):
225229
@pytest.fixture
226230
def mlflow_fixture():
227231
torch.manual_seed(0)
228-
import mlflow
229232

230233
with tempfile.TemporaryDirectory() as log_dir:
231234
exp_name = "ramala"
@@ -240,7 +243,6 @@ def mlflow_fixture():
240243
class TestMLFlowLogger:
241244
@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
242245
def test_log_scalar(self, steps, mlflow_fixture):
243-
import mlflow
244246

245247
logger, client = mlflow_fixture
246248
values = torch.rand(3)
@@ -259,8 +261,8 @@ def test_log_scalar(self, steps, mlflow_fixture):
259261
assert metric.value == values[i].item()
260262

261263
@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
264+
@pytest.mark.skipif(not _has_tv, reason="torchvision not installed")
262265
def test_log_video(self, steps, mlflow_fixture):
263-
import mlflow
264266

265267
logger, client = mlflow_fixture
266268
videos = torch.cat(

0 commit comments

Comments
 (0)