Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Commit bcaaffb

Browse files
committed
support torch 2.1
feat(dmodule): support parallelized dtensor init feat(dtensor): support for query random op feat(dtensor): support deferred init on device
1 parent 9c1b9f5 commit bcaaffb

File tree

4 files changed

+165
-3
lines changed

4 files changed

+165
-3
lines changed

src/cc/torchdistx/deferred_init.cc

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ class Op {
173173
}
174174

175175
void materialize();
176+
void materializeWithShape(c10::IntArrayRef shape, const c10::optional<c10::Device> device);
176177

177178
std::size_t num_outputs() const noexcept {
178179
return num_outputs_;
@@ -220,7 +221,6 @@ Op Op::fromOperatorHandle(const OperatorHandle& handle, Stack s) {
220221
};
221222

222223
const FunctionSchema& shm = handle.schema();
223-
224224
return Op{shm.name(), std::move(fn), shm.arguments().size(), shm.returns().size(), std::move(s)};
225225
}
226226

@@ -271,6 +271,44 @@ void Op::materialize() {
271271
materialized_ = true;
272272
}
273273

274+
void Op::materializeWithShape(c10::IntArrayRef shape, const c10::optional<c10::Device> device) {
275+
if (materialized_) {
276+
return;
277+
}
278+
279+
{
280+
ThreadLocalStateGuard state_guard{*tls_};
281+
282+
auto replace_first_shape = [&](c10::IntArrayRef sp){
283+
IValue local_shape(sp);
284+
stack_[0] = local_shape;
285+
};
286+
287+
std::vector<std::string> op_white_list{"aten::randn", "aten::rand", "aten::empty", "aten::ones", "aten::zeros", "aten::full" };
288+
289+
if (std::find(op_white_list.begin(),op_white_list.end(), name()) != op_white_list.end()){
290+
// if the op is operator
291+
replace_first_shape(shape);
292+
}
293+
294+
if(device.has_value()){ // set target device
295+
for (size_t i = 0 ; i < stack_.size(); i++){
296+
if(stack_[i].isDevice()){
297+
stack_[i] = IValue(device.value());
298+
}
299+
}
300+
}
301+
302+
fn_(stack_);
303+
}
304+
305+
fn_ = nullptr;
306+
307+
tls_ = nullopt;
308+
309+
materialized_ = true;
310+
}
311+
274312
const Tensor& Op::getOutput(std::size_t idx) const noexcept {
275313
const Tensor* opt_out = nullptr;
276314

@@ -343,6 +381,8 @@ class OpNode {
343381
// Materializes the operation held by this node along with all the operations
344382
// in its recorded call stack.
345383
void materialize();
384+
// with changed shape
385+
void materializeWithShape(c10::IntArrayRef shape, c10::optional<c10::Device> device);
346386

347387
private:
348388
void buildCallStack();
@@ -527,6 +567,30 @@ void OpNode::materialize() {
527567
call_stack_.clear();
528568
}
529569

570+
void OpNode::materializeWithShape(c10::IntArrayRef shape, const c10::optional<c10::Device> device) {
571+
// Do not try to shortcut this function by checking if the node is already
572+
// materialized. A later in-place operation can still change the output of
573+
// this node.
574+
575+
buildCallStack();
576+
577+
for (OpNode* node : call_stack_) {
578+
if (node->op_.materialized()) {
579+
continue;
580+
}
581+
582+
node->materializeArguments();
583+
584+
node->op_.materializeWithShape(shape, device);
585+
586+
// Make sure that we deallocate parts of the operation graph that are not
587+
// needed anymore.
588+
node->detachDependencies();
589+
}
590+
591+
call_stack_.clear();
592+
}
593+
530594
void OpNode::buildCallStack() {
531595
OpNode* last_node = getLastInPlaceOpNode();
532596

@@ -728,6 +792,24 @@ Tensor materialize(const Tensor& fake) {
728792
return out;
729793
}
730794

795+
Tensor materialize_with_shape(const Tensor& fake, c10::IntArrayRef shape, const c10::optional<c10::Device> device) {
796+
TensorRecord& record = getTensorRecord(fake);
797+
798+
const OpOutputDescriptor& output_desc = record.output_descriptor();
799+
800+
output_desc.node()->materializeWithShape(shape, device);
801+
802+
Tensor out = output_desc.node()->op().getOutput(output_desc.output_index());
803+
804+
// Unfortunately there is no way for us to track calls to `requires_grad_()`,
805+
// so instead we explicitly set `requires_grad` after materialization.
806+
if (fake.is_leaf() && fake.requires_grad()) {
807+
out.set_requires_grad(true);
808+
}
809+
810+
return out;
811+
}
812+
731813
// The catch-all handler for the `DeferredInit` dispatch key.
732814
class DeferredInitHandler {
733815
public:
@@ -1032,6 +1114,12 @@ class ProxyVariableHooks : public VariableHooksInterface {
10321114
inner_->requires_grad_(self, value);
10331115
}
10341116

1117+
void basic_autograd_not_implemented_fallback(const c10::OperatorHandle& op,
1118+
c10::DispatchKeySet dispatch_keys,
1119+
torch::jit::Stack* stack) const override {
1120+
inner_->basic_autograd_not_implemented_fallback(op, dispatch_keys, stack);
1121+
}
1122+
10351123
VariableHooksInterface* inner() noexcept {
10361124
return inner_;
10371125
}
@@ -1164,6 +1252,7 @@ bool canMaterialize(const Tensor& tensor) noexcept {
11641252
return isFake(tensor) && unsafeAsFake(tensor).hasData(DispatchKey::DeferredInit);
11651253
}
11661254

1255+
11671256
Tensor materializeTensor(const Tensor& tensor) {
11681257
if (canMaterialize(tensor)) {
11691258
return detail::materialize(tensor);
@@ -1172,4 +1261,24 @@ Tensor materializeTensor(const Tensor& tensor) {
11721261
}
11731262
}
11741263

1264+
Tensor materializeTensorWithLocalShape(const at::Tensor& tensor, c10::IntArrayRef shape, const c10::optional<c10::Device> device){
1265+
if (canMaterialize(tensor)) {
1266+
return detail::materialize_with_shape(tensor, shape, device);
1267+
} else {
1268+
return tensor;
1269+
}
1270+
}
1271+
1272+
bool isGenByRandomOp(const Tensor& tensor) noexcept{
1273+
if (canMaterialize(tensor)) {
1274+
detail::TensorRecord& record = detail::getTensorRecord(tensor);
1275+
const detail::OpOutputDescriptor& output_desc = record.output_descriptor();
1276+
auto name = output_desc.node()->op().name();
1277+
std::vector<std::string> op_white_list{"aten::randn", "aten::rand"};
1278+
return std::find(op_white_list.begin(),op_white_list.end(), name) != op_white_list.end();
1279+
}else{
1280+
return false;
1281+
}
1282+
}
1283+
11751284
} // namespace torchdistx

src/cc/torchdistx/deferred_init.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <c10/core/DispatchKey.h>
1010
#include <c10/core/impl/LocalDispatchKeySet.h>
11+
#include <c10/core/SymIntArrayRef.h>
12+
#include <c10/core/Device.h>
1113

1214
#include "macros.h"
1315

@@ -27,9 +29,10 @@ TDX_API void leaveDeferredInit() noexcept;
2729

2830
// Indicates whether `tensor` has been constructed in a deferred-init context.
2931
TDX_API bool canMaterialize(const at::Tensor& tensor) noexcept;
30-
32+
TDX_API bool isGenByRandomOp(const at::Tensor& tensor) noexcept;
3133
// Materializes `tensor`.
3234
TDX_API at::Tensor materializeTensor(const at::Tensor& tensor);
35+
TDX_API at::Tensor materializeTensorWithLocalShape(const at::Tensor& tensor, c10::IntArrayRef shape, const c10::optional<c10::Device> device = {});
3336

3437
// Temporarily disables deferred-init.
3538
class TDX_API NoDeferredInit {

src/python/torchdistx/_C.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,17 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8+
from torch.types import _int, SymInt, _device
9+
from collections import Sequence
10+
from typing import Union, Optional
811

912
def enter_deferred_init() -> None: ...
1013
def leave_deferred_init() -> None: ...
1114
def enter_fake_mode(fake_mode: bool) -> None: ...
1215
def leave_fake_mode() -> None: ...
1316
def is_fake(tensor: torch.Tensor) -> bool: ...
17+
def is_gen_by_random_op(tensor: torch.Tensor) -> bool: ...
1418
def can_materialize(tensor: torch.Tensor) -> bool: ...
1519
def materialize_tensor(tensor: torch.Tensor) -> torch.Tensor: ...
20+
def materialize_tensor_with_local_shape(tensor: torch.Tensor, shape: Sequence[Union[_int, SymInt]], device: Optional[Union[_device, str, None]] = None) -> torch.Tensor: ...
1621
def meta_like(fake: torch.Tensor) -> torch.Tensor: ...

src/python/torchdistx/_C/deferred_init.cc

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <ATen/Tensor.h>
1010
#include <c10/core/TensorImpl.h>
11+
#include <c10/core/SymIntArrayRef.h>
1112
#include <torch/csrc/PyInterpreter.h>
1213
#include <torch/csrc/autograd/python_variable.h>
1314
#include <torch/csrc/utils/pybind.h>
@@ -94,14 +95,58 @@ py::object materializeVariable(const py::object& var) {
9495
return makeVariable(Py_TYPE(naked_var), std::move(materialized_data));
9596
}
9697

98+
99+
// Materializing a tensor in Python requires an extra step. We need to ensure
100+
// that the materialized tensor has the same Python class (e.g. `Variable` or
101+
// `Parameter`) as the original tensor.
102+
// and with dtensor case we need to change the parallized tensor shape
103+
py::object materializeVariableWithLocalShape(const py::object& var, const py::object &shape, const c10::optional<c10::Device> device) {
104+
PyObject* naked_var = var.ptr();
105+
auto c_shape = shape.cast<std::vector<int64_t>>();
106+
107+
if (!THPVariable_Check(naked_var)) {
108+
throw TypeError{"`var` has to be a `Variable`, but got `%s`.", Py_TYPE(naked_var)->tp_name};
109+
}
110+
111+
const Tensor& data = THPVariable_Unpack(naked_var);
112+
113+
auto materialize = [=](const Tensor& tensor, c10::IntArrayRef sp) {
114+
py::gil_scoped_release guard{};
115+
116+
return materializeTensorWithLocalShape(tensor, sp, device);
117+
};
118+
119+
Tensor materialized_data = materialize(data, at::IntArrayRef(c_shape));
120+
121+
// Check if we have really materialized `data`. Materializing a regular tensor
122+
// is a no-op, so we can simply return.
123+
if (materialized_data.is_same(data)) {
124+
return var;
125+
}
126+
127+
// We might have already materialized `data`. Make sure that we preserve its
128+
// identity on the Python side and avoid creating a new Python tensor.
129+
c10::optional<PyObject*> opt_materialized_var =
130+
materialized_data.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter());
131+
if (opt_materialized_var.has_value()) {
132+
return py::reinterpret_borrow<py::object>(*opt_materialized_var);
133+
}
134+
135+
// Otherwise ensure that our materialized tensor has the same Python class as
136+
// the original tensor.
137+
return makeVariable(Py_TYPE(naked_var), std::move(materialized_data));
138+
}
139+
140+
97141
} // namespace
98142

99143
void initDeferredInitFunctions(py::module& m) {
100144
m.def("enter_deferred_init", enterDeferredInit);
101145
m.def("leave_deferred_init", leaveDeferredInit);
102-
103146
m.def("can_materialize", canMaterialize);
147+
m.def("is_gen_by_random_op", isGenByRandomOp);
104148
m.def("materialize_tensor", materializeVariable);
149+
m.def("materialize_tensor_with_local_shape", materializeVariableWithLocalShape);
105150
}
106151

107152
} // namespace torchdistx::python

0 commit comments

Comments
 (0)