Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix training of pipeline based peft's lora model #5477

Merged
merged 9 commits into from
Oct 29, 2024

Conversation

xuanhua
Copy link
Contributor

@xuanhua xuanhua commented Apr 29, 2024

Hi, guys

I find there is an assert failure when I train huggingface's lora based model in pipeline style.

Here is the whole steps that I created my model:

  1. Load the pre-trained chatglm-6b model from huggingface, as Model_A
  2. Use huggingface's peft's get_peft_model(...) and my LoraConfig(...) from Model_A to create the lora model, as Model_B
  3. Create my own pipeline based model Model_C from Model_B

And I run Model_C under 2 3090ti GPUs. And the assertion failure looks like this:

Traceback (most recent call last):
  File "/home/ubuntu/proj/chatglm-finetuning/train_pipeline.py", line 372, in <module>
    main()
  File "/home/ubuntu/proj/chatglm-finetuning/train_pipeline.py", line 351, in main
    loss = engine.train_batch(data_iter=train_dataloader)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/runtime/pipe/engine.py", line 375, in train_batch
    self._exec_schedule(sched)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/runtime/pipe/engine.py", line 1375, in _exec_schedule
    self._exec_instr(**cmd.kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/runtime/pipe/engine.py", line 276, in _exec_reduce_tied_grads
    dist.all_reduce(grad, group=group)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/comm/comm.py", line 117, in log_wrapper
    return func(*args, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/comm/comm.py", line 496, in all_reduce
    return cdb.all_reduce(tensor, op, group, async_op)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/comm/torch.py", line 159, in all_reduce
    return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 1520, in all_reduce
    _check_single_tensor(tensor, "tensor")
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 463, in _check_single_tensor
    raise RuntimeError(
RuntimeError: Invalid function argument. Expected parameter `tensor` to be of type torch.Tensor.

After some debugging, I find out the root cause is that my configuration of lora (in below) only add extra lora layer(part) in qkv related layers but not the embedding layer. So the whole embedding layer's parameters are freezed.

lora_config = LoraConfig(r=8, # copied from finetuning_lora.py
                        lora_alpha=32,
                        target_modules=["query_key_value"],
                        lora_dropout=0.1,
                        bias="none",
                        task_type="CAUSAL_LM",
                        inference_mode=False,
                        )   

And in my implementation of pipeline based model, I declared the embeding layer as a tied-layer. So the whole thing is that there are no gradients at all for embedding layer, but embedding layer as the tied layer needs to be synced between two gpus. The value of gradient is None but is still passed to all_reduce operation.

Current, my fix is simple and add a check if this grad is None.

@xuanhua xuanhua requested a review from duli2012 as a code owner April 29, 2024 12:18
@xuanhua
Copy link
Contributor Author

xuanhua commented May 7, 2024

@duli2012 Hi, I'm not sure if this pull request meet the project's requirement ? Or any suggestions on this PR, expect your reply :)

@loadams loadams requested review from tjruwase and tohtana May 22, 2024 17:17
Copy link
Contributor

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xuanhua Sorry for the delay. Let's merge this after the tests pass.

@xuanhua
Copy link
Contributor Author

xuanhua commented Sep 23, 2024

@tohtana , thank you for your reply, I saw some unit test failures above, do I need to look into it ?

@tohtana
Copy link
Contributor

tohtana commented Sep 23, 2024

@xuanhua I wonder if this is an issue on our CI. Let us take a look and restart after it is fixed.

@loadams loadams self-requested a review as a code owner October 28, 2024 20:08
@loadams loadams added this pull request to the merge queue Oct 29, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Oct 29, 2024
@loadams loadams added this pull request to the merge queue Oct 29, 2024
Merged via the queue into microsoft:master with commit e4a247e Oct 29, 2024
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants