Skip to content

Avoid TP helper imports for non-TP distributed LoRA loads#3261

Open
jiqing-feng wants to merge 3 commits into
huggingface:mainfrom
jiqing-feng:dev
Open

Avoid TP helper imports for non-TP distributed LoRA loads#3261
jiqing-feng wants to merge 3 commits into
huggingface:mainfrom
jiqing-feng:dev

Conversation

@jiqing-feng
Copy link
Copy Markdown
Contributor

Summary

Fix a ModuleNotFoundError in distributed LoRA adapter loading when torch.distributed is initialized but the model is not actually tensor-parallel sharded.

Problem

set_peft_model_state_dict calls _maybe_shard_state_dict_for_tp whenever distributed is initialized. This also happens for regular DDP runs. The helper imported transformers.integrations.tensor_parallel before checking whether any LoRA layer had TP metadata, so non-TP distributed runs could fail immediately with:

ModuleNotFoundError: No module named 'transformers.integrations.tensor_parallel'

Fix

Check for real TP-sharded LoRA layers first by looking for _hf_tp_plan and _hf_device_mesh on the base layers. If none are found, return early. Only import Transformers TP helpers when there is actually something to shard.

This keeps the existing TP path intact while avoiding false-positive TP handling for DDP/non-TP workloads.

Test

Added a regression test for a LoRA model with a model-level _tp_plan but no layer-level TP metadata:

python -m pytest -o addopts="" tests/test_other.py::test_maybe_shard_state_dict_for_tp_noops_without_tp_layers -q

Result: 1 passed.

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @BenjaminBossan . Would you please review the PR? Thanks!

Copy link
Copy Markdown
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this PR. I agree that we can check for TP layers first before importing the TP related class, so that part of the PR looks good. What I don't quite get is the test. Under what circumstances would we expect the test to fail? Could you perhaps add a comment to explain?

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Thank you for this PR. I agree that we can check for TP layers first before importing the TP related class, so that part of the PR looks good. What I don't quite get is the test. Under what circumstances would we expect the test to fail? Could you perhaps add a comment to explain?

Thanks for the feedback. I updated the test comment to make the intent clearer.

It is a unit test for _maybe_shard_state_dict_for_tp. It verifies that when no LoRA base layer has TP metadata (_hf_tp_plan / _hf_device_mesh), the helper is a no-op and leaves the state dict unchanged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants