Skip to content

Commit

Permalink
documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisxcai committed May 15, 2024
1 parent 5926a79 commit 5d08aa3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 38 deletions.
19 changes: 16 additions & 3 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,14 +1105,20 @@ def no_sync(self) -> Generator:
if isinstance(m, FullyShardedDataParallel):
old_flags.append((m, m._require_backward_grad_sync))
m._require_backward_grad_sync = False
m._fsdp_wrapped_module._require_backward_grad_sync = False
if self.optimize_backward_concat:
# Set the flag on the wrapped FlattenParamsWrapper module as well,
# so that FlattenParamsWrapper could accumulate grads at corresponding
# leaf nodes without triggering concat operations when gradient
# synchronization is not needed.
m._fsdp_wrapped_module._require_backward_grad_sync = False
try:
yield
finally:
for m, old_flag in old_flags:
assert m._require_backward_grad_sync is False
m._require_backward_grad_sync = old_flag
m._fsdp_wrapped_module._require_backward_grad_sync = old_flag
if self.optimize_backward_concat:
m._fsdp_wrapped_module._require_backward_grad_sync = old_flag

@contextlib.contextmanager
def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Generator:
Expand Down Expand Up @@ -1723,8 +1729,9 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
self._use_fp32_param_shard([param])

if self.fp32_reduce_scatter:

if self.optimize_backward_concat:
# Flatten and concat the accumulated fp32 grads
# and assign them to param.unsharded_main_grad
param.unsharded_main_grad = torch.cat([grad.flatten() for grad in self._fsdp_wrapped_module.fp32_grads])
# Clean up accumulated grads between data batches
self._fsdp_wrapped_module.fp32_grads = []
Expand Down Expand Up @@ -1866,6 +1873,9 @@ def _wait_for_post_backward(self) -> None:
# state will remain in `TrainingState.BACKWARD_PRE`.
if any([p.requires_grad for p in self.params]):
if self.optimize_backward_concat:
# If self.optimize_backward_concat==True, FSDP backward should
# only be triggered (which will invoke concat())
# when self._fsdp_wrapped_module._require_backward_grad_sync = True
if self._fsdp_wrapped_module._require_backward_grad_sync:
self.assert_state(TrainingState.BACKWARD_POST)
else:
Expand Down Expand Up @@ -1949,6 +1959,9 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
# state will remain in `TrainingState.BACKWARD_PRE`.
if any([p.requires_grad for p in m.params]):
if self.optimize_backward_concat:
# If self.optimize_backward_concat==True, FSDP backward should
# only be triggered (which will invoke concat())
# when self._fsdp_wrapped_module._require_backward_grad_sync = True
if self._fsdp_wrapped_module._require_backward_grad_sync:
m.assert_state(TrainingState.BACKWARD_POST)
else:
Expand Down
50 changes: 15 additions & 35 deletions fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from contextlib import contextmanager
import functools
from itertools import chain
from re import split
import tempfile
import typing
from typing import (
Expand All @@ -35,14 +34,10 @@

from fairscale.experimental.nn.ssd_offload import SsdFlatParameter
from fairscale.utils.state_dict import replace_by_prefix_
import functools

if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401

from logging import getLogger
logger = getLogger()

class FlatParameter(nn.Parameter):
"""A parameter that is initialized from a list of parameters and can be
turned into a list of views as needed.
Expand Down Expand Up @@ -95,36 +90,8 @@ def get_param_views(self, require_backward_grad_sync, external_data: Optional[Te
raise ValueError(
f"Incorrect numel of supplied data: got {data.numel()} but expected {sum(self._param_numels)}"
)
# logger.info(f"CHRISLOG: {data.numel()=}")
# logger.info(f"CHRISLOG: {self._param_numels=}")
# logger.info(f"CHRISLOG: {self._param_shapes=}")

# logger.info(f"CHRISLOG: {data.is_leaf=}, {data.grad_fn=}")

# def post_accumulate_grad_hook(
# param
# ):
# logger.info(f"CHRISLOG: cleaning up {param.grad=}")
# param.grad = None

# data.register_post_accumulate_grad_hook(
# functools.partial(
# post_accumulate_grad_hook
# )
# )
# logger.info("CHRISLOG: registered post_accumulate_grad_hook for bf16 grad cleanup on data")

split_outputs = data.split(self._param_numels)
# for split_output in split_outputs:
# logger.info(f"CHRISLOG: {require_backward_grad_sync=} {split_output.is_leaf=}, {split_output.grad_fn=}, {split_output.grad=}") #
# if not require_backward_grad_sync:
# split_output.register_hook(
# functools.partial(
# post_accumulate_grad_hook
# )
# )
# logger.info("CHRISLOG: registered post_accumulate_grad_hook for bf16 grad cleanup on split_output")

return (t.view(s) for (t, s) in zip(split_outputs, self._param_shapes))

def metadata(self) -> Tuple[List[str], List[torch.Size], List[int]]:
Expand Down Expand Up @@ -183,6 +150,11 @@ class FlattenParamsWrapper(nn.Module):
flat_param_names (Optional[List[str]]):
originally, give each flat_param a unique name. Note a "flat_param_"
prefix will be added to those names.
optimize_backward_concat (bool):
If True, only trigger the self.flat_params backward(), which will
invoke the parent FSDP module's _post_backward_hook() and concat() op,
when self._require_backward_grad_sync is True (e.g. last microbatch)
NOTE: this likely will incur more GPU memory usage
"""

def __init__(
Expand All @@ -197,9 +169,12 @@ def __init__(
super().__init__()
self._fpw_module = module
self.is_flattened = False

self.optimize_backward_concat = optimize_backward_concat
# If self.optimize_backward_concat == True, used to propagate the
# parent FSDP modules's _require_backward_grad_sync flag
self._require_backward_grad_sync = True
# If self.optimize_backward_concat == True, used to accumulate the
# fp32 gradients for the flattened parameters
self.fp32_grads = []

# Handle param_list being None.
Expand Down Expand Up @@ -404,7 +379,7 @@ def _unflatten_params(self, external_data: Optional[List[Optional[Tensor]]] = No
delattr(self, n)
self.flat_params = []


# The post backward hook used to accumulate fp32 gradients
def _grad_accumulation_hook(
self,
grad,
Expand Down Expand Up @@ -434,8 +409,12 @@ def _unflatten_params_as_views(self) -> None:
for (_, m, n), p in zip(self._param_infos, ps):
setattr(p, '_fsdp_weight', True)
setattr(m, n, p) # This will set as plain attr
# The param_index of p used to accumulate the correspnding
# gradients in self.fp32_grads
param_index = len(param_views)
if self.optimize_backward_concat:
# Register post backward hook to accumulate the gradients
# in self.fp32_grads
p.register_hook(
functools.partial(
self._grad_accumulation_hook,
Expand All @@ -445,6 +424,7 @@ def _unflatten_params_as_views(self) -> None:
param_views.append(p)

if self.optimize_backward_concat and len(self.fp32_grads) == 0:
# Allocate self.fp32_grads at the beginning of each data batch's forward()
self.fp32_grads = [None] * len(param_views)

# Save param views for easy access if anyone still wants to access
Expand Down

0 comments on commit 5d08aa3

Please sign in to comment.