Skip to content

Commit

Permalink
Fix import exceptions in tracing.py for older (<1.12) versions of pyt…
Browse files Browse the repository at this point in the history
…orch

Summary:
X-link: facebookresearch/d2go#362

Pull Request resolved: facebookresearch#4491

Recently landed D35518556 (facebookresearch@36a65a0) / Github: 36a65a0 throws an exception with older versions of PyTorch, due to a missing library for import. This has been reported by multiple members of the PyTorch community at facebookresearch@36a65a0

This change uses `try/except` to check for libraries and set flags on presence/absence to later guard code that would use them.

Differential Revision: D38879134

fbshipit-source-id: 7c16f5cea14f7f19184f8c14f2503ab330e73b14
  • Loading branch information
Simon Hollis authored and facebook-github-bot committed Aug 19, 2022
1 parent 36a65a0 commit cc4f271
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
6 changes: 3 additions & 3 deletions detectron2/modeling/poolers.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,16 @@ def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]):
num_level_assignments = len(self.level_poolers)

assert_fx_safe(
isinstance(x, list) and isinstance(box_lists, list), "Arguments to pooler must be lists"
"isinstance(x, list) and isinstance(box_lists, list)", "Arguments to pooler must be lists"
)
assert_fx_safe(
len(x) == num_level_assignments,
"len(x) == num_level_assignments",
"unequal value, num_level_assignments={}, but x is list of {} Tensors".format(
num_level_assignments, len(x)
),
)
assert_fx_safe(
len(box_lists) == x[0].size(0),
"len(box_lists) == x[0].size(0)",
"unequal value, x[0] batch dim 0 is {}, but box_list has length {}".format(
x[0].size(0), len(box_lists)
),
Expand Down
26 changes: 21 additions & 5 deletions detectron2/utils/tracing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import inspect
from typing import Union
import torch
from torch.fx._symbolic_trace import _orig_module_call
from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current

from detectron2.utils.env import TORCH_VERSION

try:
from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current

tracing_current_exists = True
except ImportError:
tracing_current_exists = False

try:
from torch.fx._symbolic_trace import _orig_module_call

tracing_legacy_exists = True
except ImportError:
tracing_legacy_exists = False


@torch.jit.ignore
def is_fx_tracing_legacy() -> bool:
Expand All @@ -20,14 +32,18 @@ def is_fx_tracing_legacy() -> bool:
def is_fx_tracing() -> bool:
"""Returns whether execution is currently in
Torch FX tracing mode"""
if TORCH_VERSION >= (1, 10):
if TORCH_VERSION >= (1, 10) and tracing_current_exists:
return is_fx_tracing_current()
else:
elif tracing_legacy_exists:
return is_fx_tracing_legacy()
else:
# Can't find either current or legacy tracing indication code.
# Enabling this assert_fx_safe() call regardless of tracing status.
return False


@torch.jit.ignore
def assert_fx_safe(condition: Union[bool, str], message: str):
def assert_fx_safe(condition: str, message: str):
"""An FX-tracing safe version of assert.
Avoids erroneous type assertion triggering when types are masked inside
an fx.proxy.Proxy object during tracing.
Expand Down

0 comments on commit cc4f271

Please sign in to comment.