Skip to content

Remove some code duplication in bindings #416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 20, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 63 additions & 78 deletions python_bindings/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn
while (true) {
size_t id = current.fetch_add(1);

if ((id >= end)) {
if (id >= end) {
break;
}

Expand Down Expand Up @@ -79,6 +79,54 @@ inline void assert_true(bool expr, const std::string & msg) {
}


inline void get_input_array_shapes(const py::buffer_info& buffer, size_t* rows, size_t* features) {
if (buffer.ndim != 2 && buffer.ndim != 1) {
char msg[256];
snprintf(msg, sizeof(msg),
"Input vector data wrong shape. Number of dimensions %d. Data must be a 1D or 2D array.",
buffer.ndim);
throw std::runtime_error(msg);
}
if (buffer.ndim == 2) {
*rows = buffer.shape[0];
*features = buffer.shape[1];
} else {
*rows = 1;
*features = buffer.shape[0];
}
}


inline std::vector<size_t> get_input_ids_and_check_shapes(const py::object& ids_, size_t feature_rows) {
std::vector<size_t> ids;
if (!ids_.is_none()) {
py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_);
auto ids_numpy = items.request();
// check shapes
if (!((ids_numpy.ndim == 1 && ids_numpy.shape[0] == feature_rows) ||
(ids_numpy.ndim == 0 && feature_rows == 1))) {
char msg[256];
snprintf(msg, sizeof(msg),
"The input label shape %d does not match the input data vector shape %d",
ids_numpy.ndim, feature_rows);
throw std::runtime_error(msg);
}
// extract data
if (ids_numpy.ndim == 1) {
std::vector<size_t> ids1(ids_numpy.shape[0]);
for (size_t i = 0; i < ids1.size(); i++) {
ids1[i] = items.data()[i];
}
ids.swap(ids1);
} else if (ids_numpy.ndim == 0) {
ids.push_back(*items.data());
}
}

return ids;
}


template<typename dist_t, typename data_t = float>
class Index {
public:
Expand Down Expand Up @@ -146,7 +194,7 @@ class Index {
void set_ef(size_t ef) {
default_ef = ef;
if (appr_alg)
appr_alg->ef_ = ef;
appr_alg->ef_ = ef;
}


Expand Down Expand Up @@ -188,41 +236,17 @@ class Index {
num_threads = num_threads_default;

size_t rows, features;

if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array");
if (buffer.ndim == 2) {
rows = buffer.shape[0];
features = buffer.shape[1];
} else {
rows = 1;
features = buffer.shape[0];
}
get_input_array_shapes(buffer, &rows, &features);

if (features != dim)
throw std::runtime_error("wrong dimensionality of the vectors");
throw std::runtime_error("Wrong dimensionality of the vectors");

// avoid using threads when the number of additions is small:
if (rows <= num_threads * 4) {
num_threads = 1;
}

std::vector<size_t> ids;

if (!ids_.is_none()) {
py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_);
auto ids_numpy = items.request();
if (ids_numpy.ndim == 1 && ids_numpy.shape[0] == rows) {
std::vector<size_t> ids1(ids_numpy.shape[0]);
for (size_t i = 0; i < ids1.size(); i++) {
ids1[i] = items.data()[i];
}
ids.swap(ids1);
} else if (ids_numpy.ndim == 0 && rows == 1) {
ids.push_back(*items.data());
} else {
throw std::runtime_error("wrong dimensionality of the labels");
}
}
std::vector<size_t> ids = get_input_ids_and_check_shapes(ids_, rows);

{
int start = 0;
Expand Down Expand Up @@ -503,7 +527,7 @@ class Index {

for (size_t i = 0; i < appr_alg->cur_element_count; i++) {
if (label_lookup_val_npy.data()[i] < 0) {
throw std::runtime_error("internal id cannot be negative!");
throw std::runtime_error("Internal id cannot be negative!");
} else {
appr_alg->label_lookup_.insert(std::make_pair(label_lookup_key_npy.data()[i], label_lookup_val_npy.data()[i]));
}
Expand Down Expand Up @@ -561,15 +585,7 @@ class Index {

{
py::gil_scoped_release l;

if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array");
if (buffer.ndim == 2) {
rows = buffer.shape[0];
features = buffer.shape[1];
} else {
rows = 1;
features = buffer.shape[0];
}
get_input_array_shapes(buffer, &rows, &features);

// avoid using threads when the number of searches is small:
if (rows <= num_threads * 4) {
Expand Down Expand Up @@ -725,36 +741,12 @@ class BFIndex {
py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input);
auto buffer = items.request();
size_t rows, features;

if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array");
if (buffer.ndim == 2) {
rows = buffer.shape[0];
features = buffer.shape[1];
} else {
rows = 1;
features = buffer.shape[0];
}
get_input_array_shapes(buffer, &rows, &features);

if (features != dim)
throw std::runtime_error("wrong dimensionality of the vectors");
throw std::runtime_error("Wrong dimensionality of the vectors");

std::vector<size_t> ids;

if (!ids_.is_none()) {
py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_);
auto ids_numpy = items.request();
if (ids_numpy.ndim == 1 && ids_numpy.shape[0] == rows) {
std::vector<size_t> ids1(ids_numpy.shape[0]);
for (size_t i = 0; i < ids1.size(); i++) {
ids1[i] = items.data()[i];
}
ids.swap(ids1);
} else if (ids_numpy.ndim == 0 && rows == 1) {
ids.push_back(*items.data());
} else {
throw std::runtime_error("wrong dimensionality of the labels");
}
}
std::vector<size_t> ids = get_input_ids_and_check_shapes(ids_, rows);

{
for (size_t row = 0; row < rows; row++) {
Expand Down Expand Up @@ -802,14 +794,7 @@ class BFIndex {
{
py::gil_scoped_release l;

if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array");
if (buffer.ndim == 2) {
rows = buffer.shape[0];
features = buffer.shape[1];
} else {
rows = 1;
features = buffer.shape[0];
}
get_input_array_shapes(buffer, &rows, &features);

data_numpy_l = new hnswlib::labeltype[rows * k];
data_numpy_d = new dist_t[rows * k];
Expand All @@ -836,14 +821,14 @@ class BFIndex {

return py::make_tuple(
py::array_t<hnswlib::labeltype>(
{rows, k}, // shape
{k * sizeof(hnswlib::labeltype),
{ rows, k }, // shape
{ k * sizeof(hnswlib::labeltype),
sizeof(hnswlib::labeltype)}, // C-style contiguous strides for each index
data_numpy_l, // the data pointer
free_when_done_l),
py::array_t<dist_t>(
{rows, k}, // shape
{k * sizeof(dist_t), sizeof(dist_t)}, // C-style contiguous strides for each index
{ rows, k }, // shape
{ k * sizeof(dist_t), sizeof(dist_t) }, // C-style contiguous strides for each index
data_numpy_d, // the data pointer
free_when_done_d));
}
Expand Down