Skip to content

Large speed difference between forward and backward passes  #179

@deasmhumhna

Description

@deasmhumhna

I just discovered the Diffrax package and it's great! However, I'm encountering an issue where the gradient evaluation is 40-80 times slower than the forward pass for my particular network (a potential network whose gradient is the vector field). The difference factor also seems to grow with the number of steps used in the integration loop. When the model is still near the identity map and the number of steps is low, the forward pass takes ~50 ms and the backward + update pass takes ~2 s (x40). Later on, the forward pass takes ~500 ms while the backward pass + update takes ~40 s (x80). These times are on a single GPU. I expect the gradient to take longer than the forward evaluation but 80 times slower seems extreme, even with the extra compute from checkpointing. For reference, evaluating the vector field takes ~2 ms while calculating the gradient on the norm of the vector field takes ~12 ms. I'll create a minimal working example shortly. I am using Flax so I wonder if that has something to do with it.

EDIT: Example in Colab: single conditioned 3x3 convolution followed by logsumexp pooling and some final ResNet layers. Similar behavior reproduced on CPU. Curious if it has something to do with the convolutions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions