Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve GPU functionality #780

Merged
merged 10 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Interfaces = "0.3"
IntervalSets = "0.5, 0.6, 0.7"
InvertedIndices = "1"
IteratorInterfaceExtensions = "1"
JLArrays = "0.1"
LinearAlgebra = "1"
Makie = "0.19, 0.20, 0.21"
OffsetArrays = "1"
Expand Down Expand Up @@ -85,6 +86,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ImageFiltering = "6a3955dd-da59-5b1f-98d4-e7296123deb5"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Expand All @@ -95,4 +97,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["Aqua", "ArrayInterface", "BenchmarkTools", "CategoricalArrays", "ColorTypes", "Combinatorics", "CoordinateTransformations", "DataFrames", "Distributions", "Documenter", "ImageFiltering", "ImageTransformations", "CairoMakie", "OffsetArrays", "Plots", "Random", "SafeTestsets", "StatsPlots", "Test", "Unitful"]
test = ["Aqua", "ArrayInterface", "BenchmarkTools", "CategoricalArrays", "ColorTypes", "Combinatorics", "CoordinateTransformations", "DataFrames", "Distributions", "Documenter", "ImageFiltering", "ImageTransformations", "JLArrays", "CairoMakie", "OffsetArrays", "Plots", "Random", "SafeTestsets", "StatsPlots", "Test", "Unitful"]
30 changes: 14 additions & 16 deletions src/array/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,24 @@ function Broadcast.copy(bc::Broadcasted{DimensionalStyle{S}}) where S
end

function Base.copyto!(dest::AbstractArray, bc::Broadcasted{DimensionalStyle{S}}) where S
_dims = comparedims(dims(dest), _broadcasted_dims(bc); ignore_length_one=true, order=true)
#TODO: this will cause a comparisson to happen twice. We should avoid that
comparedims(dims(dest), _broadcasted_dims(bc); ignore_length_one=true, order=true)
copyto!(dest, _unwrap_broadcasted(bc))
A = _firstdimarray(bc)
if A isa Nothing || _dims isa Nothing
dest
else
rebuild(A, dest, _dims, refdims(A))
end
return dest
end
function Base.copyto!(dest::AbstractDimArray, bc::Broadcasted{DimensionalStyle{S}}) where S
_dims = comparedims(dims(dest), _broadcasted_dims(bc); ignore_length_one=true, order=true)
copyto!(parent(dest), _unwrap_broadcasted(bc))
A = _firstdimarray(bc)
if A isa Nothing || _dims isa Nothing
dest
else
rebuild(A, parent(dest), _dims, refdims(A))
end

ptiede marked this conversation as resolved.
Show resolved Hide resolved

ptiede marked this conversation as resolved.
Show resolved Hide resolved
@inline function Base.Broadcast.materialize!(dest::AbstractDimArray, bc::Base.Broadcast.Broadcasted{<:Any})
# needed because we need to check whether the dims are compatible in dest which are already
# stripped when sent to copyto!
comparedims(dims(dest), _broadcasted_dims(bc); ignore_length_one=true, order=true)
style = DimensionalData.DimensionalStyle(Base.Broadcast.combine_styles(parent(dest), bc))
Base.Broadcast.materialize!(style, parent(dest), bc)
return dest
end



function Base.similar(bc::Broadcast.Broadcasted{DimensionalStyle{S}}, ::Type{T}) where {S,T}
A = _firstdimarray(bc)
rebuildsliced(A, similar(_unwrap_broadcasted(bc), T, axes(bc)...), axes(bc), Symbol(""))
Expand Down
22 changes: 3 additions & 19 deletions src/array/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,9 @@ for (m, f) in ((:Statistics, :median), (:Base, :any), (:Base, :all))
end
end

# These are not exported but it makes a lot of things easier using them
function Base._mapreduce_dim(f, op, nt::NamedTuple{(),<:Tuple}, A::AbstractDimArray, dims)
rebuild(A, Base._mapreduce_dim(f, op, nt, parent(A), dimnum(A, _astuple(dims))), reducedims(A, dims))
end
function Base._mapreduce_dim(f, op, nt::NamedTuple{(),<:Tuple}, A::AbstractDimArray, dims::Colon)
Base._mapreduce_dim(f, op, nt, parent(A), dims)
end
function Base._mapreduce_dim(f, op, nt, A::AbstractDimArray, dims)
rebuild(A, Base._mapreduce_dim(f, op, nt, parent(A), dimnum(A, dims)), reducedims(A, dims))
end
function Base._mapreduce_dim(f, op, nt, A::AbstractDimArray, dims::Colon)
rebuild(A, Base._mapreduce_dim(f, op, nt, parent(A), dimnum(A, dims)), reducedims(A, dims))
end

