Skip to content

Commit

Permalink
support buffer protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
moriyoshi committed Apr 23, 2022
1 parent 97cc18b commit c299472
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 45 deletions.
3 changes: 3 additions & 0 deletions dartsclone/_dartsclone.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,6 @@ cdef extern from "darts.h":

cdef class DoubleArray:
cdef CppDoubleArray *wrapped
cdef Py_buffer _buf
cdef Py_ssize_t _shape[1]
cdef Py_ssize_t _strides[1]
160 changes: 119 additions & 41 deletions dartsclone/_dartsclone.pyx
Original file line number Diff line number Diff line change
@@ -1,26 +1,58 @@
from libc.stdlib cimport malloc, free
from libc.stdlib cimport calloc, free

cdef extern from "Python.h":
ctypedef struct PyObject
int PyObject_GetBuffer(PyObject *exporter, Py_buffer *view, int flags)
void PyBuffer_Release(Py_buffer *view)
const int PyBUF_C_CONTIGUOUS


cdef class DoubleArray:
def __cinit__(self):
self.wrapped = new CppDoubleArray()
self._strides[0] = 1

def __dealloc__(self):
if <PyObject *>self._buf.obj != NULL:
PyBuffer_Release(&self._buf)
del self.wrapped

def __getstate__(self):
return self.array()
return bytes(self.array())

def __setstate__(self, array):
self.set_array(array)

def array(self):
cdef size_t total_size = self.wrapped.total_size()
cdef char[:] data = <char[:total_size]>self.wrapped.array()
return bytes(data)
def __getbuffer__(self, Py_buffer *buffer, int flags):
buffer.buf = <char *>self.wrapped.array()
buffer.obj = self
buffer.len = self._shape[0] = self.wrapped.total_size()
buffer.readonly = True
buffer.itemsize = 1
buffer.format = 'B'
buffer.ndim = 1
buffer.shape = self._shape
buffer.strides = self._strides
buffer.suboffsets = NULL
buffer.internal = NULL

def __releasebuffer__(self, Py_buffer *buffer):
pass

def set_array(self, const unsigned char[::1] array, size_t size=0):
self.wrapped.set_array(<const void*> &array[0], size)
def array(self):
return memoryview(self)

def set_array(self, array, size_t size=0):
cdef Py_buffer _buf
if PyObject_GetBuffer(<PyObject *>array, &_buf, PyBUF_C_CONTIGUOUS) < 0:
return
if _buf.buf == self.wrapped.array():
PyBuffer_Release(&_buf)
raise ValueError("passed buffer refers to itself")
if <PyObject *>self._buf.obj != NULL:
PyBuffer_Release(&self._buf)
self._buf = _buf
self.wrapped.set_array(_buf.buf, size)

def clear(self):
self.wrapped.clear()
Expand All @@ -41,26 +73,45 @@ cdef class DoubleArray:
lengths = None,
values = None):
cdef size_t num_keys = len(keys)
cdef const char** _keys = <const char**> malloc(num_keys * sizeof(char*))
cdef const char** _keys = NULL
cdef Py_buffer* _buf = NULL
cdef size_t *_lengths = NULL
cdef int *_values = NULL
for i, key in enumerate(keys):
_keys[i] = key
if lengths is not None:
_lengths = <size_t *> malloc(num_keys * sizeof(size_t))
for i, length in enumerate(lengths):
_lengths[i] = length
if values is not None:
_values = <int *> malloc(num_keys * sizeof(int))
for i, value in enumerate(values):
_values[i] = value

try:
_keys = <const char**> calloc(num_keys, sizeof(char*))
if _keys == NULL:
raise MemoryError("failed to allocate memory for key array")
_buf = <Py_buffer *> calloc(num_keys, sizeof(Py_buffer))
if _buf == NULL:
raise MemoryError("failed to allocate memory for buffer")
for i, key in enumerate(keys):
if PyObject_GetBuffer(<PyObject *>key, &_buf[i], PyBUF_C_CONTIGUOUS) < 0:
return
_keys[i] = <const char *> _buf[i].buf
if lengths is not None:
_lengths = <size_t *> calloc(num_keys, sizeof(size_t))
if _lengths == NULL:
raise MemoryError("failed to allocate memory for length array")
for i, length in enumerate(lengths):
_lengths[i] = length
if values is not None:
_values = <int *> calloc(num_keys, sizeof(int))
if _values == NULL:
raise MemoryError("failed to allocate memory for value array")
for i, value in enumerate(values):
_values[i] = value
self.wrapped.build(num_keys, _keys, <const size_t*> _lengths, <const int*> _values, NULL)
finally:
free(_keys)
if lengths is not None:
if _keys != NULL:
free(_keys)
if _buf != NULL:
for i in range(num_keys):
PyBuffer_Release(&_buf[i])
free(_buf)
if _lengths != NULL:
free(_lengths)
if values is not None:
if _values != NULL:
free(_values)

def open(self, file_name,
Expand Down Expand Up @@ -88,39 +139,66 @@ cdef class DoubleArray:
size_t length = 0,
size_t node_pos = 0,
pair_type=True):
cdef const char *_key = key
if pair_type:
return self.__exact_match_search_pair_type(_key, length, node_pos)
else:
return self.__exact_match_search(_key, length, node_pos)
cdef Py_buffer buf
if PyObject_GetBuffer(<PyObject *>key, &buf, PyBUF_C_CONTIGUOUS) < 0:
return
try:
if length == 0:
if buf.len == 0:
raise ValueError("buffer cannot be empty")
length = buf.len
if pair_type:
return self.__exact_match_search_pair_type(<const char *>buf.buf, length, node_pos)
else:
return self.__exact_match_search(<const char *>buf.buf, length, node_pos)
finally:
PyBuffer_Release(&buf)

