Skip to content

Commit

Permalink
Changed to only run reshard hook if all gradients computed (#1166)
Browse files Browse the repository at this point in the history
* Changed to only run reshard hook if all gradients computed

* Fix decreasing it/s with multi-grad hook
  • Loading branch information
awgu authored and chrisxcai committed May 15, 2024
1 parent 9d0e41e commit e43a22f
Showing 1 changed file with 73 additions and 3 deletions.
76 changes: 73 additions & 3 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Mapping,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
Union,
Expand All @@ -42,6 +43,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.utils.hooks import RemovableHandle

from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap
Expand Down Expand Up @@ -1678,12 +1680,9 @@ def _register_post_backward_hooks(self) -> None:
def _register_post_backward_reshard_hooks(
self, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> None:
if not hasattr(torch.autograd.graph, "register_multi_grad_hook"):
return # unsupported
if not torch.is_grad_enabled():
return
from torch.utils._pytree import tree_flatten
from torch.autograd.graph import register_multi_grad_hook
# Construct `inp_tensors` lazily to avoid CPU overhead in typical case
# where each parameter requires gradient
inp_tensors: Optional[List[torch.Tensor]] = None
Expand Down Expand Up @@ -2867,3 +2866,74 @@ def auto_wrap_bn(
enable_wrap(config_auto_wrap_policy, wrapper_cls=FullyShardedDataParallel) if wrap_it else contextlib.suppress()
):
return auto_wrap(module)


class Handle(RemovableHandle):
handles: Tuple[RemovableHandle, ...]

def __init__(self, handles: Tuple[RemovableHandle, ...]):
self.handles = handles

def remove(self):
for handle in self.handles:
handle.remove()

def __getstate__(self):
return self.handles

def __setstate__(self, state):
self.handles = state


def register_multi_grad_hook(
tensors: Sequence[torch.Tensor],
fn: Callable[[Sequence[Optional[torch.Tensor]]], None]
):
count: Dict[int, int] = dict()
nb_calls = None
buffer: Dict[int, List[Optional[torch.Tensor]]] = dict()

grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors))
len_tensors = len(tensors)

def get_inner_hook(idx):
def inner_hook(grad: torch.Tensor):
nonlocal count, nb_calls, buffer, fn
id = torch._C._current_graph_task_id()
assert (
id != -1
), "expected this hook to be called inside a backward call"
count[id] = count.get(id, 0)
buffer[id] = buffer.get(id, [None] * len_tensors)

if count[id] == 0:
# On the first call, compute the actual nb_calls and buffer
# nb_calls = sum(torch._C._will_engine_execute_node(g) for g in grad_fns) # type: ignore[attr-defined]

# NOTE: To avoid resharding too early when microbatches share
# some same module inputs, let us require all gradients to be
# computed in this backward for the hook to run.
nb_calls = len(grad_fns)

buffer[id][idx] = grad
count[id] += 1

if count[id] == nb_calls:
fn = cast(Callable[[Sequence[Optional[torch.Tensor]]], None], fn)
fn(buffer[id])
del count[id]
del buffer[id]

return inner_hook

handles: Tuple[RemovableHandle, ...] = tuple(
t.register_hook(get_inner_hook(i)) for i, t in enumerate(tensors)
)
return Handle(handles)


def _get_grad_fn_or_grad_acc(t):
if t.requires_grad and t.grad_fn is None:
return t.view_as(t).grad_fn.next_functions[0][0]
else:
return t.grad_fn

0 comments on commit e43a22f

Please sign in to comment.