From d1f72ade78601100d78a592aaca5717cfa1daa58 Mon Sep 17 00:00:00 2001 From: gpzlx1 Date: Tue, 1 Aug 2023 05:33:16 +0800 Subject: [PATCH] append indexsearch --- src/pybind.cc | 1 + src/tensor_ops.cc | 10 ++++++++++ src/tensor_ops.h | 2 ++ 3 files changed, 13 insertions(+) diff --git a/src/pybind.cc b/src/pybind.cc index 4361726..9bd17dd 100644 --- a/src/pybind.cc +++ b/src/pybind.cc @@ -74,6 +74,7 @@ TORCH_LIBRARY(gs_ops, m) { m.def("_CAPI_BatchSplitByOffset", &gs::impl::batch::SplitByOffset); m.def("_CAPI_BatchIndptrSplitByOffset", &gs::impl::batch::SplitIndptrByOffsetCUDA); + m.def("_CAPI_IndexSearch", &IndexSearch); } namespace gs {} \ No newline at end of file diff --git a/src/tensor_ops.cc b/src/tensor_ops.cc index 73c7cbc..e06dbf1 100644 --- a/src/tensor_ops.cc +++ b/src/tensor_ops.cc @@ -27,4 +27,14 @@ std::tuple BatchListSampling( range); } +torch::Tensor IndexSearch(torch::Tensor origin_data, torch::Tensor keys) { + torch::Tensor key_buffer, value_buffer; + + std::tie(key_buffer, value_buffer) = + impl::IndexHashMapInsertCUDA(origin_data); + torch::Tensor result = + impl::IndexHashMapSearchCUDA(key_buffer, value_buffer, keys); + return result; +} + } // namespace gs diff --git a/src/tensor_ops.h b/src/tensor_ops.h index 303fd14..4d92693 100644 --- a/src/tensor_ops.h +++ b/src/tensor_ops.h @@ -26,5 +26,7 @@ std::tuple BatchListSamplingProbs( std::tuple BatchListSampling(int64_t num_picks, bool replace, torch::Tensor range); + +torch::Tensor IndexSearch(torch::Tensor origin_data, torch::Tensor keys); } // namespace gs #endif \ No newline at end of file