diff --git a/python/src/buffer.h b/python/src/buffer.h index 112fd7aaf..a5f426a54 100644 --- a/python/src/buffer.h +++ b/python/src/buffer.h @@ -104,7 +104,7 @@ extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) { view->internal = info; view->buf = a.data(); view->itemsize = a.itemsize(); - view->len = a.size(); + view->len = a.nbytes(); view->readonly = false; if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { view->format = const_cast(info->format.c_str()); diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 912f3bbb1..4084a2693 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2,6 +2,7 @@ import operator import pickle +import sys import unittest import weakref from copy import copy, deepcopy @@ -1497,6 +1498,17 @@ def test_buffer_protocol(self): e = cm.exception self.assertTrue("Item size 2 for PEP 3118 buffer format string" in str(e)) + # Test buffer protocol with non-arrays ie bytes + a = ord("a") * 257 + mx.arange(10).astype(mx.int16) + ab = bytes(a) + self.assertEqual(len(ab), 20) + if sys.byteorder == "little": + self.assertEqual(b"aaaaaaaaaa", ab[1::2]) + self.assertEqual(b"abcdefghij", ab[::2]) + else: + self.assertEqual(b"aaaaaaaaaa", ab[::2]) + self.assertEqual(b"abcdefghij", ab[1::2]) + def test_buffer_protocol_ref_counting(self): a = mx.arange(3) wr = weakref.ref(a)