Skip to content

Commit 4b523b9

Browse files
authored
Update indexing.py
1 parent 20216a7 commit 4b523b9

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

sparse/compressed/indexing.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ def getitem(x, key):
2222

2323
# zip_longest so things like x[..., None] are picked up.
2424
if len(key) != 0 and all(isinstance(k, slice) and k == slice(0, dim, 1)
25-
for k, dim in zip_longest(key, x.shape)):
25+
for k, dim in zip_longest(key, x.shape)):
2626
return x
27-
27+
2828
# return a single element
2929
if all(isinstance(k, int) for k in key): # indexing for a single element
3030
key = np.array(key)[x.axis_order] # reordering the input
@@ -44,7 +44,7 @@ def getitem(x, key):
4444

4545
Nones_removed = [k for k in key if k is not None]
4646
count = 0
47-
for i,ind in enumerate(Nones_removed):
47+
for i, ind in enumerate(Nones_removed):
4848
if isinstance(ind, Integral):
4949
continue
5050
elif ind is None:
@@ -65,19 +65,20 @@ def getitem(x, key):
6565
else:
6666
uncompressed_inds[i] = True
6767
count += 1
68-
68+
6969
reordered_key = [Nones_removed[i] for i in x.axis_order]
70-
70+
7171
# prepare for converting to flat indices
7272
for i, ind in enumerate(reordered_key[:x.axisptr]):
73-
if isinstance(ind,slice):
74-
reordered_key[i] = range(ind.start,ind.stop,ind.step)
73+
if isinstance(ind, slice):
74+
reordered_key[i] = range(ind.start, ind.stop, ind.step)
7575
for i, ind in enumerate(reordered_key[x.axisptr:]):
7676
if isinstance(ind, Integral):
7777
reordered_key[i + x.axisptr] = [ind]
7878
elif isinstance(ind, slice):
79-
reordered_key[i + x.axisptr] = np.arange(ind.start, ind.stop, ind.step)
80-
79+
reordered_key[i +
80+
x.axisptr] = np.arange(ind.start, ind.stop, ind.step)
81+
8182
# find starts and ends of rows
8283
a = x.indptr[:-1].reshape(x.reordered_shape[:x.axisptr])
8384
b = x.indptr[1:].reshape(x.reordered_shape[:x.axisptr])
@@ -133,23 +134,22 @@ def getitem(x, key):
133134
np.cumsum(np.bincount(uncompressed,
134135
minlength=shape[0]), out=indptr[1:])
135136
indices = indices % size
136-
137+
137138
arg = (data, indices, indptr)
138139

139140
compressed_axes = np.array(compressed_axes)
140141
shape = shape.tolist()
141142
for i in range(len(key)):
142143
if key[i] is None:
143-
shape.insert(i,1)
144-
compressed_axes[compressed_axes>=i] += 1
144+
shape.insert(i, 1)
145+
compressed_axes[compressed_axes >= i] += 1
145146

146147
compressed_axes = tuple(compressed_axes)
147148
shape = tuple(shape)
148-
149+
149150
if len(shape) == 1:
150151
compressed_axes = None
151152

152-
153153
return GXCS(
154154
arg,
155155
shape=shape,

0 commit comments

Comments
 (0)