I want to integrate a dopri5 solver for training a small flow matching model, so I give torchdiffeq a try. However I encountered extreme performance degradation (40x slower) when running with batch sizes around 200 compared to do Euler manually (max_num_steps set accordingly to Euler steps). In the debugger I saw that the batch is unwrapped and run sequentially.
Is this intended and what's the reason? Is there a way to run the batch in parallel?
Also there should be a way to return the last solution if max_num_steps is hit and not just raise an Exception. Tuning rtol and atol for different training stages is hard. Maybe an option "abort_num_steps" that just stops integrating and returns the current state?
I want to integrate a dopri5 solver for training a small flow matching model, so I give torchdiffeq a try. However I encountered extreme performance degradation (40x slower) when running with batch sizes around 200 compared to do Euler manually (max_num_steps set accordingly to Euler steps). In the debugger I saw that the batch is unwrapped and run sequentially.
Is this intended and what's the reason? Is there a way to run the batch in parallel?
Also there should be a way to return the last solution if max_num_steps is hit and not just raise an Exception. Tuning rtol and atol for different training stages is hard. Maybe an option "abort_num_steps" that just stops integrating and returns the current state?