diff --git a/detectron2/modeling/poolers.py b/detectron2/modeling/poolers.py index 12073b0524..8d368e81ae 100644 --- a/detectron2/modeling/poolers.py +++ b/detectron2/modeling/poolers.py @@ -7,7 +7,7 @@ from detectron2.layers import ROIAlign, ROIAlignRotated, cat, nonzero_tuple, shapes_to_tensor from detectron2.structures import Boxes -from detectron2.utils.tracing import assert_fx_safe +from detectron2.utils.tracing import assert_fx_safe, assert_fx_safe_expr """ To export ROIPooler to torchscript, in this file, variables that should be annotated with @@ -220,8 +220,9 @@ def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]): """ num_level_assignments = len(self.level_poolers) - assert_fx_safe( - isinstance(x, list) and isinstance(box_lists, list), "Arguments to pooler must be lists" + assert_fx_safe_expr( + "isinstance(x, list) and isinstance(box_lists, list)", + "Arguments to pooler must be lists", ) assert_fx_safe( len(x) == num_level_assignments, diff --git a/detectron2/utils/tracing.py b/detectron2/utils/tracing.py index 994b615a76..5d7c7b115a 100644 --- a/detectron2/utils/tracing.py +++ b/detectron2/utils/tracing.py @@ -1,11 +1,22 @@ import inspect -from typing import Union import torch -from torch.fx._symbolic_trace import _orig_module_call -from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current from detectron2.utils.env import TORCH_VERSION +try: + from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current + + tracing_current_exists = True +except ImportError: + tracing_current_exists = False + +try: + from torch.fx._symbolic_trace import _orig_module_call + + tracing_legacy_exists = True +except ImportError: + tracing_legacy_exists = False + @torch.jit.ignore def is_fx_tracing_legacy() -> bool: @@ -20,14 +31,18 @@ def is_fx_tracing_legacy() -> bool: def is_fx_tracing() -> bool: """Returns whether execution is currently in Torch FX tracing mode""" - if TORCH_VERSION >= (1, 10): + if TORCH_VERSION >= (1, 10) and tracing_current_exists: return is_fx_tracing_current() - else: + elif tracing_legacy_exists: return is_fx_tracing_legacy() + else: + # Can't find either current or legacy tracing indication code. + # Enabling this assert_fx_safe() call regardless of tracing status. + return False -@torch.jit.ignore -def assert_fx_safe(condition: Union[bool, str], message: str): +@torch.jit.ignore # Remove and "try blocks aren't supported comes from torch.jit" +def assert_fx_safe(condition: bool, message: str) -> None: """An FX-tracing safe version of assert. Avoids erroneous type assertion triggering when types are masked inside an fx.proxy.Proxy object during tracing. @@ -49,3 +64,23 @@ def assert_fx_safe(condition: Union[bool, str], message: str): "Found a non-FX compatible assertion. Skipping the check. Failure is shown below" + str(e) ) + + +def assert_fx_safe_expr(condition: str, message: str) -> None: + if not is_fx_tracing(): + try: + if isinstance(condition, str): + caller_frame = inspect.currentframe().f_back + torch._assert( + eval(condition, caller_frame.f_globals, caller_frame.f_locals), message + ) + else: + raise TypeError( + "Expected a string condition argument. \ + Did you mean to use the bool-taking 'assert_fx_safe' instead?" + ) + except torch.fx.proxy.TraceError as e: + print( + "Found a non-FX compatible assertion. Skipping the check. Failure is shown below" + + str(e) + )