Skip to content

Commit f36050e

Browse files
Update test_tensorboard.py
1 parent c14a6ea commit f36050e

File tree

1 file changed

+31
-29
lines changed

1 file changed

+31
-29
lines changed

tests/tests_fabric/loggers/test_tensorboard.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -146,48 +146,50 @@ def test_tensorboard_log_hparams_and_metrics(tmp_path):
146146
logger.log_hyperparams(hparams, metrics)
147147

148148

149-
@pytest.mark.parametrize(
150-
"model_cls", [BoringModel, pytest.importorskip("lightning.pytorch.demos.boring_classes").BoringModel]
151-
)
152149
@pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 32)])
153-
def test_tensorboard_log_graph(tmp_path, example_input_array, model_cls):
154-
"""Test that log graph works with both model.example_input_array and if array is passed externally."""
155-
model = model_cls()
150+
def test_tensorboard_log_graph_plain_module(tmp_path, example_input_array):
151+
model = BoringModel()
152+
logger = TensorBoardLogger(tmp_path)
153+
logger._experiment = Mock()
154+
155+
logger.log_graph(model, example_input_array)
156+
if example_input_array is not None:
157+
logger.experiment.add_graph.assert_called_with(model, example_input_array)
158+
else:
159+
logger.experiment.add_graph.assert_not_called()
160+
161+
logger._experiment.reset_mock()
162+
163+
wrapped = _FabricModule(model, strategy=Mock())
164+
logger.log_graph(wrapped, example_input_array)
156165
if example_input_array is not None:
157-
model.example_input_array = None
166+
logger.experiment.add_graph.assert_called_with(model, example_input_array)
158167

168+
169+
@pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 32)])
170+
def test_tensorboard_log_graph_with_batch_transfer_hooks(tmp_path, example_input_array):
171+
model = pytest.importorskip("lightning.pytorch.demos.boring_classes").BoringModel()
159172
logger = TensorBoardLogger(tmp_path)
160173
logger._experiment = Mock()
161174

162-
if isinstance(model, torch.nn.Module) and hasattr(model, "_apply_batch_transfer_handler"):
163-
with (
164-
mock.patch.object(model, "_on_before_batch_transfer", return_value=example_input_array) as before_mock,
165-
mock.patch.object(
166-
model, "_apply_batch_transfer_handler", return_value=example_input_array
167-
) as transfer_mock,
168-
):
169-
logger.log_graph(model, example_input_array)
170-
logger._experiment.reset_mock()
171-
wrapped = _FabricModule(model, strategy=Mock())
172-
logger.log_graph(wrapped, example_input_array)
173-
if example_input_array is not None:
174-
assert before_mock.call_count == 2
175-
assert transfer_mock.call_count == 2
176-
logger.experiment.add_graph.assert_called_with(model, example_input_array)
177-
else:
178-
before_mock.assert_not_called()
179-
transfer_mock.assert_not_called()
180-
logger.experiment.add_graph.assert_not_called()
181-
else:
175+
with (
176+
mock.patch.object(model, "_on_before_batch_transfer", return_value=example_input_array) as before_mock,
177+
mock.patch.object(model, "_apply_batch_transfer_handler", return_value=example_input_array) as transfer_mock,
178+
):
182179
logger.log_graph(model, example_input_array)
183-
if example_input_array is not None:
184-
logger.experiment.add_graph.assert_called_with(model, example_input_array)
185180
logger._experiment.reset_mock()
186181

187182
wrapped = _FabricModule(model, strategy=Mock())
188183
logger.log_graph(wrapped, example_input_array)
184+
189185
if example_input_array is not None:
186+
assert before_mock.call_count == 2
187+
assert transfer_mock.call_count == 2
190188
logger.experiment.add_graph.assert_called_with(model, example_input_array)
189+
else:
190+
before_mock.assert_not_called()
191+
transfer_mock.assert_not_called()
192+
logger.experiment.add_graph.assert_not_called()
191193

192194

193195
@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason="tensorboard is required")

0 commit comments

Comments
 (0)