Skip to content

Remove unused FSDP1 components #1933

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ PEFT Components
peft.get_adapter_params
peft.set_trainable_params
peft.validate_missing_and_unexpected_for_lora
peft.validate_state_dict_for_lora
peft.disable_adapter


Expand Down
3 changes: 0 additions & 3 deletions docs/source/api_ref_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,9 @@ Utilities for enabling and working with distributed training.
:toctree: generated/
:nosignatures:

FSDPPolicyType
init_distributed
is_distributed
get_world_size_and_rank
get_full_finetune_fsdp_wrap_policy
lora_fsdp_wrap_policy

.. _ac_label:

Expand Down
14 changes: 0 additions & 14 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
get_merged_lora_ckpt,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
validate_state_dict_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface

Expand Down Expand Up @@ -270,19 +269,6 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

validate_state_dict_for_lora(
lora_attn_modules=cfg_model.lora_attn_modules,
apply_lora_to_mlp=cfg_model.apply_lora_to_mlp,
apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False),
full_model_state_dict_keys=model.state_dict().keys(),
lora_state_dict_keys=(
lora_weights_state_dict.keys()
if lora_weights_state_dict is not None
else None
),
base_model_state_dict_keys=base_model_state_dict.keys(),
)

base_missing, base_unexpected = model.load_state_dict(
base_model_state_dict, strict=False
)
Expand Down
5 changes: 2 additions & 3 deletions tests/torchtune/modules/peft/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
LoRALinear,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
validate_state_dict_for_lora,
)

