Skip to content

solver unwraps batch and runs a python loop #278

@TobiasPfeifer

Description

@TobiasPfeifer

I want to integrate a dopri5 solver for training a small flow matching model, so I give torchdiffeq a try. However I encountered extreme performance degradation (40x slower) when running with batch sizes around 200 compared to do Euler manually (max_num_steps set accordingly to Euler steps). In the debugger I saw that the batch is unwrapped and run sequentially.

Is this intended and what's the reason? Is there a way to run the batch in parallel?

Also there should be a way to return the last solution if max_num_steps is hit and not just raise an Exception. Tuning rtol and atol for different training stages is hard. Maybe an option "abort_num_steps" that just stops integrating and returns the current state?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions