Skip to content

Commit 678f180

Browse files
committed
Update
[ghstack-poisoned]
1 parent 17e45db commit 678f180

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

kernels/test/op_mul_test.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,21 @@ TEST_F(OpMulOutTest, DynamicShapeUnbound) {
746746
EXPECT_TENSOR_CLOSE(out, expected_result);
747747
}
748748

749+
// >>> torch.ops.aten.mul(torch.tensor([100], dtype=torch.int8),
750+
// torch.tensor([100], dtype=torch.int8), out=torch.zeros([1],
751+
// dtype=torch.long)) tensor([16])
752+
TEST_F(OpMulOutTest, MixedIntegerDtypeMatchesATen) {
753+
TensorFactory<ScalarType::Char> tf_in;
754+
TensorFactory<ScalarType::Long> tf_out;
755+
756+
Tensor in = tf_in.make({1}, {100});
757+
Tensor out = tf_out.zeros({1});
758+
Tensor ret = op_mul_out(in, in, out);
759+
760+
Tensor expected = tf_out.make({1}, {16});
761+
EXPECT_TENSOR_CLOSE(out, expected);
762+
}
763+
749764
TEST_F(OpMulScalarOutTest, SanityCheck) {
750765
TensorFactory<ScalarType::Bool> tf_a;
751766
TensorFactory<ScalarType::Float> tf_out;

0 commit comments

Comments
 (0)