Skip to content

Commit c81c9ba

Browse files
zou3519pytorchmergebot
authored andcommitted
Disallow {FakeTensor,FunctionalTensor}.data_ptr (#122514)
This PR: - disallows FakeTensor.data_ptr when it is called inside PT2 or fx tracing. - disallows FunctionalTensor.data_ptr (python FunctionalTensor is only used in PT2) The motivation behind this is that the leading cause of segfaults when using custom ops with PT2 is calling .data_ptr on FunctionalTensor or FakeTensor. This change is BC-breaking. If your code broke as a result of this, it's because there was a bug in it (these .data_ptr should never be accessed!). You can either fix the bug (recommended) or get the previous behavior back with: ``` from torch._subclasses.fake_tensor import FakeTensor from torch._subclasses.functional_tensor import FunctionalTensor data_ptr = 0 if isinstance(tensor, (FakeTensor, FunctionalTensor)) else tensor.data_ptr() ``` Test Plan: - existing tests Differential Revision: [D55366199](https://our.internmc.facebook.com/intern/diff/D55366199) Pull Request resolved: #122514 Approved by: https://github.com/ezyang, https://github.com/albanD, https://github.com/yifuwang, https://github.com/kurtamohler
1 parent 04399a3 commit c81c9ba

File tree

15 files changed

+207
-13
lines changed

15 files changed

+207
-13
lines changed

c10/core/StorageImpl.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@ C10_API std::array<StorageImplCreateHelper, at::COMPILE_TIME_MAX_DEVICE_TYPES>
1111
static ska::flat_hash_set<c10::DeviceType> DeviceTypeAllowList{
1212
DeviceType::PrivateUse1};
1313

14+
void throwNullDataPtrError() {
15+
TORCH_CHECK(
16+
false,
17+
"Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). "
18+
"If you're using torch.compile/export/fx, it is likely that we are erroneously "
19+
"tracing into a custom kernel. To fix this, please wrap the custom kernel into "
20+
"an opaque custom op. Please see the following for details: "
21+
"https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ");
22+
}
23+
1424
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
1525
// Allowlist verification.
1626
// Only if the devicetype is in the allowlist,

c10/core/StorageImpl.h

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
namespace c10 {
1818

19+
C10_API void throwNullDataPtrError();
20+
1921
// A storage represents the underlying backing data buffer for a
2022
// tensor. This concept was inherited from the original Torch7
2123
// codebase; we'd kind of like to get rid of the concept
@@ -59,6 +61,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
5961
TORCH_INTERNAL_ASSERT(
6062
allocator_, "For resizable storage, allocator must be provided");
6163
}
64+
refresh_has_data_ptr_check();
6265
}
6366

6467
StorageImpl(
@@ -118,12 +121,22 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
118121
return resizable_;
119122
}
120123

124+
const at::DataPtr& data_ptr() const {
125+
return data_ptr_;
126+
}
127+
121128
at::DataPtr& mutable_data_ptr() {
122-
maybe_materialize_cow();
129+
if (C10_UNLIKELY(has_data_ptr_check_)) {
130+
if (throw_on_mutable_data_ptr_) {
131+
throwNullDataPtrError();
132+
}
133+
maybe_materialize_cow();
134+
}
123135
return data_ptr_;
124136
}
125137

126-
const at::DataPtr& data_ptr() const {
138+
// Returns the data_ptr. Bypasses all checks.
139+
at::DataPtr& _mutable_data_ptr_no_checks() {
127140
return data_ptr_;
128141
}
129142

@@ -137,14 +150,20 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
137150

138151
void set_data_ptr_noswap(at::DataPtr&& data_ptr) {
139152
data_ptr_ = std::move(data_ptr);
153+
refresh_has_data_ptr_check();
140154
}
141155

142156
const void* data() const {
143157
return data_ptr_.get();
144158
}
145159

146160
void* mutable_data() {
147-
maybe_materialize_cow();
161+
if (C10_UNLIKELY(has_data_ptr_check_)) {
162+
if (throw_on_mutable_data_ptr_) {
163+
throwNullDataPtrError();
164+
}
165+
maybe_materialize_cow();
166+
}
148167
return data_ptr_.mutable_get();
149168
}
150169

@@ -222,6 +241,11 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
222241
return &pyobj_slot_;
223242
}
224243