N_LAYERS = 3
Expand Down Expand Up @@ -384,7 +383,7 @@ def test_validate_lora_state_dict(
)
if expected:
with pytest.raises(AssertionError, match=expected):
validate_state_dict_for_lora(
validate_missing_and_unexpected_for_lora(
lora_attn_modules,
apply_lora_to_mlp,
apply_lora_to_output,
Expand All @@ -393,7 +392,7 @@ def test_validate_lora_state_dict(
base_model_state_dict_keys=base_model_state_dict_keys,
)
else:
validate_state_dict_for_lora(
validate_missing_and_unexpected_for_lora(
lora_attn_modules,
apply_lora_to_mlp,
apply_lora_to_output,
Expand Down
131 changes: 2 additions & 129 deletions tests/torchtune/training/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,22 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

import copy
from itertools import chain

import pytest
import torch
import torch.nn as nn
from packaging import version
from tests.test_utils import gpu_test, single_box_init
from tests.test_utils import gpu_test
from torch.distributed import launcher

from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointWrapper,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from torch.testing._internal.common_fsdp import FSDPTest, MLP
from torchao.dtypes.nf4tensor import NF4Tensor
from torchtune import modules, training
from torchtune.models.llama2._component_builders import llama2, lora_llama2
from torchtune.models.llama3._component_builders import llama3
from torchtune.models.llama2._component_builders import lora_llama2
from torchtune.modules import TransformerSelfAttentionLayer
from torchtune.modules.peft import (
DoRALinear,
Expand Down Expand Up @@ -110,51 +106,6 @@ def test_validate_no_params_on_meta_device(self) -> None:
with pytest.raises(RuntimeError, match="Unexpected param or buffer"):
training.validate_no_params_on_meta_device(model)

def test_get_fsdp_wrap_policies(self) -> None:
with single_box_init():
llama3_policy = training.get_full_finetune_fsdp_wrap_policy(
memory_efficient_fsdp_wrap=True,
modules_to_wrap={modules.TransformerSelfAttentionLayer},
)
l3 = llama3(
vocab_size=64,
num_layers=1,
num_heads=4,
num_kv_heads=4,
embed_dim=64,
max_seq_len=128,
)
wrapped_l3 = FSDP(
l3, auto_wrap_policy=llama3_policy, device_id=torch.device("cpu")
)
# Ensure embedding, output proj, and transformer decoder blocks are wrapped
assert isinstance(wrapped_l3.tok_embeddings, FSDP)
assert isinstance(wrapped_l3.output, FSDP)
for layer in wrapped_l3.layers:
assert isinstance(layer, FSDP)

llama2_policy = training.get_full_finetune_fsdp_wrap_policy(
memory_efficient_fsdp_wrap=False,
modules_to_wrap={modules.TransformerSelfAttentionLayer},
)
l2 = llama2(
vocab_size=64,
num_layers=1,
num_heads=4,
num_kv_heads=4,
embed_dim=64,
max_seq_len=128,
)
wrapped_l2 = FSDP(
l2, auto_wrap_policy=llama2_policy, device_id=torch.device("cpu")
)
# Ensure embedding, output proj, and transformer decoder blocks are not wrapped
assert not isinstance(wrapped_l2.tok_embeddings, FSDP)
assert not isinstance(wrapped_l2.output, FSDP)
# Ensure transformer decoder blocks are wrapped
for layer in wrapped_l2.layers:
assert isinstance(layer, FSDP)


N_LAYERS = 3
IN_DIM = 5
Expand All @@ -181,84 +132,6 @@ def _get_n_lora_and_tformer_layers(model):
return num_lora_ab, num_transformer_layers


# TODO: figure out a permanent home for FSDP + LoRA code
class TestLoRAFSDP:
def test_lora_fsdp_wrap(self):
with torch.device("meta"):
model = lora_llama2(
lora_attn_modules=["q_proj", "v_proj"],
vocab_size=VOCAB_SIZE,
num_layers=N_LAYERS,
num_heads=NUM_HEADS,
num_kv_heads=NUM_KV_HEADS,
embed_dim=EMBED_DIM,
max_seq_len=MAX_SEQ_LEN,
lora_rank=4,
lora_alpha=1.0,
)

adapter_params = get_adapter_params(model)
set_trainable_params(model, adapter_params)
num_lora_ab, num_transformer_layers = _get_n_lora_and_tformer_layers(model)
with single_box_init():
lora_wrap_policy = training.lora_fsdp_wrap_policy(
modules_to_wrap={TransformerSelfAttentionLayer}
)
training.prepare_model_for_fsdp_with_meta_device(model)
wrapped_lora = FSDP(
model,
auto_wrap_policy=lora_wrap_policy,
device_id=torch.device("cpu"),
)

# After FSDP wrap, nothing should be left on meta device, and LoRA params
# should be initialized.
for p in chain(wrapped_lora.parameters(), wrapped_lora.buffers()):
assert not p.is_meta

for m in wrapped_lora.modules():
if isinstance(m, LoRALinear) or isinstance(m, DoRALinear):
torch.testing.assert_close(
m.lora_b.weight, torch.zeros_like(m.lora_b.weight)
)
# Total # FSDP modules should be num_transformer + num_lora_ab + 1
total_fsdp_submodules = len([m for m in FSDP.fsdp_modules(wrapped_lora)])
assert total_fsdp_submodules == (num_lora_ab + num_transformer_layers + 1)
# LoRA a & b linears should be individually wrapped.
# And TransformerSelfAttentionLayers should be individually wrapped.
for fsdp_submodule in FSDP.fsdp_modules(wrapped_lora):
if isinstance(fsdp_submodule.module, nn.Linear):
num_lora_ab -= 1
elif isinstance(fsdp_submodule.module, TransformerSelfAttentionLayer):
num_transformer_layers -= 1
assert num_lora_ab == 0
assert num_transformer_layers == 0

def test_lora_meta_device_init_fsdp(self):
with torch.device("meta"):
lora = lora_llama2(
lora_attn_modules=["q_proj", "v_proj"],
vocab_size=VOCAB_SIZE,
num_layers=N_LAYERS,
num_heads=NUM_HEADS,
num_kv_heads=NUM_KV_HEADS,
embed_dim=EMBED_DIM,
max_seq_len=MAX_SEQ_LEN,
lora_rank=4,
lora_alpha=8,
)
training.prepare_model_for_fsdp_with_meta_device(lora)
for m in lora.modules():
m.to_empty(device=torch.device("cpu"), recurse=False)
m.reset_parameters()
# No params should be left on meta device
for n, p in lora.named_parameters():
assert not p.is_meta, f"parameter {n} is still on meta device!"
# Neither should buffers
for n, b in lora.named_buffers():
assert not b.is_meta, f"buffer {n} is still on meta device!"


class TestFullyShardState(FSDPTest):
@property
def world_size(self) -> int:
Expand Down
2 changes: 0 additions & 2 deletions torchtune/modules/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
LORA_ATTN_MODULES,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
validate_state_dict_for_lora,
)
from .dora import DoRALinear
from .lora import LoRALinear
Expand All @@ -27,7 +26,6 @@
"get_adapter_params",
"set_trainable_params",
"validate_missing_and_unexpected_for_lora",
"validate_state_dict_for_lora",
"load_dora_magnitudes",
"disable_adapter",
"get_merged_lora_ckpt",
Expand Down
85 changes: 0 additions & 85 deletions torchtune/modules/peft/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,91 +107,6 @@ def get_lora_module_names(
return lora_module_keys


def validate_state_dict_for_lora(
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool,
apply_lora_to_output: bool,
full_model_state_dict_keys: List[str],
lora_state_dict_keys: Optional[List[str]] = None,
base_model_state_dict_keys: Optional[List[str]] = None,
) -> None:
"""
Validate that the state dict keys for a LoRA model are as expected.

(1) If lora_state_dict_keys are passed, this function will confirm that they match exactly the
LoRA param names from the full model (as determined by lora_modules).
(2) If base_model_state_dict_keys are passed, this function will confirm that they are exactly the
complement of the LoRA param names from the full model.
(3) If both lora_state_dict_keys and base_model_state_dict_keys are passed, this function will
confirm that the full model's params are exactly their disjoint union.

Args:
lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers
LoRA should be applied to in each self-attention block. Options are
``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
apply_lora_to_mlp (bool): whether LoRA is applied to each MLP linear.
apply_lora_to_output (bool): whether LoRA is applied to the final output projection.
full_model_state_dict_keys (List[str]): List of keys in the full model state dict.
lora_state_dict_keys (Optional[List[str]]): List of keys in the LoRA state dict.
If none, LoRA state dict keys will not be validated.
base_model_state_dict_keys (Optional[List[str]]): List of keys in the base model state dict.
If none, base model keys will not be validated.

Returns:
None

Raises:
AssertionError: If base model state dict is missing any non-LoRA params from the full model.
AssertionError: If LoRA state dict is missing any LoRA params from the full model.
AssertionError: If base model state dict has any LoRA params.
AssertionError: If LoRA state dict has any non-LoRA params.
AssertionError: If base model and LoRA state dicts have overlapping keys.
AssertionError: If full model state dict is missing keys from either base model or LoRA state dict.

"""
lora_modules = get_lora_module_names(
lora_attn_modules, apply_lora_to_mlp, apply_lora_to_output
)
is_lora_param = lambda x: any(
[
".".join([k, "lora"]) in x or ".".join([k, "magnitude"]) in x
for k in lora_modules
]
)
for k in full_model_state_dict_keys:
if not is_lora_param(k):
if base_model_state_dict_keys is not None:
if k not in base_model_state_dict_keys:
raise AssertionError(
f"Missing non-LoRA key {k} from base model state dict"
)
if lora_state_dict_keys is not None:
if k in lora_state_dict_keys:
raise AssertionError(f"Non-LoRA key {k} found in LoRA state dict")
else:
if base_model_state_dict_keys is not None:
if k in base_model_state_dict_keys:
raise AssertionError(f"LoRA key {k} found in base model state dict")
if lora_state_dict_keys is not None:
if k not in lora_state_dict_keys:
raise AssertionError(f"Missing LoRA key {k} From LoRA state dict")

# Full model is disjoint union of base model and LoRA weights
if lora_state_dict_keys is not None and base_model_state_dict_keys is not None:
combined_state_dict_keys = set(lora_state_dict_keys).union(
base_model_state_dict_keys
)
shared_state_dict_keys = set(lora_state_dict_keys).intersection(
base_model_state_dict_keys
)
assert (
shared_state_dict_keys == set()
), "Base model and LoRA state dict have overlapping keys"
assert combined_state_dict_keys == set(
full_model_state_dict_keys
), "Extra keys not present in full model"


def _get_lora_modules(state_dict: Dict[str, Any]) -> Set[str]:
"""
Get the keys from a state dict that correspond to LoRALinear modules.
Expand Down
10 changes: 0 additions & 10 deletions torchtune/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from torchtune.training._activation_offloading import NoOpManager, OffloadActivations
from torchtune.training._compile import compile_loss, compile_model
from torchtune.training._distributed import (
contains_fsdp,
FSDPPolicyType,
get_full_finetune_fsdp_wrap_policy,
get_full_model_state_dict,
get_full_optimizer_state_dict,
get_shard_conditions,
Expand All @@ -17,8 +14,6 @@
is_distributed,
load_from_full_model_state_dict,
load_from_full_optimizer_state_dict,
lora_fsdp_wrap_policy,
prepare_model_for_fsdp_with_meta_device,
set_torch_num_threads,
shard_model,
validate_no_params_on_meta_device,
Expand Down Expand Up @@ -108,12 +103,7 @@
"set_torch_num_threads",
"shard_model",
"get_shard_conditions",
"prepare_model_for_fsdp_with_meta_device",
"validate_no_params_on_meta_device",
"contains_fsdp",
"FSDPPolicyType",
"get_full_finetune_fsdp_wrap_policy",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to my comment above, I think this will also need to be removed from the torchtune.training API ref doc.

Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Nov 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's also a reference to this in the QAT recipe in the note just below this heading. Should this note just be removed? @ebsmothers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry started writing out a comment to this but I guess I never hit send. Yes we should remove the note about memory_efficient_fsdp_wrap. You can leave the field in any QAT configs as that'll get handled in #1854 but you can remove from e.g. here.

"lora_fsdp_wrap_policy",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here to be removed from the training API docs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that would be good to remove as well

"get_full_model_state_dict",
"get_full_optimizer_state_dict",
"load_from_full_model_state_dict",
Expand Down
Loading