Skip to content

Commit

Permalink
fix flop counting with explicit shapes
Browse files Browse the repository at this point in the history
Reviewed By: sstsai-adl

Differential Revision: D32120146

fbshipit-source-id: 06866e218337fe9d5a15df864eee870c51472dc6
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Nov 4, 2021
1 parent 175b245 commit c47167e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
3 changes: 2 additions & 1 deletion detectron2/modeling/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def detector_postprocess(
Returns:
Instances: the resized output from the model, based on the output resolution
"""
if torch.jit.is_tracing():
if isinstance(output_width, torch.Tensor):
# This shape might (but not necessarily) be tensors during tracing.
# Converts integer tensors to float temporaries to ensure true
# division is performed when computing scale_x and scale_y.
output_width_tmp = output_width.float()
Expand Down
5 changes: 5 additions & 0 deletions tests/test_model_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def test_flop(self):
# almost 0 for random inputs.
self.assertTrue(int(res["conv"]), 117)

def test_flop_with_output_shape(self):
inputs = [{"image": torch.rand(3, 800, 800), "height": 700, "width": 700}]
res = flop_count_operators(self.model, inputs)
self.assertTrue(int(res["conv"]), 117)

def test_param_count(self):
res = parameter_count(self.model)
self.assertTrue(res[""], 41699936)
Expand Down

0 comments on commit c47167e

Please sign in to comment.