@@ -32,7 +32,8 @@ Reducer::Reducer(
32
32
std::shared_ptr<c10d::ProcessGroup> process_group,
33
33
std::vector<std::vector<bool >> expect_sparse_gradients,
34
34
int64_t bucket_bytes_cap,
35
- bool find_unused_parameters)
35
+ bool find_unused_parameters,
36
+ bool gradient_as_bucket_view)
36
37
: replicas_(std::move(replicas)),
37
38
process_group_ (std::move(process_group)),
38
39
expect_sparse_gradients_(std::move(expect_sparse_gradients)),
@@ -41,6 +42,7 @@ Reducer::Reducer(
41
42
next_bucket_(0 ),
42
43
has_marked_unused_parameters_(false ),
43
44
find_unused_parameters_(find_unused_parameters),
45
+ gradient_as_bucket_view_(gradient_as_bucket_view),
44
46
local_used_maps_reduced_(false ),
45
47
backward_stats_base_(0 ),
46
48
has_rebuilt_bucket_(false ),
@@ -310,6 +312,56 @@ void Reducer::verify_replica0_across_processes() {
310
312
}
311
313
}
312
314
315
+ void Reducer::check_grad_layout (
316
+ const at::Tensor& grad,
317
+ const at::Tensor& bucket_view) {
318
+ // Ensure that the gradient type matches the bucket type.
319
+ TORCH_CHECK (
320
+ grad.options ().type_equal (bucket_view.options ()),
321
+ " Expected " ,
322
+ bucket_view.toString (),
323
+ " , got " ,
324
+ grad.toString ());
325
+ TORCH_INTERNAL_ASSERT (grad.device () == bucket_view.device ());
326
+ TORCH_INTERNAL_ASSERT (grad.numel () == bucket_view.numel ());
327
+ // AccumulateGrad doesn't HAVE to obey the grad layout contract.
328
+ // The penalty for disobedience is reduced performance, not numerical
329
+ // death. Warnings here help diagnose poor DDP performance.
330
+ if (grad.strides () != bucket_view.strides ()) {
331
+ TORCH_WARN_ONCE (
332
+ " Grad strides do not match bucket view strides. "
333
+ " This may indicate grad was not created according to the "
334
+ " gradient layout contract, or that the param's strides "
335
+ " changed since DDP was constructed. This is not an error, "
336
+ " but may impair performance.\n "
337
+ " grad.sizes() = " ,
338
+ grad.sizes (),
339
+ " , strides() = " ,
340
+ grad.strides (),
341
+ " \n " ,
342
+ " bucket_view.sizes() = " ,
343
+ bucket_view.sizes (),
344
+ " , strides() = " ,
345
+ bucket_view.strides ());
346
+ }
347
+ if (!gradient_as_bucket_view_) {
348
+ TORCH_INTERNAL_ASSERT (!grad.is_alias_of (bucket_view));
349
+ }
350
+ }
351
+
352
+ void Reducer::copy_grad_to_bucket (at::Tensor& grad, at::Tensor& bucket_view) {
353
+ // See Note [DDP Communication Hook]
354
+ if (comm_hook_ == nullptr ) {
355
+ // imitates wrapped_scalar_tensor in ATen/native/BinaryOps.cpp
356
+ auto wrapped = c10::scalar_to_tensor (double (1 .) / divFactor_);
357
+ wrapped.unsafeGetTensorImpl ()->set_wrapped_number (true );
358
+ // Divides while copying into the bucket view.
359
+ at::native::mul_out (bucket_view, grad, wrapped);
360
+ } else {
361
+ bucket_view.copy_ (grad);
362
+ }
363
+ }
364
+
313
365
void Reducer::mark_variable_ready_dense (VariableIndex index) {
314
366
const auto replica_index = index.replica_index ;
315
367
const auto variable_index = index.variable_index ;
@@ -327,49 +379,27 @@ void Reducer::mark_variable_ready_dense(VariableIndex index) {
327
379
// of the bucket it would otherwise hold.
328
380
runGradCallbackForVariable (variable, [&](auto & grad) {
329
381
if (grad.defined ()) {
330
- // Ensure that the gradient type matches the bucket type.
331
- TORCH_CHECK (
332
- grad.options ().type_equal (bucket_view.options ()),
333
- " Expected " ,
334
- bucket_view.toString (),
335
- " , got " ,
336
- grad.toString ());
337
- // Assert that the grad tensor and the bucket don't share storage.
338
- // If they did, we could avoid the copy altogether.
339
- // The reason for not doing this is that existing code calls
340
- // `detach_` from `zero_grad`, which is incompatible with views.
341
- TORCH_INTERNAL_ASSERT (!grad.is_alias_of (bucket_view));
342
- TORCH_INTERNAL_ASSERT (grad.device () == bucket_view.device ());
343
- TORCH_INTERNAL_ASSERT (grad.numel () == bucket_view.numel ());
344
- // AccumulateGrad doesn't HAVE to obey the grad layout contract.
345
- // The penalty for disobedience is reduced performance, not numerical
346
- // death. Warnings here help diagnose poor DDP performance.
347
- if (grad.strides () != bucket_view.strides ()) {
348
- TORCH_WARN_ONCE (
349
- " Grad strides do not match bucket view strides. "
350
- " This may indicate grad was not created according to the "
351
- " gradient layout contract, or that the param's strides "
352
- " changed since DDP was constructed. This is not an error, "
353
- " but may impair performance.\n "
354
- " grad.sizes() = " ,
355
- grad.sizes (),
356
- " , strides() = " ,
357
- grad.strides (),
358
- " \n " ,
359
- " bucket_view.sizes() = " ,
360
- bucket_view.sizes (),
361
- " , strides() = " ,
362
- bucket_view.strides ());
363
- }
364
- // See Note [DDP Communication Hook]
365
- if (comm_hook_ == nullptr ) {
366
- // imitates wrapped_scalar_tensor in ATen/native/BinaryOps.cpp
367
- auto wrapped = c10::scalar_to_tensor (double (1 .) / divFactor_);
368
- wrapped.unsafeGetTensorImpl ()->set_wrapped_number (true );
369
- // Divides while copying into the bucket view.
370
- at::native::mul_out (bucket_view, grad, wrapped);
382
+ this ->check_grad_layout (grad, bucket_view);
383
+ // When gradient_as_bucket_view_ is false, or even when
384
+ // gradient_as_bucket_view_ is true, in rare cases users may set grad to
385
+ // be None after every iteration. In these cases, grad and bucket_view are
386
+ // pointing to different storages and thus need to copy grads to
387
+ // bucket_view. If gradient_as_bucket_view_ is set as true, let grad point
388
+ // to bucket_view. If grad has already been set as views of buckets in
389
+ // previous iterations, no copy is needed.
390
+ if (!grad.is_alias_of (bucket_view)) {
391
+ this ->copy_grad_to_bucket (grad, bucket_view);
392
+ if (gradient_as_bucket_view_) {
393
+ // Let grad point to bucket_view buffer.
394
+ grad = bucket_view;
395
+ // The grad is modified and need to be written back.
396
+ return true ;
397
+ }
371
398
} else {
372
- bucket_view.copy_ (grad);
399
+ // If grad and bucket view point to the same storage, no need to copy
400
+ if (comm_hook_ == nullptr ) {
401
+ bucket_view.div_ (divFactor_);
402
+ }
373
403
}
374
404
} else {
375
405
bucket_view.zero_ ();
@@ -674,6 +704,17 @@ void Reducer::mark_bucket_ready(size_t bucket_index) {
674
704
675
705
void Reducer::initialize_buckets (
676
706
std::vector<std::vector<size_t >> bucket_indices) {
707
+ // If initialize_buckets is called inside DDP constructor, then
708
+ // it does not matter rpc context ptr is nullptr or not, as grad
709
+ // will not be mutated.
710
+ // If initialize_buckets is called during training loop, e.g, inside
711
+ // rebuild_buckets(), since grad could be mutated and be pointed to
712
+ // bucket_view, then it needs to check rpc context ptr is nullptr or not,
713
+ // If rpc context ptr is nullptr, mutate variable.grad(); otherwise,
714
+ // mutate grad in rpc context.
715
+ using torch::distributed::autograd::ThreadLocalDistAutogradContext;
716
+ this ->rpc_context_ .set (ThreadLocalDistAutogradContext::getContextPtr ());
717
+
677
718
// This shouldn't be called if we're expecting autograd hooks to fire.
678
719
TORCH_CHECK (
679
720
!expect_autograd_hooks_,
@@ -825,7 +866,7 @@ void Reducer::initialize_bucket_views(
825
866
Reducer::BucketReplica& replica,
826
867
at::Tensor& contents) {
827
868
for (size_t i = 0 ; i < replica.variables .size (); i++) {
828
- const auto & v = replica.variables [i];
869
+ auto & v = replica.variables [i];
829
870
const auto offset = replica.offsets [i];
830
871
const auto length = replica.lengths [i];
831
872
if (v.is_non_overlapping_and_dense ()) {
@@ -844,6 +885,29 @@ void Reducer::initialize_bucket_views(
844
885
// By default `bucket_views_out` and `bucket_views_in` are
845
886
// essentially the same thing.
846
887
replica.bucket_views_out = replica.bucket_views_in ;
888
+
889
+ // If gradient_as_bucket_view_ is set as true, then there are two cases to
890
+ // handle: initialize_bucket_views could be called inside initialize_buckets
891
+ // when rebuild_buckets, if grad has already been defined/calculated in
892
+ // previous iteration, old grad needs to be copied into new bucket_view and
893
+ // let grad point to the new bucket_view, initialize_bucket_views could also
894
+ // be called inside initialize_buckets during construction. Grads are not
895
+ // defined during construction time, in this case, do not let grad point to
896
+ // bucket_view, because grads should be kept as being undefined for globally
897
+ // unused parameters.
898
+ if (gradient_as_bucket_view_) {
899
+ auto & bucket_view = replica.bucket_views_in .back ();
900
+ runGradCallbackForVariable (v, [&](auto & grad) {
901
+ if (grad.defined () && !grad.is_alias_of (bucket_view)) {
902
+ bucket_view.copy_ (grad);
903
+ grad = bucket_view;
904
+ // The grad is modefied and needs to be written back.
905
+ return true ;
906
+ }
907
+ // The grad is not modified and does not need to be written back.
908
+ return false ;
909
+ });
910
+ }
847
911
}
848
912
}
849
913
@@ -965,6 +1029,31 @@ void Reducer::prepare_for_backward(
965
1029
}
966
1030
}
967
1031
1032
+ void Reducer::copy_bucket_to_grad (
1033
+ torch::autograd::Variable& variable,
1034
+ Reducer::BucketReplica& replica,
1035
+ size_t intra_bucket_index,
1036
+ bool global_unused) {
1037
+ const auto & bucket_view = replica.bucket_views_out [intra_bucket_index];
1038
+ runGradCallbackForVariable (variable, [&](auto & grad) {
1039
+ // If a parameter is globally unused, we keep its grad untouched.
1040
+ if (!global_unused) {
1041
+ if (!grad.defined ()) {
1042
+ // Creates grad according to the "Gradient Layout Contract"
1043
+ // (see torch/csrc/grad/AccumulateGrad.h)
1044
+ grad =
1045
+ torch::autograd::utils::clone_obey_contract (bucket_view, variable);
1046
+ } else {
1047
+ grad.copy_ (bucket_view);
1048
+ }
1049
+ // The grad is modified and needs to be written back.
1050
+ return true ;
1051
+ }
1052
+ // The grad is not modified.
1053
+ return false ;
1054
+ });
1055
+ }
1056
+
968
1057
// A bucket with one or more dense tensors needs to be unflattened.
969
1058
void Reducer::finalize_bucket_dense (Bucket& bucket) {
970
1059
for (size_t replica_index = 0 ; replica_index < bucket.replicas .size ();
@@ -1015,24 +1104,52 @@ void Reducer::finalize_bucket_dense(Bucket& bucket) {
1015
1104
}
1016
1105
}
1017
1106
1018
- const auto & bucket_view = replica.bucket_views_out [intra_bucket_index];
1019
- runGradCallbackForVariable (variable, [&](auto & grad) {
1020
- // If a parameter is globally unused, we keep its grad untouched.
1021
- if (!global_unused) {
1022
- if (!grad.defined ()) {
1023
- // Creates grad according to the "Gradient Layout Contract"
1024
- // (see torch/csrc/grad/AccumulateGrad.h)
1025
- grad = torch::autograd::utils::clone_obey_contract (
1026
- bucket_view, variable);
1027
- } else {
1028
- grad.copy_ (bucket_view);
1029
- }
1030
- // The grad is modified and needs to be written back.
1031
- return true ;
1107
+ if (!gradient_as_bucket_view_) {
1108
+ copy_bucket_to_grad (
1109
+ variable, replica, intra_bucket_index, global_unused);
1110
+ } else {
1111
+ const auto & bucket_view_out =
1112
+ replica.bucket_views_out [intra_bucket_index];
1113
+ auto & bucket_view_in = replica.bucket_views_in [intra_bucket_index];
1114
+ // If communication_hook is registered, bucket_view_out stores
1115
+ // allreduced results in a newly allocated tensor, copy bucket_view_out
1116
+ // back to bucket_view_in that referring to replica.content tensor and
1117
+ // grad.
1118
+ if (!bucket_view_in.is_alias_of (bucket_view_out)) {
1119
+ bucket_view_in.copy_ (bucket_view_out);
1032
1120
}
1033
- // The grad is not modified.
1034
- return false ;
1035
- });
1121
+ runGradCallbackForVariable (variable, [&](auto & grad) {
1122
+ // If a parameter is globally unused, we keep its grad untouched.
1123
+ if (!global_unused) {
1124
+ // If grad is globally used but locally unused, let grad point to
1125
+ // bucket_view_in
1126
+ if (!grad.defined ()) {
1127
+ grad = bucket_view_in;
1128
+ } else {
1129
+ if (!grad.is_alias_of (bucket_view_in)) {
1130
+ grad.copy_ (bucket_view_in);
1131
+ TORCH_WARN_ONCE (
1132
+ " Detected at least one parameter gradient is not the "
1133
+ " expected DDP bucket view when setting "
1134
+ " gradient_as_bucket_view=True. This can happen when "
1135
+ " multiple parameters sharing the same gradient. For "
1136
+ " example, param0 and param1 share the same gradient "
1137
+ " grad0. In this case, grad0 would first point to "
1138
+ " bucket_view_in0 when param0 is ready. Later, when "
1139
+ " param1 is ready, it will override grad0 to point to "
1140
+ " bucket_view_in1. However, param0 still expects grad0 "
1141
+ " to point to bucket_view_in0, and hence hit this "
1142
+ " warning. If you saw this message, please double-check if "
1143
+ " the above situation is expected for your application." );
1144
+ }
1145
+ }
1146
+ // The grad is modified and needs to be written back.
1147
+ return true ;
1148
+ }
1149
+ // The grad is not modified.
1150
+ return false ;
1151
+ });
1152
+ }
1036
1153
}
1037
1154
}
1038
1155
}
0 commit comments