Skip to content

Commit

Permalink
remove torch.equal usages (pytorch#89527)
Browse files Browse the repository at this point in the history
Preparation for the next PR in this stack: pytorch#89559.

I replaced

- `self.assertTrue(torch.equal(...))` with `self.assertEqual(..., rtol=0, atol=0, exact_device=True)`,
- the same for `self.assertFalse(...)` with `self.assertNotEqual(...)`, and
- `assert torch.equal(...)` with `torch.testing.assert_close(..., rtol=0, atol=0)` (note that we don't need to set `check_device=True` here since that is the default).

There were a few instances where the result of `torch.equal` is used directly. In that cases I've replaced with `(... == ...).all().item()` while sometimes also dropping the `.item()` depending on the context.

Pull Request resolved: pytorch#89527
Approved by: https://github.com/mruberry
  • Loading branch information
pmeier authored and pytorchmergebot committed Dec 1, 2022
1 parent 0acbcef commit 4095ef8
Show file tree
Hide file tree
Showing 38 changed files with 169 additions and 154 deletions.
2 changes: 1 addition & 1 deletion docs/source/nested.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ If all dimensions are regular, the NestedTensor is intended to be semantically i
torch.Size([2, 20, 128])
>>> torch.stack([a, a]).size()
torch.Size([2, 20, 128])
>>> torch.equal(torch.stack(nt.unbind()), torch.stack([a, a]))
>>> (torch.stack(nt.unbind()) == torch.stack([a, a])).all().item()
True

In the future we might make it easier to detect this condition and convert seamlessly.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_torch_equal(self):

spec, alt_spec = self.get_gpu_specs()
st1, st2 = self.get_random_tensors(spec, spec, 10, 10)
self.assertTrue(torch.equal(st1, st2))
self.assertEqual(st1, st2, rtol=0, atol=0, exact_device=True)

@with_comms
@skip_if_lt_x_gpu(4)
Expand Down
4 changes: 2 additions & 2 deletions test/distributed/_shard/sharded_tensor/ops/test_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def test_inplace_copy(self):
)
st = sharded_tensor.rand(spec, (12, 5))
ones_st = sharded_tensor.ones(spec, (12, 5))
self.assertFalse(torch.equal(ones_st, st))
self.assertNotEqual(ones_st, st, rtol=0, atol=0, exact_device=True)
st.copy_(ones_st)
self.assertTrue(torch.equal(st, ones_st))
self.assertEqual(st, ones_st, rtol=0, atol=0, exact_device=True)

# no grad inplace_copy should work between two with different requires_grad
st_with_grad = sharded_tensor.rand(spec, (12, 5), requires_grad=True)
Expand Down
8 changes: 4 additions & 4 deletions test/distributed/_shard/sharded_tensor/test_sharded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,8 +1125,8 @@ def test_state_dict(self):
self.assertTrue("sharded_tensor1" in loaded_dict_keys)
self.assertTrue("submodule.sharded_tensor2" in loaded_dict_keys)
# Verify after load.
self.assertTrue(torch.equal(m.sharded_tensor1, module_load.sharded_tensor1))
self.assertTrue(torch.equal(m.submodule.sharded_tensor2, module_load.submodule.sharded_tensor2))
self.assertEqual(m.sharded_tensor1, module_load.sharded_tensor1, rtol=0, atol=0, exact_device=True)
self.assertEqual(m.submodule.sharded_tensor2, module_load.submodule.sharded_tensor2, rtol=0, atol=0, exact_device=True)

@with_comms
@skip_if_lt_x_gpu(4)
Expand Down Expand Up @@ -1161,8 +1161,8 @@ def test_state_dict_new_group(self):
module_load.load_state_dict(state_dict_deser, strict=False)

# Verify after load.
self.assertTrue(torch.equal(m.sharded_tensor1, module_load.sharded_tensor1))
self.assertTrue(torch.equal(m.submodule.sharded_tensor2, module_load.submodule.sharded_tensor2))
self.assertEqual(m.sharded_tensor1, module_load.sharded_tensor1, rtol=0, atol=0, exact_device=True)
self.assertEqual(m.submodule.sharded_tensor2, module_load.submodule.sharded_tensor2, rtol=0, atol=0, exact_device=True)

@with_comms
@skip_if_lt_x_gpu(4)
Expand Down
20 changes: 14 additions & 6 deletions test/distributed/checkpoint/test_file_system_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,22 @@ def assert_state_dict_equal(
for local_shard_1, local_shard_2 in zip(
value_1.local_shards(), value_2.local_shards()
):
self.assertTrue(
torch.equal(local_shard_1.tensor, local_shard_1.tensor),
f"Key {key}'s shard does not match",
self.assertEqual(
local_shard_1.tensor,
local_shard_1.tensor,
rtol=0,
atol=0,
exact_device=True,
msg=f"Key {key}'s shard does not match"
)
elif isinstance(value_1, torch.Tensor):
self.assertTrue(
torch.equal(value_1, value_2),
f"Key {key}'s tensor does not match",
self.assertEqual(
value_1,
value_2,
rtol=0,
atol=0,
exact_device=True,
msg=f"Key {key}'s tensor does not match"
)

return True
Expand Down
20 changes: 14 additions & 6 deletions test/distributed/checkpoint/test_file_system_checkpoint_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,22 @@ def assert_state_dict_equal(
for local_shard_1, local_shard_2 in zip(
value_1.local_shards(), value_2.local_shards()
):
self.assertTrue(
torch.equal(local_shard_1.tensor, local_shard_1.tensor),
f"Key {key}'s shard does not match",
self.assertEqual(
local_shard_1.tensor,
local_shard_1.tensor,
rtol=0,
atol=0,
exact_device=True,
msg=f"Key {key}'s shard does not match",
)
elif isinstance(value_1, torch.Tensor):
self.assertTrue(
torch.equal(value_1, value_2),
f"Key {key}'s tensor does not match",
self.assertEqual(
value_1,
value_2,
rtol=0,
atol=0,
exact_device=True,
msg=f"Key {key}'s tensor does not match",
)

return True
Expand Down
6 changes: 3 additions & 3 deletions test/distributed/fsdp/test_fsdp_clip_grad_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,12 @@ def _test_ddp_parity(

# Check that the gradients were modified by `clip_grad_norm_()`
for param, orig_grad in zip(ddp_model.parameters(), orig_ddp_grads):
assert not torch.equal(param.grad, orig_grad)
self.assertNotEqual(param.grad, orig_grad, rtol=0, atol=0, exact_device=True)
for param, orig_grad in zip(fsdp_model.parameters(), orig_fsdp_grads):
if param.grad is None:
self.assertEqual(param.grad, orig_grad) # `None`
self.assertIsNone(orig_grad)
else:
assert not torch.equal(param.grad, orig_grad)
self.assertNotEqual(param.grad, orig_grad, rtol=0, atol=0, exact_device=True)

# Run an optimizer step to ensure gradients matched after clipping
ddp_optim.step()
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/fsdp/test_fsdp_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _check_equal(local, fsdp):
# above check would be vacuously true.
self.assertTrue(
any(
not torch.equal(p1, p2)
(p1 != p2).all()
for p1, p2 in zip(prev_params, m_local.parameters())
)
)
Expand Down
4 changes: 1 addition & 3 deletions test/distributed/fsdp/test_fsdp_summon_full_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,7 @@ def test_summon_full_param_shard_value(self, mixed_precision):

# shards are padded but the full_param tensor is not
a, b = my_shard[0 : my_slice.numel()], my_slice
self.assertTrue(
torch.equal(my_shard[0 : my_slice.numel()].cpu(), my_slice.cpu())
)
self.assertEqual(my_shard[0 : my_slice.numel()].cpu(), my_slice.cpu(), rtol=0, atol=0, exact_device=True)

@skip_if_lt_x_gpu(2)
@parametrize("recurse", [True, False])
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/pipeline/sync/test_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ def forward(self, a, b):
model = Pipe(nn.Sequential(Module1().cuda(0), Module2().cuda(0)), chunks=2, checkpoint=checkpoint)
t = torch.rand(10)
res = model(t, t, t).local_value()
assert torch.equal(res, (t + t + t) + (t * t * t))
torch.testing.assert_close(res, (t + t + t) + (t * t * t), rtol=0, atol=0)

@skip_if_no_cuda
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need atleast two GPUs")
Expand Down
4 changes: 2 additions & 2 deletions test/fx/test_dce_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def is_leaf_module(self, m, qualname):
traced.recompile()
# Make sure we run and get the same results before/after DCE.
inputs = [torch.tensor([1.5])] * new_num_phs
self.assertTrue(torch.equal(m(*inputs), traced(*inputs)))
self.assertEqual(m(*inputs), traced(*inputs), rtol=0, atol=0, exact_device=True)

def test_simple(self):
"""
Expand Down Expand Up @@ -176,7 +176,7 @@ def __init__(self):
super().__init__()

def forward(self, a: torch.Tensor) -> torch.Tensor:
torch._assert(torch.equal(a, a), "a must equal a")
torch._assert((a == a).all(), "a must equal a")
return a * 2

# Note: Don't need to specify torch._assert as having side effects
Expand Down
48 changes: 24 additions & 24 deletions test/fx/test_fx_const_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def forward(self, x, y):
in_x, in_y = torch.tensor([[-0.45]]), torch.tensor([0.9])
base_result = mod(in_x, in_y)
fold_result = mod_folded(in_x, in_y)
self.assertTrue(torch.equal(fold_result, base_result))
self.assertEqual(fold_result, base_result, rtol=0, atol=0, exact_device=True)

def test_const_fold_basic_one_attr_name_collision(self):
r"""
Expand Down Expand Up @@ -125,7 +125,7 @@ def forward(self, x, y):
in_x, in_y = torch.tensor([[5.0]]), torch.tensor([4.0])
base_result = mod(in_x, in_y)
fold_result = mod_folded(in_x, in_y)
self.assertTrue(torch.equal(fold_result, base_result))
self.assertEqual(fold_result, base_result, rtol=0, atol=0, exact_device=True)

def test_const_fold_basic_placeholder_reordered(self):
"""
Expand Down Expand Up @@ -157,7 +157,7 @@ def forward(self, x, y):
in_y = torch.tensor([[0.45]])
base_result = mod(in_x, in_y)
fold_result = mod_folded(in_x, in_y)
self.assertTrue(torch.equal(fold_result, base_result))
self.assertEqual(fold_result, base_result, rtol=0, atol=0, exact_device=True)

def test_const_fold_noop(self):
r"""
Expand Down Expand Up @@ -188,7 +188,7 @@ def forward(self, x):
in_x = torch.tensor([[-0.45]])
base_result = mod(in_x)
fold_result = mod_folded(in_x)
self.assertTrue(torch.equal(fold_result, base_result))
self.assertEqual(fold_result, base_result, rtol=0, atol=0, exact_device=True)

def test_const_fold_basic_two_attr_three_input(self):
r"""
Expand Down Expand Up @@ -237,7 +237,7 @@ def forward(self, x, y, z):
)
base_result = mod(in_x, in_y, in_z)
fold_result = mod_folded(in_x, in_y, in_z)
self.assertTrue(torch.equal(fold_result, base_result))
self.assertEqual(fold_result, base_result, rtol=0, atol=0, exact_device=True)

def test_const_fold_basic_two_attr(self):
r"""
Expand Down Expand Up @@ -274,7 +274,7 @@ def forward(self, x):
in_x = torch.randn(2, 3)
fold_result = mod_folded(in_x)
base_result = mod(in_x)
self.assertTrue(torch.equal(fold_result, base_result))
self.assertEqual(fold_result, base_result, rtol=0, atol=0, exact_device=True)

def test_const_fold_multi_const_folded_attrs(self):
r"""
Expand Down Expand Up @@ -325,7 +325,7 @@ def forward(self, x, y):
in_x, in_y = torch.randn(4, 4), torch.randn(4)
fold_result = mod_folded(in_x, in_y)
base_result = mod(in_x, in_y)
self.assertTrue(torch.equal(fold_result, base_result))
self.assertEqual(fold_result, base_result, rtol=0, atol=0, exact_device=True)

def test_const_fold_submod_hierarchy(self):
r"""
Expand Down Expand Up @@ -359,7 +359,7 @@ def forward(self, x):
in_x = torch.randn(2, 3)
fold_result = mod_folded(in_x)
base_result = mod(in_x)
self.assertTrue(torch.equal(fold_result, base_result))
self.assertEqual(fold_result, base_result, rtol=0, atol=0, exact_device=True)

def test_retain_node_meta(self):
r"""
Expand Down Expand Up @@ -412,7 +412,7 @@ def forward(self, x):
in_x = torch.randn(2, 3)
fold_result = gm_folded(in_x)
base_result = mod(in_x)
self.assertTrue(torch.equal(fold_result, base_result))
self.assertEqual(fold_result, base_result, rtol=0, atol=0, exact_device=True)

def test_const_fold_has_inlined_call_module_node(self):
class ConstFoldTestModule(torch.nn.Module):
Expand All @@ -433,7 +433,7 @@ def forward(self, x):
in_x = torch.randn(2, 3)
fold_result = gm_folded(in_x)
base_result = mod(in_x)
self.assertTrue(torch.equal(fold_result, base_result))
self.assertEqual(fold_result, base_result, rtol=0, atol=0, exact_device=True)

def test_const_fold_module_attr(self):
class ConstFoldTestModule(torch.nn.Module):
Expand All @@ -455,7 +455,7 @@ def forward(self, x):
in_x = torch.randn(2, 3)
fold_result = gm_folded(in_x)
base_result = mod(in_x)
self.assertTrue(torch.equal(fold_result, base_result))
self.assertEqual(fold_result, base_result, rtol=0, atol=0, exact_device=True)

def test_const_fold_unused_placeholder(self):
class ConstFoldTestModule(torch.nn.Module):
Expand All @@ -474,7 +474,7 @@ def forward(self, x, y, z):
in_x = torch.randn(2, 3)
fold_result = gm_folded(in_x, in_x, in_x)
base_result = mod(in_x, in_x, in_x)
self.assertTrue(torch.equal(fold_result, base_result))
self.assertEqual(fold_result, base_result, rtol=0, atol=0, exact_device=True)

def test_dict_output(self):
class ConstFoldTestModule(torch.nn.Module):
Expand All @@ -493,7 +493,7 @@ def forward(self, x):
in_x = torch.randn(2, 3)
fold_result = gm_folded(in_x)
base_result = mod(in_x)
self.assertTrue(torch.equal(fold_result["result"], base_result["result"]))
self.assertEqual(fold_result["result"], base_result["result"], rtol=0, atol=0, exact_device=True)

def test_two_outputs(self):
class ConstFoldTestModule(torch.nn.Module):
Expand All @@ -512,8 +512,8 @@ def forward(self, x):
in_x = torch.randn(2, 3)
fold_result = gm_folded(in_x)
base_result = mod(in_x)
self.assertTrue(torch.equal(fold_result[0], base_result[0]))
self.assertTrue(torch.equal(fold_result[1], base_result[1]))
self.assertEqual(fold_result[0], base_result[0], rtol=0, atol=0, exact_device=True)
self.assertEqual(fold_result[1], base_result[1], rtol=0, atol=0, exact_device=True)

def test_three_outputs(self):
class ConstFoldTestModule(torch.nn.Module):
Expand All @@ -532,9 +532,9 @@ def forward(self, x):
in_x = torch.randn(2, 3)
fold_result = gm_folded(in_x)
base_result = mod(in_x)
self.assertTrue(torch.equal(fold_result[0], base_result[0]))
self.assertTrue(torch.equal(fold_result[1], base_result[1]))
self.assertTrue(torch.equal(fold_result[2], base_result[2]))
self.assertEqual(fold_result[0], base_result[0], rtol=0, atol=0, exact_device=True)
self.assertEqual(fold_result[1], base_result[1], rtol=0, atol=0, exact_device=True)
self.assertEqual(fold_result[2], base_result[2], rtol=0, atol=0, exact_device=True)

def test_check_inline_non_const(self):
r"""
Expand Down Expand Up @@ -566,7 +566,7 @@ def forward(self, x):
in_x = torch.randn(2, 3)
fold_result = gm_folded(in_x)
base_result = mod(in_x)
self.assertTrue(torch.equal(fold_result, base_result))
self.assertEqual(fold_result, base_result, rtol=0, atol=0, exact_device=True)

def test_check_inline_non_const_mult_return(self):
r"""
Expand Down Expand Up @@ -598,8 +598,8 @@ def forward(self, x):
in_x = torch.randn(2, 3)
fold_result = gm_folded(in_x)
base_result = mod(in_x)
self.assertTrue(torch.equal(fold_result[0], base_result[0]))
self.assertTrue(torch.equal(fold_result[1], base_result[1]))
self.assertEqual(fold_result[0], base_result[0], rtol=0, atol=0, exact_device=True)
self.assertEqual(fold_result[1], base_result[1], rtol=0, atol=0, exact_device=True)

def test_check_skip_folding_quant_dequant_pattern(self):
r"""
Expand Down Expand Up @@ -645,7 +645,7 @@ def skip_folding_quant_dequant(node: torch.fx.Node):
# Now run both folded and non-folded to check results equal.
fold_result = gm_folded(in_x)
base_result = mod(in_x)
self.assertTrue(torch.equal(fold_result, base_result))
self.assertEqual(fold_result, base_result, rtol=0, atol=0, exact_device=True)

def test_fold_module(self):
r"""
Expand All @@ -667,7 +667,7 @@ def forward(self, x):

# Now run both folded and non-folded to check results equal.
inp = torch.randn(4, 4)
self.assertTrue(torch.equal(mod_folded(inp), mod(inp)))
self.assertEqual(mod_folded(inp), mod(inp), rtol=0, atol=0, exact_device=True)

def test_const_fold_tensor_meta(self):
self._test_const_fold_tensor_meta(True)
Expand Down Expand Up @@ -708,4 +708,4 @@ def forward(self, x, y):
# Now run both folded and non-folded to check results equal.
base_result = mod(in_x, in_y)
fold_result = mod_folded(in_x, in_y)
self.assertTrue(torch.equal(fold_result, base_result))
self.assertEqual(fold_result, base_result, rtol=0, atol=0, exact_device=True)
4 changes: 2 additions & 2 deletions test/jit/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def forward(self, a):
m2 = torch.jit.load(path)

x = torch.tensor([1.0, 2.0, 3.0, 4.0])
self.assertTrue(torch.equal(m(x), m2(x)))
self.assertEqual(m(x), m2(x), rtol=0, atol=0, exact_device=True)

def test_save_nonexit_file(self):
class Foo(torch.nn.Module):
Expand Down Expand Up @@ -880,7 +880,7 @@ def forward(self, a):
m2 = torch.jit.load(path)

x = torch.tensor([1.0, 2.0, 3.0, 4.0])
self.assertTrue(torch.equal(m(x), m2(x)))
self.assertEqual(m(x), m2(x), rtol=0, atol=0, exact_device=True)

def test_save_namedtuple_input_only(self):
"""
Expand Down
4 changes: 2 additions & 2 deletions test/nn/test_lazy_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_linear(self):
self.assertTrue(module.weight.shape == (10, 5))
self.assertTrue(module.bias.shape == (10,))
y = module(input)
self.assertTrue(torch.equal(torch.nn.functional.linear(input, module.weight, module.bias), y))
self.assertEqual(torch.nn.functional.linear(input, module.weight, module.bias), y, rtol=0, atol=0, exact_device=True)

@suppress_warnings
def test_lazy_linear_pickle(self):
Expand Down Expand Up @@ -170,7 +170,7 @@ def _check_lazy_conv(self, cls, lazy_cls, func, init_args, input_shape,
if module.bias is not None:
self.assertEqual(module.bias.shape, expected_bias_shape)
y = module(input)
self.assertTrue(torch.equal(func(input, module.weight, module.bias), y))
self.assertEqual(func(input, module.weight, module.bias), y, rtol=0, atol=0, exact_device=True)

def _check_lazy_conv_pickle(self, cls, lazy_cls, init_args, input_shape,
expected_weight_shape, expected_bias_shape):
Expand Down
Loading

0 comments on commit 4095ef8

Please sign in to comment.