-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speedup and fix of multiplication by OneHotMatrix #1756
Conversation
Also fixed gradient calculation on GPU.
Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
@racinmat do you want to include #1355 (comment) as well? |
Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
yes. I want, pushing right now. |
@racinmat gpu tests (buildkite) are still failing, not sure why |
I see, I messed up dimension validation before multiplicating by the adjoint, fixing it and adding it also to cpu tests. |
Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
I'm a bit unsure about adding code which would work for ReshapedArrays without actually adding testcase which would exercise that code and called the method with reshaped array. Is it ok, or should I add tests for it? |
That's a good question, are we clear on what the semantics would be with reshaped arrays? |
Reshaped arrays in this context means N-d So using reshape(OneHotArray(rand(1:10, 5, 5), 10), 10, :) should work for adding tests to exercise it. I do think we should add tests. |
During testing reshaped arrays I realized there is no julia> b4 = reshape(Flux.OneHotMatrix([1 2 3; 2 2 1], 3), 3, :)
3×6 reshape(OneHotArray(::Matrix{Int64}), 3, 6) with eltype Bool:
1 ⋅ ⋅ ⋅ ⋅ 1
⋅ 1 1 1 ⋅ ⋅
⋅ ⋅ ⋅ ⋅ 1 ⋅
julia> b5 = reshape(b4, 6, :)
6×3 reshape(OneHotArray(::Matrix{Int64}), 6, 3) with eltype Bool:
1 0 0
0 1 0
0 0 1
0 0 1
1 1 0
0 0 0
julia> b5'
3×6 adjoint(reshape(OneHotArray(::Matrix{Int64}), 6, 3)) with eltype Bool:
1 0 0 0 1 0
0 1 0 0 1 0
0 0 1 1 0 0 I thought I could check I'm currently not sure which way to go:
|
In the end I decided I will keep it as |
bors r+ |
Build succeeded: |
PR Checklist
Fixes #1355 .
Also fixes bug mentioned in #1355 (comment).
Adds tests for both gpu and cpu.
Adds multiplication by sparse matrix to benchmarks.