Skip to content

Commit 706470a

Browse files
committed
Update on "modify the conditions as PythonModuleVariable"
## Motivation The current code of `value in [torch.backends.cudnn, torch.ops]` requires `value` to have the implementation of `__eq__`. If the value is a custom object and does not implement `__eq__`, dynamo will throw error. For example, ConvolutionOpContext, the custom 'torch._C.ScriptClass' object registered in IPEX, dynamo will throw the following error: **torch._dynamo.exc.InternalTorchDynamoError: '__eq__' is not implemented for __torch__.torch.classes.ipex_prepack.ConvolutionOpContext** I think this is a common issue, To avoid this issue, the PR replaces the current code `value in [torch.backends.cudnn, torch.ops]`with `isinstance(value, (torch.backends.cudnn.CudnnModule, torch._ops._Ops)` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv chenyang78 aakhundov kadeng [ghstack-poisoned]
2 parents d06ed5c + 41c090f commit 706470a

File tree

1 file changed

+3
-9
lines changed

1 file changed

+3
-9
lines changed

test/test_transformers.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
set_default_dtype,
3131
gradcheck,
3232
make_tensor,
33-
NOTEST_CPU,
34-
skipIfTorchDynamo
33+
NOTEST_CPU
3534
)
3635

3736

@@ -3290,8 +3289,7 @@ def test_causal_variants(self, device, causal_variant: CausalVariant, shape: Lis
32903289

32913290
self.run_test(device, False, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol)
32923291

3293-
@skipIfRocm # No support for the second variant for now
3294-
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
3292+
@unittest.skip("This test fails on some parameters and on some CI machines")
32953293
@parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
32963294
@parametrize(
32973295
"shape",
@@ -3317,11 +3315,7 @@ def test_causal_variants_compile(self, device, causal_variant: CausalVariant, sh
33173315
else:
33183316
attn_bias = causal_lower_right(seq_len_q, seq_len_kv)
33193317

3320-
if causal_variant == CausalVariant.LOWER_RIGHT and shape in [(16, 16, 128, 256, 32), (1, 1, 23, 56, 15)]:
3321-
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "call_function UserDefinedClassVariable"):
3322-
self.run_test(device, True, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol)
3323-
else:
3324-
self.run_test(device, True, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol)
3318+
self.run_test(device, True, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol)
33253319

33263320
@parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)])
33273321
def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]):

0 commit comments

Comments
 (0)