Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch for Cambricon MLUs test #1747

Merged
merged 14 commits into from
Jun 6, 2024
Prev Previous commit
Next Next commit
patch for MLUs pytest
  • Loading branch information
huismiling committed May 20, 2024
commit 4447fe2276a92bb20a60192e2eb4a9c0856af3af
3 changes: 3 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,9 @@ def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, co
# check that there is a difference in results after training
assert not torch.allclose(outputs_before, outputs_after, atol=atol, rtol=rtol)

if self.torch_device in ['mlu'] and model_id in ['Conv2d']:
atol, rtol = 1e-3, 1e-2 # MLU

# unmerged or merged should make no difference
assert torch.allclose(outputs_after, outputs_unmerged, atol=atol, rtol=rtol)

Expand Down
2 changes: 2 additions & 0 deletions tests/test_lora_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from peft import LoraConfig, PeftModel, get_peft_model, get_peft_model_state_dict

from .testing_utils import require_torch_gpu

def is_megatron_available() -> bool:
return importlib.util.find_spec("megatron") is not None
Expand Down Expand Up @@ -93,6 +94,7 @@ def forward(self, input):
x = self.lm_head(x)[0]
return x

@require_torch_gpu
class TestMegatronLora(unittest.TestCase):
def setUp(self):
initialize_model_parallel(1, 1)
Expand Down
13 changes: 9 additions & 4 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,13 +463,13 @@ def _test_merge_layers_fp16(self, model_id, config_cls, config_kwargs):
if ("gpt2" in model_id.lower()) and (config_cls != LoraConfig):
self.skipTest("Merging GPT2 adapters not supported for IA³ (yet)")

model = self.transformers_class.from_pretrained(model_id, torch_dtype=torch.float16)
model = self.transformers_class.from_pretrained(model_id)
config = config_cls(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch_dtype=torch.float16 leads to an error.
RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I cannot replicate this, whether with or without GPU. The idea of this test is exactly to check that this error does not occur with fp16, so not using this dtype is counter-productive. Is this only occurring with MLU devices?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reproduction code is as follows.
The issue can be reproduced using PyTorch 2.1, but it executes normally with PyTorch 2.3.

import torch
a = torch.rand(4,4).to(torch.float16)
b = torch.rand(4,4).to(torch.float16)
a @ b

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so instead of changing the dtype, how about skipping the test if an old pytorch version is detected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, maybe we can use fp16 with pt>=2.3, and fp32 with pt<2.3 ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We really don't need to test merging with fp32 here, as it's tested extensively in other tests. This test is very specifically for merging with fp16, so if we don't use fp16, we can skip it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha, Got it! I will fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I found that it is the device "cpu" leads error.
When device is changed to self.torch_device as I fixed, MLU tests is OK with torch.float16 .
@BenjaminBossan Will this test use "cpu" device? If not, it isn't need to skip test for pt2.1 .

model = model.to(device="cpu", dtype=torch.float16)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So IIRC, there is an error when using CPU + float16 + old PyTorch. If we change either of those variables, there is no error. On CI, we have a new PyTorch version, so it passes, despite using CPU.

If we switch to self.torch_device, it depends, because the device is auto-inferred based on what devices are available. So on our CI, this would still be CPU. On yours, it might not, but then we don't really test what was the original intent, namely that float16 works on CPU.

I assume this fails on your CI because it uses an older PyTorch version. This is why I suggested to just skip the test with older PyTorch versions. If you want, you could add a specific test for merging float16 with MLU, which would be skipped if the device is not available.

Copy link
Contributor Author

@huismiling huismiling Jun 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan
Got it! Skipping test_merge_layers_fp16 for pt2.1 when cpu device, this should be OK.

base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config)
model = model.to(device="cpu", dtype=torch.float16)
model = model.to(device=self.torch_device, dtype=torch.float16)

model.eval()

Expand Down Expand Up @@ -560,6 +560,8 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs):
logits_merged_unloaded = model(**dummy_input)[0]

atol, rtol = 1e-4, 1e-4
if self.torch_device in ['mlu']:
atol, rtol = 1e-3, 1e-3 # MLU
if (config.peft_type == "IA3") and (model_id == "Conv2d"):
# for some reason, the IA³ Conv2d introduces a larger error
atol, rtol = 0.3, 0.01
Expand Down Expand Up @@ -688,14 +690,17 @@ def _test_safe_merge(self, model_id, config_cls, config_kwargs):
model = get_peft_model(model, config).eval()
logits_peft = model(**inputs)[0]

atol, rtol = 1e-6, 1e-6 # default
# sanity check that the logits are different
assert not torch.allclose(logits_base, logits_peft, atol=1e-6, rtol=1e-6)
assert not torch.allclose(logits_base, logits_peft, atol=atol, rtol=rtol)

model_unloaded = model.merge_and_unload(safe_merge=True)
logits_unloaded = model_unloaded(**inputs)[0]

if self.torch_device in ['mlu']:
atol, rtol = 1e-3, 1e-3 # MLU
# check that the logits are the same after unloading
assert torch.allclose(logits_peft, logits_unloaded, atol=1e-6, rtol=1e-6)
assert torch.allclose(logits_peft, logits_unloaded, atol=atol, rtol=rtol)

def _test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs):
# Test for mixing different adapters in a single batch by passing the adapter_names argument
Expand Down