function Base._mapreduce_dim(f, op, nt::Base._InitialValue, A::AbstractDimArray, dims)
rebuild(A, Base._mapreduce_dim(f, op, nt, parent(A), dimnum(A, dims)), reducedims(A, dims))
end
function Base._mapreduce_dim(f, op, nt::Base._InitialValue, A::AbstractDimArray, dims::Colon)
Base._mapreduce_dim(f, op, nt, parent(A), dims)
function Base.mapreduce(f, op, A::AbstractDimArray; dims=Base.Colon(), kwargs...)
dims === Colon() && return mapreduce(f, op, parent(A); kwargs...)
rebuild(A, mapreduce(f, op, parent(A); dims=dimnum(A, dims), kwargs...), reducedims(A, dims))
ptiede marked this conversation as resolved.
Show resolved Hide resolved
end


Expand Down
135 changes: 129 additions & 6 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
using DimensionalData, Test

using JLArrays
using DimensionalData: NoLookup

# Tests taken from NamedDims. Thanks @oxinabox

da = ones(X(3))
dajl = rebuild(da, JLArray(parent(da)));
@test Base.BroadcastStyle(typeof(da)) isa DimensionalData.DimensionalStyle

@testset "standard case" begin
Expand All @@ -19,18 +20,37 @@ end
@test da2 .* da2[:, 1:1] == [1, 4, 9, 16] * (1:2:8)'
end

