Skip to content

Commit

Permalink
fixed zero numel init bug
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Apr 16, 2020
1 parent 66105d4 commit 1eef2be
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 8 deletions.
5 changes: 4 additions & 1 deletion csrc/cpu/scatter_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,11 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
arg_out_data = arg_out.value().data_ptr<int64_t>();
}

if (index.numel() == 0)
if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return std::make_tuple(out, arg_out);
}

auto B = 1;
for (auto i = 0; i < dim; i++)
Expand Down
10 changes: 8 additions & 2 deletions csrc/cpu/segment_coo_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,11 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
arg_out = torch::zeros(sizes, out.options());
}

if (index.numel() == 0)
if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return std::make_tuple(out, arg_out);
}

auto B = index.numel() / src.size(dim);
auto E = src.size(dim);
Expand Down Expand Up @@ -158,8 +161,11 @@ torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
out = torch::empty(sizes, src.options());
}

if (index.numel() == 0)
if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return out;
}

auto B = index.numel() / out.size(dim);
auto E = index.size(dim);
Expand Down
10 changes: 8 additions & 2 deletions csrc/cpu/segment_csr_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
arg_out_data = arg_out.value().data_ptr<int64_t>();
}

if (src.numel() == 0)
if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return std::make_tuple(out, arg_out);
}

auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
Expand Down Expand Up @@ -120,8 +123,11 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
out = torch::empty(sizes, src.options());
}

if (src.numel() == 0)
if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return out;
}

auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N;
Expand Down
5 changes: 4 additions & 1 deletion csrc/cuda/scatter_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,11 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
arg_out_data = arg_out.value().data_ptr<int64_t>();
}

if (index.numel() == 0)
if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return std::make_tuple(out, arg_out);
}

auto B = 1;
for (auto i = 0; i < dim; i++)
Expand Down
10 changes: 8 additions & 2 deletions csrc/cuda/segment_csr_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,11 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
arg_out_data = arg_out.value().data_ptr<int64_t>();
}

if (src.numel() == 0)
if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return std::make_tuple(out, arg_out);
}

auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
Expand Down Expand Up @@ -251,8 +254,11 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
out = torch::empty(sizes, src.options());
}

if (src.numel() == 0)
if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return out;
}

auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N;
Expand Down

0 comments on commit 1eef2be

Please sign in to comment.