From 1eef2be14127be2101a81ef3bc3c63bf585a3e1a Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 16 Apr 2020 10:53:41 +0200 Subject: [PATCH] fixed zero numel init bug --- csrc/cpu/scatter_cpu.cpp | 5 ++++- csrc/cpu/segment_coo_cpu.cpp | 10 ++++++++-- csrc/cpu/segment_csr_cpu.cpp | 10 ++++++++-- csrc/cuda/scatter_cuda.cu | 5 ++++- csrc/cuda/segment_csr_cuda.cu | 10 ++++++++-- 5 files changed, 32 insertions(+), 8 deletions(-) diff --git a/csrc/cpu/scatter_cpu.cpp b/csrc/cpu/scatter_cpu.cpp index 5f9da470..21e5f327 100644 --- a/csrc/cpu/scatter_cpu.cpp +++ b/csrc/cpu/scatter_cpu.cpp @@ -43,8 +43,11 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, arg_out_data = arg_out.value().data_ptr(); } - if (index.numel() == 0) + if (src.numel() == 0) { + if (!optional_out.has_value()) + out.fill_(0); return std::make_tuple(out, arg_out); + } auto B = 1; for (auto i = 0; i < dim; i++) diff --git a/csrc/cpu/segment_coo_cpu.cpp b/csrc/cpu/segment_coo_cpu.cpp index c59afd2b..074fe597 100644 --- a/csrc/cpu/segment_coo_cpu.cpp +++ b/csrc/cpu/segment_coo_cpu.cpp @@ -52,8 +52,11 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, arg_out = torch::zeros(sizes, out.options()); } - if (index.numel() == 0) + if (src.numel() == 0) { + if (!optional_out.has_value()) + out.fill_(0); return std::make_tuple(out, arg_out); + } auto B = index.numel() / src.size(dim); auto E = src.size(dim); @@ -158,8 +161,11 @@ torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index, out = torch::empty(sizes, src.options()); } - if (index.numel() == 0) + if (src.numel() == 0) { + if (!optional_out.has_value()) + out.fill_(0); return out; + } auto B = index.numel() / out.size(dim); auto E = index.size(dim); diff --git a/csrc/cpu/segment_csr_cpu.cpp b/csrc/cpu/segment_csr_cpu.cpp index 6dca23f6..ad258cec 100644 --- a/csrc/cpu/segment_csr_cpu.cpp +++ b/csrc/cpu/segment_csr_cpu.cpp @@ -44,8 +44,11 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, arg_out_data = arg_out.value().data_ptr(); } - if (src.numel() == 0) + if (src.numel() == 0) { + if (!optional_out.has_value()) + out.fill_(0); return std::make_tuple(out, arg_out); + } auto N = out.size(dim) * (indptr.numel() / indptr.size(-1)); auto K = out.numel() / N; @@ -120,8 +123,11 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr, out = torch::empty(sizes, src.options()); } - if (src.numel() == 0) + if (src.numel() == 0) { + if (!optional_out.has_value()) + out.fill_(0); return out; + } auto N = src.size(dim) * (indptr.numel() / indptr.size(-1)); auto K = src.numel() / N; diff --git a/csrc/cuda/scatter_cuda.cu b/csrc/cuda/scatter_cuda.cu index 1f2191c8..5578a90f 100644 --- a/csrc/cuda/scatter_cuda.cu +++ b/csrc/cuda/scatter_cuda.cu @@ -99,8 +99,11 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim, arg_out_data = arg_out.value().data_ptr(); } - if (index.numel() == 0) + if (src.numel() == 0) { + if (!optional_out.has_value()) + out.fill_(0); return std::make_tuple(out, arg_out); + } auto B = 1; for (auto i = 0; i < dim; i++) diff --git a/csrc/cuda/segment_csr_cuda.cu b/csrc/cuda/segment_csr_cuda.cu index d08bdffd..c7f5e436 100644 --- a/csrc/cuda/segment_csr_cuda.cu +++ b/csrc/cuda/segment_csr_cuda.cu @@ -135,8 +135,11 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr, arg_out_data = arg_out.value().data_ptr(); } - if (src.numel() == 0) + if (src.numel() == 0) { + if (!optional_out.has_value()) + out.fill_(0); return std::make_tuple(out, arg_out); + } auto N = out.size(dim) * (indptr.numel() / indptr.size(-1)); auto K = out.numel() / N; @@ -251,8 +254,11 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, out = torch::empty(sizes, src.options()); } - if (src.numel() == 0) + if (src.numel() == 0) { + if (!optional_out.has_value()) + out.fill_(0); return out; + } auto N = src.size(dim) * (indptr.numel() / indptr.size(-1)); auto K = src.numel() / N;