@@ -146,48 +146,50 @@ def test_tensorboard_log_hparams_and_metrics(tmp_path):
146
146
logger .log_hyperparams (hparams , metrics )
147
147
148
148
149
- @pytest .mark .parametrize (
150
- "model_cls" , [BoringModel , pytest .importorskip ("lightning.pytorch.demos.boring_classes" ).BoringModel ]
151
- )
152
149
@pytest .mark .parametrize ("example_input_array" , [None , torch .rand (2 , 32 )])
153
- def test_tensorboard_log_graph (tmp_path , example_input_array , model_cls ):
154
- """Test that log graph works with both model.example_input_array and if array is passed externally."""
155
- model = model_cls ()
150
+ def test_tensorboard_log_graph_plain_module (tmp_path , example_input_array ):
151
+ model = BoringModel ()
152
+ logger = TensorBoardLogger (tmp_path )
153
+ logger ._experiment = Mock ()
154
+
155
+ logger .log_graph (model , example_input_array )
156
+ if example_input_array is not None :
157
+ logger .experiment .add_graph .assert_called_with (model , example_input_array )
158
+ else :
159
+ logger .experiment .add_graph .assert_not_called ()
160
+
161
+ logger ._experiment .reset_mock ()
162
+
163
+ wrapped = _FabricModule (model , strategy = Mock ())
164
+ logger .log_graph (wrapped , example_input_array )
156
165
if example_input_array is not None :
157
- model . example_input_array = None
166
+ logger . experiment . add_graph . assert_called_with ( model , example_input_array )
158
167
168
+
169
+ @pytest .mark .parametrize ("example_input_array" , [None , torch .rand (2 , 32 )])
170
+ def test_tensorboard_log_graph_with_batch_transfer_hooks (tmp_path , example_input_array ):
171
+ model = pytest .importorskip ("lightning.pytorch.demos.boring_classes" ).BoringModel ()
159
172
logger = TensorBoardLogger (tmp_path )
160
173
logger ._experiment = Mock ()
161
174
162
- if isinstance (model , torch .nn .Module ) and hasattr (model , "_apply_batch_transfer_handler" ):
163
- with (
164
- mock .patch .object (model , "_on_before_batch_transfer" , return_value = example_input_array ) as before_mock ,
165
- mock .patch .object (
166
- model , "_apply_batch_transfer_handler" , return_value = example_input_array
167
- ) as transfer_mock ,
168
- ):
169
- logger .log_graph (model , example_input_array )
170
- logger ._experiment .reset_mock ()
171
- wrapped = _FabricModule (model , strategy = Mock ())
172
- logger .log_graph (wrapped , example_input_array )
173
- if example_input_array is not None :
174
- assert before_mock .call_count == 2
175
- assert transfer_mock .call_count == 2
176
- logger .experiment .add_graph .assert_called_with (model , example_input_array )
177
- else :
178
- before_mock .assert_not_called ()
179
- transfer_mock .assert_not_called ()
180
- logger .experiment .add_graph .assert_not_called ()
181
- else :
175
+ with (
176
+ mock .patch .object (model , "_on_before_batch_transfer" , return_value = example_input_array ) as before_mock ,
177
+ mock .patch .object (model , "_apply_batch_transfer_handler" , return_value = example_input_array ) as transfer_mock ,
178
+ ):
182
179
logger .log_graph (model , example_input_array )
183
- if example_input_array is not None :
184
- logger .experiment .add_graph .assert_called_with (model , example_input_array )
185
180
logger ._experiment .reset_mock ()
186
181
187
182
wrapped = _FabricModule (model , strategy = Mock ())
188
183
logger .log_graph (wrapped , example_input_array )
184
+
189
185
if example_input_array is not None :
186
+ assert before_mock .call_count == 2
187
+ assert transfer_mock .call_count == 2
190
188
logger .experiment .add_graph .assert_called_with (model , example_input_array )
189
+ else :
190
+ before_mock .assert_not_called ()
191
+ transfer_mock .assert_not_called ()
192
+ logger .experiment .add_graph .assert_not_called ()
191
193
192
194
193
195
@pytest .mark .skipif (not _TENSORBOARD_AVAILABLE , reason = "tensorboard is required" )
0 commit comments