forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathRefcountedDeleter.cpp
78 lines (62 loc) · 2.42 KB
/
RefcountedDeleter.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include <c10/core/RefcountedDeleter.h>
#include <mutex>
namespace c10 {
void refcounted_deleter(void* ctx_) {
RefcountedDeleterContext& ctx =
*reinterpret_cast<RefcountedDeleterContext*>(ctx_);
ctx.refcount--;
if (ctx.refcount == 0) {
ctx.other_ctx = nullptr;
delete &ctx;
}
}
std::mutex replace_data_ptr_mutex;
void maybeApplyRefcountedDeleter(const c10::Storage& storage) {
std::lock_guard<std::mutex> guard(replace_data_ptr_mutex);
c10::DataPtr& data_ptr = storage.mutable_data_ptr();
if ((void*)data_ptr.get_deleter() == (void*)&c10::refcounted_deleter) {
// Data pointer is already shared
return;
}
void* data = data_ptr.get();
void* other_ctx = data_ptr.get_context();
c10::DeleterFnPtr other_deleter = data_ptr.get_deleter();
c10::Device device = data_ptr.device();
// Release the context of the original DataPtr so that the data doesn't
// get deleted when the original DataPtr is replaced
data_ptr.release_context();
c10::RefcountedDeleterContext* refcount_ctx =
new c10::RefcountedDeleterContext(other_ctx, other_deleter);
c10::DataPtr new_data_ptr(
data,
reinterpret_cast<void*>(refcount_ctx),
&c10::refcounted_deleter,
device);
storage.set_data_ptr(std::move(new_data_ptr));
}
c10::Storage newStorageImplFromRefcountedDataPtr(const c10::Storage& storage) {
c10::maybeApplyRefcountedDeleter(storage);
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
c10::DataPtr& data_ptr = storage.mutable_data_ptr();
c10::DataPtr new_data_ptr(
data_ptr.get(),
data_ptr.get_context(),
data_ptr.get_deleter(),
data_ptr.device());
// NOTE: This refcount increment should always happen immediately after
// `new_data_ptr` is created. No other lines of code should be added between
// them in the future, unless there's a very good reason for it, because if
// any errors are raised and `new_data_ptr` is deleted before the refcount is
// incremented, the refcount will get decremented and end up being one less
// than it should be.
reinterpret_cast<c10::RefcountedDeleterContext*>(data_ptr.get_context())
->refcount++;
c10::Storage new_storage = c10::make_intrusive<c10::StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
storage_impl->nbytes(),
std::move(new_data_ptr),
storage_impl->allocator(),
/*resizable=*/storage_impl->resizable());
return new_storage;
}
} // namespace c10