Skip to content

Commit

Permalink
Add AMP support to linalg.vecdot. (pytorch#108165)
Browse files Browse the repository at this point in the history
We follow the same rules as matmul.

Fixes pytorch#108127

Pull Request resolved: pytorch#108165
Approved by: https://github.com/albanD
  • Loading branch information
lezcano authored and pytorchmergebot committed Aug 29, 2023
1 parent 75884f4 commit 86bc50a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
KERNEL_CUDA(einsum, lower_precision_fp)
KERNEL_CUDA(mm, lower_precision_fp)
KERNEL_CUDA(mv, lower_precision_fp)
KERNEL_CUDA(linalg_vecdot, lower_precision_fp)
KERNEL_CUDA(linear, lower_precision_fp)
KERNEL_CUDA(addbmm, lower_precision_fp)
KERNEL_CUDA(baddbmm, lower_precision_fp)
Expand Down Expand Up @@ -395,6 +396,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
KERNEL_CPU2(conv3d, padding, lower_precision_fp)
KERNEL_CPU(bmm, lower_precision_fp)
KERNEL_CPU(mm, lower_precision_fp)
KERNEL_CPU(linalg_vecdot, lower_precision_fp)
KERNEL_CPU(baddbmm, lower_precision_fp)
KERNEL_CPU(addmm, lower_precision_fp)
KERNEL_CPU(addbmm, lower_precision_fp)
Expand Down
1 change: 1 addition & 0 deletions torch/testing/_internal/autocast_test_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def __init__(self, dev):
("multi_margin_loss", mat0_fp16 + (torch.ones((n,), device=dev, dtype=torch.long),)),
]
self.linalg_fp16 = [
("linalg_vecdot", mat0_fp32 + mat0_fp32),
("linalg_multi_dot", (mat0_fp32 + mat1_fp32 + mat2_fp32,)),
]
self.methods_fp16 = [
Expand Down

0 comments on commit 86bc50a

Please sign in to comment.