Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better communication computation overlap #2811

Merged
merged 13 commits into from
Jan 2, 2024
8 changes: 6 additions & 2 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def patch_pytorch():
ChunkShardingSpec.build_metadata = build_metadata

elif version.parse(torch.__version__) < version.parse('2.1.1'):
# Monkey path for torch < 2.1.1 ie torch == 2.1.0
# Monkey patch for torch < 2.1.1 ie torch == 2.1.0

# Monkey patch sharding method
ChunkShardingSpec.build_metadata = build_metadata
Expand All @@ -55,8 +55,12 @@ def patch_pytorch():
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

elif version.parse(torch.__version__) < version.parse('2.2.0'):
# Monkey path for torch < 2.2.0 ie torch == 2.1.1, 2.1.2
# Monkey patch for torch < 2.2.0 ie torch == 2.1.1, 2.1.2

# Allow 2D HSDP
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

# Better overlap communication and computation
from composer.trainer.mosaic_fsdp_utils import new_share_state_and_init_handle_attrs
_runtime_utils._share_state_and_init_handle_attrs = new_share_state_and_init_handle_attrs
122 changes: 121 additions & 1 deletion composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torch.distributed import ProcessGroup
from torch.distributed._shard.sharding_spec import ShardMetadata
from torch.distributed._shard.sharding_spec._internals import get_chunked_dim_size, get_split_size
from torch.distributed.distributed_c10d import get_process_group_ranks
from torch.distributed.fsdp import (BackwardPrefetch, CPUOffload, FullyShardedDataParallel, MixedPrecision,
ShardingStrategy)
from torch.distributed.fsdp._fsdp_extensions import _ext_pre_load_state_dict_transform
Expand All @@ -31,7 +32,7 @@

if TYPE_CHECKING:
if version.parse(torch.__version__) >= version.parse('2.0.1') and version.parse(
torch.__version__) < version.parse('2.0.2'):
torch.__version__) < version.parse('2.2.0'):
from torch.distributed.fsdp._common_utils import _FSDPState

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -753,3 +754,122 @@ def _sharded_pre_load_state_dict_hook(
state_dict[fqn_from_global_root] = param.to_local()

_enter_unshard_params_ctx(module, fsdp_state, writeback=True)


def fsdp_state_has_default_pg(state: '_FSDPState') -> bool:
"""Indicates whether FlatParamHandle has the default process group.

Args:
handle (_FSDPState): FSDP State object

Returns:
bool: True if the ProcessGroup of the _FSDPState object is the default process group. False
otherwise.
"""
if state.process_group is None:
# If no process group is attached to the _FSDPState, assume it uses default process group.
return True
return len(get_process_group_ranks(state.process_group)) == dist.get_world_size()


def fsdp_state_pg_ranks(state: '_FSDPState') -> Tuple[int, ...]:
"""Gets the ranks included in the ProcessGroup of an _FSDPState.

Args:
state (_FSDPState): FSDP State object

Returns:
Tuple[int]: Ranks for the FSDP State's process group.
"""
if state.process_group is None:
# If no process group is attached to the _FSDPState, assume it uses default process group.
return tuple(range(dist.get_world_size()))
else:
return tuple(get_process_group_ranks(state.process_group))


@no_type_check
def new_share_state_and_init_handle_attrs(
root_state: '_FSDPState',
root_module: nn.Module,
) -> None:
"""Shares state from ``root_state`` to other FSDP states.

Shares data structure state from the ``root_state`` to all FSDP states in
``root_module`` 's module tree, and initializes handle attributes. These are
done together to require a single loop over the states. This function has
been modified to assign a different unshard stream to each process group.
"""
from torch.distributed.fsdp._runtime_utils import (HOMOGENEOUS_ATTR_NAMES, _init_device_mesh,
_validate_and_get_hybrid_shard_state)
from torch.distributed.utils import _p_assert

handle = root_state._handle
if handle:
handle.init_flat_param_attributes()
_validate_and_get_hybrid_shard_state(root_module)
attr_name_to_values: Dict[str, Set[Any]] = {}
for attr_name in HOMOGENEOUS_ATTR_NAMES:
attr_name_to_values[attr_name] = set()
root_state._all_handles = root_state._exec_order_data.all_handles # share reference
root_state._device_mesh = _init_device_mesh(root_state)
# Update _has_optim_in_backward for each handle.
for handle in root_state._all_handles:
flat_param = handle.flat_param
if hasattr(flat_param, '_in_backward_optimizers'):
raise RuntimeError('FSDP optimizer in backward only supported with use_orig_params=True!')
handle._has_optim_in_backward = flat_param._params is not None and any(
hasattr(param, '_in_backward_optimizers') for param in flat_param._params)

# Patching so that _FSDPStates with different process groups have separate unshard streams.
# Keep track of any new unshard streams we may have to add for specific process groups.
fsdp_pg_unshard_streams = {}
unshard_priority = root_state._unshard_stream.priority
for fsdp_state in root_state._all_fsdp_states:
for attr_name in HOMOGENEOUS_ATTR_NAMES:
_p_assert(
hasattr(fsdp_state, attr_name),
f'FSDP state missing attribute {attr_name}',
)
attr_name_to_values[attr_name].add(getattr(fsdp_state, attr_name))
if fsdp_state is root_state:
continue
# Relax the assert for non-root FSDP instances in case the nested
# initialized module is wrapped again in FSDP later (e.g. after
# training to run inference)
_p_assert(
fsdp_state._is_root is None or not fsdp_state._is_root,
"Non-root FSDP instance's `_is_root` should not have been "
'set yet or should have been set to `False`',
)
fsdp_state._is_root = False

# Take care of any new unshard streams we have to create for non-default process groups.
if fsdp_state_has_default_pg(fsdp_state):
# If using default process group, unshard stream is the same as root fsdp instance.
fsdp_state._unshard_stream = root_state._unshard_stream
else:
# Otherwise, unshard stream is separate.
state_pg_ranks = fsdp_state_pg_ranks(fsdp_state)
if state_pg_ranks in fsdp_pg_unshard_streams:
# We have created the unshard stream for this process group already. Use it.
fsdp_state._unshard_stream = fsdp_pg_unshard_streams[state_pg_ranks]
else:
# We don't have an unshard stream for this process group yet. Make it.
fsdp_state._unshard_stream = fsdp_state._device_handle.Stream(priority=unshard_priority)
fsdp_pg_unshard_streams[state_pg_ranks] = fsdp_state._unshard_stream

# All other stream assignments stay common across all of FSDP.
fsdp_state._post_backward_stream = root_state._post_backward_stream
fsdp_state._pre_unshard_stream = root_state._pre_unshard_stream
fsdp_state._all_reduce_stream = root_state._all_reduce_stream
fsdp_state._default_stream = root_state._default_stream
fsdp_state._exec_order_data = root_state._exec_order_data
fsdp_state._free_event_queue = root_state._free_event_queue
fsdp_state._device_mesh = root_state._device_mesh
handle = fsdp_state._handle
if handle:
handle.init_flat_param_attributes()
for attr_name, attr_values in attr_name_to_values.items():
if len(attr_values) != 1:
raise ValueError(f'Expects one homogeneous value for {attr_name} but got {attr_values}')
Loading