Skip to content

Commit

Permalink
[Feature] Add is_tracing to wrap torch.jit.is_tracing in differen…
Browse files Browse the repository at this point in the history
…t versions. (#1187)

* Add `is_tracing` to wrap `torch.jit.is_tracing` in different versions.

* Remame `is_tracing` to `is_jit_tracing`

* Ignore `is_jit_tracing` tests in CI.
  • Loading branch information
mzr1996 authored Jul 13, 2021
1 parent c3ddcf9 commit 6659c38
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
- name: Run unittests and generate coverage report
run: |
pip install -r requirements/test.txt
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_utils/test_trace.py
build_without_ops:
runs-on: ubuntu-18.04
Expand Down
3 changes: 2 additions & 1 deletion mmcv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
_AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd,
_ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config)
from .registry import Registry, build_from_cfg
from .trace import is_jit_tracing
__all__ = [
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast',
Expand All @@ -63,5 +64,5 @@
'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
'assert_params_all_zeros', 'check_python_script',
'is_method_overridden'
'is_method_overridden', 'is_jit_tracing'
]
21 changes: 21 additions & 0 deletions mmcv/utils/trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import warnings
from distutils.version import LooseVersion

import torch


def is_jit_tracing() -> bool:
if LooseVersion(torch.__version__) >= LooseVersion('1.6.0'):
on_trace = torch.jit.is_tracing()
# In PyTorch 1.6, torch.jit.is_tracing has a bug.
# Refers to https://github.com/pytorch/pytorch/issues/42448
if isinstance(on_trace, bool):
return on_trace
else:
return torch._C._is_tracing()
else:
warnings.warn(
'torch.jit.is_tracing is only supported after v1.6.0. '
'Therefore is_tracing returns False automatically. Please '
'set on_trace manually if you are using trace.', UserWarning)
return False
26 changes: 26 additions & 0 deletions tests/test_utils/test_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from distutils.version import LooseVersion

import pytest
import torch

from mmcv.utils import is_jit_tracing


@pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion('1.6.0'),
reason='torch.jit.is_tracing is not available before 1.6.0')
def test_is_jit_tracing():

def foo(x):
if is_jit_tracing():
return x
else:
return x.tolist()

x = torch.rand(3)
# test without trace
assert isinstance(foo(x), list)

# test with trace
traced_foo = torch.jit.trace(foo, (torch.rand(1), ))
assert isinstance(traced_foo(x), torch.Tensor)

0 comments on commit 6659c38

Please sign in to comment.