-
-
Notifications
You must be signed in to change notification settings - Fork 169
Description
I just discovered the Diffrax package and it's great! However, I'm encountering an issue where the gradient evaluation is 40-80 times slower than the forward pass for my particular network (a potential network whose gradient is the vector field). The difference factor also seems to grow with the number of steps used in the integration loop. When the model is still near the identity map and the number of steps is low, the forward pass takes ~50 ms and the backward + update pass takes ~2 s (x40). Later on, the forward pass takes ~500 ms while the backward pass + update takes ~40 s (x80). These times are on a single GPU. I expect the gradient to take longer than the forward evaluation but 80 times slower seems extreme, even with the extra compute from checkpointing. For reference, evaluating the vector field takes ~2 ms while calculating the gradient on the norm of the vector field takes ~12 ms. I'll create a minimal working example shortly. I am using Flax so I wonder if that has something to do with it.
EDIT: Example in Colab: single conditioned 3x3 convolution followed by logsumexp pooling and some final ResNet layers. Similar behavior reproduced on CPU. Curious if it has something to do with the convolutions.