Skip to content

Commit 830adda

Browse files
author
marc-chiesa
authored
Modified Vector STL bind initialization from a buffer type with optimization for simple arrays (#2298)
* Modified Vector STL bind initialization from a buffer type with optimization for simple arrays * Add subtests to demonstrate processing Python buffer protocol objects with step > 1 * Fixed memoryview step test to only run on Python 3+ * Modified Vector constructor from buffer to return by value for readability
1 parent 1534e17 commit 830adda

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

include/pybind11/stl_bind.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -397,14 +397,19 @@ vector_buffer(Class_& cl) {
397397
if (!detail::compare_buffer_info<T>::compare(info) || (ssize_t) sizeof(T) != info.itemsize)
398398
throw type_error("Format mismatch (Python: " + info.format + " C++: " + format_descriptor<T>::format() + ")");
399399

400-
auto vec = std::unique_ptr<Vector>(new Vector());
401-
vec->reserve((size_t) info.shape[0]);
402400
T *p = static_cast<T*>(info.ptr);
403401
ssize_t step = info.strides[0] / static_cast<ssize_t>(sizeof(T));
404402
T *end = p + info.shape[0] * step;
405-
for (; p != end; p += step)
406-
vec->push_back(*p);
407-
return vec.release();
403+
if (step == 1) {
404+
return Vector(p, end);
405+
}
406+
else {
407+
Vector vec;
408+
vec.reserve((size_t) info.shape[0]);
409+
for (; p != end; p += step)
410+
vec.push_back(*p);
411+
return vec;
412+
}
408413
}));
409414

410415
return;

tests/test_stl_binders.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ def test_vector_buffer():
8585
mv[2] = '\x06'
8686
assert v[2] == 6
8787

88+
if sys.version_info.major > 2:
89+
mv = memoryview(b)
90+
v = m.VectorUChar(mv[::2])
91+
assert v[1] == 3
92+
8893
with pytest.raises(RuntimeError) as excinfo:
8994
m.create_undeclstruct() # Undeclared struct contents, no buffer interface
9095
assert "NumPy type info missing for " in str(excinfo.value)
@@ -119,6 +124,10 @@ def test_vector_buffer_numpy():
119124
('y', 'float64'), ('z', 'bool')], align=True)))
120125
assert len(v) == 3
121126

127+
b = np.array([1, 2, 3, 4], dtype=np.uint8)
128+
v = m.VectorUChar(b[::2])
129+
assert v[1] == 3
130+
122131

123132
def test_vector_bool():
124133
import pybind11_cross_module_tests as cm

0 commit comments

Comments
 (0)