Skip to content

Commit

Permalink
Fix torch.backends.cudnn mypy error (pytorch#38947)
Browse files Browse the repository at this point in the history
Summary:
Fix pytorch#38410

![image](https://user-images.githubusercontent.com/6421097/82724121-74b26880-9c99-11ea-9b63-e92de2dccdf2.png)
Pull Request resolved: pytorch#38947

Differential Revision: D21765290

Pulled By: ezyang

fbshipit-source-id: 5d2b25f039a653c609d60cdaac4a7ac5812ae291
  • Loading branch information
ShawnZhong authored and facebook-github-bot committed Jun 3, 2020
1 parent 6a60a8c commit 21ba3b4
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 13 deletions.
6 changes: 0 additions & 6 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -340,15 +340,9 @@ ignore_errors = True
[mypy-torch.multiprocessing.spawn]
ignore_errors = True

[mypy-torch.backends.cudnn.rnn]
ignore_errors = True

[mypy-torch.backends.cuda]
ignore_errors = True

[mypy-torch.backends.cudnn]
ignore_errors = True

[mypy-torch.backends.quantized]
ignore_errors = True

Expand Down
7 changes: 7 additions & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,19 @@ def _get_backcompat_keepdim_warn() -> _bool: ...
def _is_xnnpack_enabled() -> _bool: ...
def _get_mkldnn_enabled() -> _bool: ...
def _set_mkldnn_enabled(arg: _bool) -> None: ...
def _get_cudnn_enabled() -> _bool: ...
def _set_cudnn_enabled(arg: _bool) -> None: ...
def _get_cudnn_benchmark() -> _bool: ...
def _set_cudnn_benchmark(arg: _bool) -> None: ...
def _get_cudnn_deterministic() -> _bool: ...
def _set_cudnn_deterministic(arg: _bool) -> None: ...
def _set_default_tensor_type(type) -> None: ... # ick, what a bad legacy API
def _set_default_dtype(d: _dtype) -> None: ...
def _initExtension(shm_manager_path: str) -> None: ...
has_openmp: _bool
has_mkldnn: _bool
has_mkl: _bool
has_cudnn: _bool
_GLIBCXX_USE_CXX11_ABI: _bool

# Defined in torch/csrc/jit/python/script_init.cpp
Expand Down
17 changes: 17 additions & 0 deletions torch/_C/_cudnn.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from enum import Enum

from torch.types import Tuple, Number, _bool

# Defined in torch/csrc/cuda/shared/cudnn.cpp
is_cuda: _bool

def getRuntimeVersion() -> Tuple[int, int, int]: ...
def getCompileVersion() -> Tuple[int, int, int]: ...
def getVersionInt() -> int: ...

class RNNMode(int, Enum):
value: int
rnn_relu = ...
rnn_tanh = ...
lstm = ...
gru = ...
11 changes: 6 additions & 5 deletions torch/backends/cudnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
try:
from torch._C import _cudnn
except ImportError:
_cudnn = None
_cudnn = None # type: ignore

# Write:
#
Expand Down Expand Up @@ -83,11 +83,7 @@ def is_acceptable(tensor):
return True


_handles = {}


def set_flags(_enabled, _benchmark, _deterministic):
global benchmark, deterministic
orig_flags = (torch._C._get_cudnn_enabled(),
torch._C._get_cudnn_benchmark(),
torch._C._get_cudnn_deterministic())
Expand Down Expand Up @@ -124,3 +120,8 @@ def __init__(self, m, name):
# This is the sys.modules replacement trick, see
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
sys.modules[__name__] = CudnnModule(sys.modules[__name__], __name__)

# Add type annotation for the replaced module
enabled: bool
deterministic: bool
benchmark: bool
4 changes: 2 additions & 2 deletions torch/backends/cudnn/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
except ImportError:
# Uses of all the functions below should be guarded by torch.backends.cudnn.is_available(),
# so it's safe to not emit any checks here.
_cudnn = None
_cudnn = None # type: ignore


def get_cudnn_mode(mode):
Expand Down Expand Up @@ -48,7 +48,7 @@ def init_dropout_state(dropout, train, dropout_seed, dropout_state):
if dropout_p == 0:
dropout_state[dropout_desc_name] = Unserializable(None)
else:
dropout_state[dropout_desc_name] = Unserializable(torch._cudnn_init_dropout_state(
dropout_state[dropout_desc_name] = Unserializable(torch._cudnn_init_dropout_state( # type: ignore
dropout_p,
train,
dropout_seed,
Expand Down

0 comments on commit 21ba3b4

Please sign in to comment.