244+
void set_throw_on_mutable_data_ptr() {
245+
throw_on_mutable_data_ptr_ = true;
246+
refresh_has_data_ptr_check();
247+
}
248+
225249
protected:
226250
// materialize_cow_storage needs to call set_data_ptr_no_materlize_cow
227251
friend void c10::impl::cow::materialize_cow_storage(StorageImpl& storage);
@@ -231,13 +255,22 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
231255
at::DataPtr set_data_ptr_no_materialize_cow(at::DataPtr&& data_ptr) {
232256
at::DataPtr old_data_ptr(std::move(data_ptr_));
233257
data_ptr_ = std::move(data_ptr);
258+
refresh_has_data_ptr_check();
234259
return old_data_ptr;
235260
}
236261

237262
private:
263+
void refresh_has_data_ptr_check() {
264+
has_data_ptr_check_ = is_cow() || throw_on_mutable_data_ptr_;
265+
}
266+
267+
inline bool is_cow() const {
268+
return c10::impl::cow::is_cow_data_ptr(data_ptr_);
269+
}
270+
238271
// Triggers a copy if this is a copy-on-write tensor.
239272
void maybe_materialize_cow() {
240-
if (data_ptr_.get_deleter() == impl::cow::cow_deleter) {
273+
if (is_cow()) {
241274
impl::cow::materialize_cow_storage(*this);
242275
}
243276
}
@@ -249,6 +282,12 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
249282
// Identifies that Storage was received from another process and doesn't have
250283
// local to process cuda memory allocation
251284
bool received_cuda_;
285+
// All special checks in data/data_ptr calls are guarded behind this single
286+
// boolean. This is for performance: .data/.data_ptr calls are commonly in the
287+
// hot-path.
288+
bool has_data_ptr_check_ = false;
289+
// If we should throw when mutable_data_ptr() or mutable_data() is called.
290+
bool throw_on_mutable_data_ptr_ = false;
252291
Allocator* allocator_;
253292
impl::PyObjectSlot pyobj_slot_;
254293
};

