From e731410cf235fdac72b0a74aa5c83190026c05f2 Mon Sep 17 00:00:00 2001 From: Simon Hollis Date: Fri, 19 Aug 2022 21:23:33 -0700 Subject: [PATCH] Fix import exceptions in tracing.py for older (<1.12) versions of pytorch Summary: X-link: https://github.com/facebookresearch/d2go/pull/362 Pull Request resolved: https://github.com/facebookresearch/detectron2/pull/4491 Recently landed D35518556 (https://github.com/facebookresearch/detectron2/commit/36a65a0907d90ed591479b2ebaa8b61cfa0b4ef0) / Github: 36a65a0907d90ed591479b2ebaa8b61cfa0b4ef0 throws an exception with older versions of PyTorch, due to a missing library for import. This has been reported by multiple members of the PyTorch community at https://github.com/facebookresearch/detectron2/commit/36a65a0907d90ed591479b2ebaa8b61cfa0b4ef0 This change uses `try/except` to check for libraries and set flags on presence/absence to later guard code that would use them. Differential Revision: D38879134 fbshipit-source-id: a6d1b2e7484a44e0bf611eee293aef351dd0db45 --- detectron2/modeling/poolers.py | 2 +- detectron2/utils/tracing.py | 34 +++++++++++++++++++++++++--------- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/detectron2/modeling/poolers.py b/detectron2/modeling/poolers.py index 12073b0524..f8d0493989 100644 --- a/detectron2/modeling/poolers.py +++ b/detectron2/modeling/poolers.py @@ -221,7 +221,7 @@ def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]): num_level_assignments = len(self.level_poolers) assert_fx_safe( - isinstance(x, list) and isinstance(box_lists, list), "Arguments to pooler must be lists" + True, "Arguments to pooler must be lists", condition_expr="isinstance(x, list) and isinstance(box_lists, list)" ) assert_fx_safe( len(x) == num_level_assignments, diff --git a/detectron2/utils/tracing.py b/detectron2/utils/tracing.py index 994b615a76..d4b2e6c531 100644 --- a/detectron2/utils/tracing.py +++ b/detectron2/utils/tracing.py @@ -1,11 +1,23 @@ import inspect -from typing import Union +from typing import Optional import torch -from torch.fx._symbolic_trace import _orig_module_call -from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current from detectron2.utils.env import TORCH_VERSION +try: + from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current + + tracing_current_exists = True +except ImportError: + tracing_current_exists = False + +try: + from torch.fx._symbolic_trace import _orig_module_call + + tracing_legacy_exists = True +except ImportError: + tracing_legacy_exists = False + @torch.jit.ignore def is_fx_tracing_legacy() -> bool: @@ -20,14 +32,18 @@ def is_fx_tracing_legacy() -> bool: def is_fx_tracing() -> bool: """Returns whether execution is currently in Torch FX tracing mode""" - if TORCH_VERSION >= (1, 10): + if TORCH_VERSION >= (1, 10) and tracing_current_exists: return is_fx_tracing_current() - else: + elif tracing_legacy_exists: return is_fx_tracing_legacy() + else: + # Can't find either current or legacy tracing indication code. + # Enabling this assert_fx_safe() call regardless of tracing status. + return False -@torch.jit.ignore -def assert_fx_safe(condition: Union[bool, str], message: str): +@torch.jit.ignore # Remove and "try blocks aren't supported comes from torch.jit" +def assert_fx_safe(condition: bool, message: str, condition_expr: Optional[str] = None) -> None: """An FX-tracing safe version of assert. Avoids erroneous type assertion triggering when types are masked inside an fx.proxy.Proxy object during tracing. @@ -37,10 +53,10 @@ def assert_fx_safe(condition: Union[bool, str], message: str): marks and supplying it as a string.""" if not is_fx_tracing(): try: - if isinstance(condition, str): + if condition_expr is not None and isinstance(condition_expr, str): caller_frame = inspect.currentframe().f_back torch._assert( - eval(condition, caller_frame.f_globals, caller_frame.f_locals), message + eval(condition_expr, caller_frame.f_globals, caller_frame.f_locals), message ) else: torch._assert(condition, message)