-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve GPU functionality #780
Conversation
will this also work with Metal.jl? |
Is the extension really needed? Cant we just always do that with a more generic I had also assumed that this already worked, we just can't test on GPUs in CI so it can break. But Maybe something has changed... but there is surely a more elegant way to do this. |
Co-authored-by: Rafael Schouten <rafaelschouten@gmail.com>
Co-authored-by: Rafael Schouten <rafaelschouten@gmail.com>
The JLArray extension should catch most of these. The only one it can't is
I added the methods when An an example, I tried to replace the extension with function Base.copyto!(dest::AbstractDimArray, bc::Broadcasted{<:Broadcast.AbstractArrayStyle})
copyto!(parent(dest), bc)
dest
end But I get the following ambiguity error Candidates:
copyto!(dest::AbstractDimArray, bc::Base.Broadcast.Broadcasted{<:Base.Broadcast.AbstractArrayStyle})
@ DimensionalData ~/.julia/dev/DimensionalData/src/array/broadcast.jl:81
copyto!(dest::AbstractArray, bc::Base.Broadcast.Broadcasted{<:GPUArraysCore.AbstractGPUArrayStyle})
@ GPUArrays ~/.julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:44
Possible fix, define
copyto!(::AbstractDimArray, ::Base.Broadcast.Broadcasted{<:GPUArraysCore.AbstractGPUArrayStyle}) which is annoying... |
@lazarusA I believe it should work for any Julia supported GPU. I don't have a Mac to try this on, however. |
This is a potential workaround without dependencies: @inline function Base.Broadcast.materialize!(dest::AbstractDimArray, bc::Base.Broadcast.Broadcasted{<:Any})
style = DimensionalData.DimensionalStyle(Base.Broadcast.combine_styles(parent(dest), bc))
return Base.materialize!(style, parent(dest), bc)
end We will need to handle the dimension checks befor we do |
That works great for me! I just made the following modifications @inline function Base.Broadcast.materialize!(dest::AbstractDimArray, bc::Base.Broadcast.Broadcasted{<:Any})
style = DimensionalData.DimensionalStyle(Base.Broadcast.combine_styles(parent(dest), bc))
Base.Broadcast.materialize!(style, parent(dest), bc)
return dest
end
What dimension checks do you mean? I thought the dimension checks in DimensionalData.jl/src/array/broadcast.jl Lines 49 to 58 in fe39de7
should cover us since materialize will always call that. |
Oh you're right, that's it then! A few more tests that all this works with JLArrays could be nice. |
No actually dest won't be a DimArray in |
Ya I just realized that as well. Working on the fix. |
Quick question. Is it expected behavior that ab = DimArray(rand(2,2), (X, Y))
ba = DimArray(rand(2,2), (Y, X))
z = zeros(2,2)
function inplace!(z, ab, ba)
z .= ab .+ ba
end returns a DimArray? It seems strange that the output type of an in place broadcast changes, but this may be wanted. The only reason I bring this up is that is indeed the wanted behavior then I need the following implementation function Base.copyto!(dest::AbstractArray, bc::Broadcasted{DimensionalStyle{S}}) where S
_dims = comparedims(dims(dest), _broadcasted_dims(bc); ignore_length_one=true, order=true)
copyto!(dest, _unwrap_broadcasted(bc))
A = _firstdimarray(bc)
if A isa Nothing || _dims isa Nothing
dest
else
rebuild(A, dest, _dims, refdims(A))
end
end
@inline function Base.Broadcast.materialize!(dest::AbstractDimArray, bc::Base.Broadcast.Broadcasted{<:Any})
# needed because we need to check whether the dims are compatible in dest which are already
# stripped when sent to copyto!
_dims = comparedims(dims(dest), _broadcasted_dims(bc); ignore_length_one=true, order=true)
style = DimensionalData.DimensionalStyle(Base.Broadcast.combine_styles(parent(dest), bc))
Base.Broadcast.materialize!(style, parent(dest), bc)
A = _firstdimarray(bc)
if A isa Nothing || _dims isa Nothing
dest
else
rebuild(A, parent(dest), _dims, refdims(A))
end
end to do get the tests to pass. The is slightly non-ideal because it means |
Ya actually this is probably needed otherwise some dimension matches will happen. |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #780 +/- ##
==========================================
+ Coverage 82.64% 85.01% +2.36%
==========================================
Files 44 47 +3
Lines 4132 4277 +145
==========================================
+ Hits 3415 3636 +221
+ Misses 717 641 -76 ☔ View full report in Codecov by Sentry. |
Ahh probably your right it should just return We may need to look at some base broadcast docs to see the expected behavior |
Turns out not returning |
Co-authored-by: Rafael Schouten <rafaelschouten@gmail.com>
One thought... Should we instead loop over test objects rather than duplicating all that code? |
Probably. The annoying part is that some of the tests won't work on the GPU because there isn't a compatible implementation using the more generic GPUArrays, e.g., it only works on CUDA or it is missing all together. Some examples I ran into are:
*: missing JLArray method but works on CUDA. I can try to rewrite the tests to be more generic, but I probably won't have time for a little bit, given all the cases. I am about to head out for a conference next week so a lot of my attention will be on that. |
Ahh typical the test array isn't a full working implementation. That makes sense to do separately then, let's merge as is |
Ok thanks for all your help with the PR btw! |
Apologies I actually need to revert this PR, it has breaking changes. I'll lump it with the performance PR instead. |
* Improve GPU functionality * Add missing weakdeps * Update src/array/broadcast.jl Co-authored-by: Rafael Schouten <rafaelschouten@gmail.com> * Update src/array/broadcast.jl Co-authored-by: Rafael Schouten <rafaelschouten@gmail.com> * Push materialize fix * Clean up mapreduce and add a bunch of tests for JLArray broadcast * Add some more JLArray tests * Just return dest in broadcast * Update src/array/methods.jl Co-authored-by: Rafael Schouten <rafaelschouten@gmail.com> * Format --------- Co-authored-by: Rafael Schouten <rafaelschouten@gmail.com>
This PR adds some missing methods required to ensure the DimensionalData works on the GPU.
The PR makes the following possible.
I've also added some tests based on JLArrays to ensure test for potential scalar indexing problems in the future.
The biggest thing I am unsure of is the changes to
mapreduce
. The new implementation is pretty simple so I was uncertain if there was a specific case I missed. All the tests pass, but maybe this impacts a downstream package?