Skip to content

Commit

Permalink
append indexsearch
Browse files Browse the repository at this point in the history
  • Loading branch information
gpzlx1 committed Jul 31, 2023
1 parent 351b9da commit d1f72ad
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ TORCH_LIBRARY(gs_ops, m) {
m.def("_CAPI_BatchSplitByOffset", &gs::impl::batch::SplitByOffset);
m.def("_CAPI_BatchIndptrSplitByOffset",
&gs::impl::batch::SplitIndptrByOffsetCUDA);
m.def("_CAPI_IndexSearch", &IndexSearch);
}

namespace gs {}
10 changes: 10 additions & 0 deletions src/tensor_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,14 @@ std::tuple<torch::Tensor, torch::Tensor> BatchListSampling(
range);
}

torch::Tensor IndexSearch(torch::Tensor origin_data, torch::Tensor keys) {
torch::Tensor key_buffer, value_buffer;

std::tie(key_buffer, value_buffer) =
impl::IndexHashMapInsertCUDA(origin_data);
torch::Tensor result =
impl::IndexHashMapSearchCUDA(key_buffer, value_buffer, keys);
return result;
}

} // namespace gs
2 changes: 2 additions & 0 deletions src/tensor_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,7 @@ std::tuple<torch::Tensor, torch::Tensor> BatchListSamplingProbs(
std::tuple<torch::Tensor, torch::Tensor> BatchListSampling(int64_t num_picks,
bool replace,
torch::Tensor range);

torch::Tensor IndexSearch(torch::Tensor origin_data, torch::Tensor keys);
} // namespace gs
#endif

0 comments on commit d1f72ad

Please sign in to comment.