From 3cfda4704cb3b23255f29d683cf05cbaf10daaa4 Mon Sep 17 00:00:00 2001 From: Simon Hollis Date: Mon, 22 Aug 2022 17:24:49 -0700 Subject: [PATCH] Fix import exceptions in tracing.py for older (<1.12) versions of pytorch Summary: X-link: https://github.com/facebookresearch/d2go/pull/362 Pull Request resolved: https://github.com/facebookresearch/detectron2/pull/4491 Recently landed D35518556 (https://github.com/facebookresearch/detectron2/commit/36a65a0907d90ed591479b2ebaa8b61cfa0b4ef0) / Github: 36a65a0907d90ed591479b2ebaa8b61cfa0b4ef0 throws an exception with older versions of PyTorch, due to a missing library for import. This has been reported by multiple members of the PyTorch community at https://github.com/facebookresearch/detectron2/commit/36a65a0907d90ed591479b2ebaa8b61cfa0b4ef0 This change uses `try/except` to check for libraries and set flags on presence/absence to later guard code that would use them. Reviewed By: wat3rBro Differential Revision: D38879134 fbshipit-source-id: 72f5a7a8d350eb82be87567f006368bf207f5a74 --- detectron2/modeling/poolers.py | 10 ++++++---- detectron2/utils/tracing.py | 32 ++++++++++++++++++++++++++------ 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/detectron2/modeling/poolers.py b/detectron2/modeling/poolers.py index 12073b0524..3393794507 100644 --- a/detectron2/modeling/poolers.py +++ b/detectron2/modeling/poolers.py @@ -7,7 +7,7 @@ from detectron2.layers import ROIAlign, ROIAlignRotated, cat, nonzero_tuple, shapes_to_tensor from detectron2.structures import Boxes -from detectron2.utils.tracing import assert_fx_safe +from detectron2.utils.tracing import assert_fx_safe, is_fx_tracing """ To export ROIPooler to torchscript, in this file, variables that should be annotated with @@ -220,9 +220,11 @@ def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]): """ num_level_assignments = len(self.level_poolers) - assert_fx_safe( - isinstance(x, list) and isinstance(box_lists, list), "Arguments to pooler must be lists" - ) + if not is_fx_tracing(): + torch._assert( + isinstance(x, list) and isinstance(box_lists, list), + "Arguments to pooler must be lists", + ) assert_fx_safe( len(x) == num_level_assignments, "unequal value, num_level_assignments={}, but x is list of {} Tensors".format( diff --git a/detectron2/utils/tracing.py b/detectron2/utils/tracing.py index 994b615a76..577df4e2f4 100644 --- a/detectron2/utils/tracing.py +++ b/detectron2/utils/tracing.py @@ -1,11 +1,22 @@ import inspect -from typing import Union import torch -from torch.fx._symbolic_trace import _orig_module_call -from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current from detectron2.utils.env import TORCH_VERSION +try: + from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current + + tracing_current_exists = True +except ImportError: + tracing_current_exists = False + +try: + from torch.fx._symbolic_trace import _orig_module_call + + tracing_legacy_exists = True +except ImportError: + tracing_legacy_exists = False + @torch.jit.ignore def is_fx_tracing_legacy() -> bool: @@ -20,14 +31,18 @@ def is_fx_tracing_legacy() -> bool: def is_fx_tracing() -> bool: """Returns whether execution is currently in Torch FX tracing mode""" - if TORCH_VERSION >= (1, 10): + if TORCH_VERSION >= (1, 10) and tracing_current_exists: return is_fx_tracing_current() - else: + elif tracing_legacy_exists: return is_fx_tracing_legacy() + else: + # Can't find either current or legacy tracing indication code. + # Enabling this assert_fx_safe() call regardless of tracing status. + return False @torch.jit.ignore -def assert_fx_safe(condition: Union[bool, str], message: str): +def assert_fx_safe(condition: bool, message: str) -> torch.Tensor: """An FX-tracing safe version of assert. Avoids erroneous type assertion triggering when types are masked inside an fx.proxy.Proxy object during tracing. @@ -35,6 +50,8 @@ def assert_fx_safe(condition: Union[bool, str], message: str): the condition to test. If this assert triggers an exception when tracing due to dynamic control flow, try encasing the expression in quotation marks and supplying it as a string.""" + # Must return a concrete tensor for compatibility with PyTorch <=1.8. + # If <=1.8 compatibility is not needed, return type can be converted to None if not is_fx_tracing(): try: if isinstance(condition, str): @@ -42,10 +59,13 @@ def assert_fx_safe(condition: Union[bool, str], message: str): torch._assert( eval(condition, caller_frame.f_globals, caller_frame.f_locals), message ) + return torch.ones(1) else: torch._assert(condition, message) + return torch.ones(1) except torch.fx.proxy.TraceError as e: print( "Found a non-FX compatible assertion. Skipping the check. Failure is shown below" + str(e) ) + return torch.zeros(1)