Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch for Cambricon MLUs test #1747

Merged
merged 14 commits into from
Jun 6, 2024
Prev Previous commit
Next Next commit
make style for MLUs
  • Loading branch information
huismiling committed May 22, 2024
commit 99965ac1f6084c017260c82ffd982288e0df702e
4 changes: 2 additions & 2 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,8 +956,8 @@ def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, co
# check that there is a difference in results after training
assert not torch.allclose(outputs_before, outputs_after, atol=atol, rtol=rtol)

if self.torch_device in ['mlu'] and model_id in ['Conv2d']:
atol, rtol = 1e-3, 1e-2 # MLU
if self.torch_device in ["mlu"] and model_id in ["Conv2d"]:
atol, rtol = 1e-3, 1e-2 # MLU

# unmerged or merged should make no difference
assert torch.allclose(outputs_after, outputs_unmerged, atol=atol, rtol=rtol)
Expand Down
1 change: 1 addition & 0 deletions tests/test_lora_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from .testing_utils import require_torch_gpu


def is_megatron_available() -> bool:
return importlib.util.find_spec("megatron") is not None

Expand Down
10 changes: 5 additions & 5 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,8 +561,8 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs):
logits_merged_unloaded = model(**dummy_input)[0]

atol, rtol = 1e-4, 1e-4
if self.torch_device in ['mlu']:
atol, rtol = 1e-3, 1e-3 # MLU
if self.torch_device in ["mlu"]:
atol, rtol = 1e-3, 1e-3 # MLU
if (config.peft_type == "IA3") and (model_id == "Conv2d"):
# for some reason, the IA³ Conv2d introduces a larger error
atol, rtol = 0.3, 0.01
Expand Down Expand Up @@ -691,7 +691,7 @@ def _test_safe_merge(self, model_id, config_cls, config_kwargs):
model = get_peft_model(model, config).eval()
logits_peft = model(**inputs)[0]

atol, rtol = 1e-6, 1e-6 # default
atol, rtol = 1e-6, 1e-6 # default
# Initializing with LN tuning cannot be configured to change the outputs (unlike init_lora_weights=False)
if not issubclass(config_cls, LNTuningConfig):
# sanity check that the logits are different
Expand All @@ -700,8 +700,8 @@ def _test_safe_merge(self, model_id, config_cls, config_kwargs):
model_unloaded = model.merge_and_unload(safe_merge=True)
logits_unloaded = model_unloaded(**inputs)[0]

if self.torch_device in ['mlu']:
atol, rtol = 1e-3, 1e-3 # MLU
if self.torch_device in ["mlu"]:
atol, rtol = 1e-3, 1e-3 # MLU
# check that the logits are the same after unloading
assert torch.allclose(logits_peft, logits_unloaded, atol=atol, rtol=rtol)

Expand Down
Loading