def common_prefix_search(self, key,
size_t max_num_results = 0,
size_t length = 0,
size_t node_pos = 0,
pair_type=True):
cdef const char *_key = key
if max_num_results == 0:
max_num_results = len(key)
if pair_type:
return self.__common_prefix_search_pair_type(_key, max_num_results, length, node_pos)
else:
return self.__common_prefix_search(_key, max_num_results, length, node_pos)
cdef Py_buffer buf
if PyObject_GetBuffer(<PyObject *>key, &buf, PyBUF_C_CONTIGUOUS) < 0:
return
try:
if length == 0:
if buf.len == 0:
raise ValueError("buffer cannot be empty")
length = buf.len
if max_num_results == 0:
max_num_results = len(key)
if pair_type:
return self.__common_prefix_search_pair_type(<const char *>buf.buf, max_num_results, length, node_pos)
else:
return self.__common_prefix_search(<const char *>buf.buf, max_num_results, length, node_pos)
finally:
PyBuffer_Release(&buf)

def traverse(self, key,
size_t node_pos,
size_t key_pos,
size_t length = 0):
cdef const char *_key = key
cdef Py_buffer buf
cdef int result
with nogil:
result = self.wrapped.traverse(_key, node_pos, key_pos, length)
return result
if PyObject_GetBuffer(<PyObject *>key, &buf, PyBUF_C_CONTIGUOUS) < 0:
return
try:
if length == 0:
if buf.len == 0:
raise ValueError("buffer cannot be empty")
length = buf.len
with nogil:
result = self.wrapped.traverse(<const char *>buf.buf, node_pos, key_pos, length)
return result
finally:
PyBuffer_Release(&buf)

def __exact_match_search(self, const char *key,
size_t length = 0,
size_t node_pos = 0):
cdef int result
cdef int result = 0
with nogil:
self.wrapped.exact_match_search(key, result, length, node_pos)
return result
Expand All @@ -137,7 +215,7 @@ cdef class DoubleArray:
size_t max_num_results,
size_t length,
size_t node_pos):
cdef int *results = <int *> malloc(max_num_results * sizeof(int))
cdef int *results = <int *> calloc(max_num_results, sizeof(int))
cdef int result_len
try:
with nogil:
Expand All @@ -153,7 +231,7 @@ cdef class DoubleArray:
size_t max_num_results,
size_t length,
size_t node_pos):
cdef result_pair_type *results = <result_pair_type *> malloc(max_num_results * sizeof(result_pair_type))
cdef result_pair_type *results = <result_pair_type *> calloc(max_num_results, sizeof(result_pair_type))
cdef result_pair_type result
cdef int result_len
try:
Expand Down
15 changes: 11 additions & 4 deletions test/test_darts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class DoubleArrayTest(unittest.TestCase):
def test_darts_no_values(self):
keys = ['test', 'テスト', 'テストケース']
darts = DoubleArray()
darts.build(sorted([key.encode() for key in keys]))
darts.build([key.encode() for key in keys])
self.assertEqual(1, darts.exact_match_search('テスト'.encode(), pair_type=False))
self.assertEqual(0, darts.common_prefix_search('testcase'.encode(), pair_type=False)[0])
self.assertEqual(0, darts.exact_match_search('test'.encode(), pair_type=False))
Expand All @@ -21,7 +21,7 @@ def test_darts_no_values(self):
def test_darts_with_values(self):
keys = ['test', 'テスト', 'テストケース']
darts = DoubleArray()
darts.build(sorted([key.encode() for key in keys]), values=[3, 5, 1])
darts.build([key.encode() for key in keys], values=[3, 5, 1])
self.assertEqual(5, darts.exact_match_search('テスト'.encode(), pair_type=False))
self.assertEqual(3, darts.common_prefix_search('testcase'.encode(), pair_type=False)[0])
self.assertEqual(1, darts.exact_match_search('テストケース'.encode(), pair_type=False))
Expand All @@ -30,7 +30,7 @@ def test_darts_with_values(self):
def test_darts_save(self):
keys = ['test', 'テスト', 'テストケース']
darts = DoubleArray()
darts.build(sorted([key.encode() for key in keys]), values=[3, 5, 1])
darts.build([key.encode() for key in keys], values=[3, 5, 1])
with tempfile.NamedTemporaryFile('wb') as output_file:
darts.save(output_file.name)
output_file.flush()
Expand All @@ -54,13 +54,20 @@ def test_darts_pickle(self):
def test_darts_array(self):
keys = ['test', 'テスト', 'テストケース']
darts = DoubleArray()
darts.build(sorted([key.encode() for key in keys]), values=[3, 5, 1])
darts.build([key.encode() for key in keys], values=[3, 5, 1])
array = darts.array()
darts = DoubleArray()
darts.set_array(array)
self.assertEqual(5, darts.exact_match_search('テスト'.encode(), pair_type=False))
self.assertEqual(3, darts.common_prefix_search('testcase'.encode(), pair_type=False)[0])

def test_darts_buffers(self):
keys = ['test', 'テスト', 'テストケース']
darts = DoubleArray()
darts.build([memoryview(key.encode()) for key in keys], values=[3, 5, 1])
self.assertEqual(5, darts.exact_match_search(memoryview('テスト'.encode()), pair_type=False))
self.assertEqual(3, darts.common_prefix_search(memoryview('testcase'.encode()), pair_type=False)[0])


if __name__ == "__main__":
unittest.main()

0 comments on commit c299472

Please sign in to comment.