Description
I think the update!
API should be presented up-front in addition, or instead of, the Flux.train!
API. This will help significantly with attracting deep learning researchers who I see as the bridge to wider adoption.
Motivation. I first encountered FluxML.jl maybe ~1.5 years ago. At the time, I skimmed the docs, saw this Flux.train!
API on the current README.md quickstart page, and wrote off the entire package as being another one of those super high-level deep learning libraries - one where it's easy to write things in the high-level API but nearly impossible to tweak the internals. (Many others out there might do the same quick first impressions evaluation, even though a package maintainer's dream is that every user read all the docs.)
Today, I decided to take another look through the docs in more detail: I wanted to find something equivalent to what PyTorch and JAX deep learning frameworks have in that you can work directly on gradient updates and parameters. (This is important for many areas of deep learning research, as I am sure you know!)
I found the update!
API (and withgradient
) after a lot of digging through the docs. I am really happy with this API, as it gives me the low-level control over my deep learning models that I need for my research! So now I am actually planning to use FluxML for research.
Conclusion. It took me two passes at the docs, the second one very deep, before I actually found this API. Even after I found it, I only found the API reference for update!
, rather than an easy-to-find example I could copy and start working with. This user experience is something that might lose potential users.
Proposal. Therefore, I propose that the update!
API be demonstrated in the quick start example: both on the README, and up front in the documentation. I think this is really key to attract deep learning researchers as users, as the most popular deep learning packages by default expose this slightly lower-level API. It needs to be extremely obvious that one can do a similar thing with Flux.jl!
Here's an example I propose, which is similar to the style of PyTorch training loops (and so is a great way to convert some PyTorch users!):
using Flux
import Flux: withgradient, update!
# Chain of linear layers:
mlp = Chain(
Dense(5 => 128), relu,
Dense(128 => 128), relu,
Dense(128 => 128), relu,
Dense(128 => 128), relu,
Dense(128 => 1),
)
# Set up the optimizer:
p = params(mlp)
opt = Adam(1e-3)
n_steps = 10_000
for i in 1:n_steps
# Batch of example data:
X = rand(5, 100) .* 10 .- 5
y = cos.(X[[3], :] * 1.5) .- 0.2
# Compute gradient of the following code
# with respect to parameters:
loss, grad = withgradient(p) do
# Forward pass:
y_pred = mlp(X)
# Square error loss
sum((y_pred .- y) .^ 2)
end
# Step:
update!(opt, p, grad)
# Logging:
println(loss)
end