Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 2 additions & 15 deletions test/quantized_ops/test_dot_general.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
import os
import re

import torch
import torch_xla
import unittest

device = torch_xla.device()

torch.manual_seed(12345)


class DotGeneralTest(unittest.TestCase):

def test_dot_general(self):
x = torch.rand(10, 3, 4).to(torch.bfloat16)
w = torch.rand(10, 4, 5).to(torch.bfloat16)
Expand All @@ -21,7 +15,6 @@ def test_dot_general(self):
xla_out = torch_xla._XLAC._xla_dot_general(xla_x, xla_w,
(([2], [1]), ([0], [0])))
self.assertTrue(torch.allclose(xla_out.cpu(), expected_out))

def test_dot_general_negative_dim(self):
x = torch.rand(10, 3, 4).to(torch.bfloat16)
w = torch.rand(10, 4, 5).to(torch.bfloat16)
Expand All @@ -31,7 +24,6 @@ def test_dot_general_negative_dim(self):
xla_out = torch_xla._XLAC._xla_dot_general(xla_x, xla_w,
(([-1], [-2]), ([0], [0])))
self.assertTrue(torch.allclose(xla_out.cpu(), expected_out))

def test_dot_general_linear(self):
x = torch.rand(10, 3, 4).to(torch.bfloat16)
w = torch.rand(5, 4).to(torch.bfloat16)
Expand All @@ -40,7 +32,6 @@ def test_dot_general_linear(self):
xla_w = w.to(device)
xla_out = torch_xla._XLAC._xla_dot_general(xla_x, xla_w, (([2], [1]), ()))
self.assertTrue(torch.allclose(xla_out.cpu(), expected_out))

def test_dot_general_int32_dtype(self):
x = torch.randint(-15, 15, (10, 3, 4)).to(torch.int32)
w = torch.randint(-15, 15, (10, 4, 5)).to(torch.int32)
Expand All @@ -55,11 +46,9 @@ def test_dot_general_int32_dtype(self):
xla_w, (([2], [1]), ([0], [0])),
preferred_element_type=torch.int32)
self.assertTrue(torch.allclose(xla_out.cpu(), expected_out))

def test_raises_error_on_non_xla_tensor(self):
lhs = torch.rand(10, 3, 4, dtype=torch.bfloat16)
rhs = torch.rand(10, 4, 5, dtype=torch.bfloat16)

def test(args, non_xla_tensor_arg):
arg_number_to_str = ["first", "second"]
position = arg_number_to_str[non_xla_tensor_arg]
Expand All @@ -70,12 +59,10 @@ def test(args, non_xla_tensor_arg):
f"Expected input tensor ({position} argument) to be an actual XLA tensor. "
f"Got: CPUBFloat16Type. Consider moving it ({position} argument) to XLA."
)
self.assertEqual(str(err), error_message)

msg = str(err)
self.assertIn(error_message, msg)
test((lhs, rhs.to(device)), non_xla_tensor_arg=0)
test((lhs.to(device), rhs), non_xla_tensor_arg=1)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
Loading