Skip to content

Commit

Permalink
Enable torch tracing by changing assertions in d2go forwards to allow…
Browse files Browse the repository at this point in the history
… for torch.fx.proxy.Proxy type.

Summary:
Pull Request resolved: facebookresearch#4227

X-link: facebookresearch/d2go#241

Torch FX tracing propagates a type of `torch.fx.proxy.Proxy` through the graph.

Existing type assertions in the d2go code base trigger during torch FX tracing, causing tracing to fail.

This adds a check for FX tracing in progress and  adds a helper function `assert_fx_safe()`, that can be used in place of a standard assertion. This function only applies the assertion if one is not tracing, allowing d2go assertion tests to be compatible with FX tracing.

Reviewed By: wat3rBro

Differential Revision: D35518556

fbshipit-source-id: a9b5d3d580518ca74948544973ae89f8b9de3282
  • Loading branch information
Simon Hollis authored and facebook-github-bot committed Aug 18, 2022
1 parent 5aeb252 commit 36a65a0
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 12 deletions.
26 changes: 14 additions & 12 deletions detectron2/modeling/poolers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from detectron2.layers import ROIAlign, ROIAlignRotated, cat, nonzero_tuple, shapes_to_tensor
from detectron2.structures import Boxes
from detectron2.utils.tracing import assert_fx_safe

"""
To export ROIPooler to torchscript, in this file, variables that should be annotated with
Expand Down Expand Up @@ -219,19 +220,20 @@ def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]):
"""
num_level_assignments = len(self.level_poolers)

assert isinstance(x, list) and isinstance(
box_lists, list
), "Arguments to pooler must be lists"
assert (
len(x) == num_level_assignments
), "unequal value, num_level_assignments={}, but x is list of {} Tensors".format(
num_level_assignments, len(x)
assert_fx_safe(
isinstance(x, list) and isinstance(box_lists, list), "Arguments to pooler must be lists"
)

assert len(box_lists) == x[0].size(
0
), "unequal value, x[0] batch dim 0 is {}, but box_list has length {}".format(
x[0].size(0), len(box_lists)
assert_fx_safe(
len(x) == num_level_assignments,
"unequal value, num_level_assignments={}, but x is list of {} Tensors".format(
num_level_assignments, len(x)
),
)
assert_fx_safe(
len(box_lists) == x[0].size(0),
"unequal value, x[0] batch dim 0 is {}, but box_list has length {}".format(
x[0].size(0), len(box_lists)
),
)
if len(box_lists) == 0:
return _create_zeros(None, x[0].shape[1], *self.output_size, x[0])
Expand Down
51 changes: 51 additions & 0 deletions detectron2/utils/tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import inspect
from typing import Union
import torch
from torch.fx._symbolic_trace import _orig_module_call
from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current

from detectron2.utils.env import TORCH_VERSION


@torch.jit.ignore
def is_fx_tracing_legacy() -> bool:
"""
Returns a bool indicating whether torch.fx is currently symbolically tracing a module.
Can be useful for gating module logic that is incompatible with symbolic tracing.
"""
return torch.nn.Module.__call__ is not _orig_module_call


@torch.jit.ignore
def is_fx_tracing() -> bool:
"""Returns whether execution is currently in
Torch FX tracing mode"""
if TORCH_VERSION >= (1, 10):
return is_fx_tracing_current()
else:
return is_fx_tracing_legacy()


@torch.jit.ignore
def assert_fx_safe(condition: Union[bool, str], message: str):
"""An FX-tracing safe version of assert.
Avoids erroneous type assertion triggering when types are masked inside
an fx.proxy.Proxy object during tracing.
Args: condition - either a boolean expression or a string representing
the condition to test. If this assert triggers an exception when tracing
due to dynamic control flow, try encasing the expression in quotation
marks and supplying it as a string."""
if not is_fx_tracing():
try:
if isinstance(condition, str):
caller_frame = inspect.currentframe().f_back
torch._assert(
eval(condition, caller_frame.f_globals, caller_frame.f_locals), message
)
else:
torch._assert(condition, message)
except torch.fx.proxy.TraceError as e:
print(
"Found a non-FX compatible assertion. Skipping the check. Failure is shown below"
+ str(e)
)

0 comments on commit 36a65a0

Please sign in to comment.