Skip to content

Commit

Permalink
changed to no_alloc_cast (#5230)
Browse files Browse the repository at this point in the history
  • Loading branch information
LiyuanLucasLiu authored Mar 6, 2024
1 parent cf9fabf commit 7854388
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,21 @@ def fp8_to_fp32(fp8_param, fp8_group_flat):
out = out.view(fp8_group_flat.size())
return out

def fp32_to_fp8(fp8_param, fp8_partition_size, fp32_partition):
def fp32_to_fp8(fp8_param, fp8_partition_size, fp32_partition, out):
scale_inv = fp8_param._scale_inv
scale = scale_inv.reciprocal()
fp8_dtype = fp8_param._fp8_dtype
amax = torch.empty_like(scale)

out = tex.cast_to_fp8(fp32_partition.view(1,-1),
scale,
amax,
scale_inv,
fp8_dtype,
).view(fp8_partition_size)
return out
amax = torch.ones_like(scale)

tex.cast_to_fp8_noalloc(
fp32_partition.view(1, -1),
scale,
out.view(1, -1),
amax,
scale_inv,
fp8_dtype,
)
return None


class BF16_Optimizer(ZeROOptimizer):
Expand Down Expand Up @@ -425,9 +427,11 @@ def update_lp_params(self):
fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)):
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
if bf16_partitions[partition_id].dtype == torch.uint8:
bf16_partitions[partition_id].data.copy_(fp32_to_fp8(self.bf16_groups[i][0],
bf16_partitions[partition_id].size(),
fp32_partition.data))
fp32_to_fp8(self.bf16_groups[i][0],
bf16_partitions[partition_id].size(),
fp32_partition.data,
out=bf16_partitions[partition_id].data)

else:
bf16_partitions[partition_id].data.copy_(fp32_partition.data)
# print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True)
Expand Down

0 comments on commit 7854388

Please sign in to comment.