Skip to content

Commit

Permalink
Better communication computation overlap (#2811)
Browse files Browse the repository at this point in the history
* patched torch

* fixed torch imports

* fixed torch imports

* fixed torch imports

* patching through composer

* patching through composer

* patching typingr

* comment added

* don't patch torch 2.1.0

* patch torch 2.1.1 and 2.2.0

* linting fix
  • Loading branch information
snarayan21 authored Jan 2, 2024
1 parent ee7cb69 commit 52ac18c
Show file tree
Hide file tree
Showing 2 changed files with 222 additions and 3 deletions.
17 changes: 15 additions & 2 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def patch_pytorch():
ChunkShardingSpec.build_metadata = build_metadata

elif version.parse(torch.__version__) < version.parse('2.1.1'):
# Monkey path for torch < 2.1.1 ie torch == 2.1.0
# Monkey patch for torch < 2.1.1 ie torch == 2.1.0

# Monkey patch sharding method
ChunkShardingSpec.build_metadata = build_metadata
Expand All @@ -55,8 +55,21 @@ def patch_pytorch():
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

elif version.parse(torch.__version__) < version.parse('2.2.0'):
# Monkey path for torch < 2.2.0 ie torch == 2.1.1, 2.1.2
# Monkey patch for torch < 2.2.0 ie torch == 2.1.1, 2.1.2

# Allow 2D HSDP
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

# Better overlap communication and computation
from composer.trainer.mosaic_fsdp_utils import _share_state_and_init_handle_attrs_t2p1
_runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p1

elif version.parse(torch.__version__) < version.parse('2.2.1'):
# Monkey patch for torch < 2.2.1 ie torch == 2.2.0

# Better overlap communication and computation
from torch.distributed.fsdp import _runtime_utils

from composer.trainer.mosaic_fsdp_utils import _share_state_and_init_handle_attrs_t2p2
_runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p2
208 changes: 207 additions & 1 deletion composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torch.distributed import ProcessGroup
from torch.distributed._shard.sharding_spec import ShardMetadata
from torch.distributed._shard.sharding_spec._internals import get_chunked_dim_size, get_split_size
from torch.distributed.distributed_c10d import get_process_group_ranks
from torch.distributed.fsdp import (BackwardPrefetch, CPUOffload, FullyShardedDataParallel, MixedPrecision,
ShardingStrategy)
from torch.distributed.fsdp._fsdp_extensions import _ext_pre_load_state_dict_transform
Expand All @@ -31,7 +32,7 @@

if TYPE_CHECKING:
if version.parse(torch.__version__) >= version.parse('2.0.1') and version.parse(
torch.__version__) < version.parse('2.0.2'):
torch.__version__) < version.parse('2.2.0'):
from torch.distributed.fsdp._common_utils import _FSDPState

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -753,3 +754,208 @@ def _sharded_pre_load_state_dict_hook(
state_dict[fqn_from_global_root] = param.to_local()

_enter_unshard_params_ctx(module, fsdp_state, writeback=True)


def fsdp_state_has_default_pg(state: '_FSDPState') -> bool:
"""Indicates whether FlatParamHandle has the default process group.
Args:
handle (_FSDPState): FSDP State object
Returns:
bool: True if the ProcessGroup of the _FSDPState object is the default process group. False
otherwise.
"""
if state.process_group is None:
# If no process group is attached to the _FSDPState, assume it uses default process group.
return True
return len(get_process_group_ranks(state.process_group)) == dist.get_world_size()


def fsdp_state_pg_ranks(state: '_FSDPState') -> Tuple[int, ...]:
"""Gets the ranks included in the ProcessGroup of an _FSDPState.
Args:
state (_FSDPState): FSDP State object
Returns:
Tuple[int]: Ranks for the FSDP State's process group.
"""
if state.process_group is None:
# If no process group is attached to the _FSDPState, assume it uses default process group.
return tuple(range(dist.get_world_size()))
else:
return tuple(get_process_group_ranks(state.process_group))


@no_type_check
def _share_state_and_init_handle_attrs_t2p1(
root_state: '_FSDPState',
root_module: nn.Module,
) -> None:
"""Shares state from ``root_state`` to other FSDP states.
Shares data structure state from the ``root_state`` to all FSDP states in
``root_module`` 's module tree, and initializes handle attributes. These are
done together to require a single loop over the states. This function has
been modified to assign a different unshard stream to each process group.
"""
from torch.distributed.fsdp._runtime_utils import (HOMOGENEOUS_ATTR_NAMES, _init_device_mesh,
_validate_and_get_hybrid_shard_state)
from torch.distributed.utils import _p_assert

handle = root_state._handle
if handle:
handle.init_flat_param_attributes()
_validate_and_get_hybrid_shard_state(root_module)
attr_name_to_values: Dict[str, Set[Any]] = {}
for attr_name in HOMOGENEOUS_ATTR_NAMES:
attr_name_to_values[attr_name] = set()
root_state._all_handles = root_state._exec_order_data.all_handles # share reference
root_state._device_mesh = _init_device_mesh(root_state)
# Update _has_optim_in_backward for each handle.
for handle in root_state._all_handles:
flat_param = handle.flat_param
if hasattr(flat_param, '_in_backward_optimizers'):
raise RuntimeError('FSDP optimizer in backward only supported with use_orig_params=True!')
handle._has_optim_in_backward = flat_param._params is not None and any(
hasattr(param, '_in_backward_optimizers') for param in flat_param._params)

# Patching so that _FSDPStates with different process groups have separate unshard streams.
# Keep track of any new unshard streams we may have to add for specific process groups.
fsdp_pg_unshard_streams = {}
unshard_priority = root_state._unshard_stream.priority
for fsdp_state in root_state._all_fsdp_states:
for attr_name in HOMOGENEOUS_ATTR_NAMES:
_p_assert(
hasattr(fsdp_state, attr_name),
f'FSDP state missing attribute {attr_name}',
)
attr_name_to_values[attr_name].add(getattr(fsdp_state, attr_name))
if fsdp_state is root_state:
continue
# Relax the assert for non-root FSDP instances in case the nested
# initialized module is wrapped again in FSDP later (e.g. after
# training to run inference)
_p_assert(
fsdp_state._is_root is None or not fsdp_state._is_root,
"Non-root FSDP instance's `_is_root` should not have been "
'set yet or should have been set to `False`',
)
fsdp_state._is_root = False

# Take care of any new unshard streams we have to create for non-default process groups.
if fsdp_state_has_default_pg(fsdp_state):
# If using default process group, unshard stream is the same as root fsdp instance.
fsdp_state._unshard_stream = root_state._unshard_stream
else:
# Otherwise, unshard stream is separate.
state_pg_ranks = fsdp_state_pg_ranks(fsdp_state)
if state_pg_ranks in fsdp_pg_unshard_streams:
# We have created the unshard stream for this process group already. Use it.
fsdp_state._unshard_stream = fsdp_pg_unshard_streams[state_pg_ranks]
else:
# We don't have an unshard stream for this process group yet. Make it.
fsdp_state._unshard_stream = fsdp_state._device_handle.Stream(priority=unshard_priority)
fsdp_pg_unshard_streams[state_pg_ranks] = fsdp_state._unshard_stream

# All other stream assignments stay common across all of FSDP.
fsdp_state._post_backward_stream = root_state._post_backward_stream
fsdp_state._pre_unshard_stream = root_state._pre_unshard_stream
fsdp_state._all_reduce_stream = root_state._all_reduce_stream
fsdp_state._default_stream = root_state._default_stream
fsdp_state._exec_order_data = root_state._exec_order_data
fsdp_state._free_event_queue = root_state._free_event_queue
fsdp_state._device_mesh = root_state._device_mesh
handle = fsdp_state._handle
if handle:
handle.init_flat_param_attributes()
for attr_name, attr_values in attr_name_to_values.items():
if len(attr_values) != 1:
raise ValueError(f'Expects one homogeneous value for {attr_name} but got {attr_values}')


@no_type_check
def _share_state_and_init_handle_attrs_t2p2(
root_state: '_FSDPState',
root_module: nn.Module,
) -> None:
"""Shares state from ``root_state`` to other FSDP states.
Shares data structure state from the ``root_state`` to all FSDP states in
``root_module`` 's module tree, and initializes handle attributes. These are
done together to require a single loop over the states. This function has
been modified to assign a different unshard stream to each process group.
"""
from torch.distributed.fsdp._runtime_utils import HOMOGENEOUS_ATTR_NAMES, _validate_and_get_hybrid_shard_state
from torch.distributed.utils import _p_assert

handle = root_state._handle
if handle:
handle.init_flat_param_attributes()
_validate_and_get_hybrid_shard_state(root_module)
attr_name_to_values: Dict[str, Set[Any]] = {}
for attr_name in HOMOGENEOUS_ATTR_NAMES:
attr_name_to_values[attr_name] = set()
root_state._all_handles = root_state._exec_order_data.all_handles # share reference
# Update _has_optim_in_backward for each handle.
for handle in root_state._all_handles:
flat_param = handle.flat_param
if hasattr(flat_param, '_in_backward_optimizers'):
raise RuntimeError('FSDP optimizer in backward only supported with use_orig_params=True!')
handle._has_optim_in_backward = flat_param._params is not None and any(
hasattr(param, '_in_backward_optimizers') for param in flat_param._params)
if handle._has_optim_in_backward:
torch._C._log_api_usage_once('fsdp.optimizer_in_backward')

# Patching so that _FSDPStates with different process groups have separate unshard streams.
# Keep track of any new unshard streams we may have to add for specific process groups.
fsdp_pg_unshard_streams = {}
unshard_priority = root_state._unshard_stream.priority
for fsdp_state in root_state._all_fsdp_states:
for attr_name in HOMOGENEOUS_ATTR_NAMES:
_p_assert(
hasattr(fsdp_state, attr_name),
f'FSDP state missing attribute {attr_name}',
)
attr_name_to_values[attr_name].add(getattr(fsdp_state, attr_name))
if fsdp_state is root_state:
continue
# Relax the assert for non-root FSDP instances in case the nested
# initialized module is wrapped again in FSDP later (e.g. after
# training to run inference)
_p_assert(
fsdp_state._is_root is None or not fsdp_state._is_root,
"Non-root FSDP instance's `_is_root` should not have been "
'set yet or should have been set to `False`',
)
fsdp_state._is_root = False

# Take care of any new unshard streams we have to create for non-default process groups.
if fsdp_state_has_default_pg(fsdp_state):
# If using default process group, unshard stream is the same as root fsdp instance.
fsdp_state._unshard_stream = root_state._unshard_stream
else:
# Otherwise, unshard stream is separate.
state_pg_ranks = fsdp_state_pg_ranks(fsdp_state)
if state_pg_ranks in fsdp_pg_unshard_streams:
# We have created the unshard stream for this process group already. Use it.
fsdp_state._unshard_stream = fsdp_pg_unshard_streams[state_pg_ranks]
else:
# We don't have an unshard stream for this process group yet. Make it.
fsdp_state._unshard_stream = fsdp_state._device_handle.Stream(priority=unshard_priority)
fsdp_pg_unshard_streams[state_pg_ranks] = fsdp_state._unshard_stream

# All other stream assignments stay common across all of FSDP.
fsdp_state._post_backward_stream = root_state._post_backward_stream
fsdp_state._pre_unshard_stream = root_state._pre_unshard_stream
fsdp_state._all_reduce_stream = root_state._all_reduce_stream
fsdp_state._default_stream = root_state._default_stream
fsdp_state._exec_order_data = root_state._exec_order_data
fsdp_state._free_event_queue = root_state._free_event_queue
handle = fsdp_state._handle
if handle:
handle.init_flat_param_attributes()
for attr_name, attr_values in attr_name_to_values.items():
if len(attr_values) != 1:
raise ValueError(f'Expects one homogeneous value for {attr_name} but got {attr_values}')

0 comments on commit 52ac18c

Please sign in to comment.