Skip to content

Commit

Permalink
add fsdp model check
Browse files Browse the repository at this point in the history
  • Loading branch information
hanwen-sun committed Sep 24, 2024
1 parent ddfbf39 commit b866dde
Showing 1 changed file with 31 additions and 12 deletions.
43 changes: 31 additions & 12 deletions torchacc/dist/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,16 @@ def sharded_optim_state_dict(
Dict[str, Any]: A :class:`dict` containing the optimizer state for
fsdp model. Each rank get the sharded optim state added with shard_metadata.
"""
if not isinstance(model, xla_fsdp.XlaFullyShardedDataParallel):
if not isinstance(
model, xla_fsdp.XlaFullyShardedDataParallel) and not isinstance(
model, FullyShardedDataParallel):
raise NotImplementedError(
"The model must be torchacc or xla FSDP model")
assert isinstance(model,
xla_fsdp.XlaFullyShardedDataParallel) or isinstance(
model, FullyShardedDataParallel)

if isinstance(model, FullyShardedDataParallel):
model = model.model

optimizer = {
Expand Down Expand Up @@ -271,11 +280,16 @@ def full_optim_state_dict(model: torch.nn.Module,
:meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=True``,
then nonzero ranks return an :class:`dict` with keys but empty value.
"""
if not isinstance(model, xla_fsdp.XlaFullyShardedDataParallel):
if not hasattr(model, 'model'):
raise NotImplementedError(
"The model passed in must be torchacc or xla FSDP model")
assert hasattr(model, 'model')
if not isinstance(
model, xla_fsdp.XlaFullyShardedDataParallel) and not isinstance(
model, FullyShardedDataParallel):
raise NotImplementedError(
"The model must be torchacc or xla FSDP model")
assert isinstance(model,
xla_fsdp.XlaFullyShardedDataParallel) or isinstance(
model, FullyShardedDataParallel)

if isinstance(model, xla_fsdp.XlaFullyShardedDataParallel):
model = model.model

shard_meta_data = model.get_shard_metadata()
Expand Down Expand Up @@ -346,18 +360,23 @@ def load_optim_state_dict(model: torch.nn.Module,
passed into the optimizer whose state_dict is ``optim_state_dict``.
optim_state_dict (Dict[str, Any]): The optimizer states to be loaded.
rank0_only: (bool): control whether load state_dict only from
rank0 at the begining.(Default: ``True``) If set to True,
rank0 at the begining.(Default: ``True``). If set to True,
nonzero ranks should pass None in.
Returns:
Dict[str, Any]: A :class:`dict` containing the optimizer state for
model which is sharded.
"""
if not isinstance(model, xla_fsdp.XlaFullyShardedDataParallel):
if not hasattr(model, 'model'):
raise NotImplementedError(
"The model passed in must be torchacc or xla FSDP model")
assert hasattr(model, 'model')
if not isinstance(
model, xla_fsdp.XlaFullyShardedDataParallel) and not isinstance(
model, FullyShardedDataParallel):
raise NotImplementedError(
"The model must be torchacc or xla FSDP model")
assert isinstance(model,
xla_fsdp.XlaFullyShardedDataParallel) or isinstance(
model, FullyShardedDataParallel)

if isinstance(model, FullyShardedDataParallel):
model = model.model

shard_meta_data = model.get_shard_metadata()
Expand Down

0 comments on commit b866dde

Please sign in to comment.