@@ -3774,6 +3774,18 @@ def patch_testing_methods_to_collect_info():
37743774 _patch_with_call_info (torch .testing , "assert_close" , _parse_call_info , target_args = ("actual" , "expected" ))
37753775
37763776 _patch_with_call_info (unittest .case .TestCase , "assertEqual" , _parse_call_info , target_args = ("first" , "second" ))
3777+ _patch_with_call_info (unittest .case .TestCase , "assertListEqual" , _parse_call_info , target_args = ("list1" , "list2" ))
3778+ _patch_with_call_info (
3779+ unittest .case .TestCase , "assertTupleEqual" , _parse_call_info , target_args = ("tuple1" , "tuple2" )
3780+ )
3781+ _patch_with_call_info (unittest .case .TestCase , "assertSetEqual" , _parse_call_info , target_args = ("set1" , "set1" ))
3782+ _patch_with_call_info (unittest .case .TestCase , "assertDictEqual" , _parse_call_info , target_args = ("d1" , "d2" ))
3783+ _patch_with_call_info (unittest .case .TestCase , "assertIn" , _parse_call_info , target_args = ("member" , "container" ))
3784+ _patch_with_call_info (unittest .case .TestCase , "assertNotIn" , _parse_call_info , target_args = ("member" , "container" ))
3785+ _patch_with_call_info (unittest .case .TestCase , "assertLess" , _parse_call_info , target_args = ("a" , "b" ))
3786+ _patch_with_call_info (unittest .case .TestCase , "assertLessEqual" , _parse_call_info , target_args = ("a" , "b" ))
3787+ _patch_with_call_info (unittest .case .TestCase , "assertGreater" , _parse_call_info , target_args = ("a" , "b" ))
3788+ _patch_with_call_info (unittest .case .TestCase , "assertGreaterEqual" , _parse_call_info , target_args = ("a" , "b" ))
37773789
37783790
37793791def torchrun (script : str , nproc_per_node : int , is_torchrun : bool = True , env : Optional [dict ] = None ):
0 commit comments