Avoid TP helper imports for non-TP distributed LoRA loads#3261
Avoid TP helper imports for non-TP distributed LoRA loads#3261jiqing-feng wants to merge 3 commits into
Conversation
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
|
Hi @BenjaminBossan . Would you please review the PR? Thanks! |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thank you for this PR. I agree that we can check for TP layers first before importing the TP related class, so that part of the PR looks good. What I don't quite get is the test. Under what circumstances would we expect the test to fail? Could you perhaps add a comment to explain?
Thanks for the feedback. I updated the test comment to make the intent clearer. It is a unit test for |
Summary
Fix a
ModuleNotFoundErrorin distributed LoRA adapter loading whentorch.distributedis initialized but the model is not actually tensor-parallel sharded.Problem
set_peft_model_state_dictcalls_maybe_shard_state_dict_for_tpwhenever distributed is initialized. This also happens for regular DDP runs. The helper importedtransformers.integrations.tensor_parallelbefore checking whether any LoRA layer had TP metadata, so non-TP distributed runs could fail immediately with:Fix
Check for real TP-sharded LoRA layers first by looking for
_hf_tp_planand_hf_device_meshon the base layers. If none are found, return early. Only import Transformers TP helpers when there is actually something to shard.This keeps the existing TP path intact while avoiding false-positive TP handling for DDP/non-TP workloads.
Test
Added a regression test for a LoRA model with a model-level
_tp_planbut no layer-level TP metadata:python -m pytest -o addopts="" tests/test_other.py::test_maybe_shard_state_dict_for_tp_noops_without_tp_layers -qResult:
1 passed.