Skip to content

Commit

Permalink
Fix import exceptions in tracing.py for older (<1.12) versions of pyt…
Browse files Browse the repository at this point in the history
…orch

Summary:
X-link: facebookresearch/d2go#362

Pull Request resolved: facebookresearch#4491

Recently landed D35518556 (facebookresearch@36a65a0) / Github: 36a65a0 throws an exception with older versions of PyTorch, due to a missing library for import. This has been reported by multiple members of the PyTorch community at facebookresearch@36a65a0

This change uses `try/except` to check for libraries and set flags on presence/absence to later guard code that would use them.

Differential Revision: D38879134

fbshipit-source-id: a6d1b2e7484a44e0bf611eee293aef351dd0db45
  • Loading branch information
Simon Hollis authored and facebook-github-bot committed Aug 20, 2022
1 parent 89ec4ab commit e731410
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
2 changes: 1 addition & 1 deletion detectron2/modeling/poolers.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]):
num_level_assignments = len(self.level_poolers)

assert_fx_safe(
isinstance(x, list) and isinstance(box_lists, list), "Arguments to pooler must be lists"
True, "Arguments to pooler must be lists", condition_expr="isinstance(x, list) and isinstance(box_lists, list)"
)
assert_fx_safe(
len(x) == num_level_assignments,
Expand Down
34 changes: 25 additions & 9 deletions detectron2/utils/tracing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import inspect
from typing import Union
from typing import Optional
import torch
from torch.fx._symbolic_trace import _orig_module_call
from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current

from detectron2.utils.env import TORCH_VERSION

try:
from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current

tracing_current_exists = True
except ImportError:
tracing_current_exists = False

try:
from torch.fx._symbolic_trace import _orig_module_call

tracing_legacy_exists = True
except ImportError:
tracing_legacy_exists = False


@torch.jit.ignore
def is_fx_tracing_legacy() -> bool:
Expand All @@ -20,14 +32,18 @@ def is_fx_tracing_legacy() -> bool:
def is_fx_tracing() -> bool:
"""Returns whether execution is currently in
Torch FX tracing mode"""
if TORCH_VERSION >= (1, 10):
if TORCH_VERSION >= (1, 10) and tracing_current_exists:
return is_fx_tracing_current()
else:
elif tracing_legacy_exists:
return is_fx_tracing_legacy()
else:
# Can't find either current or legacy tracing indication code.
# Enabling this assert_fx_safe() call regardless of tracing status.
return False


@torch.jit.ignore
def assert_fx_safe(condition: Union[bool, str], message: str):
@torch.jit.ignore # Remove and "try blocks aren't supported comes from torch.jit"
def assert_fx_safe(condition: bool, message: str, condition_expr: Optional[str] = None) -> None:
"""An FX-tracing safe version of assert.
Avoids erroneous type assertion triggering when types are masked inside
an fx.proxy.Proxy object during tracing.
Expand All @@ -37,10 +53,10 @@ def assert_fx_safe(condition: Union[bool, str], message: str):
marks and supplying it as a string."""
if not is_fx_tracing():
try:
if isinstance(condition, str):
if condition_expr is not None and isinstance(condition_expr, str):
caller_frame = inspect.currentframe().f_back
torch._assert(
eval(condition, caller_frame.f_globals, caller_frame.f_locals), message
eval(condition_expr, caller_frame.f_globals, caller_frame.f_locals), message
)
else:
torch._assert(condition, message)
Expand Down

0 comments on commit e731410

Please sign in to comment.