-
Notifications
You must be signed in to change notification settings - Fork 7.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable torch tracing by changing assertions in d2go forwards to allow…
… for torch.fx.proxy.Proxy type. Summary: Pull Request resolved: #4227 X-link: facebookresearch/d2go#241 Torch FX tracing propagates a type of `torch.fx.proxy.Proxy` through the graph. Existing type assertions in the d2go code base trigger during torch FX tracing, causing tracing to fail. This adds a check for FX tracing in progress and adds a helper function `assert_fx_safe()`, that can be used in place of a standard assertion. This function only applies the assertion if one is not tracing, allowing d2go assertion tests to be compatible with FX tracing. Reviewed By: wat3rBro Differential Revision: D35518556 fbshipit-source-id: a9b5d3d580518ca74948544973ae89f8b9de3282
- Loading branch information
1 parent
5aeb252
commit 36a65a0
Showing
2 changed files
with
65 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import inspect | ||
from typing import Union | ||
import torch | ||
from torch.fx._symbolic_trace import _orig_module_call | ||
from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current | ||
|
||
from detectron2.utils.env import TORCH_VERSION | ||
|
||
|
||
@torch.jit.ignore | ||
def is_fx_tracing_legacy() -> bool: | ||
""" | ||
Returns a bool indicating whether torch.fx is currently symbolically tracing a module. | ||
Can be useful for gating module logic that is incompatible with symbolic tracing. | ||
""" | ||
return torch.nn.Module.__call__ is not _orig_module_call | ||
|
||
|
||
@torch.jit.ignore | ||
def is_fx_tracing() -> bool: | ||
"""Returns whether execution is currently in | ||
Torch FX tracing mode""" | ||
if TORCH_VERSION >= (1, 10): | ||
return is_fx_tracing_current() | ||
else: | ||
return is_fx_tracing_legacy() | ||
|
||
|
||
@torch.jit.ignore | ||
def assert_fx_safe(condition: Union[bool, str], message: str): | ||
"""An FX-tracing safe version of assert. | ||
Avoids erroneous type assertion triggering when types are masked inside | ||
an fx.proxy.Proxy object during tracing. | ||
Args: condition - either a boolean expression or a string representing | ||
the condition to test. If this assert triggers an exception when tracing | ||
due to dynamic control flow, try encasing the expression in quotation | ||
marks and supplying it as a string.""" | ||
if not is_fx_tracing(): | ||
try: | ||
if isinstance(condition, str): | ||
caller_frame = inspect.currentframe().f_back | ||
torch._assert( | ||
eval(condition, caller_frame.f_globals, caller_frame.f_locals), message | ||
) | ||
else: | ||
torch._assert(condition, message) | ||
except torch.fx.proxy.TraceError as e: | ||
print( | ||
"Found a non-FX compatible assertion. Skipping the check. Failure is shown below" | ||
+ str(e) | ||
) |
36a65a0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for pytorch = 1.11.0 version got ImportError: cannot import name 'is_fx_tracing' from 'torch.fx._symbolic_trace'
36a65a0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same
36a65a0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTorch 1.10.0 has same error
36a65a0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please verify commit before push to master at least on pytorch stable!!
Noone knows where did your dirty imports come from....
36a65a0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
ModuleNotFoundError: No module named 'torch.fx._symbolic_trace'
36a65a0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
giving the error of
cannot import name 'is_fx_tracing' from 'torch.fx._symbolic_trace'
on pytorch=1.12
36a65a0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi everyone. Thanks for your feedback regarding incompatibility of this change with older versions of pytorch. I will prepare an update to resolve this. Thanks for your patience.
36a65a0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a patch to resolve these import errors available in PR #4491 and facebookresearch/d2go#362
The patch passes CI tests for PyTorch 1.10 and I am aiming to commit this to master on Monday, but if you need unblocking before then, please try out this patch and let me know any feedback.