forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathDeviceGuard.h
199 lines (173 loc) · 7.53 KB
/
DeviceGuard.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
#pragma once
#include <c10/core/Device.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/impl/InlineDeviceGuard.h>
#include <c10/core/impl/VirtualGuardImpl.h>
#include <c10/util/Optional.h>
namespace c10 {
/// RAII guard that sets a certain default device in its constructor, and
/// changes it back to the device that was originally active upon destruction.
///
/// The device is always reset to the one that was active at the time of
/// construction of the guard. Even if you `set_device` after construction, the
/// destructor will still reset the device to the one that was active at
/// construction time.
///
/// This device guard does NOT have an uninitialized state; it is guaranteed
/// to reset a device on exit. If you are in a situation where you *might*
/// want to setup a guard (i.e., are looking for the moral equivalent
/// of optional<DeviceGuard>), see OptionalDeviceGuard.
class DeviceGuard {
public:
/// No default constructor; see Note [Omitted default constructor from RAII]
explicit DeviceGuard() = delete;
/// Set the current device to the passed Device.
explicit DeviceGuard(Device device) : guard_(device) {}
/// This constructor is for testing only.
explicit DeviceGuard(
Device device,
const impl::DeviceGuardImplInterface* impl)
: guard_(device, impl) {}
/// Copy is disallowed
DeviceGuard(const DeviceGuard&) = delete;
DeviceGuard& operator=(const DeviceGuard&) = delete;
/// Move is disallowed, as DeviceGuard does not have an uninitialized state,
/// which is required for moves on types with nontrivial destructors.
DeviceGuard(DeviceGuard&& other) = delete;
DeviceGuard& operator=(DeviceGuard&& other) = delete;
/// Sets the device to the given one. The specified device must be consistent
/// with the device type originally specified during guard construction.
///
/// TODO: The consistency check here is inconsistent with StreamGuard's
/// behavior with set_stream, where a stream on a different device than
/// the original one isn't an error; we just reset the stream and then
/// switch devices.
void reset_device(at::Device device) {
guard_.reset_device(device);
}
/// This method is for testing only.
void reset_device(
at::Device device,
const impl::DeviceGuardImplInterface* impl) {
guard_.reset_device(device, impl);
}
/// Sets the device index to the given one. The device type is inferred
/// from the original device type the guard was constructed with.
void set_index(DeviceIndex index) {
guard_.set_index(index);
}
/// Returns the device that was set at the time the guard was constructed.
Device original_device() const {
return guard_.original_device();
}
/// Returns the most recent device that was set using this device guard,
/// either from construction, or via set_device.
Device current_device() const {
return guard_.current_device();
}
private:
impl::InlineDeviceGuard<impl::VirtualGuardImpl> guard_;
};
/**
* A OptionalDeviceGuard is an RAII class that sets a device to some value on
* initialization, and resets the device to its original value on destruction.
* Morally, a OptionalDeviceGuard is equivalent to optional<DeviceGuard>, but
* with extra constructors and methods as appropriate.
*
* Besides its obvious use (optionally applying a DeviceGuard),
* OptionalDeviceGuard is often also used for the following idiom:
*
* OptionalDeviceGuard g;
* for (const auto& t : tensors) {
* g.set_device(t.device());
* do_something_with(t);
* }
*
* This usage is marginally more efficient than constructing a DeviceGuard every
* iteration of the for loop, as it avoids an unnecessary device reset.
*
* Unlike DeviceGuard, a OptionalDeviceGuard may be uninitialized. This occurs
* when you use the nullary constructor, or pass a nullopt to the constructor.
* Uninitialized OptionalDeviceGuards do *nothing*; they do not know what the
* original device was and they do not reset on destruction. This is why
* original_device() and current_device() return optional<Device> rather than
* Device (as they do in DeviceGuard), and also is why we didn't just
* provide OptionalDeviceGuard by default and hide DeviceGuard from users.
*
* The semantics of an OptionalDeviceGuard are exactly explained by thinking
* of it as an optional<DeviceGuard>. In particular, an initialized
* OptionalDeviceGuard doesn't restore device to its value at construction; it
* restores device to its value *at initialization*. So if you have the
* program:
*
* setDevice(1);
* OptionalDeviceGuard g;
* setDevice(2);
* g.reset_device(Device(DeviceType::CUDA, 3)); // initializes!
*
* On destruction, g will reset device to 2, rather than 1.
*
* An uninitialized OptionalDeviceGuard is distinct from a (initialized)
* DeviceGuard whose original_device_ and current_device_ match, since the
* DeviceGuard will still reset the device to original_device_.
*/
class OptionalDeviceGuard {
public:
/// Create an uninitialized guard. Set the guard later using reset_device.
explicit OptionalDeviceGuard() = default;
/// Initialize the guard, setting the current device to the passed Device.
explicit OptionalDeviceGuard(Device device) : guard_(device) {}
/// Initialize the guard if a Device is passed; otherwise leave the
/// guard uninitialized.
explicit OptionalDeviceGuard(optional<Device> device) : guard_(device) {}
/// Constructor for testing only.
explicit OptionalDeviceGuard(
Device device,
const impl::DeviceGuardImplInterface* impl)
: guard_(device, impl) {}
/// Copy is disallowed
OptionalDeviceGuard(const OptionalDeviceGuard&) = delete;
OptionalDeviceGuard& operator=(const OptionalDeviceGuard&) = delete;
/// Move is disallowed
/// See Note [Explicit initialization of optional fields]
/// and // Note [Move construction for RAII guards is tricky]
/// for rationale.
OptionalDeviceGuard(OptionalDeviceGuard&& other) = delete;
OptionalDeviceGuard& operator=(OptionalDeviceGuard&& other) = delete;
/// Sets the device to the given one. The specified device must be consistent
/// with the device type originally specified during guard construction.
void reset_device(at::Device device) {
guard_.reset_device(device);
}
/// For testing only
void reset_device(
at::Device device,
const impl::DeviceGuardImplInterface* impl) {
guard_.reset_device(device, impl);
}
/// Returns the device that was set at the time the guard was constructed.
optional<Device> original_device() const {
return guard_.original_device();
}
/// Returns the most recent device that was set using this device guard,
/// either from construction, or via reset_device.
optional<Device> current_device() const {
return guard_.current_device();
}
private:
impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl> guard_{};
};
// Note [Whither the DeviceGuard boilerplate]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Design note: in principle, we could avoid these wrappers using:
//
// using DeviceGuard = impl::InlineDeviceGuard<impl::VirtualGuardImpl>;
// using OptionalDeviceGuard =
// impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl>;
//
// But the error messages are worse, and our users can't just look at the
// header file to find out what's going on. Furthermore, for specializations
// like CUDAStreamGuard, it can be profitable to replace some interfaces with
// refined types (e.g., return CUDAStream instead of Stream). So, we eat
// the boilerplate and write out the API explicitly.
} // namespace c10