Skip to content

Commit

Permalink
Fix import exceptions in tracing.py for older (<1.12) versions of pyt…
Browse files Browse the repository at this point in the history
…orch

Summary:
X-link: facebookresearch/d2go#362

Pull Request resolved: facebookresearch#4491

Recently landed D35518556 (facebookresearch@36a65a0) / Github: 36a65a0 throws an exception with older versions of PyTorch, due to a missing library for import. This has been reported by multiple members of the PyTorch community at facebookresearch@36a65a0

This change uses `try/except` to check for libraries and set flags on presence/absence to later guard code that would use them.

Differential Revision: D38879134

fbshipit-source-id: 6a85f8044b06e2652ecc7ff6e7b358dd4ecccccb
  • Loading branch information
Simon Hollis authored and facebook-github-bot committed Aug 20, 2022
1 parent 89ec4ab commit f208983
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 10 deletions.
7 changes: 4 additions & 3 deletions detectron2/modeling/poolers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from detectron2.layers import ROIAlign, ROIAlignRotated, cat, nonzero_tuple, shapes_to_tensor
from detectron2.structures import Boxes
from detectron2.utils.tracing import assert_fx_safe
from detectron2.utils.tracing import assert_fx_safe, assert_fx_safe_expr

"""
To export ROIPooler to torchscript, in this file, variables that should be annotated with
Expand Down Expand Up @@ -220,8 +220,9 @@ def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]):
"""
num_level_assignments = len(self.level_poolers)

assert_fx_safe(
isinstance(x, list) and isinstance(box_lists, list), "Arguments to pooler must be lists"
assert_fx_safe_expr(
"isinstance(x, list) and isinstance(box_lists, list)",
"Arguments to pooler must be lists",
)
assert_fx_safe(
len(x) == num_level_assignments,
Expand Down
49 changes: 42 additions & 7 deletions detectron2/utils/tracing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import inspect
from typing import Union
import torch
from torch.fx._symbolic_trace import _orig_module_call
from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current

from detectron2.utils.env import TORCH_VERSION

try:
from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current

tracing_current_exists = True
except ImportError:
tracing_current_exists = False

try:
from torch.fx._symbolic_trace import _orig_module_call

tracing_legacy_exists = True
except ImportError:
tracing_legacy_exists = False


@torch.jit.ignore
def is_fx_tracing_legacy() -> bool:
Expand All @@ -20,14 +31,18 @@ def is_fx_tracing_legacy() -> bool:
def is_fx_tracing() -> bool:
"""Returns whether execution is currently in
Torch FX tracing mode"""
if TORCH_VERSION >= (1, 10):
if TORCH_VERSION >= (1, 10) and tracing_current_exists:
return is_fx_tracing_current()
else:
elif tracing_legacy_exists:
return is_fx_tracing_legacy()
else:
# Can't find either current or legacy tracing indication code.
# Enabling this assert_fx_safe() call regardless of tracing status.
return False


@torch.jit.ignore
def assert_fx_safe(condition: Union[bool, str], message: str):
@torch.jit.ignore # Remove and "try blocks aren't supported comes from torch.jit"
def assert_fx_safe(condition: bool, message: str) -> None:
"""An FX-tracing safe version of assert.
Avoids erroneous type assertion triggering when types are masked inside
an fx.proxy.Proxy object during tracing.
Expand All @@ -49,3 +64,23 @@ def assert_fx_safe(condition: Union[bool, str], message: str):
"Found a non-FX compatible assertion. Skipping the check. Failure is shown below"
+ str(e)
)


def assert_fx_safe_expr(condition: str, message: str) -> None:
if not is_fx_tracing():
try:
if isinstance(condition, str):
caller_frame = inspect.currentframe().f_back
torch._assert(
eval(condition, caller_frame.f_globals, caller_frame.f_locals), message
)
else:
raise TypeError(
"Expected a string condition argument. \
Did you mean to use the bool-taking 'assert_fx_safe' instead?"
)
except torch.fx.proxy.TraceError as e:
print(
"Found a non-FX compatible assertion. Skipping the check. Failure is shown below"
+ str(e)
)

0 comments on commit f208983

Please sign in to comment.