-
-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add N-ary broadcasting operations. #98
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left a few comments, mostly on style. I need to go over your earlier conversation with @shoyer before I'm able to properly review this.
sparse/tests/test_coo.py
Outdated
|
||
assert_eq(sparse.elemwise(func, xs, ys, zs), func(x, y, z)) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some extra checks in the removed tests that we may want to maintain, for example that the result of elemwise is a COO
, and that its non-zeros are as expected
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might also consider having tests for some of the following:
- N-ary broadcasting where the arguments have different shapes
- N-ary broadcasting including arguments that are scalars and zero-dimensional arrays
sparse/coo.py
Outdated
|
||
__doc__ = func.__doc__ | ||
|
||
return Partial() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thoughts on replacing this with a just a functools.partial
on top of a normal function?
This is our solution for dask
def partial_by_order(*args, **kwargs):
"""
>>> from operator import add
>>> partial_by_order(5, function=add, other=[(1, 10)])
15
"""
function = kwargs.pop('function')
other = kwargs.pop('other')
args2 = list(args)
for i, arg in other:
args2.insert(i, arg)
return function(*args2, **kwargs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could have, but our situation is slightly unique:
- We're using
str(func)
for exceptions.functools.wraps
doesn't work on that for all callables (e.g.ufunc
s), and breaks a few docstrings. This leads to illegible names in exceptions like_posarg_partial.<locals>.wrapper
(and the same for debugging). - We're replacing a number of arguments in different positions.
I guess I could turn it into a class rather than a decorator style function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I turned it into a callable class.
sparse/coo.py
Outdated
@@ -2426,80 +2468,39 @@ def _elemwise_unary(func, self, *args, **kwargs): | |||
sorted=self.sorted) | |||
|
|||
|
|||
def _get_matching_coords(coords1, coords2, shape1, shape2): | |||
def _get_nary_matching_coords(coords, params, shape): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe just call this _get_matching_coords
and drop the nary. Presumably there wll be no need to distinguish any longer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
sparse/coo.py
Outdated
matching_coords : np.ndarray | ||
The coordinates of the output array for which both inputs will be nonzero. | ||
numpy.ndarray | ||
The broacasted coordinates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Style nit, there is no need to place a period at the end of a phrase like this. We tend to reserve periods for full sentences.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
sparse/coo.py
Outdated
result_shape = _get_broadcast_shape(self.shape, other.shape) | ||
Parameters | ||
---------- | ||
args : tuple[COO] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're trying for parametrized python type annotations then I think it's supposed to be standard to use capitalized types like List[COO]
or Tuple[np.ndarray]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know though, this is somewhat new to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tuple
and list
tend to work better with intersphinx and code type annotations, so I tend to prefer those. Of course, I could import in something, but then that gives me PEP8 failures as I don't use it in code, just in docstrings.
sparse/coo.py
Outdated
other_data = other_data[i] | ||
# Filter out scalars as they are 'baked' into the function. | ||
func = _posarg_partial(func, pos, posargs) | ||
args = list(filter(lambda arg: not isscalar(arg), args)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might consider toolz.remove
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer not to introduce a dependency for something as simple as this.
sparse/coo.py
Outdated
args = list(args) | ||
posargs = [] | ||
pos = [] | ||
for i in range(len(args)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might consider for i, arg in enumerate(args)
, which might be a bit more idiomatic for Python readers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
sparse/coo.py
Outdated
@@ -1954,6 +1894,20 @@ def tril(x, k=0): | |||
return COO(coords, data, x.shape, x.has_duplicates, x.sorted) | |||
|
|||
|
|||
def _nary_match(*arrays): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not able to quickly figure out what this function does. Can I ask you for a small docstring? If possible I find small example sections in docstrings to be very helpful when learning codebases that others have written.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops, must have missed that one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this function is no longer used. Delete?
sparse/coo.py
Outdated
ci, di = _unmatch_coo(func, args, mask, **kwargs) | ||
|
||
coords_list.extend(ci) | ||
data_list.extend(di) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This confuses me and seems concerning. I see that this was a main point of your conversation with @shoyer earlier. I probably have some thinking to do on this problem before I'm able to reasonably comment on this.
I think that it would be good to see a more comprehensive test suite that fully explains the complexity of what we're trying to accomplish here. I think that that will make it more clear as we discuss different possibilities here. We might ask "why are we doing X" and the answer can be "see test_X". I get the sense that you've thought deeply about this problem and know all of the problems that might arise. It would be very valuable to encode that deep thinking and all of those corner cases into a test suite. |
I plan to make more comprehensive tests, yes. But the issue is some of the complexity can't be directly tested: For example, the optimizations are just that: Optimizations. We can design the tests so the optimizations are hit but we can't know that they kicked in without weird monkey-patching of some sort. |
03a552f
to
84e4e6c
Compare
It seems there's a slight bug for number of inputs >2 and broadcasting, nothing unfixable, but will have to think a bit. I'm on it. |
I think that there are probably a lot of correctness tests that could be written as well. In #1 you discussed many situations that might arise for which a system like this would be necessary to catch. Ideally we would encode all of those situations as tests to ensure that future developers don't change code to alter correct behavior here. |
sparse/coo.py
Outdated
other_data = other_data[i] | ||
# Filter out scalars as they are 'baked' into the function. | ||
func = PositinalArgumentPartial(func, pos, posargs) | ||
args = list(filter(lambda arg: not isscalar(arg), args)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optional: consider using a list comprehension instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
sparse/coo.py
Outdated
matched_coords : np.ndarray | ||
The overall coordinates that match from both arrays. | ||
args : tuple[COO] | ||
The input :obj:`COO` arrays. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add in func
, mask
and **kwargs
to the docstring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
sparse/coo.py
Outdated
|
||
coords_list = [] | ||
data_list = [] | ||
pos, = np.where([not m for m in mask]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe use np.flatnonzero()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't really a numerical operation. I've converted it to a tuple(generator comprehension)
form and avoided np.where
altogether. The exact code is
pos = tuple(i for i, m in enumerate(mask) if not m)
I agree with @mrocklin that a more extensive test suite is vital here. This logic is complicated and fixing bugs later will be hard. I haven't seriously tried to follow it yet. I would suggest parametric tests verifying proper broadcasting with 2 or 3 arguments with:
|
d96e178
to
a715e9c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some small coverage comments
pos.append(i) | ||
posargs.append(args[i]) | ||
elif isinstance(arg, SparseArray) and not isinstance(arg, COO): | ||
args[i] = COO(arg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line doesn't get hit by tests. Should we add a small DOK test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added!
posargs.append(args[i]) | ||
elif isinstance(arg, SparseArray) and not isinstance(arg, COO): | ||
args[i] = COO(arg) | ||
elif not isinstance(arg, COO): | ||
raise ValueError("Performing this operation would produce " | ||
"a dense result: %s" % str(func)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here. No test triggers this error-handling code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a small test that hits this.
args = [arg for arg in args if not isscalar(arg)] | ||
|
||
if len(args) == 0: | ||
return func(**kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also here. No test operates on no args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added another small test for this.
sparse/coo.py
Outdated
raise ValueError('Unknown kwargs %s' % kwargs.keys()) | ||
|
||
if return_midx and (len(args) != 2 or cache is not None): | ||
raise NotImplementedError('Matching only supported for two args, and no cache.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need this option?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we don't. I'm not omniscient, so I went ahead and added this check in case someone tried to trigger caching on return_midx
(which we don't cache, it's never repeated); or tried to match indices for len(args) != 2
(I'm not sure if we'll need this in the future, but we might, and it's useful to err rather than have it return incorrect results).
fs = sparse.elemwise(func, *args) | ||
assert isinstance(fs, COO) | ||
|
||
assert_eq(fs, func(*dense_args)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to test and verify that we are not creating unnecessary zeroes in the data attribute. We might either test that explicitly here, or we might put it into assert_eq
. I've gone ahead and pushed a commit to your branch that adds a check into assert_eq
. Please remove if you prefer not to add this here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to verify we don't create additional zeros for all our operations, so that seems like a rather useful addition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although I would prefer to use np.count_nonzero
.
Edit: I reconsidered, this might be more useful for fill values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On a related note, maybe it would be useful to add a test that verifies "sparse" broadcasting is actually done in a sparse way?
I think you could do this by mocking the underlying functions (e.g., np.mul
) and then verifying that the calls match expectations.
sparse/tests/test_coo.py
Outdated
def value_array(n): | ||
ar = np.empty((n,), dtype=np.float_) | ||
ar[:] = value | ||
return ar |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want just a few of the values to be pathological instead of all of them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll modify the test to match that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified.
I've incorporated more or less all of your suggestions about coverage, with one exception (see comments!) |
@shoyer do you have a chance to look at this? "Nope" is a fine answer. |
I'm guessing @shoyer doesn't work weekends. :-) If there's no reply or a "Nope" by the end of Monday, we can decide what to do next. |
It's a mixed bag on weekends, but this weekend my wife is away so I have time for open source :). I'll take a look. |
(2,), | ||
(3, 2), | ||
(4, 3, 2), | ||
], lambda x, y, z: (x + y) * z), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider doing a full cross-product of shapes and functions here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
sparse/tests/test_coo.py
Outdated
(4, 4), | ||
(4, 4, 4), | ||
], lambda x, y, z: x - y + z), | ||
]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to add checks for a few more variations on the broadcasting logic to exercise the matching logic:
- Dimensions of size 1, e.g.,
(3, 1)
+(3, 4)
->(3, 4)
- Output shapes that don't match one of the inputs, e.g.,
(3, 1)
+(1, 4)
->(3, 4)
. - Outputs that require matching across three inputs, e.g.,
(1, 1, 2)
+(1, 3, 1)
+(4, 1, 1)
->(4, 3, 2)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first two were already covered in test_broadcasting
. I renamed that to test_binary_broadcasting
and moved it closer to these.
The third, I also added.
fs = sparse.elemwise(func, *args) | ||
assert isinstance(fs, COO) | ||
|
||
assert_eq(fs, func(*dense_args)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On a related note, maybe it would be useful to add a test that verifies "sparse" broadcasting is actually done in a sparse way?
I think you could do this by mocking the underlying functions (e.g., np.mul
) and then verifying that the calls match expectations.
I can't seem to be able to respond to your "sparse" broadcasting comment, so I'm responding here. I monkey-patched one of our own functions and verified the behavior is correct there. I also verified Edit: However; I will add that like all monkey patching, it's implementation dependent, not (just) API dependent. |
Yeah, I was trying to do that as well. I haven't seen that before Is checking for the right number of non-zeros in the output not sufficient? Do we have code paths that would re-sparsify a dense intermediate result? |
We're actually talking about the "optimized" code path for things like |
Ah, happy to retract my comment |
No, let's leave it up for other people who can't follow the terminology. |
Are there any further comments or is this good to merge? |
I haven't reviewed the logic in detail, but the implementation looks relatively sane and I am satisfied with the test coverage. 👍 |
Merged! |
This PR adds N-ary broadcasting operations (in preparation for where) and simplifies code for the N-ary case.
This PR adds N-ary broadcasting operations (in preparation for
where
) and simplifies code for the N-ary case.Discussed in #1