Skip to content

Commit fd42256

Browse files
authored
change linearization algorithm
1 parent 64d6bd5 commit fd42256

File tree

1 file changed

+31
-44
lines changed

1 file changed

+31
-44
lines changed

sparse/compressed/convert.py

Lines changed: 31 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,49 @@
11
import numpy as np
22
import numba
33

4-
54
def convert_to_flat(inds, shape, axisptr):
6-
75
inds = [np.array(ind) for ind in inds]
86
if any(ind.ndim > 1 for ind in inds):
97
raise IndexError('Only one-dimensional iterable indices supported.')
10-
col_shapes = np.array(shape[axisptr:])
11-
col_idx_size = np.prod([ind.size for ind in inds[axisptr:]])
12-
col_inds = inds[axisptr:]
13-
if len(col_inds) == 1:
14-
return col_inds[0]
15-
cols = np.empty(col_idx_size, dtype=int)
16-
col_operations = np.prod(
17-
[ind.size for ind in inds[axisptr:-1]]) if len(inds[axisptr:]) > 1 else 1
18-
col_key_vals = np.array([int(col_inds[i][0]) for i in range(
19-
len(col_inds[:-1]))] if len(col_inds) > 1 else [int(col_inds[0][0])])
20-
positions = np.zeros(len(col_shapes) - 1, dtype=int)
21-
cols = convert_to_2d(
22-
col_inds,
23-
col_key_vals,
24-
transform_shape(col_shapes),
25-
col_operations,
26-
cols,
27-
positions)
8+
uncompressed_inds = inds[axisptr:]
9+
cols = np.empty(np.prod([ind.size for ind in uncompressed_inds]),dtype=np.intp)
10+
shape_bins = transform_shape(shape[axisptr:])
11+
increments = [uncompressed_inds[i] * shape_bins[i] for i in range(len(uncompressed_inds))]
12+
operations = np.prod([ind.shape[0] for ind in increments[:-1]])
13+
return compute_flat(increments,cols,operations)
14+
15+
@numba.jit(nopython=True,nogil=True)
16+
def compute_flat(increments,cols,operations):
17+
start = 0
18+
end = increments[-1].shape[0]
19+
positions = np.zeros(len(increments)-1,dtype=np.intp)
20+
pos = len(increments)-2
21+
for i in range(operations):
22+
if i != 0 and positions[pos] == increments[pos].shape[0]:
23+
positions[pos] = 0
24+
pos -= 1
25+
positions[pos] += 1
26+
pos += 1
27+
to_add = np.array([increments[i][positions[i]] for i in range(len(increments)-1)]).sum()
28+
cols[start:end] = increments[-1] + to_add
29+
positions[pos] += 1
30+
start += increments[-1].shape[0]
31+
end += increments[-1].shape[0]
2832
return cols
29-
33+
34+
3035
def transform_shape(shape):
36+
"""
37+
turns a shape into the linearized increments that
38+
it represents. For example, given (5,5,5), it returns
39+
np.array([25,5,1]).
40+
"""
3141
shape_bins = np.empty(len(shape),dtype=np.intp)
3242
shape_bins[-1] = 1
3343
for i in range(len(shape)-1):
3444
shape_bins[i] = np.prod(shape[i:-1])
3545
return shape_bins
3646

37-
38-
@numba.jit(nopython=True, nogil=True)
39-
def convert_to_2d(inds, key_vals, shape_bins, operations, indices, positions):
40-
41-
pos = len(key_vals) - 1
42-
increment = 0
43-
44-
for i in range(operations):
45-
if i != 0 and key_vals[pos] == inds[pos][-1]:
46-
key_vals[pos] = inds[pos][0]
47-
positions[pos] = 0
48-
pos -= 1
49-
positions[pos] += 1
50-
key_vals[pos] = inds[pos][positions[pos]]
51-
pos = len(key_vals) - 1
52-
positions[pos] += 1
53-
linearized = ((key_vals + np.array([inds[-1][0]])) * shape_bins).sum()
54-
indices[increment:increment + len(inds[-1])] = inds[-1] + linearized - inds[-1][0]
55-
increment += len(inds[-1])
56-
57-
return indices
58-
59-
6047
@numba.jit(nopython=True, nogil=True)
6148
def uncompress_dimension(indptr):
6249
"""converts an index pointer array into an array of coordinates"""

0 commit comments

Comments
 (0)