Skip to content

Commit 652a8bf

Browse files
navahgarfacebook-github-bot
authored andcommitted
[nnc] Updated indices during broadcast to use int64_t (pytorch#64627)
Summary: Pull Request resolved: pytorch#64627 This fixes the root cause of S242719 Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D30801686 Pulled By: navahgar fbshipit-source-id: b6d3ebdc7eb57116eaced53c2f35c7798bb17e80
1 parent 459653a commit 652a8bf

File tree

3 files changed

+47
-1
lines changed

3 files changed

+47
-1
lines changed

test/cpp/tensorexpr/test_kernel.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,36 @@ TEST_F(Kernel, CatInputTypesPromotion) {
577577
}
578578
}
579579

580+
TEST_F(Kernel, CatAndInlineWithAConstantDim) {
581+
const auto graph_string = R"IR(
582+
graph(%0 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu),
583+
%1 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu)):
584+
%2 : bool = prim::Constant[value=0]()
585+
%3 : int = prim::Constant[value=1]()
586+
%4 : Tensor[] = prim::ListConstruct(%0, %1)
587+
%5 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%4, %3)
588+
%6 : Tensor[] = prim::ListConstruct(%5)
589+
%7 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%6, %3)
590+
%8 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::_cast_Float(%7, %2)
591+
return (%8, %7))IR";
592+
593+
auto graph = std::make_shared<Graph>();
594+
parseIR(graph_string, &*graph);
595+
TensorExprKernel k(graph);
596+
597+
auto a = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat));
598+
auto b = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat));
599+
auto ref = at::_cast_Float(at::cat({a, b}, 1), 0);
600+
601+
std::vector<at::Tensor> inputs = {a, b};
602+
std::vector<IValue> stack = fmap<IValue>(inputs);
603+
k.run(stack);
604+
auto o = stack[0].toTensor();
605+
ASSERT_EQ(o.sizes(), ref.sizes());
606+
ASSERT_EQ(o.dtype(), ref.dtype());
607+
ASSERT_TRUE(at::allclose(o, ref));
608+
}
609+
580610
TEST_F(Kernel, CatWoConditionals) {
581611
getCatWoConditionals() = true;
582612
const auto graph_string = R"IR(

test/test_tensorexpr.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,6 +1108,22 @@ def foo(*args):
11081108
ref = foo(*values)
11091109
np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
11101110

1111+
def test_cat_with_constant_dim(self):
1112+
for device in self.devices:
1113+
def foo(*args):
1114+
v1 = torch.cat(args, dim=1)
1115+
v2 = torch.cat([v1], dim=1)
1116+
return v2 * v2
1117+
1118+
empty = torch.tensor([], device=device, dtype=torch.float32)
1119+
inputs = [empty] + [torch.randn(1, 64, device=device), torch.randn(1, 64, device=device)]
1120+
traced = torch.jit.trace(foo, inputs)
1121+
1122+
x = warmup_and_run_forward(traced, *inputs)
1123+
self.assertLastGraphAllFused()
1124+
ref = foo(*inputs)
1125+
np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1126+
11111127
def test_scalar(self):
11121128
@torch.jit.script
11131129
def test_float(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: float, b: float) -> torch.Tensor:

torch/csrc/jit/tensorexpr/kernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ std::vector<ExprHandle> computeIndicesToBroadcast(
466466
while (sizeIt != inputSizes.rend()) {
467467
auto const& size = intValue(*sizeIt);
468468
if (size && *size == 1) {
469-
bcast.emplace_back(0);
469+
bcast.emplace_back(LongImm::make(0));
470470
} else {
471471
bcast.emplace_back(*axisIt);
472472
}

0 commit comments

Comments
 (0)