Skip to content

Fix #16 #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,36 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl
"""
onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...)

function _onehotbatch(data, labels)
indices = UInt32[something(_findval(i, labels), 0) for i in data]
if 0 in indices
function _onehotbatch(data, labels) # this accepts any iterator
indices = UInt32[something(_findval(x, labels), 0) for x in data]
if any(iszero, indices)
for x in data
isnothing(_findval(x, labels)) && error("Value $x not found in labels")
isnothing(_findval(x, labels)) && throw(ArgumentError("Value x = $x not found in labels = $labels"))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed these error types partly so that tests can distinguish scalar indexing errors from helpful messages.

end
end
return OneHotArray(indices, length(labels))
end
function _onehotbatch(data::AbstractArray, labels) # this works for GPUArrays too
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One reason to change this is to avoid ever making a MVector or something weird like that:

Suggested change
function _onehotbatch(data::AbstractArray, labels) # this works for GPUArrays too
function _onehotbatch(data::AbstractGPUArray, labels)

indices = similar(data, UInt32)
map!(x -> something(_findval(x, labels), 0), indices, data)
if any(iszero, indices)
badx = @allowscalar data[findfirst(iszero, indices)]
throw(ArgumentError("Value x = $badx not found in labels = $labels"))
end
return OneHotArray(indices, length(labels))
end

function _onehotbatch(data, labels, default)
default_index = _findval(default, labels)
isnothing(default_index) && error("Default value $default is not in labels")
indices = UInt32[something(_findval(i, labels), default_index) for i in data]
isnothing(default_index) && throw(ArgumentError("Default value $default is not in labels = $labels"))
indices = UInt32[something(_findval(x, labels), default_index) for x in data]
return OneHotArray(indices, length(labels))
end
function _onehotbatch(data::AbstractArray, labels, default)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function _onehotbatch(data::AbstractArray, labels, default)
function _onehotbatch(data::AbstractGPUArray, labels, default)

default_index = _findval(default, labels)
isnothing(default_index) && throw(ArgumentError("Default value $default is not in labels = $labels"))
indices = similar(data, UInt32)
map!(x -> something(_findval(x, labels), default_index), indices, data)
return OneHotArray(indices, length(labels))
end

Expand Down
49 changes: 27 additions & 22 deletions test/gpu.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,23 @@

# Tests from Flux, probably not the optimal testset organisation!

@testset "CUDA" begin
x = randn(5, 5)
cx = cu(x)
@test cx isa CuArray

@test_skip onecold(cu([1.0, 2.0, 3.0])) == 3 # passes with CuArray with Julia 1.6, but fails with JLArray

x = onehotbatch([1, 2, 3], 1:3)
cx = cu(x)
@test cx isa OneHotMatrix && cx.indices isa CuArray
@test (cx .+ 1) isa CuArray

@testset "onehotbatch gpu" begin
# move to GPU after construction
x = onehotbatch([1, 2, 3, 2], 1:3)
@test cu(x) isa OneHotMatrix
@test cu(x).indices isa CuArray

# broadcast style works:
@test (cu(x) .+ 1) isa CuArray
xs = rand(5, 5)
ys = onehotbatch(1:5,1:5)
ys = onehotbatch(rand(1:5, 5), 1:5)
@test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
end

@testset "onehot gpu" begin
y = onehotbatch(ones(3), 1:2) |> cu;
@test (repr("text/plain", y); true)

gA = rand(3, 2) |> cu;
@test_broken gradient(A -> sum(A * y), gA)[1] isa CuArray # fails with JLArray, bug in Zygote?
# move to GPU before construction
z1 = onehotbatch(cu([3f0, 1f0, 2f0, 2f0]), (1.0, 2f0, 3))
@test z1.indices isa CuArray
z2 = onehotbatch(cu([3f0, 1f0, 2f0, 2f0]), [1, 2], 2) # with default
@test z2.indices isa CuArray
@test_throws ArgumentError onehotbatch(cu([1, 2, 3]), [1, 2]) # friendly error, not scalar indexing
@test_throws ArgumentError onehotbatch(cu([1, 2, 3]), [1, 2], 5)
end

@testset "onecold gpu" begin
Expand All @@ -32,6 +26,17 @@ end
@test onecold(y) isa CuArray
@test y[3,:] isa CuArray
@test onecold(y, l) == ['a', 'a', 'a']

@test_skip onecold(cu([1.0, 2.0, 3.0])) == 3 # passes with CuArray with Julia 1.6, but fails with JLArray
end

@testset "matrix multiplication gpu" begin
y = onehotbatch([1, 2, 1], [1, 2]) |> cu;
A = rand(3, 2) |> cu;

@test_broken collect(A * y) ≈ collect(A) * collect(y)

@test_broken gradient(A -> sum(abs, A * y), A)[1] isa CuArray # gather!(dst::JLArray, ...) fails
Comment on lines +37 to +39
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the solution here to use gather (and take on a dep)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs gather/scatter from NNlibCUDA to work on the GPU. And since there's no corresponding code for non-CuArray GPUArrays, I think it can't work with this fake JLArray.

For testing it, you could set up the whole buildkite story to run honest CUDA tests. But perhaps it's not worth it, and this package should just trust NNlib + NNlibCUDA to test things. And perhaps Flux to test the integration?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Xref FluxML/NNlib.jl#427 too --- it would be nice if forgetting to load NNlibCUDA gave friendly errors, not scalar indexing.

It would be nicer if that could be loaded automatically, of course.

end

@testset "onehot forward map to broadcast" begin
Expand Down