|
1 | 1 | import numpy as np |
2 | 2 | import numba |
3 | 3 |
|
4 | | - |
5 | 4 | def convert_to_flat(inds, shape, axisptr): |
6 | | - |
7 | 5 | inds = [np.array(ind) for ind in inds] |
8 | 6 | if any(ind.ndim > 1 for ind in inds): |
9 | 7 | raise IndexError('Only one-dimensional iterable indices supported.') |
10 | | - col_shapes = np.array(shape[axisptr:]) |
11 | | - col_idx_size = np.prod([ind.size for ind in inds[axisptr:]]) |
12 | | - col_inds = inds[axisptr:] |
13 | | - if len(col_inds) == 1: |
14 | | - return col_inds[0] |
15 | | - cols = np.empty(col_idx_size, dtype=int) |
16 | | - col_operations = np.prod( |
17 | | - [ind.size for ind in inds[axisptr:-1]]) if len(inds[axisptr:]) > 1 else 1 |
18 | | - col_key_vals = np.array([int(col_inds[i][0]) for i in range( |
19 | | - len(col_inds[:-1]))] if len(col_inds) > 1 else [int(col_inds[0][0])]) |
20 | | - positions = np.zeros(len(col_shapes) - 1, dtype=int) |
21 | | - cols = convert_to_2d( |
22 | | - col_inds, |
23 | | - col_key_vals, |
24 | | - transform_shape(col_shapes), |
25 | | - col_operations, |
26 | | - cols, |
27 | | - positions) |
| 8 | + uncompressed_inds = inds[axisptr:] |
| 9 | + cols = np.empty(np.prod([ind.size for ind in uncompressed_inds]),dtype=np.intp) |
| 10 | + shape_bins = transform_shape(shape[axisptr:]) |
| 11 | + increments = [uncompressed_inds[i] * shape_bins[i] for i in range(len(uncompressed_inds))] |
| 12 | + operations = np.prod([ind.shape[0] for ind in increments[:-1]]) |
| 13 | + return compute_flat(increments,cols,operations) |
| 14 | + |
| 15 | +@numba.jit(nopython=True,nogil=True) |
| 16 | +def compute_flat(increments,cols,operations): |
| 17 | + start = 0 |
| 18 | + end = increments[-1].shape[0] |
| 19 | + positions = np.zeros(len(increments)-1,dtype=np.intp) |
| 20 | + pos = len(increments)-2 |
| 21 | + for i in range(operations): |
| 22 | + if i != 0 and positions[pos] == increments[pos].shape[0]: |
| 23 | + positions[pos] = 0 |
| 24 | + pos -= 1 |
| 25 | + positions[pos] += 1 |
| 26 | + pos += 1 |
| 27 | + to_add = np.array([increments[i][positions[i]] for i in range(len(increments)-1)]).sum() |
| 28 | + cols[start:end] = increments[-1] + to_add |
| 29 | + positions[pos] += 1 |
| 30 | + start += increments[-1].shape[0] |
| 31 | + end += increments[-1].shape[0] |
28 | 32 | return cols |
29 | | - |
| 33 | + |
| 34 | + |
30 | 35 | def transform_shape(shape): |
| 36 | + """ |
| 37 | + turns a shape into the linearized increments that |
| 38 | + it represents. For example, given (5,5,5), it returns |
| 39 | + np.array([25,5,1]). |
| 40 | + """ |
31 | 41 | shape_bins = np.empty(len(shape),dtype=np.intp) |
32 | 42 | shape_bins[-1] = 1 |
33 | 43 | for i in range(len(shape)-1): |
34 | 44 | shape_bins[i] = np.prod(shape[i:-1]) |
35 | 45 | return shape_bins |
36 | 46 |
|
37 | | - |
38 | | -@numba.jit(nopython=True, nogil=True) |
39 | | -def convert_to_2d(inds, key_vals, shape_bins, operations, indices, positions): |
40 | | - |
41 | | - pos = len(key_vals) - 1 |
42 | | - increment = 0 |
43 | | - |
44 | | - for i in range(operations): |
45 | | - if i != 0 and key_vals[pos] == inds[pos][-1]: |
46 | | - key_vals[pos] = inds[pos][0] |
47 | | - positions[pos] = 0 |
48 | | - pos -= 1 |
49 | | - positions[pos] += 1 |
50 | | - key_vals[pos] = inds[pos][positions[pos]] |
51 | | - pos = len(key_vals) - 1 |
52 | | - positions[pos] += 1 |
53 | | - linearized = ((key_vals + np.array([inds[-1][0]])) * shape_bins).sum() |
54 | | - indices[increment:increment + len(inds[-1])] = inds[-1] + linearized - inds[-1][0] |
55 | | - increment += len(inds[-1]) |
56 | | - |
57 | | - return indices |
58 | | - |
59 | | - |
60 | 47 | @numba.jit(nopython=True, nogil=True) |
61 | 48 | def uncompress_dimension(indptr): |
62 | 49 | """converts an index pointer array into an array of coordinates""" |
|
0 commit comments