[Bug] Infinite loop in HMC._find_reasonable_step_size
#3156
Closed
Description
Issue Description
In some edge cases, the while loop in HMC._find_reasonable_step_size
never exits and becomes an infinite loop. This happens because the direction
never changes and the step_size
eventually becomes 0
. The below printout of t
& step_size
shows an example of this.
Environment
For any bugs, please provide the following:
- OS and python version: CentOS 8, Python 3.8
- PyTorch version, or if relevant, output of
pip freeze
. Latest main. - Pyro version: output of
python -c 'import pyro; print pyro.__version__'
1.5.2 (checked that the error is not fixed in latest)
Code Snippet
Unfortunately, I do not have a clean code snippet that reproduces the issue. I have extracted the HMC
object and the input z
that produces the infinite loop and used torch.save
to save them. With the attached output, run
import torch
save_dict = torch.load("/home/saitcakmak/pyro_debug.pt") # update the file path
save_dict["hmc"]._find_reasonable_step_size(save_dict["z"])
Quick fix
For a hot-fix, I updated the following line
Line 193 in 8b7e564
with
while direction_new == direction and step_size > 1e-100:
which prevents the infinite loop.