-
Notifications
You must be signed in to change notification settings - Fork 31.5k
fix device dismatch issue for pe_audio_video model parallelism #42917
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
zucchini-nlp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @eustlb for audio PE
| _no_split_modules = [ | ||
| "PeAudioVideoEncoderLayer", | ||
| "TimmWrapperForImageClassification", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting, timm doesn't support accelerate. Usually we don't add a backbone model since no_split_modules will unwrap recursively for all children.
I think since timm doesn't support accelerate, this is a possible workaround. Though we should add it in TimmWrapperPreTrainedModel._no_split_modules and let it be re-used in other multimodal LLMs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I put TimmWrapperForImageClassification to TimmWrapperPreTrainedModel._no_split_modules, it will fail here: self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1}), and throws error AssertionError: Items in the second set but not the first:, and apart from this, it will fail for case pytest -rA tests/models/pe_audio_video/test_modeling_pe_audio_video.py::PeAudioVideoEncoderTest::test_cpu_offload as well, with error:
src/transformers/models/pe_video/modeling_pe_video.py:182: in forward
vision_encoder_outputs = self.vision_model(pixel_values_videos)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1778: in _wrapped_call_impl return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1789: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ /opt/venv/lib/python3.12/site-packages/accelerate/hooks.py:175: in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
src/transformers/models/timm_wrapper/modeling_timm_wrapper.py:360: in forward
logits = self.timm_model(pixel_values, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1778: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ /opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1789: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/opt/venv/lib/python3.12/site-packages/accelerate/hooks.py:170: in new_forward
args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ /opt/venv/lib/python3.12/site-packages/accelerate/hooks.py:369: in pre_forward
return send_to_device(args, self.execution_device), send_to_device(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/opt/venv/lib/python3.12/site-packages/accelerate/utils/operations.py:170: in send_to_device return honor_type(
/opt/venv/lib/python3.12/site-packages/accelerate/utils/operations.py:82: in honor_type
return type(obj)(generator)
^^^^^^^^^^^^^^^^^^^^
/opt/venv/lib/python3.12/site-packages/accelerate/utils/operations.py:171: in <genexpr>
tensor, (send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) for t in tensor)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tensor = tensor(..., device='meta', size=(288, 3, 14, 14)), device = 0, non_blocking = False, skip_keys = None
def send_to_device(tensor, device, non_blocking=False, skip_keys=None):
"""
Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device.
Args:
tensor (nested list/tuple/dictionary of `torch.Tensor`):
The data to send to a given device.
device (`torch.device`):
The device to send the data to.
Returns:
The same data structure as `tensor` with all tensors sent to the proper device.
"""
if is_torch_tensor(tensor) or hasattr(tensor, "to"):
# `torch.Tensor.to("npu")` could not find context when called for the first time (see this [issue](https://gitee.com/ascend
/pytorch/issues/I8KECW?from=project-issue)).
if device == "npu":
device = "npu:0"
try:
> return tensor.to(device, non_blocking=non_blocking)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E NotImplementedError: Cannot copy out of meta tensor; no data!
So can we just skip the model parallelism tests here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's failing for me even before moving no_split_module under a timm PreTrainedModel, so the issue is not exactly in the location of no_split_module
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the parallelism test is failing because layers are too big to fit in cuda:0, so it is bulking it all in cuda:1. I'd say we can skip and add a reason in description
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, have updated the code.
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
| @unittest.skip(reason="TimmWrapperModel does not support model parallelism") | ||
| def test_model_parallelism(self): | ||
| pass | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, no, I meant to keep the changes and skip the tests. With the proposed diff, we can support model parallelism but the tests fail because of the way it is designed
Can you revert the prev diff and "move no_split_module under a timm's PreTrainedModel" instead of PE?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it OK now?
|
[For maintainers] Suggested jobs to run (before merge) run-slow: pe_audio_video, pe_video, timm_wrapper |
No description provided.