@@ -147,29 +147,52 @@ def test_tensorboard_log_hparams_and_metrics(tmp_path):
147
147
148
148
149
149
@pytest .mark .parametrize ("example_input_array" , [None , torch .rand (2 , 32 )])
150
- def test_tensorboard_log_graph (tmp_path , example_input_array ):
151
- """Test that log graph works with both model.example_input_array and if array is passed externally."""
152
- # TODO(fabric): Test both nn.Module and LightningModule
153
- # TODO(fabric): Assert _apply_batch_transfer_handler is calling the batch transfer hooks
150
+ def test_tensorboard_log_graph_plain_module (tmp_path , example_input_array ):
154
151
model = BoringModel ()
155
- if example_input_array is not None :
156
- model .example_input_array = None
157
-
158
152
logger = TensorBoardLogger (tmp_path )
159
153
logger ._experiment = Mock ()
154
+
160
155
logger .log_graph (model , example_input_array )
161
156
if example_input_array is not None :
162
157
logger .experiment .add_graph .assert_called_with (model , example_input_array )
158
+ else :
159
+ logger .experiment .add_graph .assert_not_called ()
160
+
163
161
logger ._experiment .reset_mock ()
164
162
165
- # model wrapped in `FabricModule`
166
163
wrapped = _FabricModule (model , strategy = Mock ())
167
164
logger .log_graph (wrapped , example_input_array )
168
165
if example_input_array is not None :
169
166
logger .experiment .add_graph .assert_called_with (model , example_input_array )
170
167
171
168
172
- @pytest .mark .skipif (not _TENSORBOARD_AVAILABLE , reason = str (_TENSORBOARD_AVAILABLE ))
169
+ @pytest .mark .parametrize ("example_input_array" , [None , torch .rand (2 , 32 )])
170
+ def test_tensorboard_log_graph_with_batch_transfer_hooks (tmp_path , example_input_array ):
171
+ model = pytest .importorskip ("lightning.pytorch.demos.boring_classes" ).BoringModel ()
172
+ logger = TensorBoardLogger (tmp_path )
173
+ logger ._experiment = Mock ()
174
+
175
+ with (
176
+ mock .patch .object (model , "_on_before_batch_transfer" , return_value = example_input_array ) as before_mock ,
177
+ mock .patch .object (model , "_apply_batch_transfer_handler" , return_value = example_input_array ) as transfer_mock ,
178
+ ):
179
+ logger .log_graph (model , example_input_array )
180
+ logger ._experiment .reset_mock ()
181
+
182
+ wrapped = _FabricModule (model , strategy = Mock ())
183
+ logger .log_graph (wrapped , example_input_array )
184
+
185
+ if example_input_array is not None :
186
+ assert before_mock .call_count == 2
187
+ assert transfer_mock .call_count == 2
188
+ logger .experiment .add_graph .assert_called_with (model , example_input_array )
189
+ else :
190
+ before_mock .assert_not_called ()
191
+ transfer_mock .assert_not_called ()
192
+ logger .experiment .add_graph .assert_not_called ()
193
+
194
+
195
+ @pytest .mark .skipif (not _TENSORBOARD_AVAILABLE , reason = "tensorboard is required" )
173
196
def test_tensorboard_log_graph_warning_no_example_input_array (tmp_path ):
174
197
"""Test that log graph throws warning if model.example_input_array is None."""
175
198
model = BoringModel ()
0 commit comments