Skip to content

Commit

Permalink
adapt_c_buffers: Check if the array is 1D and C-contiguous
Browse files Browse the repository at this point in the history
  • Loading branch information
pthom committed Nov 7, 2024
1 parent 27f06bf commit ba8ff59
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 65 deletions.
36 changes: 36 additions & 0 deletions src/litgen/integration_tests/_pydef_nanobind/nanobind_mylib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ void py_init_module_lg_mylib(nb::module_& m)
{
auto add_inside_buffer_adapt_c_buffers = [](nb::ndarray<> & buffer, uint8_t number_to_add)
{
// Check if the array is 1D and C-contiguous
if (! (buffer.ndim() == 1 && buffer.stride(0) == 1))
throw std::runtime_error("The array must be 1D and contiguous");

// convert nb::ndarray to C standard buffer (mutable)
void * buffer_from_pyarray = buffer.data();
size_t buffer_count = buffer.shape(0);
Expand Down Expand Up @@ -308,6 +312,10 @@ void py_init_module_lg_mylib(nb::module_& m)
{
auto buffer_sum_adapt_c_buffers = [](const nb::ndarray<> & buffer, int stride = -1) -> int
{
// Check if the array is 1D and C-contiguous
if (! (buffer.ndim() == 1 && buffer.stride(0) == 1))
throw std::runtime_error("The array must be 1D and contiguous");

// convert nb::ndarray to C standard buffer (const)
const void * buffer_from_pyarray = buffer.data();
size_t buffer_count = buffer.shape(0);
Expand Down Expand Up @@ -346,6 +354,10 @@ void py_init_module_lg_mylib(nb::module_& m)
{
auto add_inside_two_buffers_adapt_c_buffers = [](nb::ndarray<> & buffer_1, nb::ndarray<> & buffer_2, uint8_t number_to_add)
{
// Check if the array is 1D and C-contiguous
if (! (buffer_1.ndim() == 1 && buffer_1.stride(0) == 1))
throw std::runtime_error("The array must be 1D and contiguous");

// convert nb::ndarray to C standard buffer (mutable)
void * buffer_1_from_pyarray = buffer_1.data();
size_t buffer_1_count = buffer_1.shape(0);
Expand All @@ -365,6 +377,10 @@ void py_init_module_lg_mylib(nb::module_& m)
Bad type! Size mismatch, while checking the size of the type (for param "buffer_1")!
)msg"));

// Check if the array is 1D and C-contiguous
if (! (buffer_2.ndim() == 1 && buffer_2.stride(0) == 1))
throw std::runtime_error("The array must be 1D and contiguous");

// convert nb::ndarray to C standard buffer (mutable)
void * buffer_2_from_pyarray = buffer_2.data();
size_t buffer_2_count = buffer_2.shape(0);
Expand Down Expand Up @@ -397,6 +413,10 @@ void py_init_module_lg_mylib(nb::module_& m)
{
auto templated_mul_inside_buffer_adapt_c_buffers = [](nb::ndarray<> & buffer, double factor)
{
// Check if the array is 1D and C-contiguous
if (! (buffer.ndim() == 1 && buffer.stride(0) == 1))
throw std::runtime_error("The array must be 1D and contiguous");

// convert nb::ndarray to C standard buffer (mutable)
void * buffer_from_pyarray = buffer.data();
size_t buffer_count = buffer.shape(0);
Expand Down Expand Up @@ -1425,6 +1445,10 @@ void py_init_module_lg_mylib(nb::module_& m)
{
auto add_inside_buffer_adapt_c_buffers = [&self](nb::ndarray<> & buffer, uint8_t number_to_add)
{
// Check if the array is 1D and C-contiguous
if (! (buffer.ndim() == 1 && buffer.stride(0) == 1))
throw std::runtime_error("The array must be 1D and contiguous");

// convert nb::ndarray to C standard buffer (mutable)
void * buffer_from_pyarray = buffer.data();
size_t buffer_count = buffer.shape(0);
Expand Down Expand Up @@ -1454,6 +1478,10 @@ void py_init_module_lg_mylib(nb::module_& m)
{
auto templated_mul_inside_buffer_adapt_c_buffers = [&self](nb::ndarray<> & buffer, double factor)
{
// Check if the array is 1D and C-contiguous
if (! (buffer.ndim() == 1 && buffer.stride(0) == 1))
throw std::runtime_error("The array must be 1D and contiguous");

// convert nb::ndarray to C standard buffer (mutable)
void * buffer_from_pyarray = buffer.data();
size_t buffer_count = buffer.shape(0);
Expand Down Expand Up @@ -1628,6 +1656,10 @@ void py_init_module_lg_mylib(nb::module_& m)
{
auto add_inside_buffer_adapt_c_buffers = [](nb::ndarray<> & buffer, uint8_t number_to_add)
{
// Check if the array is 1D and C-contiguous
if (! (buffer.ndim() == 1 && buffer.stride(0) == 1))
throw std::runtime_error("The array must be 1D and contiguous");

// convert nb::ndarray to C standard buffer (mutable)
void * buffer_from_pyarray = buffer.data();
size_t buffer_count = buffer.shape(0);
Expand Down Expand Up @@ -1658,6 +1690,10 @@ void py_init_module_lg_mylib(nb::module_& m)
{
auto templated_mul_inside_buffer_adapt_c_buffers = [](nb::ndarray<> & buffer, double factor)
{
// Check if the array is 1D and C-contiguous
if (! (buffer.ndim() == 1 && buffer.stride(0) == 1))
throw std::runtime_error("The array must be 1D and contiguous");

// convert nb::ndarray to C standard buffer (mutable)
void * buffer_from_pyarray = buffer.data();
size_t buffer_count = buffer.shape(0);
Expand Down
63 changes: 27 additions & 36 deletions src/litgen/integration_tests/_pydef_pybind11/pybind_mylib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,9 @@ void py_init_module_lg_mylib(py::module& m)
{
auto add_inside_buffer_adapt_c_buffers = [](py::array & buffer, uint8_t number_to_add)
{
// Check if the array is C-contiguous
if (!buffer.attr("flags").attr("c_contiguous").cast<bool>()) {
throw std::runtime_error("The array must be contiguous, i.e, `a.flags.c_contiguous` must be True. Hint: use `numpy.ascontiguousarray`.");
}
// Check if the array is 1D and C-contiguous
if (! (buffer.ndim() == 1 && buffer.strides(0) == buffer.itemsize()) )
throw std::runtime_error("The array must be 1D and contiguous");

// convert py::array to C standard buffer (mutable)
void * buffer_from_pyarray = buffer.mutable_data();
Expand Down Expand Up @@ -314,10 +313,9 @@ void py_init_module_lg_mylib(py::module& m)
{
auto buffer_sum_adapt_c_buffers = [](const py::array & buffer, int stride = -1) -> int
{
// Check if the array is C-contiguous
if (!buffer.attr("flags").attr("c_contiguous").cast<bool>()) {
throw std::runtime_error("The array must be contiguous, i.e, `a.flags.c_contiguous` must be True. Hint: use `numpy.ascontiguousarray`.");
}
// Check if the array is 1D and C-contiguous
if (! (buffer.ndim() == 1 && buffer.strides(0) == buffer.itemsize()) )
throw std::runtime_error("The array must be 1D and contiguous");

// convert py::array to C standard buffer (const)
const void * buffer_from_pyarray = buffer.data();
Expand Down Expand Up @@ -351,10 +349,9 @@ void py_init_module_lg_mylib(py::module& m)
{
auto add_inside_two_buffers_adapt_c_buffers = [](py::array & buffer_1, py::array & buffer_2, uint8_t number_to_add)
{
// Check if the array is C-contiguous
if (!buffer_1.attr("flags").attr("c_contiguous").cast<bool>()) {
throw std::runtime_error("The array must be contiguous, i.e, `a.flags.c_contiguous` must be True. Hint: use `numpy.ascontiguousarray`.");
}
// Check if the array is 1D and C-contiguous
if (! (buffer_1.ndim() == 1 && buffer_1.strides(0) == buffer_1.itemsize()) )
throw std::runtime_error("The array must be 1D and contiguous");

// convert py::array to C standard buffer (mutable)
void * buffer_1_from_pyarray = buffer_1.mutable_data();
Expand All @@ -369,10 +366,9 @@ void py_init_module_lg_mylib(py::module& m)
(using py::array::dtype().char_() as an id)
)msg"));

// Check if the array is C-contiguous
if (!buffer_2.attr("flags").attr("c_contiguous").cast<bool>()) {
throw std::runtime_error("The array must be contiguous, i.e, `a.flags.c_contiguous` must be True. Hint: use `numpy.ascontiguousarray`.");
}
// Check if the array is 1D and C-contiguous
if (! (buffer_2.ndim() == 1 && buffer_2.strides(0) == buffer_2.itemsize()) )
throw std::runtime_error("The array must be 1D and contiguous");

// convert py::array to C standard buffer (mutable)
void * buffer_2_from_pyarray = buffer_2.mutable_data();
Expand Down Expand Up @@ -400,10 +396,9 @@ void py_init_module_lg_mylib(py::module& m)
{
auto templated_mul_inside_buffer_adapt_c_buffers = [](py::array & buffer, double factor)
{
// Check if the array is C-contiguous
if (!buffer.attr("flags").attr("c_contiguous").cast<bool>()) {
throw std::runtime_error("The array must be contiguous, i.e, `a.flags.c_contiguous` must be True. Hint: use `numpy.ascontiguousarray`.");
}
// Check if the array is 1D and C-contiguous
if (! (buffer.ndim() == 1 && buffer.strides(0) == buffer.itemsize()) )
throw std::runtime_error("The array must be 1D and contiguous");

// convert py::array to C standard buffer (mutable)
void * buffer_from_pyarray = buffer.mutable_data();
Expand Down Expand Up @@ -1451,10 +1446,9 @@ void py_init_module_lg_mylib(py::module& m)
{
auto add_inside_buffer_adapt_c_buffers = [&self](py::array & buffer, uint8_t number_to_add)
{
// Check if the array is C-contiguous
if (!buffer.attr("flags").attr("c_contiguous").cast<bool>()) {
throw std::runtime_error("The array must be contiguous, i.e, `a.flags.c_contiguous` must be True. Hint: use `numpy.ascontiguousarray`.");
}
// Check if the array is 1D and C-contiguous
if (! (buffer.ndim() == 1 && buffer.strides(0) == buffer.itemsize()) )
throw std::runtime_error("The array must be 1D and contiguous");

// convert py::array to C standard buffer (mutable)
void * buffer_from_pyarray = buffer.mutable_data();
Expand All @@ -1479,10 +1473,9 @@ void py_init_module_lg_mylib(py::module& m)
{
auto templated_mul_inside_buffer_adapt_c_buffers = [&self](py::array & buffer, double factor)
{
// Check if the array is C-contiguous
if (!buffer.attr("flags").attr("c_contiguous").cast<bool>()) {
throw std::runtime_error("The array must be contiguous, i.e, `a.flags.c_contiguous` must be True. Hint: use `numpy.ascontiguousarray`.");
}
// Check if the array is 1D and C-contiguous
if (! (buffer.ndim() == 1 && buffer.strides(0) == buffer.itemsize()) )
throw std::runtime_error("The array must be 1D and contiguous");

// convert py::array to C standard buffer (mutable)
void * buffer_from_pyarray = buffer.mutable_data();
Expand Down Expand Up @@ -1641,10 +1634,9 @@ void py_init_module_lg_mylib(py::module& m)
{
auto add_inside_buffer_adapt_c_buffers = [](py::array & buffer, uint8_t number_to_add)
{
// Check if the array is C-contiguous
if (!buffer.attr("flags").attr("c_contiguous").cast<bool>()) {
throw std::runtime_error("The array must be contiguous, i.e, `a.flags.c_contiguous` must be True. Hint: use `numpy.ascontiguousarray`.");
}
// Check if the array is 1D and C-contiguous
if (! (buffer.ndim() == 1 && buffer.strides(0) == buffer.itemsize()) )
throw std::runtime_error("The array must be 1D and contiguous");

// convert py::array to C standard buffer (mutable)
void * buffer_from_pyarray = buffer.mutable_data();
Expand All @@ -1670,10 +1662,9 @@ void py_init_module_lg_mylib(py::module& m)
{
auto templated_mul_inside_buffer_adapt_c_buffers = [](py::array & buffer, double factor)
{
// Check if the array is C-contiguous
if (!buffer.attr("flags").attr("c_contiguous").cast<bool>()) {
throw std::runtime_error("The array must be contiguous, i.e, `a.flags.c_contiguous` must be True. Hint: use `numpy.ascontiguousarray`.");
}
// Check if the array is 1D and C-contiguous
if (! (buffer.ndim() == 1 && buffer.strides(0) == buffer.itemsize()) )
throw std::runtime_error("The array must be 1D and contiguous");

// convert py::array to C standard buffer (mutable)
void * buffer_from_pyarray = buffer.mutable_data();
Expand Down
14 changes: 14 additions & 0 deletions src/litgen/integration_tests/mylib/c_style_array_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations
import lg_mylib
import numpy as np
import pytest


def test_const_array2_add():
Expand Down Expand Up @@ -28,3 +30,15 @@ def test_array2_modify_mutable():
pt2 = lg_mylib.Point2(53, 54)
lg_mylib.array2_modify_mutable(pt1, pt2)
assert pt1.x == 0 and pt1.y == 1 and pt2.x == 2 and pt2.y == 3


def test_refuse_non_contiguous_array():
a = np.arange(10, dtype=np.float64)
a2 = a[1::2]
with pytest.raises(RuntimeError):
lg_mylib.templated_mul_inside_buffer(a2, 3.14)

# Also refuse arrays of dim > 1
a2 = a.reshape((2, 5))
with pytest.raises(RuntimeError):
lg_mylib.templated_mul_inside_buffer(a2, 3.14)
12 changes: 7 additions & 5 deletions src/litgen/internal/adapt_function_params/_adapt_c_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,13 +429,11 @@ def _lambda_input_buffer_standard_convert_part(self, idx_param: int) -> str:
mutable_or_const = "const" if self._is_const(idx_param) else "mutable"

_ = self
param_name = self._param(idx_param).decl.decl_name
if self.options.bind_library == BindLibraryType.pybind11:
template = f"""
// Check if the array is C-contiguous
if (!{param_name}.attr("flags").attr("c_contiguous").cast<bool>()) {{
throw std::runtime_error("The array must be contiguous, i.e, `a.flags.c_contiguous` must be True. Hint: use `numpy.ascontiguousarray`.");
}}
// Check if the array is 1D and C-contiguous
if (! ({_._param_name(idx_param)}.ndim() == 1 && {_._param_name(idx_param)}.strides(0) == {_._param_name(idx_param)}.itemsize()) )
throw std::runtime_error("The array must be 1D and contiguous");
// convert py::array to C standard buffer ({mutable_or_const})
{_._const_space_or_empty(idx_param)}void * {_._buffer_from_pyarray_name(idx_param)} = {_._param_name(idx_param)}.{mutable_or_empty}data();
Expand All @@ -444,6 +442,10 @@ def _lambda_input_buffer_standard_convert_part(self, idx_param: int) -> str:
else:
# TODO: implement contiguous check for nanobind
template = f"""
// Check if the array is 1D and C-contiguous
if (! ({_._param_name(idx_param)}.ndim() == 1 && {_._param_name(idx_param)}.stride(0) == 1))
throw std::runtime_error("The array must be 1D and contiguous");
// convert nb::ndarray to C standard buffer ({mutable_or_const})
{_._const_space_or_empty(idx_param)}void * {_._buffer_from_pyarray_name(idx_param)} = {_._param_name(idx_param)}.{mutable_or_empty}data();
size_t {_._pyarray_count(idx_param)} = {_._param_name(idx_param)}.shape(0);
Expand Down
Loading

0 comments on commit ba8ff59

Please sign in to comment.