Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cast input argument #1175

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def __init__(
gradient_predivide_factor: Optional[float] = None,
limit_all_gather_events: bool = False,
limit_reduce_scatter_events: bool = False,
cast_input: bool = True,
):
try:
import torch._C
Expand Down Expand Up @@ -420,6 +421,7 @@ def __init__(
self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward
self.disable_reshard_on_root = disable_reshard_on_root
self.mixed_precision = mixed_precision
self.cast_input = cast_input
self.fp32_reduce_scatter = fp32_reduce_scatter
self.flatten_parameters = flatten_parameters
self.move_params_to_cpu = move_params_to_cpu or cpu_offload
Expand Down Expand Up @@ -1431,7 +1433,7 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
# For root and mixed precision, we convert the input to FP16 (no_grad is needed for
# the conversion).
is_bf16 = self.compute_dtype == torch.bfloat16
if self._is_root and self.mixed_precision:
if self._is_root and self.mixed_precision and self.cast_input:
args, kwargs = cast_floats_to_right_precision(True, True, is_bf16, *args, **kwargs)

if self not in self._fsdp_forward_ordering:
Expand Down
Loading