c10/core/impl/COW.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ c10::intrusive_ptr<StorageImpl> lazy_clone_storage(StorageImpl& storage) {
8181
if (has_simple_data_ptr(storage)) {
8282
// Case 1) We have a simple data pointer: wrap it.
8383
std::unique_ptr<void, DeleterFnPtr> original_ctx =
84-
storage.mutable_data_ptr().move_context();
84+
storage._mutable_data_ptr_no_checks().move_context();
8585

8686
// Save this for the result.
8787
new_data_ptr = make_data_ptr(

test/dynamo/test_aot_autograd.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Owner(s): ["module: dynamo"]
2+
import copy
23
import re
34
import unittest
45
from textwrap import dedent
@@ -12,6 +13,8 @@
1213
import torch.utils._pytree as pytree
1314
from torch._dynamo.testing import CompileCounter, expectedFailureDynamic, rand_strided
1415
from torch._functorch.aot_autograd import _aot_export_function, create_functional_call
16+
from torch._subclasses.fake_tensor import FakeTensorMode
17+
from torch.fx.experimental.proxy_tensor import make_fx
1518
from torch.profiler import profile
1619
from torch.testing._internal.common_utils import compare_equal_outs_and_grads
1720

@@ -1104,6 +1107,77 @@ def fn(x, z):
11041107
self.assertEqual(x, x_opt)
11051108
self.assertEqual(z.grad, z_opt.grad)
11061109

1110+
def test_data_ptr_access_copy(self):
1111+
with FakeTensorMode(_allow_unsafe_data_ptr_access=False):
1112+
x = torch.randn(3)
1113+
y = copy.copy(x)
1114+
self.assertEqual(y.shape, x.shape)
1115+
1116+
def test_data_ptr_access_fails_in_forward(self):
1117+
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
1118+
torch.library.define("mylib::foo", "(Tensor x) -> Tensor", lib=lib)
1119+
1120+
@torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib)
1121+
def _(x):
1122+
x.data_ptr()
1123+
return x.clone()
1124+
1125+
x = torch.randn(3)
1126+
1127+
def data_ptr_graph_input(x):
1128+
r0 = torch.ops.mylib.foo(x)
1129+
return r0
1130+
1131+
def data_ptr_graph_intermediate(x):
1132+
y = x.clone()
1133+
r0 = torch.ops.mylib.foo(y)
1134+
return r0
1135+
1136+
tests = [data_ptr_graph_input, data_ptr_graph_intermediate]
1137+
1138+
def ctx():
1139+
return self.assertRaisesRegex(
1140+
RuntimeError, "Cannot access data pointer"
1141+
)
1142+
1143+
for f in tests:
1144+
with ctx():
1145+
make_fx(f, tracing_mode="fake")(x)
1146+
with ctx():
1147+
make_fx(f, tracing_mode="symbolic")(x)
1148+
with ctx():
1149+
torch.compile(f, backend="eager", fullgraph=True)(x)
1150+
1151+
def test_data_ptr_access_fails_in_backward(self):
1152+
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
1153+
torch.library.define("mylib::foo", "(Tensor x) -> Tensor", lib=lib)
1154+
1155+
backward_called = False
1156+
1157+
class Foo(torch.autograd.Function):
1158+
@staticmethod
1159+
def forward(ctx, x):
1160+
return x.clone()
1161+
1162+
@staticmethod
1163+
def backward(ctx, grad):
1164+
nonlocal backward_called
1165+
backward_called = True
1166+
grad.data_ptr()
1167+
return grad.clone()
1168+
1169+
@torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib)
1170+
def _(x):
1171+
return Foo.apply(x)
1172+
1173+
def f(x):
1174+
return torch.ops.mylib.foo(x)
1175+
1176+
x = torch.randn(3, requires_grad=True)
1177+
with self.assertRaisesRegex(RuntimeError, "Cannot access data pointer"):
1178+
y = torch.compile(f, backend="aot_eager", fullgraph=True)(x)
1179+
self.assertTrue(backward_called)
1180+
11071181
# We don't know how to catch multiple mutations to the same memory location
11081182
@unittest.expectedFailure
11091183
def test_aot_autograd_expand_mutation_error(self):

test/torch_np/numpy_tests/core/test_dlpack.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import functools
44
import sys
5+
import unittest
56

67
from unittest import skipIf as skipif
78

@@ -15,6 +16,7 @@
1516
instantiate_parametrized_tests,
1617
parametrize,
1718
run_tests,
19+
skipIfTorchDynamo,
1820
TEST_WITH_TORCHDYNAMO,
1921
TestCase,
2022
xpassIfTorchDynamo,
@@ -46,7 +48,8 @@ def test_dunder_dlpack_refcount(self):
4648
del y
4749
assert sys.getrefcount(x) == 2
4850

49-
@xpassIfTorchDynamo # (reason="pytorch does not raise")
51+
@unittest.expectedFailure
52+
@skipIfTorchDynamo("I can't figure out how to get __dlpack__ into trace_rules.py")
5053
def test_dunder_dlpack_stream(self):
5154
x = np.arange(5)
5255
x.__dlpack__(stream=None)

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1472,6 +1472,7 @@ def _dispatch_pystub(name: str, overload: str) -> Optional[Tuple[str, str]]: ...
14721472
def _dispatch_is_alias_key(dispatch: _dispatchkey) -> _bool: ...
14731473
def _functionality_to_backend_keys(dispatch: _dispatchkey) -> List[DispatchKey]: ...
14741474
def _functionalization_reapply_views_tls() -> _bool: ...
1475+
def _set_throw_on_mutable_data_ptr(tensor: Tensor) -> None: ...
14751476

14761477
class DispatchKey(Enum):
14771478
${dispatch_key_hints}

torch/_dynamo/output_graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def __init__(
296296
shape_env=shape_env,
297297
# TODO (tmanlaibaatar) Remove this once we always lift params and buffers
298298
allow_non_fake_inputs=True if self.export else False,
299+
_allow_unsafe_data_ptr_access=False,
299300
)
300301
self.tracing_context: TracingContext = TracingContext(fake_mode)
301302
self.init_ambient_guards()
@@ -1138,6 +1139,7 @@ def compile_and_call_fx_graph(self, tx, rv, root):
11381139
# TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
11391140
backend_fake_mode = torch._subclasses.FakeTensorMode(
11401141
shape_env=old_fake_mode.shape_env,
1142+
_allow_unsafe_data_ptr_access=False,
11411143
)
11421144
# TODO(voz): Ostensibily, this should be scoped and
11431145
# restore back to old_fake_mode, but doing so currently violates

torch/_dynamo/trace_rules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@
104104
"torch.compiler.is_compiling": TorchInGraphFunctionVariable,
105105
"torch.compiler.is_dynamo_compiling": TorchInGraphFunctionVariable,
106106
"torch.autograd._profiler_enabled": SkipFunctionVariable,
107+
"torch._C._to_dlpack": SkipFunctionVariable,
108+
"torch.to_dlpack": SkipFunctionVariable,
107109
# We graph break on RNG state setters or getters like
108110
# `torch.get_rng_state` or `torch.set_rng_state`. These functions
109111
# are not aten operations and therefore they are completely ignored
@@ -1187,7 +1189,6 @@
11871189
"torch._C._test_only_populate_upgraders",
11881190
"torch._C._test_only_remove_entry_to_op_version_map",
11891191
"torch._C._test_only_remove_upgraders",
1190-
"torch._C._to_dlpack",
11911192
"torch._C._to_functionality_key",
11921193
"torch._C._tracer_set_force_outplace",
11931194
"torch._C._tracer_set_get_unique_name_fn",

torch/_subclasses/fake_tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,8 @@ def __new__(cls, fake_mode, elem, device, constant=None):
433433
dispatch_device=True,
434434
device_for_backend_keys=device,
435435
)
436+
if not fake_mode._allow_unsafe_data_ptr_access:
437+
torch._C._set_throw_on_mutable_data_ptr(self)
436438

