Skip to content

Commit

Permalink
Add scatter_add to amp promote list (pytorch#52133)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch#51730

I've added the `scatter_add` and `scatter_add.dimname` to the promote list as well as test cases for the former op.
However, it seems that `scatter_add` [doesn't support named tensors yet](https://github.com/pytorch/pytorch/blob/8b0cb5ede3eddb96aa0423b2c73c8560ab44788e/aten/src/ATen/native/NamedTensor.cpp#L356-L358) (thanks t-vi for the pointer):
```python
dev = 'cuda'
torch.scatter_add(torch.zeros(2, 2, 2, dtype=torch.float16, device=dev, names=('N', 'C', 'L')),
                             'C',
                             torch.randint(0, 2, (2, 2, 2), device=dev),
                             torch.randn((2, 2, 2), dtype=torch.float32, device=dev))
> RuntimeError: scatter_add: You passed a dimname (string) to this op in place of a dimension index but it does not yet support this behavior. Please pass a dimension index to work around this.
```
which raised this error after adding this test case.

I'm thus unsure, if I should also remove `scatter_add.dimname` from the promote list or not.

In any case, once named tensors are supported a potential test could be added as:
```python
            ("scatter_add", (torch.zeros(2, 2, 2, dtype=torch.float16, device=dev, names=('N', 'C', 'L')),
                             'C',
                             torch.randint(0, 2, (2, 2, 2), device=dev),
                             torch.randn((2, 2, 2), dtype=torch.float32, device=dev))),
```

CC mcarilli ngimel

Pull Request resolved: pytorch#52133

Reviewed By: ejguan

Differential Revision: D26440392

Pulled By: ngimel

fbshipit-source-id: f4ee2d0b9e1f81afb6f94261c497cf2bf79ec115
  • Loading branch information
ptrblck authored and facebook-github-bot committed Feb 25, 2021
1 parent 316eabe commit 39fa0b5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
KERNEL(ADD_NS(index_put), "index_put", Tensor (const Tensor &, const torch::List<c10::optional<Tensor>>&, const Tensor &, bool), promote)
KERNEL(ADD_NS(stack), "stack", Tensor (TensorList, int64_t), promote)
KERNEL(ADD_NS(tensordot), "tensordot", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef), promote)
KERNEL(ADD_NS(scatter_add), "scatter_add", Tensor (const Tensor&, int64_t, const Tensor&, const Tensor&), promote)

m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
Expand Down
8 changes: 8 additions & 0 deletions torch/testing/_internal/autocast_test_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,14 @@ def __init__(self, dev):
("stack", (pointwise0_fp16 + pointwise1_fp32,)),
("tensordot", (torch.randn((2, 2, 2), dtype=torch.float32, device=dev),
torch.randn((2, 2, 2), dtype=torch.float16, device=dev))),
("scatter_add", (torch.zeros(2, 2, 2, dtype=torch.float32, device=dev),
0,
torch.randint(0, 2, (2, 2, 2), device=dev),
torch.randn((2, 2, 2), dtype=torch.float16, device=dev))),
("scatter_add", (torch.zeros(2, 2, 2, dtype=torch.float16, device=dev),
0,
torch.randint(0, 2, (2, 2, 2), device=dev),
torch.randn((2, 2, 2), dtype=torch.float32, device=dev))),
]
self.nn_fp16 = [
("linear", mat0_fp32 + mat1_fp32 + mat2_fp32),
Expand Down

0 comments on commit 39fa0b5

Please sign in to comment.