forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathStorageImpl.cpp
104 lines (93 loc) · 3.78 KB
/
StorageImpl.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#include <c10/core/StorageImpl.h>
#include <c10/util/flat_hash_map.h>
namespace c10 {
// The array to save function pointer for custom storageImpl create.
C10_API std::array<StorageImplCreateHelper, at::COMPILE_TIME_MAX_DEVICE_TYPES>
StorageImplCreate;
// A allowlist of device type, currently available is PrivateUse1.
static ska::flat_hash_set<c10::DeviceType> DeviceTypeAllowList{
DeviceType::PrivateUse1};
void throwNullDataPtrError() {
TORCH_CHECK(
false,
"Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). "
"If you're using torch.compile/export/fx, it is likely that we are erroneously "
"tracing into a custom kernel. To fix this, please wrap the custom kernel into "
"an opaque custom op. Please see the following for details: "
"https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ");
}
// NOTE: [FakeTensor.data_ptr deprecation]
// Today:
// - FakeTensor.data_ptr errors out in torch.compile.
// - FakeTensor.data_ptr raises the following deprecation warning otherwise.
// - the following deprecation warning is only for FakeTensor (for now).
// In the future we can consider extending to more wrapper Tensor subclasses.
void warnDeprecatedDataPtr() {
TORCH_WARN_ONCE(
"Accessing the data pointer of FakeTensor is deprecated and will error in "
"PyTorch 2.5. This is almost definitely a bug in your code and will "
"cause undefined behavior with subsystems like torch.compile. "
"Please wrap calls to tensor.data_ptr() in an opaque custom op; "
"If all else fails, you can guard accesses to tensor.data_ptr() on "
"isinstance(tensor, FakeTensor).")
}
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
// Allowlist verification.
// Only if the devicetype is in the allowlist,
// we allow the extension to be registered for storageImpl create.
TORCH_CHECK(
DeviceTypeAllowList.find(t) != DeviceTypeAllowList.end(),
"It is only allowed to register the storageImpl create method ",
"for PrivateUse1. ",
"If you have related storageImpl requirements, ",
"please expand the allowlist");
// Register function pointer.
int device_type = static_cast<int>(t);
TORCH_CHECK(
StorageImplCreate[device_type] == nullptr,
"The StorageImplCreate function pointer for ",
t,
" has been registered.");
StorageImplCreate[device_type] = fptr;
}
StorageImplCreateHelper GetStorageImplCreate(DeviceType t) {
int device_type = static_cast<int>(t);
return StorageImplCreate[device_type];
}
c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
c10::StorageImpl::use_byte_size_t use_byte_size,
c10::SymInt size_bytes,
c10::DataPtr data_ptr,
c10::Allocator* allocator,
bool resizable,
c10::optional<at::Device> device_opt) {
// This will be non-nullptr only when there is a custom StorageImpl
// constructor for the given device
c10::StorageImplCreateHelper fptr = nullptr;
if (device_opt.has_value()) {
// We only need to check this here as this is the only case where we can
// have a device that is not CPU (and thus for which the StorageImpl
// constructor can be overwritten).
fptr = c10::GetStorageImplCreate(device_opt.value().type());
}
if (fptr != nullptr) {
return fptr(
use_byte_size,
std::move(size_bytes),
std::move(data_ptr),
allocator,
resizable);
}
// Create a c10::StorageImpl object.
if (data_ptr != nullptr) {
return c10::make_intrusive<c10::StorageImpl>(
use_byte_size,
std::move(size_bytes),
std::move(data_ptr),
allocator,
resizable);
}
return c10::make_intrusive<c10::StorageImpl>(
use_byte_size, std::move(size_bytes), allocator, resizable);
}
} // namespace c10