diff --git a/dartsclone/_dartsclone.pxd b/dartsclone/_dartsclone.pxd index ddb2780..29db18b 100644 --- a/dartsclone/_dartsclone.pxd +++ b/dartsclone/_dartsclone.pxd @@ -51,3 +51,6 @@ cdef extern from "darts.h": cdef class DoubleArray: cdef CppDoubleArray *wrapped + cdef Py_buffer _buf + cdef Py_ssize_t _shape[1] + cdef Py_ssize_t _strides[1] diff --git a/dartsclone/_dartsclone.pyx b/dartsclone/_dartsclone.pyx index f9d80e5..41adf70 100644 --- a/dartsclone/_dartsclone.pyx +++ b/dartsclone/_dartsclone.pyx @@ -1,26 +1,58 @@ -from libc.stdlib cimport malloc, free +from libc.stdlib cimport calloc, free + +cdef extern from "Python.h": + ctypedef struct PyObject + int PyObject_GetBuffer(PyObject *exporter, Py_buffer *view, int flags) + void PyBuffer_Release(Py_buffer *view) + const int PyBUF_C_CONTIGUOUS cdef class DoubleArray: def __cinit__(self): self.wrapped = new CppDoubleArray() + self._strides[0] = 1 def __dealloc__(self): + if self._buf.obj != NULL: + PyBuffer_Release(&self._buf) del self.wrapped def __getstate__(self): - return self.array() + return bytes(self.array()) def __setstate__(self, array): self.set_array(array) - def array(self): - cdef size_t total_size = self.wrapped.total_size() - cdef char[:] data = self.wrapped.array() - return bytes(data) + def __getbuffer__(self, Py_buffer *buffer, int flags): + buffer.buf = self.wrapped.array() + buffer.obj = self + buffer.len = self._shape[0] = self.wrapped.total_size() + buffer.readonly = True + buffer.itemsize = 1 + buffer.format = 'B' + buffer.ndim = 1 + buffer.shape = self._shape + buffer.strides = self._strides + buffer.suboffsets = NULL + buffer.internal = NULL + + def __releasebuffer__(self, Py_buffer *buffer): + pass - def set_array(self, const unsigned char[::1] array, size_t size=0): - self.wrapped.set_array( &array[0], size) + def array(self): + return memoryview(self) + + def set_array(self, array, size_t size=0): + cdef Py_buffer _buf + if PyObject_GetBuffer(array, &_buf, PyBUF_C_CONTIGUOUS) < 0: + return + if _buf.buf == self.wrapped.array(): + PyBuffer_Release(&_buf) + raise ValueError("passed buffer refers to itself") + if self._buf.obj != NULL: + PyBuffer_Release(&self._buf) + self._buf = _buf + self.wrapped.set_array(_buf.buf, size) def clear(self): self.wrapped.clear() @@ -41,26 +73,45 @@ cdef class DoubleArray: lengths = None, values = None): cdef size_t num_keys = len(keys) - cdef const char** _keys = malloc(num_keys * sizeof(char*)) + cdef const char** _keys = NULL + cdef Py_buffer* _buf = NULL cdef size_t *_lengths = NULL cdef int *_values = NULL - for i, key in enumerate(keys): - _keys[i] = key - if lengths is not None: - _lengths = malloc(num_keys * sizeof(size_t)) - for i, length in enumerate(lengths): - _lengths[i] = length - if values is not None: - _values = malloc(num_keys * sizeof(int)) - for i, value in enumerate(values): - _values[i] = value + try: + _keys = calloc(num_keys, sizeof(char*)) + if _keys == NULL: + raise MemoryError("failed to allocate memory for key array") + _buf = calloc(num_keys, sizeof(Py_buffer)) + if _buf == NULL: + raise MemoryError("failed to allocate memory for buffer") + for i, key in enumerate(keys): + if PyObject_GetBuffer(key, &_buf[i], PyBUF_C_CONTIGUOUS) < 0: + return + _keys[i] = _buf[i].buf + if lengths is not None: + _lengths = calloc(num_keys, sizeof(size_t)) + if _lengths == NULL: + raise MemoryError("failed to allocate memory for length array") + for i, length in enumerate(lengths): + _lengths[i] = length + if values is not None: + _values = calloc(num_keys, sizeof(int)) + if _values == NULL: + raise MemoryError("failed to allocate memory for value array") + for i, value in enumerate(values): + _values[i] = value self.wrapped.build(num_keys, _keys, _lengths, _values, NULL) finally: - free(_keys) - if lengths is not None: + if _keys != NULL: + free(_keys) + if _buf != NULL: + for i in range(num_keys): + PyBuffer_Release(&_buf[i]) + free(_buf) + if _lengths != NULL: free(_lengths) - if values is not None: + if _values != NULL: free(_values) def open(self, file_name, @@ -88,39 +139,66 @@ cdef class DoubleArray: size_t length = 0, size_t node_pos = 0, pair_type=True): - cdef const char *_key = key - if pair_type: - return self.__exact_match_search_pair_type(_key, length, node_pos) - else: - return self.__exact_match_search(_key, length, node_pos) + cdef Py_buffer buf + if PyObject_GetBuffer(key, &buf, PyBUF_C_CONTIGUOUS) < 0: + return + try: + if length == 0: + if buf.len == 0: + raise ValueError("buffer cannot be empty") + length = buf.len + if pair_type: + return self.__exact_match_search_pair_type(buf.buf, length, node_pos) + else: + return self.__exact_match_search(buf.buf, length, node_pos) + finally: + PyBuffer_Release(&buf) def common_prefix_search(self, key, size_t max_num_results = 0, size_t length = 0, size_t node_pos = 0, pair_type=True): - cdef const char *_key = key - if max_num_results == 0: - max_num_results = len(key) - if pair_type: - return self.__common_prefix_search_pair_type(_key, max_num_results, length, node_pos) - else: - return self.__common_prefix_search(_key, max_num_results, length, node_pos) + cdef Py_buffer buf + if PyObject_GetBuffer(key, &buf, PyBUF_C_CONTIGUOUS) < 0: + return + try: + if length == 0: + if buf.len == 0: + raise ValueError("buffer cannot be empty") + length = buf.len + if max_num_results == 0: + max_num_results = len(key) + if pair_type: + return self.__common_prefix_search_pair_type(buf.buf, max_num_results, length, node_pos) + else: + return self.__common_prefix_search(buf.buf, max_num_results, length, node_pos) + finally: + PyBuffer_Release(&buf) def traverse(self, key, size_t node_pos, size_t key_pos, size_t length = 0): - cdef const char *_key = key + cdef Py_buffer buf cdef int result - with nogil: - result = self.wrapped.traverse(_key, node_pos, key_pos, length) - return result + if PyObject_GetBuffer(key, &buf, PyBUF_C_CONTIGUOUS) < 0: + return + try: + if length == 0: + if buf.len == 0: + raise ValueError("buffer cannot be empty") + length = buf.len + with nogil: + result = self.wrapped.traverse(buf.buf, node_pos, key_pos, length) + return result + finally: + PyBuffer_Release(&buf) def __exact_match_search(self, const char *key, size_t length = 0, size_t node_pos = 0): - cdef int result + cdef int result = 0 with nogil: self.wrapped.exact_match_search(key, result, length, node_pos) return result @@ -137,7 +215,7 @@ cdef class DoubleArray: size_t max_num_results, size_t length, size_t node_pos): - cdef int *results = malloc(max_num_results * sizeof(int)) + cdef int *results = calloc(max_num_results, sizeof(int)) cdef int result_len try: with nogil: @@ -153,7 +231,7 @@ cdef class DoubleArray: size_t max_num_results, size_t length, size_t node_pos): - cdef result_pair_type *results = malloc(max_num_results * sizeof(result_pair_type)) + cdef result_pair_type *results = calloc(max_num_results, sizeof(result_pair_type)) cdef result_pair_type result cdef int result_len try: diff --git a/test/test_darts.py b/test/test_darts.py index a9501e9..020a49f 100644 --- a/test/test_darts.py +++ b/test/test_darts.py @@ -12,7 +12,7 @@ class DoubleArrayTest(unittest.TestCase): def test_darts_no_values(self): keys = ['test', 'テスト', 'テストケース'] darts = DoubleArray() - darts.build(sorted([key.encode() for key in keys])) + darts.build([key.encode() for key in keys]) self.assertEqual(1, darts.exact_match_search('テスト'.encode(), pair_type=False)) self.assertEqual(0, darts.common_prefix_search('testcase'.encode(), pair_type=False)[0]) self.assertEqual(0, darts.exact_match_search('test'.encode(), pair_type=False)) @@ -21,7 +21,7 @@ def test_darts_no_values(self): def test_darts_with_values(self): keys = ['test', 'テスト', 'テストケース'] darts = DoubleArray() - darts.build(sorted([key.encode() for key in keys]), values=[3, 5, 1]) + darts.build([key.encode() for key in keys], values=[3, 5, 1]) self.assertEqual(5, darts.exact_match_search('テスト'.encode(), pair_type=False)) self.assertEqual(3, darts.common_prefix_search('testcase'.encode(), pair_type=False)[0]) self.assertEqual(1, darts.exact_match_search('テストケース'.encode(), pair_type=False)) @@ -30,7 +30,7 @@ def test_darts_with_values(self): def test_darts_save(self): keys = ['test', 'テスト', 'テストケース'] darts = DoubleArray() - darts.build(sorted([key.encode() for key in keys]), values=[3, 5, 1]) + darts.build([key.encode() for key in keys], values=[3, 5, 1]) with tempfile.NamedTemporaryFile('wb') as output_file: darts.save(output_file.name) output_file.flush() @@ -54,13 +54,20 @@ def test_darts_pickle(self): def test_darts_array(self): keys = ['test', 'テスト', 'テストケース'] darts = DoubleArray() - darts.build(sorted([key.encode() for key in keys]), values=[3, 5, 1]) + darts.build([key.encode() for key in keys], values=[3, 5, 1]) array = darts.array() darts = DoubleArray() darts.set_array(array) self.assertEqual(5, darts.exact_match_search('テスト'.encode(), pair_type=False)) self.assertEqual(3, darts.common_prefix_search('testcase'.encode(), pair_type=False)[0]) + def test_darts_buffers(self): + keys = ['test', 'テスト', 'テストケース'] + darts = DoubleArray() + darts.build([memoryview(key.encode()) for key in keys], values=[3, 5, 1]) + self.assertEqual(5, darts.exact_match_search(memoryview('テスト'.encode()), pair_type=False)) + self.assertEqual(3, darts.common_prefix_search(memoryview('testcase'.encode()), pair_type=False)[0]) + if __name__ == "__main__": unittest.main()