Skip to content

Commit

Permalink
Added requires_grad check for params_with_grad method (#1171)
Browse files Browse the repository at this point in the history
Co-authored-by: Jie Wang <jiewang@meta.com>
  • Loading branch information
2 people authored and chrisxcai committed May 15, 2024
1 parent a3ff5c4 commit 9d0e41e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ def _cast_buffers(
@property
def params_with_grad(self) -> List[Parameter]:
"""[p for p in self.parameters() if p.grad is not None]"""
return [p for p in self.parameters() if (p.grad is not None or p.main_grad is not None)]
return [p for p in self.parameters() if (p.requires_grad and (p.grad is not None or p.main_grad is not None))]

@torch.no_grad()
def clip_grad_norm_(
Expand Down

0 comments on commit 9d0e41e

Please sign in to comment.