forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathInferenceMode.h
86 lines (81 loc) · 3.47 KB
/
InferenceMode.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#pragma once
#include <c10/core/AutogradState.h>
#include <c10/core/DispatchKey.h>
#include <c10/core/DispatchKeySet.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/macros/Export.h>
namespace c10 {
// A RAII, thread local (!) guard that enables or disables inference mode upon
// construction, and sets it back to the original value upon destruction.
struct C10_API InferenceMode {
// Note [Expected TLS state in InferenceMode]:
// InferenceMode: ADInplaceOrView not in
// raw_local_dispatch_key_set.included(),
// Autograd in raw_local_dispatch_key_set.excluded()
// GradMode is disabled.
// NormalMode: ADInplaceOrView in raw_local_dispatch_key_set.included(),
// Autograd not in raw_local_dispatch_key_set.excluded()
// GradMode is enabled by default unless toggled manually
// through other APIs, e.g. NoGradGuard.
//
// Invariant:
// - ADInplaceOrView is never in the excluded set
// - Autograd is never in the included set
// - Setting InferenceMode will set GradMode accordingly, but not vice versa.
//
// 1. Why do we put ADInplaceOrView in included set outside InferenceMode?
//
// Inplace update to inference tensor outside InferenceMode is not
// allowed. See Note [Inplace update inference tensor] for more details.
// Without going through ADInplaceOrView kernel, we cannot throw error
// for `inference_tensor.add_(1)` case.
//
// 2. Why not put ADInplaceOrView in the excluded set inside InferenceMode?
//
// For example:
// torch::Tensor a = torch::ones({1, 2, 3}).set_requires_grad(true);
// torch::Tensor k = a + 2;
// {
// c10::InferenceMode guard(true);
// k.add_(2);
// }
// `k.add_(2)` still need to go through ADInplaceOrView kernel so that it's
// prepared for future autograd.
//
// 3. Why does setting InferenceMode also set GradMode?
//
// This is required since InferenceMode is a faster and more restrictive
// version of NoGradGuard. All runtime checks using GradMode::is_enabled()
// are applicable to InferenceMode as well, e.g.
// `tensorTypeInCurrentExecutionContext` in interpreter.cpp.
InferenceMode(bool enabled = true)
: prev_mode(AutogradState::get_tls_state()),
prev_keyset(c10::impl::tls_local_dispatch_key_set()) {
// Enabling inference mode means disabling grad modes
// And disabling inference mode means enabling grad modes
AutogradState::set_tls_state(AutogradState(
/* grad_mode */ !enabled,
/* inference_mode */ enabled,
/* fw_grad_mode */ !enabled,
/* multithreading_enabled*/ !enabled));
DispatchKeySet included = enabled
? prev_keyset.included_.remove(c10::DispatchKey::ADInplaceOrView)
: prev_keyset.included_.add(c10::DispatchKey::ADInplaceOrView);
DispatchKeySet excluded = enabled
? (prev_keyset.excluded_ | c10::autograd_dispatch_keyset)
: (prev_keyset.excluded_ - c10::autograd_dispatch_keyset);
c10::impl::PODLocalDispatchKeySet cur_keyset{};
cur_keyset.set_included(included);
cur_keyset.set_excluded(excluded);
c10::impl::_force_tls_local_dispatch_key_set(cur_keyset);
}
~InferenceMode() {
AutogradState::set_tls_state(prev_mode);
c10::impl::_force_tls_local_dispatch_key_set(prev_keyset);
}
static bool is_enabled();
private:
AutogradState prev_mode;
c10::impl::LocalDispatchKeySet prev_keyset;
};
} // namespace c10