Skip to content

Commit

Permalink
Fix import exceptions in tracing.py for older (<1.12) versions of pyt…
Browse files Browse the repository at this point in the history
…orch

Summary:
X-link: facebookresearch/d2go#362

Pull Request resolved: #4491

Recently landed D35518556 (36a65a0) / Github: 36a65a0 throws an exception with older versions of PyTorch, due to a missing library for import. This has been reported by multiple members of the PyTorch community at 36a65a0

This change uses `try/except` to check for libraries and set flags on presence/absence to later guard code that would use them.

Reviewed By: wat3rBro

Differential Revision: D38879134

fbshipit-source-id: 72f5a7a8d350eb82be87567f006368bf207f5a74
  • Loading branch information
Simon Hollis authored and facebook-github-bot committed Aug 23, 2022
1 parent 89ec4ab commit 3cfda47
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
10 changes: 6 additions & 4 deletions detectron2/modeling/poolers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from detectron2.layers import ROIAlign, ROIAlignRotated, cat, nonzero_tuple, shapes_to_tensor
from detectron2.structures import Boxes
from detectron2.utils.tracing import assert_fx_safe
from detectron2.utils.tracing import assert_fx_safe, is_fx_tracing

"""
To export ROIPooler to torchscript, in this file, variables that should be annotated with
Expand Down Expand Up @@ -220,9 +220,11 @@ def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]):
"""
num_level_assignments = len(self.level_poolers)

assert_fx_safe(
isinstance(x, list) and isinstance(box_lists, list), "Arguments to pooler must be lists"
)
if not is_fx_tracing():
torch._assert(
isinstance(x, list) and isinstance(box_lists, list),
"Arguments to pooler must be lists",
)
assert_fx_safe(
len(x) == num_level_assignments,
"unequal value, num_level_assignments={}, but x is list of {} Tensors".format(
Expand Down
32 changes: 26 additions & 6 deletions detectron2/utils/tracing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import inspect
from typing import Union
import torch
from torch.fx._symbolic_trace import _orig_module_call
from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current

from detectron2.utils.env import TORCH_VERSION

try:
from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current

tracing_current_exists = True
except ImportError:
tracing_current_exists = False

try:
from torch.fx._symbolic_trace import _orig_module_call

tracing_legacy_exists = True
except ImportError:
tracing_legacy_exists = False


@torch.jit.ignore
def is_fx_tracing_legacy() -> bool:
Expand All @@ -20,32 +31,41 @@ def is_fx_tracing_legacy() -> bool:
def is_fx_tracing() -> bool:
"""Returns whether execution is currently in
Torch FX tracing mode"""
if TORCH_VERSION >= (1, 10):
if TORCH_VERSION >= (1, 10) and tracing_current_exists:
return is_fx_tracing_current()
else:
elif tracing_legacy_exists:
return is_fx_tracing_legacy()
else:
# Can't find either current or legacy tracing indication code.
# Enabling this assert_fx_safe() call regardless of tracing status.
return False


@torch.jit.ignore
def assert_fx_safe(condition: Union[bool, str], message: str):
def assert_fx_safe(condition: bool, message: str) -> torch.Tensor:
"""An FX-tracing safe version of assert.
Avoids erroneous type assertion triggering when types are masked inside
an fx.proxy.Proxy object during tracing.
Args: condition - either a boolean expression or a string representing
the condition to test. If this assert triggers an exception when tracing
due to dynamic control flow, try encasing the expression in quotation
marks and supplying it as a string."""
# Must return a concrete tensor for compatibility with PyTorch <=1.8.
# If <=1.8 compatibility is not needed, return type can be converted to None
if not is_fx_tracing():
try:
if isinstance(condition, str):
caller_frame = inspect.currentframe().f_back
torch._assert(
eval(condition, caller_frame.f_globals, caller_frame.f_locals), message
)
return torch.ones(1)
else:
torch._assert(condition, message)
return torch.ones(1)
except torch.fx.proxy.TraceError as e:
print(
"Found a non-FX compatible assertion. Skipping the check. Failure is shown below"
+ str(e)
)
return torch.zeros(1)

0 comments on commit 3cfda47

Please sign in to comment.