-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added support for SparseArray input in Dense layer #987
Conversation
Can you describe why this is needed a bit more? It seems like it's the same definition that's already covered by |
For some reason gradients of Dense layer with an Arrays as it's weight and a SparseMatrix as input fails.
I tried using |
Right you are, but this error actually comes from trying to convert the sparse array to Float32 to match the layer: julia> gradient(x -> sum(Float32.(x)), sparse(randn(2)))
ERROR: MethodError: no method matching zero(::Type{Any}) Your patch avoids the conversion, but really we should just fix this error in Zygote. |
So what can be defined for |
Well, at some point we're calling |
The main problem is caused by mixing Float32 and Float64, which results in W = randn(Float32, 2,2)
b = randn(Float32, 2)
md = Dense(W, b)
xs = sparse(randn(Float32, 2, 2))
gradient(() -> sum(md(xs)), Flux.params(md))
W = randn(Float64, 2,2)
b = randn(Float64, 2)
md = Dense(W, b)
xs = sparse(randn(Float64, 2, 2))
gradient(() -> sum(md(xs)), Flux.params(md)) are working. But
throw the above-mentioned error. |
For Zygote itself, it's working, see:
all of them are passing, so I don't think this is problem with Zygote. |
Hm, it seems the conversion is problematic. Dense layer explicitly casts input data to type of parameter in https://github.com/FluxML/Flux.jl/blob/master/src/layers/basic.jl#L138. And Zygote can't pullback through it, following Zygote code
crashes. The question is: should this be fixed in Zygote or in Flux? |
I guess if mixed-precision compute is supported this shouldn't be an issue. |
I made issue on Zygote with MWE, we'll see FluxML/Zygote.jl#810 |
Since this turned out to be about eltype mismatches instead of sparse array support, I think it can be safely closed. |
Fixes Issue #965
Dense Layer with is able to take SparseArray as input