diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 0f2741251bb6..42cdcd645f4d 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -808,17 +808,12 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False): start = partition_size * rank end = start + partition_size - dest_tensor = partition_buffer.view(-1).narrow(0, 0, partition_size) + dest_tensor_full_buffer = partition_buffer.view(-1).narrow(0, 0, partition_size) #print("before partition gradients") if start < param.ds_numel: elements = min(param.ds_numel - start, partition_size) - dest_tensor_full_buffer = partition_buffer.view(-1).narrow( - 0, - 0, - partition_size) - dest_tensor = dest_tensor_full_buffer.narrow(0, 0, elements) src_tensor = param.grad.view(-1).narrow(0, start, elements)