437439
assert elem.device.type == "meta", elem.device.type
438440
device = device if isinstance(device, torch.device) else torch.device(device)
@@ -759,8 +761,10 @@ def __init__(
759761
allow_non_fake_inputs=False,
760762
shape_env=None,
761763
static_shapes=None,
764+
_allow_unsafe_data_ptr_access=True,
762765
):
763766
log.debug("create_mode 0x%x", id(self))
767+
self._allow_unsafe_data_ptr_access = _allow_unsafe_data_ptr_access
764768
self.allow_fallback_kernels = allow_fallback_kernels
765769
self.fake_tensor_converter = FakeTensorConverter()
766770
if static_shapes is not None:

torch/_subclasses/functional_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def __new__(cls, elem):
120120
False, # dispatch_layout
121121
extra_dispatch_keys, # _extra_dispatch_keys
122122
)
123+
torch._C._set_throw_on_mutable_data_ptr(out)
123124
out.elem = elem
124125
return out
125126

torch/_tensor.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,9 +378,18 @@ def _reduce_ex_internal(self, proto):
378378
)
379379
return (torch._utils._rebuild_nested_tensor, args_nested)
380380
elif (
381-
self.data_ptr() == 0
382-
and type(self) is not torch.Tensor
381+
type(self) is not torch.Tensor
383382
and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
383+
and (
384+
isinstance(
385+
self,
386+
(
387+
torch._subclasses.fake_tensor.FakeTensor,
388+
torch._subclasses.functional_tensor.FunctionalTensor,
389+
),
390+
)
391+
or self.data_ptr() == 0
392+
)
384393
):
385394
arg_wrapper_subclass = (
386395
type(self),

torch/autograd/graph.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,11 +536,33 @@ def wrapped_fn(grad: torch.Tensor):
536536

537537

538538
def _get_tid(t) -> Tuple[int, int, int]:
539-
return (id(t), t.data_ptr(), t._version)
539+
# FIXME: This is almost definitely a bug.
540+
if isinstance(
541+
t,
542+
(
543+
torch._subclasses.fake_tensor.FakeTensor,
544+
torch._subclasses.functional_tensor.FunctionalTensor,
545+
),
546+
):
547+
data_ptr = 0
548+
else:
549+
data_ptr = t.data_ptr()
550+
return (id(t), data_ptr, t._version)
540551

541552

542553
def _get_sid(t) -> Tuple[int, int]:
543-
return (t.data_ptr(), t._version)
554+
# FIXME: This is almost definitely a bug.
555+
if isinstance(
556+
t,
557+
(
558+
torch._subclasses.fake_tensor.FakeTensor,
559+
torch._subclasses.functional_tensor.FunctionalTensor,
560+
),
561+
):
562+
data_ptr = 0
563+
else:
564+
data_ptr = t.data_ptr()
565+
return (data_ptr, t._version)
544566

545567

546568
class _Handle:

0 commit comments

Comments
 (0)