@@ -173,6 +173,7 @@ class Op {
173
173
}
174
174
175
175
void materialize ();
176
+ void materializeWithShape (c10::IntArrayRef shape, const c10::optional<c10::Device> device);
176
177
177
178
std::size_t num_outputs () const noexcept {
178
179
return num_outputs_;
@@ -220,7 +221,6 @@ Op Op::fromOperatorHandle(const OperatorHandle& handle, Stack s) {
220
221
};
221
222
222
223
const FunctionSchema& shm = handle.schema ();
223
-
224
224
return Op{shm.name (), std::move (fn), shm.arguments ().size (), shm.returns ().size (), std::move (s)};
225
225
}
226
226
@@ -271,6 +271,44 @@ void Op::materialize() {
271
271
materialized_ = true ;
272
272
}
273
273
274
+ void Op::materializeWithShape (c10::IntArrayRef shape, const c10::optional<c10::Device> device) {
275
+ if (materialized_) {
276
+ return ;
277
+ }
278
+
279
+ {
280
+ ThreadLocalStateGuard state_guard{*tls_};
281
+
282
+ auto replace_first_shape = [&](c10::IntArrayRef sp){
283
+ IValue local_shape (sp);
284
+ stack_[0 ] = local_shape;
285
+ };
286
+
287
+ std::vector<std::string> op_white_list{" aten::randn" , " aten::rand" , " aten::empty" , " aten::ones" , " aten::zeros" , " aten::full" };
288
+
289
+ if (std::find (op_white_list.begin (),op_white_list.end (), name ()) != op_white_list.end ()){
290
+ // if the op is operator
291
+ replace_first_shape (shape);
292
+ }
293
+
294
+ if (device.has_value ()){ // set target device
295
+ for (size_t i = 0 ; i < stack_.size (); i++){
296
+ if (stack_[i].isDevice ()){
297
+ stack_[i] = IValue (device.value ());
298
+ }
299
+ }
300
+ }
301
+
302
+ fn_ (stack_);
303
+ }
304
+
305
+ fn_ = nullptr ;
306
+
307
+ tls_ = nullopt;
308
+
309
+ materialized_ = true ;
310
+ }
311
+
274
312
const Tensor& Op::getOutput (std::size_t idx) const noexcept {
275
313
const Tensor* opt_out = nullptr ;
276
314
@@ -343,6 +381,8 @@ class OpNode {
343
381
// Materializes the operation held by this node along with all the operations
344
382
// in its recorded call stack.
345
383
void materialize ();
384
+ // with changed shape
385
+ void materializeWithShape (c10::IntArrayRef shape, c10::optional<c10::Device> device);
346
386
347
387
private:
348
388
void buildCallStack ();
@@ -527,6 +567,30 @@ void OpNode::materialize() {
527
567
call_stack_.clear ();
528
568
}
529
569
570
+ void OpNode::materializeWithShape (c10::IntArrayRef shape, const c10::optional<c10::Device> device) {
571
+ // Do not try to shortcut this function by checking if the node is already
572
+ // materialized. A later in-place operation can still change the output of
573
+ // this node.
574
+
575
+ buildCallStack ();
576
+
577
+ for (OpNode* node : call_stack_) {
578
+ if (node->op_ .materialized ()) {
579
+ continue ;
580
+ }
581
+
582
+ node->materializeArguments ();
583
+
584
+ node->op_ .materializeWithShape (shape, device);
585
+
586
+ // Make sure that we deallocate parts of the operation graph that are not
587
+ // needed anymore.
588
+ node->detachDependencies ();
589
+ }
590
+
591
+ call_stack_.clear ();
592
+ }
593
+
530
594
void OpNode::buildCallStack () {
531
595
OpNode* last_node = getLastInPlaceOpNode ();
532
596
@@ -728,6 +792,24 @@ Tensor materialize(const Tensor& fake) {
728
792
return out;
729
793
}
730
794
795
+ Tensor materialize_with_shape (const Tensor& fake, c10::IntArrayRef shape, const c10::optional<c10::Device> device) {
796
+ TensorRecord& record = getTensorRecord (fake);
797
+
798
+ const OpOutputDescriptor& output_desc = record.output_descriptor ();
799
+
800
+ output_desc.node ()->materializeWithShape (shape, device);
801
+
802
+ Tensor out = output_desc.node ()->op ().getOutput (output_desc.output_index ());
803
+
804
+ // Unfortunately there is no way for us to track calls to `requires_grad_()`,
805
+ // so instead we explicitly set `requires_grad` after materialization.
806
+ if (fake.is_leaf () && fake.requires_grad ()) {
807
+ out.set_requires_grad (true );
808
+ }
809
+
810
+ return out;
811
+ }
812
+
731
813
// The catch-all handler for the `DeferredInit` dispatch key.
732
814
class DeferredInitHandler {
733
815
public:
@@ -1032,6 +1114,12 @@ class ProxyVariableHooks : public VariableHooksInterface {
1032
1114
inner_->requires_grad_ (self, value);
1033
1115
}
1034
1116
1117
+ void basic_autograd_not_implemented_fallback (const c10::OperatorHandle& op,
1118
+ c10::DispatchKeySet dispatch_keys,
1119
+ torch::jit::Stack* stack) const override {
1120
+ inner_->basic_autograd_not_implemented_fallback (op, dispatch_keys, stack);
1121
+ }
1122
+
1035
1123
VariableHooksInterface* inner () noexcept {
1036
1124
return inner_;
1037
1125
}
@@ -1164,6 +1252,7 @@ bool canMaterialize(const Tensor& tensor) noexcept {
1164
1252
return isFake (tensor) && unsafeAsFake (tensor).hasData (DispatchKey::DeferredInit);
1165
1253
}
1166
1254
1255
+
1167
1256
Tensor materializeTensor (const Tensor& tensor) {
1168
1257
if (canMaterialize (tensor)) {
1169
1258
return detail::materialize (tensor);
@@ -1172,4 +1261,24 @@ Tensor materializeTensor(const Tensor& tensor) {
1172
1261
}
1173
1262
}
1174
1263
1264
+ Tensor materializeTensorWithLocalShape (const at::Tensor& tensor, c10::IntArrayRef shape, const c10::optional<c10::Device> device){
1265
+ if (canMaterialize (tensor)) {
1266
+ return detail::materialize_with_shape (tensor, shape, device);
1267
+ } else {
1268
+ return tensor;
1269
+ }
1270
+ }
1271
+
1272
+ bool isGenByRandomOp (const Tensor& tensor) noexcept {
1273
+ if (canMaterialize (tensor)) {
1274
+ detail::TensorRecord& record = detail::getTensorRecord (tensor);
1275
+ const detail::OpOutputDescriptor& output_desc = record.output_descriptor ();
1276
+ auto name = output_desc.node ()->op ().name ();
1277
+ std::vector<std::string> op_white_list{" aten::randn" , " aten::rand" };
1278
+ return std::find (op_white_list.begin (),op_white_list.end (), name) != op_white_list.end ();
1279
+ }else {
1280
+ return false ;
1281
+ }
1282
+ }
1283
+
1175
1284
} // namespace torchdistx
0 commit comments