Skip to content

Commit

Permalink
logging
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisxcai committed May 9, 2024
1 parent 14499fe commit ad7aa1f
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,14 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# logger.info(f"CHRISLOG:{grad_sizes=}")

new_unsharded_main_grad_in_fp32 = torch.cat([grad.flatten() for grad in self._fsdp_wrapped_module.fp32_grads])
baseline_grad = param.grad.to(torch.float32)


logger.info(f"CHRISLOG: baseline grad {baseline_grad=}, {baseline_grad.size()=}")
logger.info(f"CHRISLOG: new grad {new_unsharded_main_grad_in_fp32=}, {new_unsharded_main_grad_in_fp32.size()=}")
torch.allclose(baseline_grad, new_unsharded_main_grad_in_fp32, atol=0, rtol=0)
logger.info(f"CHRISLOG: baseline grad and new grad passed allclose check")

# logger.info(f"CHRISLOG: assigning new unsharded_main_grad with size {new_unsharded_main_grad_in_fp32.size()}, type:{new_unsharded_main_grad_in_fp32.dtype}, original grad size {param.grad.size()}")
# if getattr(param, "unsharded_main_grad", None) is None:
# param.unsharded_main_grad = param.grad.to(torch.float32)
Expand Down

0 comments on commit ad7aa1f

Please sign in to comment.