-
-
Notifications
You must be signed in to change notification settings - Fork 8
Fix #16 #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix #16 #17
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -83,20 +83,36 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl | |||||
""" | ||||||
onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...) | ||||||
|
||||||
function _onehotbatch(data, labels) | ||||||
indices = UInt32[something(_findval(i, labels), 0) for i in data] | ||||||
if 0 in indices | ||||||
function _onehotbatch(data, labels) # this accepts any iterator | ||||||
indices = UInt32[something(_findval(x, labels), 0) for x in data] | ||||||
if any(iszero, indices) | ||||||
for x in data | ||||||
isnothing(_findval(x, labels)) && error("Value $x not found in labels") | ||||||
isnothing(_findval(x, labels)) && throw(ArgumentError("Value x = $x not found in labels = $labels")) | ||||||
end | ||||||
end | ||||||
return OneHotArray(indices, length(labels)) | ||||||
end | ||||||
function _onehotbatch(data::AbstractArray, labels) # this works for GPUArrays too | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One reason to change this is to avoid ever making a
Suggested change
|
||||||
indices = similar(data, UInt32) | ||||||
map!(x -> something(_findval(x, labels), 0), indices, data) | ||||||
if any(iszero, indices) | ||||||
badx = @allowscalar data[findfirst(iszero, indices)] | ||||||
throw(ArgumentError("Value x = $badx not found in labels = $labels")) | ||||||
end | ||||||
return OneHotArray(indices, length(labels)) | ||||||
end | ||||||
|
||||||
function _onehotbatch(data, labels, default) | ||||||
default_index = _findval(default, labels) | ||||||
isnothing(default_index) && error("Default value $default is not in labels") | ||||||
indices = UInt32[something(_findval(i, labels), default_index) for i in data] | ||||||
isnothing(default_index) && throw(ArgumentError("Default value $default is not in labels = $labels")) | ||||||
indices = UInt32[something(_findval(x, labels), default_index) for x in data] | ||||||
return OneHotArray(indices, length(labels)) | ||||||
end | ||||||
function _onehotbatch(data::AbstractArray, labels, default) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
default_index = _findval(default, labels) | ||||||
isnothing(default_index) && throw(ArgumentError("Default value $default is not in labels = $labels")) | ||||||
indices = similar(data, UInt32) | ||||||
map!(x -> something(_findval(x, labels), default_index), indices, data) | ||||||
return OneHotArray(indices, length(labels)) | ||||||
end | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,23 @@ | ||
|
||
# Tests from Flux, probably not the optimal testset organisation! | ||
|
||
@testset "CUDA" begin | ||
x = randn(5, 5) | ||
cx = cu(x) | ||
@test cx isa CuArray | ||
|
||
@test_skip onecold(cu([1.0, 2.0, 3.0])) == 3 # passes with CuArray with Julia 1.6, but fails with JLArray | ||
|
||
x = onehotbatch([1, 2, 3], 1:3) | ||
cx = cu(x) | ||
@test cx isa OneHotMatrix && cx.indices isa CuArray | ||
@test (cx .+ 1) isa CuArray | ||
|
||
@testset "onehotbatch gpu" begin | ||
# move to GPU after construction | ||
x = onehotbatch([1, 2, 3, 2], 1:3) | ||
@test cu(x) isa OneHotMatrix | ||
@test cu(x).indices isa CuArray | ||
|
||
# broadcast style works: | ||
@test (cu(x) .+ 1) isa CuArray | ||
xs = rand(5, 5) | ||
ys = onehotbatch(1:5,1:5) | ||
ys = onehotbatch(rand(1:5, 5), 1:5) | ||
@test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys) | ||
end | ||
|
||
@testset "onehot gpu" begin | ||
y = onehotbatch(ones(3), 1:2) |> cu; | ||
@test (repr("text/plain", y); true) | ||
|
||
gA = rand(3, 2) |> cu; | ||
@test_broken gradient(A -> sum(A * y), gA)[1] isa CuArray # fails with JLArray, bug in Zygote? | ||
# move to GPU before construction | ||
z1 = onehotbatch(cu([3f0, 1f0, 2f0, 2f0]), (1.0, 2f0, 3)) | ||
@test z1.indices isa CuArray | ||
z2 = onehotbatch(cu([3f0, 1f0, 2f0, 2f0]), [1, 2], 2) # with default | ||
@test z2.indices isa CuArray | ||
@test_throws ArgumentError onehotbatch(cu([1, 2, 3]), [1, 2]) # friendly error, not scalar indexing | ||
@test_throws ArgumentError onehotbatch(cu([1, 2, 3]), [1, 2], 5) | ||
end | ||
|
||
@testset "onecold gpu" begin | ||
|
@@ -32,6 +26,17 @@ end | |
@test onecold(y) isa CuArray | ||
@test y[3,:] isa CuArray | ||
@test onecold(y, l) == ['a', 'a', 'a'] | ||
|
||
@test_skip onecold(cu([1.0, 2.0, 3.0])) == 3 # passes with CuArray with Julia 1.6, but fails with JLArray | ||
end | ||
|
||
@testset "matrix multiplication gpu" begin | ||
y = onehotbatch([1, 2, 1], [1, 2]) |> cu; | ||
A = rand(3, 2) |> cu; | ||
|
||
@test_broken collect(A * y) ≈ collect(A) * collect(y) | ||
|
||
@test_broken gradient(A -> sum(abs, A * y), A)[1] isa CuArray # gather!(dst::JLArray, ...) fails | ||
Comment on lines
+37
to
+39
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the solution here to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this needs gather/scatter from NNlibCUDA to work on the GPU. And since there's no corresponding code for non-CuArray GPUArrays, I think it can't work with this fake JLArray. For testing it, you could set up the whole buildkite story to run honest CUDA tests. But perhaps it's not worth it, and this package should just trust NNlib + NNlibCUDA to test things. And perhaps Flux to test the integration? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Xref FluxML/NNlib.jl#427 too --- it would be nice if forgetting to load NNlibCUDA gave friendly errors, not scalar indexing. It would be nicer if that could be loaded automatically, of course. |
||
end | ||
|
||
@testset "onehot forward map to broadcast" begin | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed these error types partly so that tests can distinguish scalar indexing errors from helpful messages.