@testset "JLArray broadcast over length one dimension" begin
da2 = DimArray(JLArray((1:4) * (1:2:8)'), (X, Y))
@test Array(da2 .* da2[:, 1:1]) == [1, 4, 9, 16] * (1:2:8)'
end


@testset "in place" begin
@test parent(da .= 1 .* da .+ 7) == 8 * ones(3)
@test dims(da .= 1 .* da .+ 7) == dims(da)
end

@testset "JLArray in place" begin
@test Array(parent(dajl .= 1 .* dajl .+ 7)) == 8 * ones(3)
@test dims(dajl .= 1 .* dajl .+ 7) == dims(da)
end

@testset "Dimension disagreement" begin
@test_throws DimensionMismatch begin
DimArray(zeros(3, 3, 3), (X, Y, Z)) .+
DimArray(ones(3, 3, 3), (Y, Z, X))
end
end

@testset "JLArray Dimension disagreement" begin
@test_throws DimensionMismatch begin
DimArray(JLArray(zeros(3, 3, 3)), (X, Y, Z)) .+
DimArray(JLArray(ones(3, 3, 3)), (Y, Z, X))
end
end


@testset "dims and regular" begin
da = DimArray(ones(3, 3, 3), (X, Y, Z))
left_sum = da .+ ones(3, 3, 3)
Expand All @@ -41,6 +61,17 @@ end
@test dims(right_sum) == dims(da)
end

@testset "JLArray dims and regular" begin
da = DimArray(JLArray(ones(3, 3, 3)), (X, Y, Z))
left_sum = da .+ ones(3, 3, 3)
@test Array(left_sum) == fill(2, 3, 3, 3)
@test dims(left_sum) == dims(da)
right_sum = ones(3, 3, 3) .+ da
@test Array(right_sum) == fill(2, 3, 3, 3)
@test dims(right_sum) == dims(da)
end


@testset "changing type" begin
@test (da .> 0) isa DimArray
@test (da .* da .> 0) isa DimArray
Expand All @@ -51,6 +82,16 @@ end
@test (rand(3) .> 1 .> 0 .* da) isa DimArray
end

@testset "JLArray changing type" begin
@test (dajl .> 0) isa DimArray
@test (dajl .* dajl .> 0) isa DimArray
@test (dajl .> 0 .> rand(3)) isa DimArray
@test (dajl .* rand(3) .> 0.0) isa DimArray
@test (0 .> dajl .> 0 .> rand(3)) isa DimArray
@test (rand(3) .> dajl .> 0 .* rand(3)) isa DimArray
@test (rand(3) .> 1 .> 0 .* dajl) isa DimArray
end

@testset "trailng dimensions" begin
@test zeros(X(10), Y(5)) .* zeros(X(10), Y(1)) ==
zeros(X(10), Y(5)) .* zeros(X(1), Y(1)) ==
Expand Down Expand Up @@ -79,6 +120,18 @@ end
@test dims(s .+ v .+ m) == dims(m .+ s .+ v)
end

@testset "JLArray broadcasting" begin
v = DimArray(JLArray(zeros(3,)), X)
m = DimArray(JLArray(ones(3, 3)), (X, Y))
s = 0
@test Array(v .+ m) == ones(3, 3) == Array(m .+ v)
@test Array(s .+ m) == ones(3, 3) == Array(m .+ s)
@test Array(s .+ v .+ m) == ones(3, 3) == Array(m .+ s .+ v)
@test dims(v .+ m) == dims(m .+ v)
@test dims(s .+ m) == dims(m .+ s)
@test dims(s .+ v .+ m) == dims(m .+ s .+ v)
end

@testset "adjoint broadcasting" begin
a = DimArray(reshape(1:12, (4, 3)), (X, Y))
b = DimArray(1:3, Y)
Expand All @@ -88,6 +141,17 @@ end
@test dims(a .* b') == dims(a)
end

@testset "JLArray adjoint broadcasting" begin
a = DimArray(JLArray(reshape(1:12, (4, 3))), (X, Y))
b = DimArray(JLArray(1:3), Y)
@test_throws DimensionMismatch a .* b
@test_throws DimensionMismatch parent(a) .* parent(b)
@test Array(parent(a) .* parent(b)') == Array(parent(a .* b'))
@test dims(a .* b') == dims(a)
end



@testset "Mixed array types" begin
casts = (
A -> DimArray(A, (X, Y)), # Named Matrix
Expand Down Expand Up @@ -121,22 +185,54 @@ end
@test_throws DimensionMismatch ac .= ab .+ ba

# check that dest is written into:
@test dims(z .= ab .+ ba') == dims(ab .+ ba')
z .= ab .+ ba'
@test z == (ab.data .+ ba.data')

@test dims(z .= ab .+ a_) ==
(X(NoLookup(Base.OneTo(2))), Y(NoLookup(Base.OneTo(2))))
@test dims(a_ .= ba' .+ ab) ==
(X(NoLookup(Base.OneTo(2))), Y(NoLookup(Base.OneTo(2))))
# @test dims(z .= ab .+ a_) ==
# (X(NoLookup(Base.OneTo(2))), Y(NoLookup(Base.OneTo(2))))
# @test dims(a_ .= ba' .+ ab) ==
# (X(NoLookup(Base.OneTo(2))), Y(NoLookup(Base.OneTo(2))))
end

@testset "JLArray in-place assignment .=" begin
ab = DimArray(JLArray(rand(2,2)), (X, Y))
ba = DimArray(JLArray(rand(2,2)), (Y, X))
ac = DimArray(JLArray(rand(2,2)), (X, Z))
a_ = DimArray(JLArray(rand(2,2)), (X(), DimensionalData.AnonDim()))
z = JLArray(zeros(2,2))

@test_throws DimensionMismatch z .= ab .+ ba
@test_throws DimensionMismatch z .= ab .+ ac
@test_throws DimensionMismatch a_ .= ab .+ ac
@test_throws DimensionMismatch ab .= a_ .+ ac
@test_throws DimensionMismatch ac .= ab .+ ba

# check that dest is written into:
z .= ab .+ ba'
@test z == (ab.data .+ ba.data')

# @test dims(z .= ab .+ a_) ==
# (X(NoLookup(Base.OneTo(2))), Y(NoLookup(Base.OneTo(2))))
# @test dims(a_ .= ba' .+ ab) ==
# (X(NoLookup(Base.OneTo(2))), Y(NoLookup(Base.OneTo(2))))
end


@testset "assign using named indexing and dotview" begin
A = DimArray(zeros(3,2), (X, Y))
A[X=1:2] .= [1, 2]
A[X=3] .= 7
@test A == [1.0 1.0; 2.0 2.0; 7.0 7.0]
end

@testset "JLArray assign using named indexing and dotview" begin
A = DimArray(JLArray(zeros(3,2)), (X, Y))
A[X=1:2] .= JLArray([1, 2])
A[X=3] .= 7
@test Array(A) == [1.0 1.0; 2.0 2.0; 7.0 7.0]
end


@testset "0-dimensional array broadcasting" begin
x = DimArray(fill(3), ())
y = DimArray(fill(4), ())
Expand All @@ -148,6 +244,7 @@ end
@test @inferred(z .+ x) === 6
end


@testset "DimIndices broadcasting" begin
ds = X(1.0:0.2:2.0), Y(10:2:20)
A = rand(ds)
Expand All @@ -168,6 +265,32 @@ end
@test A[DimSelectors(sub)] == C[DimSelectors(sub)]
end

@testset "JLArray DimIndices broadcasting" begin
ds = X(1.0:0.2:2.0), Y(10:2:20)
_A = (rand(ds))
_B = (zeros(ds))
_C = (zeros(ds))

A = rebuild(_A, JLArray(parent(_A)))
B = rebuild(_B, JLArray(parent(_B)))
C = rebuild(_C, JLArray(parent(_C)))

B[DimIndices(B)] .+= A
C[DimSelectors(C)] .+= A
@test Array(A) == Array(B) == Array(C)
sub = A[1:4, 1:3]
B .= 0
C .= 0
B[DimIndices(sub)] .+= sub
C[DimSelectors(sub)] .+= sub
@test Array(A[DimIndices(sub)]) == Array(B[DimIndices(sub)]) == Array(C[DimIndices(sub)])
sub = A[2:4, 2:5]
C .= 0
C[DimSelectors(sub)] .+= sub
@test Array(A[DimSelectors(sub)]) == Array(C[DimSelectors(sub)])
end


# @testset "Competing Wrappers" begin
# da = DimArray(ones(4), X)
# ta = TrackedArray(5 * ones(4))
Expand Down
Loading
Loading