Skip to content

Commit 43691d4

Browse files
Implement todos tensorboard (#20874)
* test: enhance tensorboard log graph * Update test_tensorboard.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d032388 commit 43691d4

File tree

1 file changed

+32
-9
lines changed

1 file changed

+32
-9
lines changed

tests/tests_fabric/loggers/test_tensorboard.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -147,29 +147,52 @@ def test_tensorboard_log_hparams_and_metrics(tmp_path):
147147

148148

149149
@pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 32)])
150-
def test_tensorboard_log_graph(tmp_path, example_input_array):
151-
"""Test that log graph works with both model.example_input_array and if array is passed externally."""
152-
# TODO(fabric): Test both nn.Module and LightningModule
153-
# TODO(fabric): Assert _apply_batch_transfer_handler is calling the batch transfer hooks
150+
def test_tensorboard_log_graph_plain_module(tmp_path, example_input_array):
154151
model = BoringModel()
155-
if example_input_array is not None:
156-
model.example_input_array = None
157-
158152
logger = TensorBoardLogger(tmp_path)
159153
logger._experiment = Mock()
154+
160155
logger.log_graph(model, example_input_array)
161156
if example_input_array is not None:
162157
logger.experiment.add_graph.assert_called_with(model, example_input_array)
158+
else:
159+
logger.experiment.add_graph.assert_not_called()
160+
163161
logger._experiment.reset_mock()
164162

165-
# model wrapped in `FabricModule`
166163
wrapped = _FabricModule(model, strategy=Mock())
167164
logger.log_graph(wrapped, example_input_array)
168165
if example_input_array is not None:
169166
logger.experiment.add_graph.assert_called_with(model, example_input_array)
170167

171168

172-
@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason=str(_TENSORBOARD_AVAILABLE))
169+
@pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 32)])
170+
def test_tensorboard_log_graph_with_batch_transfer_hooks(tmp_path, example_input_array):
171+
model = pytest.importorskip("lightning.pytorch.demos.boring_classes").BoringModel()
172+
logger = TensorBoardLogger(tmp_path)
173+
logger._experiment = Mock()
174+
175+
with (
176+
mock.patch.object(model, "_on_before_batch_transfer", return_value=example_input_array) as before_mock,
177+
mock.patch.object(model, "_apply_batch_transfer_handler", return_value=example_input_array) as transfer_mock,
178+
):
179+
logger.log_graph(model, example_input_array)
180+
logger._experiment.reset_mock()
181+
182+
wrapped = _FabricModule(model, strategy=Mock())
183+
logger.log_graph(wrapped, example_input_array)
184+
185+
if example_input_array is not None:
186+
assert before_mock.call_count == 2
187+
assert transfer_mock.call_count == 2
188+
logger.experiment.add_graph.assert_called_with(model, example_input_array)
189+
else:
190+
before_mock.assert_not_called()
191+
transfer_mock.assert_not_called()
192+
logger.experiment.add_graph.assert_not_called()
193+
194+
195+
@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason="tensorboard is required")
173196
def test_tensorboard_log_graph_warning_no_example_input_array(tmp_path):
174197
"""Test that log graph throws warning if model.example_input_array is None."""
175198
model = BoringModel()

0 commit comments

Comments
 (0)