diff --git a/detectron2/modeling/postprocessing.py b/detectron2/modeling/postprocessing.py index a915491e32..52f273bb5e 100644 --- a/detectron2/modeling/postprocessing.py +++ b/detectron2/modeling/postprocessing.py @@ -27,7 +27,8 @@ def detector_postprocess( Returns: Instances: the resized output from the model, based on the output resolution """ - if torch.jit.is_tracing(): + if isinstance(output_width, torch.Tensor): + # This shape might (but not necessarily) be tensors during tracing. # Converts integer tensors to float temporaries to ensure true # division is performed when computing scale_x and scale_y. output_width_tmp = output_width.float() diff --git a/tests/test_model_analysis.py b/tests/test_model_analysis.py index deee4fa78d..234ec8f561 100644 --- a/tests/test_model_analysis.py +++ b/tests/test_model_analysis.py @@ -39,6 +39,11 @@ def test_flop(self): # almost 0 for random inputs. self.assertTrue(int(res["conv"]), 117) + def test_flop_with_output_shape(self): + inputs = [{"image": torch.rand(3, 800, 800), "height": 700, "width": 700}] + res = flop_count_operators(self.model, inputs) + self.assertTrue(int(res["conv"]), 117) + def test_param_count(self): res = parameter_count(self.model) self.assertTrue(res[""], 41699936)