Skip to content

Commit a688b29

Browse files
lwfacebook-github-bot
authored andcommitted
Support custom Python classes in CUDAFuture (pytorch#56516)
Summary: Pull Request resolved: pytorch#56516 One problem with CUDAFuture's extraction of DataPtrs from IValues is that it only supported Python objects that could be converted to "regular" IValues (e.g., lists/dicts/tuples of ints/strings/tensors/...). One notable exception are custom Python classes, which are in fact a very common data type transferred over RPC. The only solution we found for those is to use the Python pickler to extract the tensors contained in them. We can't insert a Python dependency directly into CUDAFuture, so instead I'm proposing to use the same indirection technique used to support `getSubValues` on Python objects: define some methods on the abstract class `PyObjectHolder` (which can be used by CUDAFuture) but only implement them in the concrete subclass `ConcretePyObjectHolder` (which is only built when Python support is enabled). I am a bit worried about the performance toll of this (pickling isn't exactly known to be cheap) but I think we should start by providing a functionally complete API. We already have ideas on how to make this faster if needed, for example by having users provide a custom DataPtr extractor tailored to their class via a decorator. (Or just use TorchScript). ghstack-source-id: 127295014 Test Plan: Added a test later in the stack Reviewed By: mrshenli Differential Revision: D27887189 fbshipit-source-id: 9d27e4e62390b836e5bb4f06f401cc002f0cf95b
1 parent e4efc0c commit a688b29

File tree

4 files changed

+70
-8
lines changed

4 files changed

+70
-8
lines changed

aten/src/ATen/core/ivalue_inl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,7 @@ struct ivalue::PyObjectHolder : c10::intrusive_ptr_target {
708708
virtual c10::InferredType tryToInferType() = 0;
709709
virtual IValue toIValue(const TypePtr& type, c10::optional<int32_t> N = c10::nullopt) = 0;
710710
virtual std::string toStr() = 0;
711+
virtual std::vector<at::Tensor> extractTensors() = 0;
711712

712713
virtual ~PyObjectHolder(){};
713714
};

aten/src/ATen/cuda/CUDAFuture.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,26 @@ namespace {
2626

2727
std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs(
2828
const at::IValue& value) {
29-
at::IValue::HashAliasedIValues sub_values;
30-
// Prefer getSubValues() over visit() as the latter is a silent no-op for
31-
// some unsupported types, whereas the former at least fails loudly.
32-
value.getSubValues(sub_values);
33-
3429
std::vector<std::reference_wrapper<const at::DataPtr>> data_ptrs;
35-
for (const at::IValue& sub_value : sub_values) {
36-
if (sub_value.isTensor()) {
37-
data_ptrs.emplace_back(sub_value.toTensor().storage().data_ptr());
30+
// getSubValues works poorly on Python objects: it only works if they can be
31+
// converted to a "regular" IValue type hence, for example, it doesn't support
32+
// custom subclasses. Thus, instead, we extract the tensors through pickling.
33+
if (value.isPyObject()) {
34+
std::vector<at::Tensor> tensors =
35+
value.toPyObjectHolder()->extractTensors();
36+
data_ptrs.reserve(tensors.size());
37+
for (const at::Tensor& tensor : tensors) {
38+
data_ptrs.emplace_back(tensor.storage().data_ptr());
39+
}
40+
} else {
41+
at::IValue::HashAliasedIValues sub_values;
42+
// Prefer getSubValues() over visit() as the latter is a silent no-op for
43+
// some unsupported types, whereas the former at least fails loudly.
44+
value.getSubValues(sub_values);
45+
for (const at::IValue& sub_value : sub_values) {
46+
if (sub_value.isTensor()) {
47+
data_ptrs.emplace_back(sub_value.toTensor().storage().data_ptr());
48+
}
3849
}
3950
}
4051
return data_ptrs;

torch/_jit_internal.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import torch
1616
import sys
1717
import builtins
18+
import io
19+
import pickle
1820
# This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`.
1921
# Explicitly ask to import `torch.distributed.__init__` first.
2022
# Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised.
@@ -1119,3 +1121,29 @@ def _isinstance(obj, target_type) -> bool:
11191121

11201122
# handle non-containers
11211123
return isinstance(obj, target_type)
1124+
1125+
1126+
class _TensorExtractor(pickle.Pickler):
1127+
def __init__(self, *args, tensors: List[torch.Tensor], **kwargs):
1128+
super().__init__(*args, **kwargs)
1129+
self.tensors = tensors
1130+
1131+
def persistent_id(self, obj):
1132+
if isinstance(obj, torch.Tensor):
1133+
self.tensors.append(obj)
1134+
return ""
1135+
else:
1136+
return None
1137+
1138+
1139+
def _extract_tensors(obj):
1140+
r"""
1141+
This function is exclusively called from C++.
1142+
See ``torch/csrc/jit/python/python_ivalue.h``.
1143+
1144+
It extracts the tensors contained in the given object, through pickling.
1145+
"""
1146+
tensors: List[torch.Tensor] = []
1147+
extractor = _TensorExtractor(io.BytesIO(), protocol=-1, tensors=tensors)
1148+
extractor.dump(obj)
1149+
return tensors

torch/csrc/jit/python/python_ivalue.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,28 @@ struct C10_EXPORT ConcretePyObjectHolder final : PyObjectHolder {
4242
return py::str(py_obj_);
4343
}
4444

45+
std::vector<at::Tensor> extractTensors() override {
46+
// We could implement this entirely in C++ via pybind11 but it turns out to
47+
// be substantially slower. Namely, the total time taken by markCompleted on
48+
// a CUDAFuture is 21.5us with this implementation, but goes up to 58.7us
49+
// when using C++. The reason is unclear.
50+
try {
51+
pybind11::gil_scoped_acquire ag;
52+
return py::module::import("torch._jit_internal")
53+
.attr("_extract_tensors")(py_obj_)
54+
.cast<std::vector<at::Tensor>>();
55+
} catch (py::error_already_set& e) {
56+
auto err = std::runtime_error(
57+
c10::str("Cannot extract tensors from value: ", e.what()));
58+
{
59+
pybind11::gil_scoped_acquire ag;
60+
e.restore();
61+
PyErr_Clear();
62+
}
63+
throw err;
64+
}
65+
}
66+
4567
// Note [Destructing py::object]
4668
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
4769
//

0 commit comments

Comments
 (0)